model2 commited on
Commit
50eec37
·
1 Parent(s): c735e82

Add initial code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .ci/update_windows/update.py +146 -0
  2. .ci/update_windows/update_comfyui.bat +8 -0
  3. .ci/update_windows/update_comfyui_stable.bat +8 -0
  4. .ci/windows_base_files/README_VERY_IMPORTANT.txt +31 -0
  5. .ci/windows_base_files/run_cpu.bat +2 -0
  6. .ci/windows_base_files/run_nvidia_gpu.bat +2 -0
  7. .ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat +2 -0
  8. .ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +2 -0
  9. .gitattributes +2 -0
  10. .github/ISSUE_TEMPLATE/bug-report.yml +48 -0
  11. .github/ISSUE_TEMPLATE/config.yml +11 -0
  12. .github/ISSUE_TEMPLATE/feature-request.yml +32 -0
  13. .github/ISSUE_TEMPLATE/user-support.yml +32 -0
  14. .github/workflows/pullrequest-ci-run.yml +53 -0
  15. .github/workflows/ruff.yml +23 -0
  16. .github/workflows/stable-release.yml +104 -0
  17. .github/workflows/stale-issues.yml +21 -0
  18. .github/workflows/test-build.yml +31 -0
  19. .github/workflows/test-ci.yml +96 -0
  20. .github/workflows/test-launch.yml +45 -0
  21. .github/workflows/test-unit.yml +30 -0
  22. .github/workflows/update-version.yml +58 -0
  23. .github/workflows/windows_release_dependencies.yml +71 -0
  24. .github/workflows/windows_release_nightly_pytorch.yml +91 -0
  25. .github/workflows/windows_release_package.yml +100 -0
  26. .gitignore +20 -171
  27. api_server/__init__.py +0 -0
  28. api_server/routes/__init__.py +0 -0
  29. api_server/routes/internal/README.md +3 -0
  30. api_server/routes/internal/__init__.py +0 -0
  31. api_server/routes/internal/internal_routes.py +73 -0
  32. api_server/services/__init__.py +0 -0
  33. api_server/services/terminal_service.py +60 -0
  34. api_server/utils/file_operations.py +42 -0
  35. app.py +230 -0
  36. app/__init__.py +0 -0
  37. app/app_settings.py +59 -0
  38. app/custom_node_manager.py +134 -0
  39. app/frontend_management.py +288 -0
  40. app/logger.py +98 -0
  41. app/model_manager.py +184 -0
  42. app/user_manager.py +330 -0
  43. comfy/checkpoint_pickle.py +13 -0
  44. comfy/cldm/cldm.py +433 -0
  45. comfy/cldm/control_types.py +10 -0
  46. comfy/cldm/dit_embedder.py +120 -0
  47. comfy/cldm/mmdit.py +81 -0
  48. comfy/cli_args.py +213 -0
  49. comfy/clip_config_bigg.json +23 -0
  50. comfy/clip_model.py +244 -0
.ci/update_windows/update.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pygit2
2
+ from datetime import datetime
3
+ import sys
4
+ import os
5
+ import shutil
6
+ import filecmp
7
+
8
+ def pull(repo, remote_name='origin', branch='master'):
9
+ for remote in repo.remotes:
10
+ if remote.name == remote_name:
11
+ remote.fetch()
12
+ remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target
13
+ merge_result, _ = repo.merge_analysis(remote_master_id)
14
+ # Up to date, do nothing
15
+ if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
16
+ return
17
+ # We can just fastforward
18
+ elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
19
+ repo.checkout_tree(repo.get(remote_master_id))
20
+ try:
21
+ master_ref = repo.lookup_reference('refs/heads/%s' % (branch))
22
+ master_ref.set_target(remote_master_id)
23
+ except KeyError:
24
+ repo.create_branch(branch, repo.get(remote_master_id))
25
+ repo.head.set_target(remote_master_id)
26
+ elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
27
+ repo.merge(remote_master_id)
28
+
29
+ if repo.index.conflicts is not None:
30
+ for conflict in repo.index.conflicts:
31
+ print('Conflicts found in:', conflict[0].path) # noqa: T201
32
+ raise AssertionError('Conflicts, ahhhhh!!')
33
+
34
+ user = repo.default_signature
35
+ tree = repo.index.write_tree()
36
+ repo.create_commit('HEAD',
37
+ user,
38
+ user,
39
+ 'Merge!',
40
+ tree,
41
+ [repo.head.target, remote_master_id])
42
+ # We need to do this or git CLI will think we are still merging.
43
+ repo.state_cleanup()
44
+ else:
45
+ raise AssertionError('Unknown merge analysis result')
46
+
47
+ pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0)
48
+ repo_path = str(sys.argv[1])
49
+ repo = pygit2.Repository(repo_path)
50
+ ident = pygit2.Signature('comfyui', 'comfy@ui')
51
+ try:
52
+ print("stashing current changes") # noqa: T201
53
+ repo.stash(ident)
54
+ except KeyError:
55
+ print("nothing to stash") # noqa: T201
56
+ backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
57
+ print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
58
+ try:
59
+ repo.branches.local.create(backup_branch_name, repo.head.peel())
60
+ except:
61
+ pass
62
+
63
+ print("checking out master branch") # noqa: T201
64
+ branch = repo.lookup_branch('master')
65
+ if branch is None:
66
+ ref = repo.lookup_reference('refs/remotes/origin/master')
67
+ repo.checkout(ref)
68
+ branch = repo.lookup_branch('master')
69
+ if branch is None:
70
+ repo.create_branch('master', repo.get(ref.target))
71
+ else:
72
+ ref = repo.lookup_reference(branch.name)
73
+ repo.checkout(ref)
74
+
75
+ print("pulling latest changes") # noqa: T201
76
+ pull(repo)
77
+
78
+ if "--stable" in sys.argv:
79
+ def latest_tag(repo):
80
+ versions = []
81
+ for k in repo.references:
82
+ try:
83
+ prefix = "refs/tags/v"
84
+ if k.startswith(prefix):
85
+ version = list(map(int, k[len(prefix):].split(".")))
86
+ versions.append((version[0] * 10000000000 + version[1] * 100000 + version[2], k))
87
+ except:
88
+ pass
89
+ versions.sort()
90
+ if len(versions) > 0:
91
+ return versions[-1][1]
92
+ return None
93
+ latest_tag = latest_tag(repo)
94
+ if latest_tag is not None:
95
+ repo.checkout(latest_tag)
96
+
97
+ print("Done!") # noqa: T201
98
+
99
+ self_update = True
100
+ if len(sys.argv) > 2:
101
+ self_update = '--skip_self_update' not in sys.argv
102
+
103
+ update_py_path = os.path.realpath(__file__)
104
+ repo_update_py_path = os.path.join(repo_path, ".ci/update_windows/update.py")
105
+
106
+ cur_path = os.path.dirname(update_py_path)
107
+
108
+
109
+ req_path = os.path.join(cur_path, "current_requirements.txt")
110
+ repo_req_path = os.path.join(repo_path, "requirements.txt")
111
+
112
+
113
+ def files_equal(file1, file2):
114
+ try:
115
+ return filecmp.cmp(file1, file2, shallow=False)
116
+ except:
117
+ return False
118
+
119
+ def file_size(f):
120
+ try:
121
+ return os.path.getsize(f)
122
+ except:
123
+ return 0
124
+
125
+
126
+ if self_update and not files_equal(update_py_path, repo_update_py_path) and file_size(repo_update_py_path) > 10:
127
+ shutil.copy(repo_update_py_path, os.path.join(cur_path, "update_new.py"))
128
+ exit()
129
+
130
+ if not os.path.exists(req_path) or not files_equal(repo_req_path, req_path):
131
+ import subprocess
132
+ try:
133
+ subprocess.check_call([sys.executable, '-s', '-m', 'pip', 'install', '-r', repo_req_path])
134
+ shutil.copy(repo_req_path, req_path)
135
+ except:
136
+ pass
137
+
138
+
139
+ stable_update_script = os.path.join(repo_path, ".ci/update_windows/update_comfyui_stable.bat")
140
+ stable_update_script_to = os.path.join(cur_path, "update_comfyui_stable.bat")
141
+
142
+ try:
143
+ if not file_size(stable_update_script_to) > 10:
144
+ shutil.copy(stable_update_script, stable_update_script_to)
145
+ except:
146
+ pass
.ci/update_windows/update_comfyui.bat ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ ..\python_embeded\python.exe .\update.py ..\ComfyUI\
3
+ if exist update_new.py (
4
+ move /y update_new.py update.py
5
+ echo Running updater again since it got updated.
6
+ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update
7
+ )
8
+ if "%~1"=="" pause
.ci/update_windows/update_comfyui_stable.bat ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ --stable
3
+ if exist update_new.py (
4
+ move /y update_new.py update.py
5
+ echo Running updater again since it got updated.
6
+ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ --skip_self_update --stable
7
+ )
8
+ if "%~1"=="" pause
.ci/windows_base_files/README_VERY_IMPORTANT.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HOW TO RUN:
2
+
3
+ if you have a NVIDIA gpu:
4
+
5
+ run_nvidia_gpu.bat
6
+
7
+
8
+
9
+ To run it in slow CPU mode:
10
+
11
+ run_cpu.bat
12
+
13
+
14
+
15
+ IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
16
+
17
+ You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
18
+
19
+
20
+ RECOMMENDED WAY TO UPDATE:
21
+ To update the ComfyUI code: update\update_comfyui.bat
22
+
23
+
24
+
25
+ To update ComfyUI with the python dependencies, note that you should ONLY run this if you have issues with python dependencies.
26
+ update\update_comfyui_and_python_dependencies.bat
27
+
28
+
29
+ TO SHARE MODELS BETWEEN COMFYUI AND ANOTHER UI:
30
+ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
31
+ Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
.ci/windows_base_files/run_cpu.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
2
+ pause
.ci/windows_base_files/run_nvidia_gpu.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
2
+ pause
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
2
+ pause
.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
2
+ pause
.gitattributes CHANGED
@@ -1,2 +1,4 @@
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.jpg filter=lfs diff=lfs merge=lfs -text
 
1
+ /web/assets/** linguist-generated
2
+ /web/** linguist-vendored
3
  *.safetensors filter=lfs diff=lfs merge=lfs -text
4
  *.jpg filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug-report.yml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Bug Report
2
+ description: "Something is broken inside of ComfyUI. (Do not use this if you're just having issues and need help, or if the issue relates to a custom node)"
3
+ labels: ["Potential Bug"]
4
+ body:
5
+ - type: markdown
6
+ attributes:
7
+ value: |
8
+ Before submitting a **Bug Report**, please ensure the following:
9
+
10
+ - **1:** You are running the latest version of ComfyUI.
11
+ - **2:** You have looked at the existing bug reports and made sure this isn't already reported.
12
+ - **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
13
+ `--disable-all-custom-nodes` command line argument.
14
+ - **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
15
+ steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
16
+
17
+ If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
18
+ - type: textarea
19
+ attributes:
20
+ label: Expected Behavior
21
+ description: "What you expected to happen."
22
+ validations:
23
+ required: true
24
+ - type: textarea
25
+ attributes:
26
+ label: Actual Behavior
27
+ description: "What actually happened. Please include a screenshot of the issue if possible."
28
+ validations:
29
+ required: true
30
+ - type: textarea
31
+ attributes:
32
+ label: Steps to Reproduce
33
+ description: "Describe how to reproduce the issue. Please be sure to attach a workflow JSON or PNG, ideally one that doesn't require custom nodes to test. If the bug open happens when certain custom nodes are used, most likely that custom node is what has the bug rather than ComfyUI, in which case it should be reported to the node's author."
34
+ validations:
35
+ required: true
36
+ - type: textarea
37
+ attributes:
38
+ label: Debug Logs
39
+ description: "Please copy the output from your terminal logs here."
40
+ render: powershell
41
+ validations:
42
+ required: true
43
+ - type: textarea
44
+ attributes:
45
+ label: Other
46
+ description: "Any other additional information you think might be helpful."
47
+ validations:
48
+ required: false
.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ blank_issues_enabled: true
2
+ contact_links:
3
+ - name: ComfyUI Frontend Issues
4
+ url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
5
+ about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
6
+ - name: ComfyUI Matrix Space
7
+ url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
8
+ about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
9
+ - name: Comfy Org Discord
10
+ url: https://discord.gg/comfyorg
11
+ about: The Comfy Org Discord is available for support and general discussion related to ComfyUI.
.github/ISSUE_TEMPLATE/feature-request.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Feature Request
2
+ description: "You have an idea for something new you would like to see added to ComfyUI's core."
3
+ labels: [ "Feature" ]
4
+ body:
5
+ - type: markdown
6
+ attributes:
7
+ value: |
8
+ Before submitting a **Feature Request**, please ensure the following:
9
+
10
+ **1:** You are running the latest version of ComfyUI.
11
+ **2:** You have looked to make sure there is not already a feature that does what you need, and there is not already a Feature Request listed for the same idea.
12
+ **3:** This is something that makes sense to add to ComfyUI Core, and wouldn't make more sense as a custom node.
13
+
14
+ If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
15
+ - type: textarea
16
+ attributes:
17
+ label: Feature Idea
18
+ description: "Describe the feature you want to see."
19
+ validations:
20
+ required: true
21
+ - type: textarea
22
+ attributes:
23
+ label: Existing Solutions
24
+ description: "Please search through available custom nodes / extensions to see if there are existing custom solutions for this. If so, please link the options you found here as a reference."
25
+ validations:
26
+ required: false
27
+ - type: textarea
28
+ attributes:
29
+ label: Other
30
+ description: "Any other additional information you think might be helpful."
31
+ validations:
32
+ required: false
.github/ISSUE_TEMPLATE/user-support.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: User Support
2
+ description: "Use this if you need help with something, or you're experiencing an issue."
3
+ labels: [ "User Support" ]
4
+ body:
5
+ - type: markdown
6
+ attributes:
7
+ value: |
8
+ Before submitting a **User Report** issue, please ensure the following:
9
+
10
+ **1:** You are running the latest version of ComfyUI.
11
+ **2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
12
+
13
+ If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
14
+ - type: textarea
15
+ attributes:
16
+ label: Your question
17
+ description: "Post your question here. Please be as detailed as possible."
18
+ validations:
19
+ required: true
20
+ - type: textarea
21
+ attributes:
22
+ label: Logs
23
+ description: "If your question relates to an issue you're experiencing, please go to `Server` -> `Logs` -> potentially set `View Type` to `Debug` as well, then copypaste all the text into here."
24
+ render: powershell
25
+ validations:
26
+ required: false
27
+ - type: textarea
28
+ attributes:
29
+ label: Other
30
+ description: "Any other additional information you think might be helpful."
31
+ validations:
32
+ required: false
.github/workflows/pullrequest-ci-run.yml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
2
+ # Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
3
+ name: Pull Request CI Workflow Runs
4
+ on:
5
+ pull_request_target:
6
+ types: [labeled]
7
+
8
+ jobs:
9
+ pr-test-stable:
10
+ if: ${{ github.event.label.name == 'Run-CI-Test' }}
11
+ strategy:
12
+ fail-fast: false
13
+ matrix:
14
+ os: [macos, linux, windows]
15
+ python_version: ["3.9", "3.10", "3.11", "3.12"]
16
+ cuda_version: ["12.1"]
17
+ torch_version: ["stable"]
18
+ include:
19
+ - os: macos
20
+ runner_label: [self-hosted, macOS]
21
+ flags: "--use-pytorch-cross-attention"
22
+ - os: linux
23
+ runner_label: [self-hosted, Linux]
24
+ flags: ""
25
+ - os: windows
26
+ runner_label: [self-hosted, Windows]
27
+ flags: ""
28
+ runs-on: ${{ matrix.runner_label }}
29
+ steps:
30
+ - name: Test Workflows
31
+ uses: comfy-org/comfy-action@main
32
+ with:
33
+ os: ${{ matrix.os }}
34
+ python_version: ${{ matrix.python_version }}
35
+ torch_version: ${{ matrix.torch_version }}
36
+ google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
37
+ comfyui_flags: ${{ matrix.flags }}
38
+ use_prior_commit: 'true'
39
+ comment:
40
+ if: ${{ github.event.label.name == 'Run-CI-Test' }}
41
+ runs-on: ubuntu-latest
42
+ permissions:
43
+ pull-requests: write
44
+ steps:
45
+ - uses: actions/github-script@v6
46
+ with:
47
+ script: |
48
+ github.rest.issues.createComment({
49
+ issue_number: context.issue.number,
50
+ owner: context.repo.owner,
51
+ repo: context.repo.repo,
52
+ body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
53
+ })
.github/workflows/ruff.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Python Linting
2
+
3
+ on: [push, pull_request]
4
+
5
+ jobs:
6
+ ruff:
7
+ name: Run Ruff
8
+ runs-on: ubuntu-latest
9
+
10
+ steps:
11
+ - name: Checkout repository
12
+ uses: actions/checkout@v4
13
+
14
+ - name: Set up Python
15
+ uses: actions/setup-python@v2
16
+ with:
17
+ python-version: 3.x
18
+
19
+ - name: Install Ruff
20
+ run: pip install ruff
21
+
22
+ - name: Run Ruff
23
+ run: ruff check .
.github/workflows/stable-release.yml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ name: "Release Stable Version"
3
+
4
+ on:
5
+ workflow_dispatch:
6
+ inputs:
7
+ git_tag:
8
+ description: 'Git tag'
9
+ required: true
10
+ type: string
11
+ cu:
12
+ description: 'CUDA version'
13
+ required: true
14
+ type: string
15
+ default: "126"
16
+ python_minor:
17
+ description: 'Python minor version'
18
+ required: true
19
+ type: string
20
+ default: "12"
21
+ python_patch:
22
+ description: 'Python patch version'
23
+ required: true
24
+ type: string
25
+ default: "9"
26
+
27
+
28
+ jobs:
29
+ package_comfy_windows:
30
+ permissions:
31
+ contents: "write"
32
+ packages: "write"
33
+ pull-requests: "read"
34
+ runs-on: windows-latest
35
+ steps:
36
+ - uses: actions/checkout@v4
37
+ with:
38
+ ref: ${{ inputs.git_tag }}
39
+ fetch-depth: 0
40
+ persist-credentials: false
41
+ - uses: actions/cache/restore@v4
42
+ id: cache
43
+ with:
44
+ path: |
45
+ cu${{ inputs.cu }}_python_deps.tar
46
+ update_comfyui_and_python_dependencies.bat
47
+ key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
48
+ - shell: bash
49
+ run: |
50
+ mv cu${{ inputs.cu }}_python_deps.tar ../
51
+ mv update_comfyui_and_python_dependencies.bat ../
52
+ cd ..
53
+ tar xf cu${{ inputs.cu }}_python_deps.tar
54
+ pwd
55
+ ls
56
+
57
+ - shell: bash
58
+ run: |
59
+ cd ..
60
+ cp -r ComfyUI ComfyUI_copy
61
+ curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
62
+ unzip python_embeded.zip -d python_embeded
63
+ cd python_embeded
64
+ echo ${{ env.MINOR_VERSION }}
65
+ echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
66
+ curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
67
+ ./python.exe get-pip.py
68
+ ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
69
+ sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
70
+ cd ..
71
+
72
+ git clone --depth 1 https://github.com/comfyanonymous/taesd
73
+ cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
74
+
75
+ mkdir ComfyUI_windows_portable
76
+ mv python_embeded ComfyUI_windows_portable
77
+ mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
78
+
79
+ cd ComfyUI_windows_portable
80
+
81
+ mkdir update
82
+ cp -r ComfyUI/.ci/update_windows/* ./update/
83
+ cp -r ComfyUI/.ci/windows_base_files/* ./
84
+ cp ../update_comfyui_and_python_dependencies.bat ./update/
85
+
86
+ cd ..
87
+
88
+ "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
89
+ mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z
90
+
91
+ cd ComfyUI_windows_portable
92
+ python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
93
+
94
+ ls
95
+
96
+ - name: Upload binaries to release
97
+ uses: svenstaro/upload-release-action@v2
98
+ with:
99
+ repo_token: ${{ secrets.GITHUB_TOKEN }}
100
+ file: ComfyUI_windows_portable_nvidia.7z
101
+ tag: ${{ inputs.git_tag }}
102
+ overwrite: true
103
+ prerelease: true
104
+ make_latest: false
.github/workflows/stale-issues.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 'Close stale issues'
2
+ on:
3
+ schedule:
4
+ # Run daily at 430 am PT
5
+ - cron: '30 11 * * *'
6
+ permissions:
7
+ issues: write
8
+
9
+ jobs:
10
+ stale:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/stale@v9
14
+ with:
15
+ stale-issue-message: "This issue is being marked stale because it has not had any activity for 30 days. Reply below within 7 days if your issue still isn't solved, and it will be left open. Otherwise, the issue will be closed automatically."
16
+ days-before-stale: 30
17
+ days-before-close: 7
18
+ stale-issue-label: 'Stale'
19
+ only-labels: 'User Support'
20
+ exempt-all-assignees: true
21
+ exempt-all-milestones: true
.github/workflows/test-build.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build package
2
+
3
+ #
4
+ # This workflow is a test of the python package build.
5
+ # Install Python dependencies across different Python versions.
6
+ #
7
+
8
+ on:
9
+ push:
10
+ paths:
11
+ - "requirements.txt"
12
+ - ".github/workflows/test-build.yml"
13
+
14
+ jobs:
15
+ build:
16
+ name: Build Test
17
+ runs-on: ubuntu-latest
18
+ strategy:
19
+ fail-fast: false
20
+ matrix:
21
+ python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
22
+ steps:
23
+ - uses: actions/checkout@v4
24
+ - name: Set up Python ${{ matrix.python-version }}
25
+ uses: actions/setup-python@v4
26
+ with:
27
+ python-version: ${{ matrix.python-version }}
28
+ - name: Install dependencies
29
+ run: |
30
+ python -m pip install --upgrade pip
31
+ pip install -r requirements.txt
.github/workflows/test-ci.yml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
2
+ # Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
3
+ name: Full Comfy CI Workflow Runs
4
+ on:
5
+ push:
6
+ branches:
7
+ - master
8
+ paths-ignore:
9
+ - 'app/**'
10
+ - 'input/**'
11
+ - 'output/**'
12
+ - 'notebooks/**'
13
+ - 'script_examples/**'
14
+ - '.github/**'
15
+ - 'web/**'
16
+ workflow_dispatch:
17
+
18
+ jobs:
19
+ test-stable:
20
+ strategy:
21
+ fail-fast: false
22
+ matrix:
23
+ # os: [macos, linux, windows]
24
+ os: [macos, linux]
25
+ python_version: ["3.9", "3.10", "3.11", "3.12"]
26
+ cuda_version: ["12.1"]
27
+ torch_version: ["stable"]
28
+ include:
29
+ - os: macos
30
+ runner_label: [self-hosted, macOS]
31
+ flags: "--use-pytorch-cross-attention"
32
+ - os: linux
33
+ runner_label: [self-hosted, Linux]
34
+ flags: ""
35
+ # - os: windows
36
+ # runner_label: [self-hosted, Windows]
37
+ # flags: ""
38
+ runs-on: ${{ matrix.runner_label }}
39
+ steps:
40
+ - name: Test Workflows
41
+ uses: comfy-org/comfy-action@main
42
+ with:
43
+ os: ${{ matrix.os }}
44
+ python_version: ${{ matrix.python_version }}
45
+ torch_version: ${{ matrix.torch_version }}
46
+ google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
47
+ comfyui_flags: ${{ matrix.flags }}
48
+
49
+ # test-win-nightly:
50
+ # strategy:
51
+ # fail-fast: true
52
+ # matrix:
53
+ # os: [windows]
54
+ # python_version: ["3.9", "3.10", "3.11", "3.12"]
55
+ # cuda_version: ["12.1"]
56
+ # torch_version: ["nightly"]
57
+ # include:
58
+ # - os: windows
59
+ # runner_label: [self-hosted, Windows]
60
+ # flags: ""
61
+ # runs-on: ${{ matrix.runner_label }}
62
+ # steps:
63
+ # - name: Test Workflows
64
+ # uses: comfy-org/comfy-action@main
65
+ # with:
66
+ # os: ${{ matrix.os }}
67
+ # python_version: ${{ matrix.python_version }}
68
+ # torch_version: ${{ matrix.torch_version }}
69
+ # google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
70
+ # comfyui_flags: ${{ matrix.flags }}
71
+
72
+ test-unix-nightly:
73
+ strategy:
74
+ fail-fast: false
75
+ matrix:
76
+ os: [macos, linux]
77
+ python_version: ["3.11"]
78
+ cuda_version: ["12.1"]
79
+ torch_version: ["nightly"]
80
+ include:
81
+ - os: macos
82
+ runner_label: [self-hosted, macOS]
83
+ flags: "--use-pytorch-cross-attention"
84
+ - os: linux
85
+ runner_label: [self-hosted, Linux]
86
+ flags: ""
87
+ runs-on: ${{ matrix.runner_label }}
88
+ steps:
89
+ - name: Test Workflows
90
+ uses: comfy-org/comfy-action@main
91
+ with:
92
+ os: ${{ matrix.os }}
93
+ python_version: ${{ matrix.python_version }}
94
+ torch_version: ${{ matrix.torch_version }}
95
+ google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
96
+ comfyui_flags: ${{ matrix.flags }}
.github/workflows/test-launch.yml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Test server launches without errors
2
+
3
+ on:
4
+ push:
5
+ branches: [ main, master ]
6
+ pull_request:
7
+ branches: [ main, master ]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - name: Checkout ComfyUI
14
+ uses: actions/checkout@v4
15
+ with:
16
+ repository: "comfyanonymous/ComfyUI"
17
+ path: "ComfyUI"
18
+ - uses: actions/setup-python@v4
19
+ with:
20
+ python-version: '3.9'
21
+ - name: Install requirements
22
+ run: |
23
+ python -m pip install --upgrade pip
24
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
25
+ pip install -r requirements.txt
26
+ pip install wait-for-it
27
+ working-directory: ComfyUI
28
+ - name: Start ComfyUI server
29
+ run: |
30
+ python main.py --cpu 2>&1 | tee console_output.log &
31
+ wait-for-it --service 127.0.0.1:8188 -t 30
32
+ working-directory: ComfyUI
33
+ - name: Check for unhandled exceptions in server log
34
+ run: |
35
+ if grep -qE "Exception|Error" console_output.log; then
36
+ echo "Unhandled exception/error found in server log."
37
+ exit 1
38
+ fi
39
+ working-directory: ComfyUI
40
+ - uses: actions/upload-artifact@v4
41
+ if: always()
42
+ with:
43
+ name: console-output
44
+ path: ComfyUI/console_output.log
45
+ retention-days: 30
.github/workflows/test-unit.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Unit Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [ main, master ]
6
+ pull_request:
7
+ branches: [ main, master ]
8
+
9
+ jobs:
10
+ test:
11
+ strategy:
12
+ matrix:
13
+ os: [ubuntu-latest, windows-latest, macos-latest]
14
+ runs-on: ${{ matrix.os }}
15
+ continue-on-error: true
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+ - name: Set up Python
19
+ uses: actions/setup-python@v4
20
+ with:
21
+ python-version: '3.12'
22
+ - name: Install requirements
23
+ run: |
24
+ python -m pip install --upgrade pip
25
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
26
+ pip install -r requirements.txt
27
+ - name: Run Unit Tests
28
+ run: |
29
+ pip install -r tests-unit/requirements.txt
30
+ python -m pytest tests-unit
.github/workflows/update-version.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Update Version File
2
+
3
+ on:
4
+ pull_request:
5
+ paths:
6
+ - "pyproject.toml"
7
+ branches:
8
+ - master
9
+
10
+ jobs:
11
+ update-version:
12
+ runs-on: ubuntu-latest
13
+ # Don't run on fork PRs
14
+ if: github.event.pull_request.head.repo.full_name == github.repository
15
+ permissions:
16
+ pull-requests: write
17
+ contents: write
18
+
19
+ steps:
20
+ - name: Checkout repository
21
+ uses: actions/checkout@v4
22
+
23
+ - name: Set up Python
24
+ uses: actions/setup-python@v4
25
+ with:
26
+ python-version: "3.11"
27
+
28
+ - name: Install dependencies
29
+ run: |
30
+ python -m pip install --upgrade pip
31
+
32
+ - name: Update comfyui_version.py
33
+ run: |
34
+ # Read version from pyproject.toml and update comfyui_version.py
35
+ python -c '
36
+ import tomllib
37
+
38
+ # Read version from pyproject.toml
39
+ with open("pyproject.toml", "rb") as f:
40
+ config = tomllib.load(f)
41
+ version = config["project"]["version"]
42
+
43
+ # Write version to comfyui_version.py
44
+ with open("comfyui_version.py", "w") as f:
45
+ f.write("# This file is automatically generated by the build process when version is\n")
46
+ f.write("# updated in pyproject.toml.\n")
47
+ f.write(f"__version__ = \"{version}\"\n")
48
+ '
49
+
50
+ - name: Commit changes
51
+ run: |
52
+ git config --local user.name "github-actions"
53
+ git config --local user.email "[email protected]"
54
+ git fetch origin ${{ github.head_ref }}
55
+ git checkout -B ${{ github.head_ref }} origin/${{ github.head_ref }}
56
+ git add comfyui_version.py
57
+ git diff --quiet && git diff --staged --quiet || git commit -m "chore: Update comfyui_version.py to match pyproject.toml"
58
+ git push origin HEAD:${{ github.head_ref }}
.github/workflows/windows_release_dependencies.yml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "Windows Release dependencies"
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ inputs:
6
+ xformers:
7
+ description: 'xformers version'
8
+ required: false
9
+ type: string
10
+ default: ""
11
+ extra_dependencies:
12
+ description: 'extra dependencies'
13
+ required: false
14
+ type: string
15
+ default: ""
16
+ cu:
17
+ description: 'cuda version'
18
+ required: true
19
+ type: string
20
+ default: "126"
21
+
22
+ python_minor:
23
+ description: 'python minor version'
24
+ required: true
25
+ type: string
26
+ default: "12"
27
+
28
+ python_patch:
29
+ description: 'python patch version'
30
+ required: true
31
+ type: string
32
+ default: "9"
33
+ # push:
34
+ # branches:
35
+ # - master
36
+
37
+ jobs:
38
+ build_dependencies:
39
+ runs-on: windows-latest
40
+ steps:
41
+ - uses: actions/checkout@v4
42
+ - uses: actions/setup-python@v5
43
+ with:
44
+ python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
45
+
46
+ - shell: bash
47
+ run: |
48
+ echo "@echo off
49
+ call update_comfyui.bat nopause
50
+ echo -
51
+ echo This will try to update pytorch and all python dependencies.
52
+ echo -
53
+ echo If you just want to update normally, close this and run update_comfyui.bat instead.
54
+ echo -
55
+ pause
56
+ ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
57
+ pause" > update_comfyui_and_python_dependencies.bat
58
+
59
+ python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
60
+ python -m pip install --no-cache-dir ./temp_wheel_dir/*
61
+ echo installed basic
62
+ ls -lah temp_wheel_dir
63
+ mv temp_wheel_dir cu${{ inputs.cu }}_python_deps
64
+ tar cf cu${{ inputs.cu }}_python_deps.tar cu${{ inputs.cu }}_python_deps
65
+
66
+ - uses: actions/cache/save@v4
67
+ with:
68
+ path: |
69
+ cu${{ inputs.cu }}_python_deps.tar
70
+ update_comfyui_and_python_dependencies.bat
71
+ key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
.github/workflows/windows_release_nightly_pytorch.yml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "Windows Release Nightly pytorch"
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ inputs:
6
+ cu:
7
+ description: 'cuda version'
8
+ required: true
9
+ type: string
10
+ default: "128"
11
+
12
+ python_minor:
13
+ description: 'python minor version'
14
+ required: true
15
+ type: string
16
+ default: "13"
17
+
18
+ python_patch:
19
+ description: 'python patch version'
20
+ required: true
21
+ type: string
22
+ default: "2"
23
+ # push:
24
+ # branches:
25
+ # - master
26
+
27
+ jobs:
28
+ build:
29
+ permissions:
30
+ contents: "write"
31
+ packages: "write"
32
+ pull-requests: "read"
33
+ runs-on: windows-latest
34
+ steps:
35
+ - uses: actions/checkout@v4
36
+ with:
37
+ fetch-depth: 30
38
+ persist-credentials: false
39
+ - uses: actions/setup-python@v5
40
+ with:
41
+ python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
42
+ - shell: bash
43
+ run: |
44
+ cd ..
45
+ cp -r ComfyUI ComfyUI_copy
46
+ curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
47
+ unzip python_embeded.zip -d python_embeded
48
+ cd python_embeded
49
+ echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
50
+ curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
51
+ ./python.exe get-pip.py
52
+ python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
53
+ ls ../temp_wheel_dir
54
+ ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
55
+ sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
56
+ cd ..
57
+
58
+ git clone --depth 1 https://github.com/comfyanonymous/taesd
59
+ cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
60
+
61
+ mkdir ComfyUI_windows_portable_nightly_pytorch
62
+ mv python_embeded ComfyUI_windows_portable_nightly_pytorch
63
+ mv ComfyUI_copy ComfyUI_windows_portable_nightly_pytorch/ComfyUI
64
+
65
+ cd ComfyUI_windows_portable_nightly_pytorch
66
+
67
+ mkdir update
68
+ cp -r ComfyUI/.ci/update_windows/* ./update/
69
+ cp -r ComfyUI/.ci/windows_base_files/* ./
70
+ cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
71
+
72
+ echo "call update_comfyui.bat nopause
73
+ ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
74
+ pause" > ./update/update_comfyui_and_python_dependencies.bat
75
+ cd ..
76
+
77
+ "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
78
+ mv ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI/ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
79
+
80
+ cd ComfyUI_windows_portable_nightly_pytorch
81
+ python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
82
+
83
+ ls
84
+
85
+ - name: Upload binaries to release
86
+ uses: svenstaro/upload-release-action@v2
87
+ with:
88
+ repo_token: ${{ secrets.GITHUB_TOKEN }}
89
+ file: ComfyUI_windows_portable_nvidia_or_cpu_nightly_pytorch.7z
90
+ tag: "latest"
91
+ overwrite: true
.github/workflows/windows_release_package.yml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "Windows Release packaging"
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ inputs:
6
+ cu:
7
+ description: 'cuda version'
8
+ required: true
9
+ type: string
10
+ default: "126"
11
+
12
+ python_minor:
13
+ description: 'python minor version'
14
+ required: true
15
+ type: string
16
+ default: "12"
17
+
18
+ python_patch:
19
+ description: 'python patch version'
20
+ required: true
21
+ type: string
22
+ default: "9"
23
+ # push:
24
+ # branches:
25
+ # - master
26
+
27
+ jobs:
28
+ package_comfyui:
29
+ permissions:
30
+ contents: "write"
31
+ packages: "write"
32
+ pull-requests: "read"
33
+ runs-on: windows-latest
34
+ steps:
35
+ - uses: actions/cache/restore@v4
36
+ id: cache
37
+ with:
38
+ path: |
39
+ cu${{ inputs.cu }}_python_deps.tar
40
+ update_comfyui_and_python_dependencies.bat
41
+ key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
42
+ - shell: bash
43
+ run: |
44
+ mv cu${{ inputs.cu }}_python_deps.tar ../
45
+ mv update_comfyui_and_python_dependencies.bat ../
46
+ cd ..
47
+ tar xf cu${{ inputs.cu }}_python_deps.tar
48
+ pwd
49
+ ls
50
+
51
+ - uses: actions/checkout@v4
52
+ with:
53
+ fetch-depth: 0
54
+ persist-credentials: false
55
+ - shell: bash
56
+ run: |
57
+ cd ..
58
+ cp -r ComfyUI ComfyUI_copy
59
+ curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
60
+ unzip python_embeded.zip -d python_embeded
61
+ cd python_embeded
62
+ echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
63
+ curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
64
+ ./python.exe get-pip.py
65
+ ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
66
+ sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
67
+ cd ..
68
+
69
+ git clone --depth 1 https://github.com/comfyanonymous/taesd
70
+ cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
71
+
72
+ mkdir ComfyUI_windows_portable
73
+ mv python_embeded ComfyUI_windows_portable
74
+ mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI
75
+
76
+ cd ComfyUI_windows_portable
77
+
78
+ mkdir update
79
+ cp -r ComfyUI/.ci/update_windows/* ./update/
80
+ cp -r ComfyUI/.ci/windows_base_files/* ./
81
+ cp ../update_comfyui_and_python_dependencies.bat ./update/
82
+
83
+ cd ..
84
+
85
+ "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
86
+ mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
87
+
88
+ cd ComfyUI_windows_portable
89
+ python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu
90
+
91
+ ls
92
+
93
+ - name: Upload binaries to release
94
+ uses: svenstaro/upload-release-action@v2
95
+ with:
96
+ repo_token: ${{ secrets.GITHUB_TOKEN }}
97
+ file: new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z
98
+ tag: "latest"
99
+ overwrite: true
100
+
.gitignore CHANGED
@@ -1,174 +1,23 @@
1
- # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # UV
98
- # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- #uv.lock
102
-
103
- # poetry
104
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
- # This is especially recommended for binary packages to ensure reproducibility, and is more
106
- # commonly ignored for libraries.
107
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
- #poetry.lock
109
-
110
- # pdm
111
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
- #pdm.lock
113
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
- # in version control.
115
- # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
- .pdm.toml
117
- .pdm-python
118
- .pdm-build/
119
-
120
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
- __pypackages__/
122
-
123
- # Celery stuff
124
- celerybeat-schedule
125
- celerybeat.pid
126
-
127
- # SageMath parsed files
128
- *.sage.py
129
-
130
- # Environments
131
- .env
132
- .venv
133
- env/
134
  venv/
135
- ENV/
136
- env.bak/
137
- venv.bak/
138
-
139
- # Spyder project settings
140
- .spyderproject
141
- .spyproject
142
-
143
- # Rope project settings
144
- .ropeproject
145
-
146
- # mkdocs documentation
147
- /site
148
-
149
- # mypy
150
- .mypy_cache/
151
- .dmypy.json
152
- dmypy.json
153
-
154
- # Pyre type checker
155
- .pyre/
156
-
157
- # pytype static type analyzer
158
- .pytype/
159
-
160
- # Cython debug symbols
161
- cython_debug/
162
-
163
- # PyCharm
164
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
- # and can be added to the global gitignore or merged into this file. For a more nuclear
167
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
- #.idea/
169
-
170
- # Ruff stuff:
171
- .ruff_cache/
172
-
173
- # PyPI configuration file
174
- .pypirc
 
 
1
  __pycache__/
2
  *.py[cod]
3
+ /output/
4
+ /input/
5
+ !/input/example.png
6
+ /models/
7
+ /temp/
8
+ /custom_nodes/
9
+ !custom_nodes/example_node.py.example
10
+ extra_model_paths.yaml
11
+ /.vs
12
+ .vscode/
13
+ .idea/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  venv/
15
+ .venv/
16
+ /web/extensions/*
17
+ !/web/extensions/logging.js.example
18
+ !/web/extensions/core/
19
+ /tests-ui/data/object_info.json
20
+ /user/
21
+ *.log
22
+ web_custom_versions/
23
+ .DS_Store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api_server/__init__.py ADDED
File without changes
api_server/routes/__init__.py ADDED
File without changes
api_server/routes/internal/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # ComfyUI Internal Routes
2
+
3
+ All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
api_server/routes/internal/__init__.py ADDED
File without changes
api_server/routes/internal/internal_routes.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aiohttp import web
2
+ from typing import Optional
3
+ from folder_paths import folder_names_and_paths, get_directory_by_type
4
+ from api_server.services.terminal_service import TerminalService
5
+ import app.logger
6
+ import os
7
+
8
+ class InternalRoutes:
9
+ '''
10
+ The top level web router for internal routes: /internal/*
11
+ The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
12
+ Check README.md for more information.
13
+ '''
14
+
15
+ def __init__(self, prompt_server):
16
+ self.routes: web.RouteTableDef = web.RouteTableDef()
17
+ self._app: Optional[web.Application] = None
18
+ self.prompt_server = prompt_server
19
+ self.terminal_service = TerminalService(prompt_server)
20
+
21
+ def setup_routes(self):
22
+ @self.routes.get('/logs')
23
+ async def get_logs(request):
24
+ return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
25
+
26
+ @self.routes.get('/logs/raw')
27
+ async def get_raw_logs(request):
28
+ self.terminal_service.update_size()
29
+ return web.json_response({
30
+ "entries": list(app.logger.get_logs()),
31
+ "size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
32
+ })
33
+
34
+ @self.routes.patch('/logs/subscribe')
35
+ async def subscribe_logs(request):
36
+ json_data = await request.json()
37
+ client_id = json_data["clientId"]
38
+ enabled = json_data["enabled"]
39
+ if enabled:
40
+ self.terminal_service.subscribe(client_id)
41
+ else:
42
+ self.terminal_service.unsubscribe(client_id)
43
+
44
+ return web.Response(status=200)
45
+
46
+
47
+ @self.routes.get('/folder_paths')
48
+ async def get_folder_paths(request):
49
+ response = {}
50
+ for key in folder_names_and_paths:
51
+ response[key] = folder_names_and_paths[key][0]
52
+ return web.json_response(response)
53
+
54
+ @self.routes.get('/files/{directory_type}')
55
+ async def get_files(request: web.Request) -> web.Response:
56
+ directory_type = request.match_info['directory_type']
57
+ if directory_type not in ("output", "input", "temp"):
58
+ return web.json_response({"error": "Invalid directory type"}, status=400)
59
+
60
+ directory = get_directory_by_type(directory_type)
61
+ sorted_files = sorted(
62
+ (entry for entry in os.scandir(directory) if entry.is_file()),
63
+ key=lambda entry: -entry.stat().st_mtime
64
+ )
65
+ return web.json_response([entry.name for entry in sorted_files], status=200)
66
+
67
+
68
+ def get_app(self):
69
+ if self._app is None:
70
+ self._app = web.Application()
71
+ self.setup_routes()
72
+ self._app.add_routes(self.routes)
73
+ return self._app
api_server/services/__init__.py ADDED
File without changes
api_server/services/terminal_service.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.logger import on_flush
2
+ import os
3
+ import shutil
4
+
5
+
6
+ class TerminalService:
7
+ def __init__(self, server):
8
+ self.server = server
9
+ self.cols = None
10
+ self.rows = None
11
+ self.subscriptions = set()
12
+ on_flush(self.send_messages)
13
+
14
+ def get_terminal_size(self):
15
+ try:
16
+ size = os.get_terminal_size()
17
+ return (size.columns, size.lines)
18
+ except OSError:
19
+ try:
20
+ size = shutil.get_terminal_size()
21
+ return (size.columns, size.lines)
22
+ except OSError:
23
+ return (80, 24) # fallback to 80x24
24
+
25
+ def update_size(self):
26
+ columns, lines = self.get_terminal_size()
27
+ changed = False
28
+
29
+ if columns != self.cols:
30
+ self.cols = columns
31
+ changed = True
32
+
33
+ if lines != self.rows:
34
+ self.rows = lines
35
+ changed = True
36
+
37
+ if changed:
38
+ return {"cols": self.cols, "rows": self.rows}
39
+
40
+ return None
41
+
42
+ def subscribe(self, client_id):
43
+ self.subscriptions.add(client_id)
44
+
45
+ def unsubscribe(self, client_id):
46
+ self.subscriptions.discard(client_id)
47
+
48
+ def send_messages(self, entries):
49
+ if not len(entries) or not len(self.subscriptions):
50
+ return
51
+
52
+ new_size = self.update_size()
53
+
54
+ for client_id in self.subscriptions.copy(): # prevent: Set changed size during iteration
55
+ if client_id not in self.server.sockets:
56
+ # Automatically unsub if the socket has disconnected
57
+ self.unsubscribe(client_id)
58
+ continue
59
+
60
+ self.server.send_sync("logs", {"entries": entries, "size": new_size}, client_id)
api_server/utils/file_operations.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union, TypedDict, Literal
3
+ from typing_extensions import TypeGuard
4
+ class FileInfo(TypedDict):
5
+ name: str
6
+ path: str
7
+ type: Literal["file"]
8
+ size: int
9
+
10
+ class DirectoryInfo(TypedDict):
11
+ name: str
12
+ path: str
13
+ type: Literal["directory"]
14
+
15
+ FileSystemItem = Union[FileInfo, DirectoryInfo]
16
+
17
+ def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
18
+ return item["type"] == "file"
19
+
20
+ class FileSystemOperations:
21
+ @staticmethod
22
+ def walk_directory(directory: str) -> List[FileSystemItem]:
23
+ file_list: List[FileSystemItem] = []
24
+ for root, dirs, files in os.walk(directory):
25
+ for name in files:
26
+ file_path = os.path.join(root, name)
27
+ relative_path = os.path.relpath(file_path, directory)
28
+ file_list.append({
29
+ "name": name,
30
+ "path": relative_path,
31
+ "type": "file",
32
+ "size": os.path.getsize(file_path)
33
+ })
34
+ for name in dirs:
35
+ dir_path = os.path.join(root, name)
36
+ relative_path = os.path.relpath(dir_path, directory)
37
+ file_list.append({
38
+ "name": name,
39
+ "path": relative_path,
40
+ "type": "directory"
41
+ })
42
+ return file_list
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Any, Mapping, Sequence, Union
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+ from nodes import NODE_CLASS_MAPPINGS
9
+ import spaces
10
+ from comfy import model_management
11
+
12
+ @spaces.GPU(duration=60) #modify the duration for the average it takes for your worflow to run, in seconds
13
+
14
+
15
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
16
+ """Returns the value at the given index of a sequence or mapping.
17
+
18
+ If the object is a sequence (like list or string), returns the value at the given index.
19
+ If the object is a mapping (like a dictionary), returns the value at the index-th key.
20
+
21
+ Some return a dictionary, in these cases, we look for the "results" key
22
+
23
+ Args:
24
+ obj (Union[Sequence, Mapping]): The object to retrieve the value from.
25
+ index (int): The index of the value to retrieve.
26
+
27
+ Returns:
28
+ Any: The value at the given index.
29
+
30
+ Raises:
31
+ IndexError: If the index is out of bounds for the object and the object is not a mapping.
32
+ """
33
+ try:
34
+ return obj[index]
35
+ except KeyError:
36
+ return obj["result"][index]
37
+
38
+
39
+ def find_path(name: str, path: str = None) -> str:
40
+ """
41
+ Recursively looks at parent folders starting from the given path until it finds the given name.
42
+ Returns the path as a Path object if found, or None otherwise.
43
+ """
44
+ # If no path is given, use the current working directory
45
+ if path is None:
46
+ path = os.getcwd()
47
+
48
+ # Check if the current directory contains the name
49
+ if name in os.listdir(path):
50
+ path_name = os.path.join(path, name)
51
+ print(f"{name} found: {path_name}")
52
+ return path_name
53
+
54
+ # Get the parent directory
55
+ parent_directory = os.path.dirname(path)
56
+
57
+ # If the parent directory is the same as the current directory, we've reached the root and stop the search
58
+ if parent_directory == path:
59
+ return None
60
+
61
+ # Recursively call the function with the parent directory
62
+ return find_path(name, parent_directory)
63
+
64
+
65
+ def add_comfyui_directory_to_sys_path() -> None:
66
+ """
67
+ Add 'ComfyUI' to the sys.path
68
+ """
69
+ comfyui_path = find_path("ComfyUI")
70
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
71
+ sys.path.append(comfyui_path)
72
+ print(f"'{comfyui_path}' added to sys.path")
73
+
74
+
75
+ def add_extra_model_paths() -> None:
76
+ """
77
+ Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path.
78
+ """
79
+ try:
80
+ from app import load_extra_path_config
81
+ except ImportError:
82
+ print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.")
83
+ from utils.extra_config import load_extra_path_config
84
+ extra_model_paths = find_path("extra_model_paths.yaml")
85
+
86
+ if extra_model_paths is not None:
87
+ load_extra_path_config(extra_model_paths)
88
+ else:
89
+ print("Could not find the extra_model_paths config file.")
90
+
91
+
92
+ add_comfyui_directory_to_sys_path()
93
+ add_extra_model_paths()
94
+
95
+
96
+ def import_custom_nodes() -> None:
97
+ """Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS
98
+
99
+ This function sets up a new asyncio event loop, initializes the PromptServer,
100
+ creates a PromptQueue, and initializes the custom nodes.
101
+ """
102
+ import asyncio
103
+ import execution
104
+ from nodes import init_extra_nodes
105
+ import server
106
+ # Creating a new event loop and setting it as the default loop
107
+ loop = asyncio.new_event_loop()
108
+ asyncio.set_event_loop(loop)
109
+
110
+ # Creating an instance of PromptServer with the loop
111
+ server_instance = server.PromptServer(loop)
112
+ execution.PromptQueue(server_instance)
113
+
114
+ # Initializing custom nodes
115
+ init_extra_nodes()
116
+
117
+
118
+ def advance_blur(input_image):
119
+ import_custom_nodes()
120
+ with torch.inference_mode():
121
+ load_images_node = NODE_CLASS_MAPPINGS["LoadImagesFromFolderKJ"]()
122
+ source_images_batch = load_images_node.load_images(
123
+ folder="source_faces/",
124
+ width=1024,
125
+ height=1024,
126
+ keep_aspect_ratio="crop",
127
+ image_load_cap=0,
128
+ start_index=0,
129
+ include_subfolders=False,
130
+ )
131
+
132
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
133
+ loaded_input_image = loadimage.load_image(
134
+ image=input_image,
135
+ )
136
+
137
+ upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]()
138
+ upscale_model = upscalemodelloader.load_model(
139
+ model_name="4x_NMKD-Siax_200k.pth"
140
+ )
141
+
142
+ reactorbuildfacemodel = NODE_CLASS_MAPPINGS["ReActorBuildFaceModel"]()
143
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
144
+ reactorfaceswap = NODE_CLASS_MAPPINGS["ReActorFaceSwap"]()
145
+ imageupscalewithmodel = NODE_CLASS_MAPPINGS["ImageUpscaleWithModel"]()
146
+ saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
147
+
148
+ for q in range(1):
149
+ face_model = reactorbuildfacemodel.blend_faces(
150
+ save_mode=True,
151
+ send_only=False,
152
+ face_model_name="default",
153
+ compute_method="Mean",
154
+ images=get_value_at_index(source_images_batch, 0),
155
+ )
156
+
157
+ resized_input_image = imageresize.execute(
158
+ width=2560,
159
+ height=2560,
160
+ interpolation="bicubic",
161
+ method="keep proportion",
162
+ condition="downscale if bigger",
163
+ multiple_of=0,
164
+ image=get_value_at_index(loaded_input_image, 0),
165
+ )
166
+
167
+ swapped_image = reactorfaceswap.execute(
168
+ enabled=True,
169
+ swap_model="inswapper_128.onnx",
170
+ facedetection="retinaface_resnet50",
171
+ face_restore_model="codeformer-v0.1.0.pth",
172
+ face_restore_visibility=1,
173
+ codeformer_weight=1,
174
+ detect_gender_input="no",
175
+ detect_gender_source="no",
176
+ input_faces_index="0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99",
177
+ source_faces_index="0",
178
+ console_log_level=2,
179
+ input_image=get_value_at_index(resized_input_image, 0),
180
+ face_model=get_value_at_index(face_model, 0),
181
+ )
182
+
183
+ upscaled_image = imageupscalewithmodel.upscale(
184
+ upscale_model=get_value_at_index(upscale_model, 0),
185
+ image=get_value_at_index(swapped_image, 0),
186
+ )
187
+
188
+ final_image = imageresize.execute(
189
+ width=2560,
190
+ height=2560,
191
+ interpolation="lanczos",
192
+ method="keep proportion",
193
+ condition="downscale if bigger",
194
+ multiple_of=0,
195
+ image=get_value_at_index(upscaled_image, 0),
196
+ )
197
+
198
+ saved_image = saveimage.save_images(
199
+ filename_prefix="advance_blur",
200
+ images=get_value_at_index(final_image, 0),
201
+ )
202
+
203
+ saved_path = f"output/{saved_image['ui']['images'][0]['filename']}"
204
+ return saved_path
205
+
206
+
207
+ if __name__ == "__main__":
208
+ # Start your Gradio app
209
+ with gr.Blocks() as app:
210
+ # Add a title
211
+ gr.Markdown("# Advance Blur")
212
+
213
+ with gr.Row():
214
+ with gr.Column():
215
+ input_image = gr.Image(label="Input Image", type="filepath")
216
+ generate_btn = gr.Button("Generate")
217
+
218
+ with gr.Column():
219
+ # The output image
220
+ output_image = gr.Image(label="Generated Image")
221
+
222
+ # When clicking the button, it will trigger the `generate_image` function, with the respective inputs
223
+ # and the output an image
224
+ generate_btn.click(
225
+ fn=advance_blur,
226
+ inputs=[input_image],
227
+ outputs=[output_image]
228
+ )
229
+ app.launch(share=True)
230
+
app/__init__.py ADDED
File without changes
app/app_settings.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from aiohttp import web
4
+ import logging
5
+
6
+
7
+ class AppSettings():
8
+ def __init__(self, user_manager):
9
+ self.user_manager = user_manager
10
+
11
+ def get_settings(self, request):
12
+ file = self.user_manager.get_request_user_filepath(
13
+ request, "comfy.settings.json")
14
+ if os.path.isfile(file):
15
+ try:
16
+ with open(file) as f:
17
+ return json.load(f)
18
+ except:
19
+ logging.error(f"The user settings file is corrupted: {file}")
20
+ return {}
21
+ else:
22
+ return {}
23
+
24
+ def save_settings(self, request, settings):
25
+ file = self.user_manager.get_request_user_filepath(
26
+ request, "comfy.settings.json")
27
+ with open(file, "w") as f:
28
+ f.write(json.dumps(settings, indent=4))
29
+
30
+ def add_routes(self, routes):
31
+ @routes.get("/settings")
32
+ async def get_settings(request):
33
+ return web.json_response(self.get_settings(request))
34
+
35
+ @routes.get("/settings/{id}")
36
+ async def get_setting(request):
37
+ value = None
38
+ settings = self.get_settings(request)
39
+ setting_id = request.match_info.get("id", None)
40
+ if setting_id and setting_id in settings:
41
+ value = settings[setting_id]
42
+ return web.json_response(value)
43
+
44
+ @routes.post("/settings")
45
+ async def post_settings(request):
46
+ settings = self.get_settings(request)
47
+ new_settings = await request.json()
48
+ self.save_settings(request, {**settings, **new_settings})
49
+ return web.Response(status=200)
50
+
51
+ @routes.post("/settings/{id}")
52
+ async def post_setting(request):
53
+ setting_id = request.match_info.get("id", None)
54
+ if not setting_id:
55
+ return web.Response(status=400)
56
+ settings = self.get_settings(request)
57
+ settings[setting_id] = await request.json()
58
+ self.save_settings(request, settings)
59
+ return web.Response(status=200)
app/custom_node_manager.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import folder_paths
5
+ import glob
6
+ from aiohttp import web
7
+ import json
8
+ import logging
9
+ from functools import lru_cache
10
+
11
+ from utils.json_util import merge_json_recursive
12
+
13
+
14
+ # Extra locale files to load into main.json
15
+ EXTRA_LOCALE_FILES = [
16
+ "nodeDefs.json",
17
+ "commands.json",
18
+ "settings.json",
19
+ ]
20
+
21
+
22
+ def safe_load_json_file(file_path: str) -> dict:
23
+ if not os.path.exists(file_path):
24
+ return {}
25
+
26
+ try:
27
+ with open(file_path, "r", encoding="utf-8") as f:
28
+ return json.load(f)
29
+ except json.JSONDecodeError:
30
+ logging.error(f"Error loading {file_path}")
31
+ return {}
32
+
33
+
34
+ class CustomNodeManager:
35
+ @lru_cache(maxsize=1)
36
+ def build_translations(self):
37
+ """Load all custom nodes translations during initialization. Translations are
38
+ expected to be loaded from `locales/` folder.
39
+
40
+ The folder structure is expected to be the following:
41
+ - custom_nodes/
42
+ - custom_node_1/
43
+ - locales/
44
+ - en/
45
+ - main.json
46
+ - commands.json
47
+ - settings.json
48
+
49
+ returned translations are expected to be in the following format:
50
+ {
51
+ "en": {
52
+ "nodeDefs": {...},
53
+ "commands": {...},
54
+ "settings": {...},
55
+ ...{other main.json keys}
56
+ }
57
+ }
58
+ """
59
+
60
+ translations = {}
61
+
62
+ for folder in folder_paths.get_folder_paths("custom_nodes"):
63
+ # Sort glob results for deterministic ordering
64
+ for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
65
+ locales_dir = os.path.join(custom_node_dir, "locales")
66
+ if not os.path.exists(locales_dir):
67
+ continue
68
+
69
+ for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
70
+ lang_code = os.path.basename(os.path.dirname(lang_dir))
71
+
72
+ if lang_code not in translations:
73
+ translations[lang_code] = {}
74
+
75
+ # Load main.json
76
+ main_file = os.path.join(lang_dir, "main.json")
77
+ node_translations = safe_load_json_file(main_file)
78
+
79
+ # Load extra locale files
80
+ for extra_file in EXTRA_LOCALE_FILES:
81
+ extra_file_path = os.path.join(lang_dir, extra_file)
82
+ key = extra_file.split(".")[0]
83
+ json_data = safe_load_json_file(extra_file_path)
84
+ if json_data:
85
+ node_translations[key] = json_data
86
+
87
+ if node_translations:
88
+ translations[lang_code] = merge_json_recursive(
89
+ translations[lang_code], node_translations
90
+ )
91
+
92
+ return translations
93
+
94
+ def add_routes(self, routes, webapp, loadedModules):
95
+
96
+ @routes.get("/workflow_templates")
97
+ async def get_workflow_templates(request):
98
+ """Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
99
+ files = [
100
+ file
101
+ for folder in folder_paths.get_folder_paths("custom_nodes")
102
+ for file in glob.glob(
103
+ os.path.join(folder, "*/example_workflows/*.json")
104
+ )
105
+ ]
106
+ workflow_templates_dict = (
107
+ {}
108
+ ) # custom_nodes folder name -> example workflow names
109
+ for file in files:
110
+ custom_nodes_name = os.path.basename(
111
+ os.path.dirname(os.path.dirname(file))
112
+ )
113
+ workflow_name = os.path.splitext(os.path.basename(file))[0]
114
+ workflow_templates_dict.setdefault(custom_nodes_name, []).append(
115
+ workflow_name
116
+ )
117
+ return web.json_response(workflow_templates_dict)
118
+
119
+ # Serve workflow templates from custom nodes.
120
+ for module_name, module_dir in loadedModules:
121
+ workflows_dir = os.path.join(module_dir, "example_workflows")
122
+ if os.path.exists(workflows_dir):
123
+ webapp.add_routes(
124
+ [
125
+ web.static(
126
+ "/api/workflow_templates/" + module_name, workflows_dir
127
+ )
128
+ ]
129
+ )
130
+
131
+ @routes.get("/i18n")
132
+ async def get_i18n(request):
133
+ """Returns translations from all custom nodes' locales folders."""
134
+ return web.json_response(self.build_translations())
app/frontend_management.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import re
6
+ import sys
7
+ import tempfile
8
+ import zipfile
9
+ import importlib
10
+ from dataclasses import dataclass
11
+ from functools import cached_property
12
+ from pathlib import Path
13
+ from typing import TypedDict, Optional
14
+ from importlib.metadata import version
15
+
16
+ import requests
17
+ from typing_extensions import NotRequired
18
+
19
+ from comfy.cli_args import DEFAULT_VERSION_STRING
20
+ import app.logger
21
+
22
+ # The path to the requirements.txt file
23
+ req_path = Path(__file__).parents[1] / "requirements.txt"
24
+
25
+
26
+ def frontend_install_warning_message():
27
+ """The warning message to display when the frontend version is not up to date."""
28
+
29
+ extra = ""
30
+ if sys.flags.no_user_site:
31
+ extra = "-s "
32
+ return f"""
33
+ Please install the updated requirements.txt file by running:
34
+ {sys.executable} {extra}-m pip install -r {req_path}
35
+
36
+ This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
37
+
38
+ If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
39
+ """.strip()
40
+
41
+
42
+ def check_frontend_version():
43
+ """Check if the frontend version is up to date."""
44
+
45
+ def parse_version(version: str) -> tuple[int, int, int]:
46
+ return tuple(map(int, version.split(".")))
47
+
48
+ try:
49
+ frontend_version_str = version("comfyui-frontend-package")
50
+ frontend_version = parse_version(frontend_version_str)
51
+ with open(req_path, "r", encoding="utf-8") as f:
52
+ required_frontend = parse_version(f.readline().split("=")[-1])
53
+ if frontend_version < required_frontend:
54
+ app.logger.log_startup_warning(
55
+ f"""
56
+ ________________________________________________________________________
57
+ WARNING WARNING WARNING WARNING WARNING
58
+
59
+ Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
60
+
61
+ {frontend_install_warning_message()}
62
+ ________________________________________________________________________
63
+ """.strip()
64
+ )
65
+ else:
66
+ logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
67
+ except Exception as e:
68
+ logging.error(f"Failed to check frontend version: {e}")
69
+
70
+
71
+ REQUEST_TIMEOUT = 10 # seconds
72
+
73
+
74
+ class Asset(TypedDict):
75
+ url: str
76
+
77
+
78
+ class Release(TypedDict):
79
+ id: int
80
+ tag_name: str
81
+ name: str
82
+ prerelease: bool
83
+ created_at: str
84
+ published_at: str
85
+ body: str
86
+ assets: NotRequired[list[Asset]]
87
+
88
+
89
+ @dataclass
90
+ class FrontEndProvider:
91
+ owner: str
92
+ repo: str
93
+
94
+ @property
95
+ def folder_name(self) -> str:
96
+ return f"{self.owner}_{self.repo}"
97
+
98
+ @property
99
+ def release_url(self) -> str:
100
+ return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
101
+
102
+ @cached_property
103
+ def all_releases(self) -> list[Release]:
104
+ releases = []
105
+ api_url = self.release_url
106
+ while api_url:
107
+ response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
108
+ response.raise_for_status() # Raises an HTTPError if the response was an error
109
+ releases.extend(response.json())
110
+ # GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
111
+ if "next" in response.links:
112
+ api_url = response.links["next"]["url"]
113
+ else:
114
+ api_url = None
115
+ return releases
116
+
117
+ @cached_property
118
+ def latest_release(self) -> Release:
119
+ latest_release_url = f"{self.release_url}/latest"
120
+ response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
121
+ response.raise_for_status() # Raises an HTTPError if the response was an error
122
+ return response.json()
123
+
124
+ def get_release(self, version: str) -> Release:
125
+ if version == "latest":
126
+ return self.latest_release
127
+ else:
128
+ for release in self.all_releases:
129
+ if release["tag_name"] in [version, f"v{version}"]:
130
+ return release
131
+ raise ValueError(f"Version {version} not found in releases")
132
+
133
+
134
+ def download_release_asset_zip(release: Release, destination_path: str) -> None:
135
+ """Download dist.zip from github release."""
136
+ asset_url = None
137
+ for asset in release.get("assets", []):
138
+ if asset["name"] == "dist.zip":
139
+ asset_url = asset["url"]
140
+ break
141
+
142
+ if not asset_url:
143
+ raise ValueError("dist.zip not found in the release assets")
144
+
145
+ # Use a temporary file to download the zip content
146
+ with tempfile.TemporaryFile() as tmp_file:
147
+ headers = {"Accept": "application/octet-stream"}
148
+ response = requests.get(
149
+ asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
150
+ )
151
+ response.raise_for_status() # Ensure we got a successful response
152
+
153
+ # Write the content to the temporary file
154
+ tmp_file.write(response.content)
155
+
156
+ # Go back to the beginning of the temporary file
157
+ tmp_file.seek(0)
158
+
159
+ # Extract the zip file content to the destination path
160
+ with zipfile.ZipFile(tmp_file, "r") as zip_ref:
161
+ zip_ref.extractall(destination_path)
162
+
163
+
164
+ class FrontendManager:
165
+ CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
166
+
167
+ @classmethod
168
+ def default_frontend_path(cls) -> str:
169
+ try:
170
+ import comfyui_frontend_package
171
+
172
+ return str(importlib.resources.files(comfyui_frontend_package) / "static")
173
+ except ImportError:
174
+ logging.error(
175
+ f"""
176
+ ********** ERROR ***********
177
+
178
+ comfyui-frontend-package is not installed.
179
+
180
+ {frontend_install_warning_message()}
181
+
182
+ ********** ERROR ***********
183
+ """.strip()
184
+ )
185
+ sys.exit(-1)
186
+
187
+ @classmethod
188
+ def parse_version_string(cls, value: str) -> tuple[str, str, str]:
189
+ """
190
+ Args:
191
+ value (str): The version string to parse.
192
+
193
+ Returns:
194
+ tuple[str, str]: A tuple containing provider name and version.
195
+
196
+ Raises:
197
+ argparse.ArgumentTypeError: If the version string is invalid.
198
+ """
199
+ VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
200
+ match_result = re.match(VERSION_PATTERN, value)
201
+ if match_result is None:
202
+ raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
203
+
204
+ return match_result.group(1), match_result.group(2), match_result.group(3)
205
+
206
+ @classmethod
207
+ def init_frontend_unsafe(
208
+ cls, version_string: str, provider: Optional[FrontEndProvider] = None
209
+ ) -> str:
210
+ """
211
+ Initializes the frontend for the specified version.
212
+
213
+ Args:
214
+ version_string (str): The version string.
215
+ provider (FrontEndProvider, optional): The provider to use. Defaults to None.
216
+
217
+ Returns:
218
+ str: The path to the initialized frontend.
219
+
220
+ Raises:
221
+ Exception: If there is an error during the initialization process.
222
+ main error source might be request timeout or invalid URL.
223
+ """
224
+ if version_string == DEFAULT_VERSION_STRING:
225
+ check_frontend_version()
226
+ return cls.default_frontend_path()
227
+
228
+ repo_owner, repo_name, version = cls.parse_version_string(version_string)
229
+
230
+ if version.startswith("v"):
231
+ expected_path = str(
232
+ Path(cls.CUSTOM_FRONTENDS_ROOT)
233
+ / f"{repo_owner}_{repo_name}"
234
+ / version.lstrip("v")
235
+ )
236
+ if os.path.exists(expected_path):
237
+ logging.info(
238
+ f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
239
+ )
240
+ return expected_path
241
+
242
+ logging.info(
243
+ f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
244
+ )
245
+
246
+ provider = provider or FrontEndProvider(repo_owner, repo_name)
247
+ release = provider.get_release(version)
248
+
249
+ semantic_version = release["tag_name"].lstrip("v")
250
+ web_root = str(
251
+ Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
252
+ )
253
+ if not os.path.exists(web_root):
254
+ try:
255
+ os.makedirs(web_root, exist_ok=True)
256
+ logging.info(
257
+ "Downloading frontend(%s) version(%s) to (%s)",
258
+ provider.folder_name,
259
+ semantic_version,
260
+ web_root,
261
+ )
262
+ logging.debug(release)
263
+ download_release_asset_zip(release, destination_path=web_root)
264
+ finally:
265
+ # Clean up the directory if it is empty, i.e. the download failed
266
+ if not os.listdir(web_root):
267
+ os.rmdir(web_root)
268
+
269
+ return web_root
270
+
271
+ @classmethod
272
+ def init_frontend(cls, version_string: str) -> str:
273
+ """
274
+ Initializes the frontend with the specified version string.
275
+
276
+ Args:
277
+ version_string (str): The version string to initialize the frontend with.
278
+
279
+ Returns:
280
+ str: The path of the initialized frontend.
281
+ """
282
+ try:
283
+ return cls.init_frontend_unsafe(version_string)
284
+ except Exception as e:
285
+ logging.error("Failed to initialize frontend: %s", e)
286
+ logging.info("Falling back to the default frontend.")
287
+ check_frontend_version()
288
+ return cls.default_frontend_path()
app/logger.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from datetime import datetime
3
+ import io
4
+ import logging
5
+ import sys
6
+ import threading
7
+
8
+ logs = None
9
+ stdout_interceptor = None
10
+ stderr_interceptor = None
11
+
12
+
13
+ class LogInterceptor(io.TextIOWrapper):
14
+ def __init__(self, stream, *args, **kwargs):
15
+ buffer = stream.buffer
16
+ encoding = stream.encoding
17
+ super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering)
18
+ self._lock = threading.Lock()
19
+ self._flush_callbacks = []
20
+ self._logs_since_flush = []
21
+
22
+ def write(self, data):
23
+ entry = {"t": datetime.now().isoformat(), "m": data}
24
+ with self._lock:
25
+ self._logs_since_flush.append(entry)
26
+
27
+ # Simple handling for cr to overwrite the last output if it isnt a full line
28
+ # else logs just get full of progress messages
29
+ if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"):
30
+ logs.pop()
31
+ logs.append(entry)
32
+ super().write(data)
33
+
34
+ def flush(self):
35
+ super().flush()
36
+ for cb in self._flush_callbacks:
37
+ cb(self._logs_since_flush)
38
+ self._logs_since_flush = []
39
+
40
+ def on_flush(self, callback):
41
+ self._flush_callbacks.append(callback)
42
+
43
+
44
+ def get_logs():
45
+ return logs
46
+
47
+
48
+ def on_flush(callback):
49
+ if stdout_interceptor is not None:
50
+ stdout_interceptor.on_flush(callback)
51
+ if stderr_interceptor is not None:
52
+ stderr_interceptor.on_flush(callback)
53
+
54
+ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
55
+ global logs
56
+ if logs:
57
+ return
58
+
59
+ # Override output streams and log to buffer
60
+ logs = deque(maxlen=capacity)
61
+
62
+ global stdout_interceptor
63
+ global stderr_interceptor
64
+ stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout)
65
+ stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr)
66
+
67
+ # Setup default global logger
68
+ logger = logging.getLogger()
69
+ logger.setLevel(log_level)
70
+
71
+ stream_handler = logging.StreamHandler()
72
+ stream_handler.setFormatter(logging.Formatter("%(message)s"))
73
+
74
+ if use_stdout:
75
+ # Only errors and critical to stderr
76
+ stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
77
+
78
+ # Lesser to stdout
79
+ stdout_handler = logging.StreamHandler(sys.stdout)
80
+ stdout_handler.setFormatter(logging.Formatter("%(message)s"))
81
+ stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
82
+ logger.addHandler(stdout_handler)
83
+
84
+ logger.addHandler(stream_handler)
85
+
86
+
87
+ STARTUP_WARNINGS = []
88
+
89
+
90
+ def log_startup_warning(msg):
91
+ logging.warning(msg)
92
+ STARTUP_WARNINGS.append(msg)
93
+
94
+
95
+ def print_startup_warnings():
96
+ for s in STARTUP_WARNINGS:
97
+ logging.warning(s)
98
+ STARTUP_WARNINGS.clear()
app/model_manager.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import base64
5
+ import json
6
+ import time
7
+ import logging
8
+ import folder_paths
9
+ import glob
10
+ import comfy.utils
11
+ from aiohttp import web
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
15
+
16
+
17
+ class ModelFileManager:
18
+ def __init__(self) -> None:
19
+ self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
20
+
21
+ def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
22
+ return self.cache.get(key, default)
23
+
24
+ def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
25
+ self.cache[key] = value
26
+
27
+ def clear_cache(self):
28
+ self.cache.clear()
29
+
30
+ def add_routes(self, routes):
31
+ # NOTE: This is an experiment to replace `/models`
32
+ @routes.get("/experiment/models")
33
+ async def get_model_folders(request):
34
+ model_types = list(folder_paths.folder_names_and_paths.keys())
35
+ folder_black_list = ["configs", "custom_nodes"]
36
+ output_folders: list[dict] = []
37
+ for folder in model_types:
38
+ if folder in folder_black_list:
39
+ continue
40
+ output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
41
+ return web.json_response(output_folders)
42
+
43
+ # NOTE: This is an experiment to replace `/models/{folder}`
44
+ @routes.get("/experiment/models/{folder}")
45
+ async def get_all_models(request):
46
+ folder = request.match_info.get("folder", None)
47
+ if not folder in folder_paths.folder_names_and_paths:
48
+ return web.Response(status=404)
49
+ files = self.get_model_file_list(folder)
50
+ return web.json_response(files)
51
+
52
+ @routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
53
+ async def get_model_preview(request):
54
+ folder_name = request.match_info.get("folder", None)
55
+ path_index = int(request.match_info.get("path_index", None))
56
+ filename = request.match_info.get("filename", None)
57
+
58
+ if not folder_name in folder_paths.folder_names_and_paths:
59
+ return web.Response(status=404)
60
+
61
+ folders = folder_paths.folder_names_and_paths[folder_name]
62
+ folder = folders[0][path_index]
63
+ full_filename = os.path.join(folder, filename)
64
+
65
+ previews = self.get_model_previews(full_filename)
66
+ default_preview = previews[0] if len(previews) > 0 else None
67
+ if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
68
+ return web.Response(status=404)
69
+
70
+ try:
71
+ with Image.open(default_preview) as img:
72
+ img_bytes = BytesIO()
73
+ img.save(img_bytes, format="WEBP")
74
+ img_bytes.seek(0)
75
+ return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
76
+ except:
77
+ return web.Response(status=404)
78
+
79
+ def get_model_file_list(self, folder_name: str):
80
+ folder_name = map_legacy(folder_name)
81
+ folders = folder_paths.folder_names_and_paths[folder_name]
82
+ output_list: list[dict] = []
83
+
84
+ for index, folder in enumerate(folders[0]):
85
+ if not os.path.isdir(folder):
86
+ continue
87
+ out = self.cache_model_file_list_(folder)
88
+ if out is None:
89
+ out = self.recursive_search_models_(folder, index)
90
+ self.set_cache(folder, out)
91
+ output_list.extend(out[0])
92
+
93
+ return output_list
94
+
95
+ def cache_model_file_list_(self, folder: str):
96
+ model_file_list_cache = self.get_cache(folder)
97
+
98
+ if model_file_list_cache is None:
99
+ return None
100
+ if not os.path.isdir(folder):
101
+ return None
102
+ if os.path.getmtime(folder) != model_file_list_cache[1]:
103
+ return None
104
+ for x in model_file_list_cache[1]:
105
+ time_modified = model_file_list_cache[1][x]
106
+ folder = x
107
+ if os.path.getmtime(folder) != time_modified:
108
+ return None
109
+
110
+ return model_file_list_cache
111
+
112
+ def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
113
+ if not os.path.isdir(directory):
114
+ return [], {}, time.perf_counter()
115
+
116
+ excluded_dir_names = [".git"]
117
+ # TODO use settings
118
+ include_hidden_files = False
119
+
120
+ result: list[str] = []
121
+ dirs: dict[str, float] = {}
122
+
123
+ for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
124
+ subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
125
+ if not include_hidden_files:
126
+ subdirs[:] = [d for d in subdirs if not d.startswith(".")]
127
+ filenames = [f for f in filenames if not f.startswith(".")]
128
+
129
+ filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
130
+
131
+ for file_name in filenames:
132
+ try:
133
+ relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
134
+ result.append(relative_path)
135
+ except:
136
+ logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
137
+ continue
138
+
139
+ for d in subdirs:
140
+ path: str = os.path.join(dirpath, d)
141
+ try:
142
+ dirs[path] = os.path.getmtime(path)
143
+ except FileNotFoundError:
144
+ logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
145
+ continue
146
+
147
+ return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
148
+
149
+ def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
150
+ dirname = os.path.dirname(filepath)
151
+
152
+ if not os.path.exists(dirname):
153
+ return []
154
+
155
+ basename = os.path.splitext(filepath)[0]
156
+ match_files = glob.glob(f"{basename}.*", recursive=False)
157
+ image_files = filter_files_content_types(match_files, "image")
158
+ safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
159
+ safetensors_metadata = {}
160
+
161
+ result: list[str | BytesIO] = []
162
+
163
+ for filename in image_files:
164
+ _basename = os.path.splitext(filename)[0]
165
+ if _basename == basename:
166
+ result.append(filename)
167
+ if _basename == f"{basename}.preview":
168
+ result.append(filename)
169
+
170
+ if safetensors_file:
171
+ safetensors_filepath = os.path.join(dirname, safetensors_file)
172
+ header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
173
+ if header:
174
+ safetensors_metadata = json.loads(header)
175
+ safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
176
+ if safetensors_images:
177
+ safetensors_images = json.loads(safetensors_images)
178
+ for image in safetensors_images:
179
+ result.append(BytesIO(base64.b64decode(image)))
180
+
181
+ return result
182
+
183
+ def __exit__(self, exc_type, exc_value, traceback):
184
+ self.clear_cache()
app/user_manager.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json
3
+ import os
4
+ import re
5
+ import uuid
6
+ import glob
7
+ import shutil
8
+ import logging
9
+ from aiohttp import web
10
+ from urllib import parse
11
+ from comfy.cli_args import args
12
+ import folder_paths
13
+ from .app_settings import AppSettings
14
+ from typing import TypedDict
15
+
16
+ default_user = "default"
17
+
18
+
19
+ class FileInfo(TypedDict):
20
+ path: str
21
+ size: int
22
+ modified: int
23
+
24
+
25
+ def get_file_info(path: str, relative_to: str) -> FileInfo:
26
+ return {
27
+ "path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
28
+ "size": os.path.getsize(path),
29
+ "modified": os.path.getmtime(path)
30
+ }
31
+
32
+
33
+ class UserManager():
34
+ def __init__(self):
35
+ user_directory = folder_paths.get_user_directory()
36
+
37
+ self.settings = AppSettings(self)
38
+ if not os.path.exists(user_directory):
39
+ os.makedirs(user_directory, exist_ok=True)
40
+ if not args.multi_user:
41
+ logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
42
+ logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
43
+
44
+ if args.multi_user:
45
+ if os.path.isfile(self.get_users_file()):
46
+ with open(self.get_users_file()) as f:
47
+ self.users = json.load(f)
48
+ else:
49
+ self.users = {}
50
+ else:
51
+ self.users = {"default": "default"}
52
+
53
+ def get_users_file(self):
54
+ return os.path.join(folder_paths.get_user_directory(), "users.json")
55
+
56
+ def get_request_user_id(self, request):
57
+ user = "default"
58
+ if args.multi_user and "comfy-user" in request.headers:
59
+ user = request.headers["comfy-user"]
60
+
61
+ if user not in self.users:
62
+ raise KeyError("Unknown user: " + user)
63
+
64
+ return user
65
+
66
+ def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
67
+ user_directory = folder_paths.get_user_directory()
68
+
69
+ if type == "userdata":
70
+ root_dir = user_directory
71
+ else:
72
+ raise KeyError("Unknown filepath type:" + type)
73
+
74
+ user = self.get_request_user_id(request)
75
+ path = user_root = os.path.abspath(os.path.join(root_dir, user))
76
+
77
+ # prevent leaving /{type}
78
+ if os.path.commonpath((root_dir, user_root)) != root_dir:
79
+ return None
80
+
81
+ if file is not None:
82
+ # Check if filename is url encoded
83
+ if "%" in file:
84
+ file = parse.unquote(file)
85
+
86
+ # prevent leaving /{type}/{user}
87
+ path = os.path.abspath(os.path.join(user_root, file))
88
+ if os.path.commonpath((user_root, path)) != user_root:
89
+ return None
90
+
91
+ parent = os.path.split(path)[0]
92
+
93
+ if create_dir and not os.path.exists(parent):
94
+ os.makedirs(parent, exist_ok=True)
95
+
96
+ return path
97
+
98
+ def add_user(self, name):
99
+ name = name.strip()
100
+ if not name:
101
+ raise ValueError("username not provided")
102
+ user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
103
+ user_id = user_id + "_" + str(uuid.uuid4())
104
+
105
+ self.users[user_id] = name
106
+
107
+ with open(self.get_users_file(), "w") as f:
108
+ json.dump(self.users, f)
109
+
110
+ return user_id
111
+
112
+ def add_routes(self, routes):
113
+ self.settings.add_routes(routes)
114
+
115
+ @routes.get("/users")
116
+ async def get_users(request):
117
+ if args.multi_user:
118
+ return web.json_response({"storage": "server", "users": self.users})
119
+ else:
120
+ user_dir = self.get_request_user_filepath(request, None, create_dir=False)
121
+ return web.json_response({
122
+ "storage": "server",
123
+ "migrated": os.path.exists(user_dir)
124
+ })
125
+
126
+ @routes.post("/users")
127
+ async def post_users(request):
128
+ body = await request.json()
129
+ username = body["username"]
130
+ if username in self.users.values():
131
+ return web.json_response({"error": "Duplicate username."}, status=400)
132
+
133
+ user_id = self.add_user(username)
134
+ return web.json_response(user_id)
135
+
136
+ @routes.get("/userdata")
137
+ async def listuserdata(request):
138
+ """
139
+ List user data files in a specified directory.
140
+
141
+ This endpoint allows listing files in a user's data directory, with options for recursion,
142
+ full file information, and path splitting.
143
+
144
+ Query Parameters:
145
+ - dir (required): The directory to list files from.
146
+ - recurse (optional): If "true", recursively list files in subdirectories.
147
+ - full_info (optional): If "true", return detailed file information (path, size, modified time).
148
+ - split (optional): If "true", split file paths into components (only applies when full_info is false).
149
+
150
+ Returns:
151
+ - 400: If 'dir' parameter is missing.
152
+ - 403: If the requested path is not allowed.
153
+ - 404: If the requested directory does not exist.
154
+ - 200: JSON response with the list of files or file information.
155
+
156
+ The response format depends on the query parameters:
157
+ - Default: List of relative file paths.
158
+ - full_info=true: List of dictionaries with file details.
159
+ - split=true (and full_info=false): List of lists, each containing path components.
160
+ """
161
+ directory = request.rel_url.query.get('dir', '')
162
+ if not directory:
163
+ return web.Response(status=400, text="Directory not provided")
164
+
165
+ path = self.get_request_user_filepath(request, directory)
166
+ if not path:
167
+ return web.Response(status=403, text="Invalid directory")
168
+
169
+ if not os.path.exists(path):
170
+ return web.Response(status=404, text="Directory not found")
171
+
172
+ recurse = request.rel_url.query.get('recurse', '').lower() == "true"
173
+ full_info = request.rel_url.query.get('full_info', '').lower() == "true"
174
+ split_path = request.rel_url.query.get('split', '').lower() == "true"
175
+
176
+ # Use different patterns based on whether we're recursing or not
177
+ if recurse:
178
+ pattern = os.path.join(glob.escape(path), '**', '*')
179
+ else:
180
+ pattern = os.path.join(glob.escape(path), '*')
181
+
182
+ def process_full_path(full_path: str) -> FileInfo | str | list[str]:
183
+ if full_info:
184
+ return get_file_info(full_path, path)
185
+
186
+ rel_path = os.path.relpath(full_path, path).replace(os.sep, '/')
187
+ if split_path:
188
+ return [rel_path] + rel_path.split('/')
189
+
190
+ return rel_path
191
+
192
+ results = [
193
+ process_full_path(full_path)
194
+ for full_path in glob.glob(pattern, recursive=recurse)
195
+ if os.path.isfile(full_path)
196
+ ]
197
+
198
+ return web.json_response(results)
199
+
200
+ def get_user_data_path(request, check_exists = False, param = "file"):
201
+ file = request.match_info.get(param, None)
202
+ if not file:
203
+ return web.Response(status=400)
204
+
205
+ path = self.get_request_user_filepath(request, file)
206
+ if not path:
207
+ return web.Response(status=403)
208
+
209
+ if check_exists and not os.path.exists(path):
210
+ return web.Response(status=404)
211
+
212
+ return path
213
+
214
+ @routes.get("/userdata/{file}")
215
+ async def getuserdata(request):
216
+ path = get_user_data_path(request, check_exists=True)
217
+ if not isinstance(path, str):
218
+ return path
219
+
220
+ return web.FileResponse(path)
221
+
222
+ @routes.post("/userdata/{file}")
223
+ async def post_userdata(request):
224
+ """
225
+ Upload or update a user data file.
226
+
227
+ This endpoint handles file uploads to a user's data directory, with options for
228
+ controlling overwrite behavior and response format.
229
+
230
+ Query Parameters:
231
+ - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
232
+ - full_info (optional): If "true", returns detailed file information (path, size, modified time).
233
+ If "false", returns only the relative file path.
234
+
235
+ Path Parameters:
236
+ - file: The target file path (URL encoded if necessary).
237
+
238
+ Returns:
239
+ - 400: If 'file' parameter is missing.
240
+ - 403: If the requested path is not allowed.
241
+ - 409: If overwrite=false and the file already exists.
242
+ - 200: JSON response with either:
243
+ - Full file information (if full_info=true)
244
+ - Relative file path (if full_info=false)
245
+
246
+ The request body should contain the raw file content to be written.
247
+ """
248
+ path = get_user_data_path(request)
249
+ if not isinstance(path, str):
250
+ return path
251
+
252
+ overwrite = request.query.get("overwrite", 'true') != "false"
253
+ full_info = request.query.get('full_info', 'false').lower() == "true"
254
+
255
+ if not overwrite and os.path.exists(path):
256
+ return web.Response(status=409, text="File already exists")
257
+
258
+ body = await request.read()
259
+
260
+ with open(path, "wb") as f:
261
+ f.write(body)
262
+
263
+ user_path = self.get_request_user_filepath(request, None)
264
+ if full_info:
265
+ resp = get_file_info(path, user_path)
266
+ else:
267
+ resp = os.path.relpath(path, user_path)
268
+
269
+ return web.json_response(resp)
270
+
271
+ @routes.delete("/userdata/{file}")
272
+ async def delete_userdata(request):
273
+ path = get_user_data_path(request, check_exists=True)
274
+ if not isinstance(path, str):
275
+ return path
276
+
277
+ os.remove(path)
278
+
279
+ return web.Response(status=204)
280
+
281
+ @routes.post("/userdata/{file}/move/{dest}")
282
+ async def move_userdata(request):
283
+ """
284
+ Move or rename a user data file.
285
+
286
+ This endpoint handles moving or renaming files within a user's data directory, with options for
287
+ controlling overwrite behavior and response format.
288
+
289
+ Path Parameters:
290
+ - file: The source file path (URL encoded if necessary)
291
+ - dest: The destination file path (URL encoded if necessary)
292
+
293
+ Query Parameters:
294
+ - overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
295
+ - full_info (optional): If "true", returns detailed file information (path, size, modified time).
296
+ If "false", returns only the relative file path.
297
+
298
+ Returns:
299
+ - 400: If either 'file' or 'dest' parameter is missing
300
+ - 403: If either requested path is not allowed
301
+ - 404: If the source file does not exist
302
+ - 409: If overwrite=false and the destination file already exists
303
+ - 200: JSON response with either:
304
+ - Full file information (if full_info=true)
305
+ - Relative file path (if full_info=false)
306
+ """
307
+ source = get_user_data_path(request, check_exists=True)
308
+ if not isinstance(source, str):
309
+ return source
310
+
311
+ dest = get_user_data_path(request, check_exists=False, param="dest")
312
+ if not isinstance(source, str):
313
+ return dest
314
+
315
+ overwrite = request.query.get("overwrite", 'true') != "false"
316
+ full_info = request.query.get('full_info', 'false').lower() == "true"
317
+
318
+ if not overwrite and os.path.exists(dest):
319
+ return web.Response(status=409, text="File already exists")
320
+
321
+ logging.info(f"moving '{source}' -> '{dest}'")
322
+ shutil.move(source, dest)
323
+
324
+ user_path = self.get_request_user_filepath(request, None)
325
+ if full_info:
326
+ resp = get_file_info(dest, user_path)
327
+ else:
328
+ resp = os.path.relpath(dest, user_path)
329
+
330
+ return web.json_response(resp)
comfy/checkpoint_pickle.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ load = pickle.load
4
+
5
+ class Empty:
6
+ pass
7
+
8
+ class Unpickler(pickle.Unpickler):
9
+ def find_class(self, module, name):
10
+ #TODO: safe unpickle
11
+ if module.startswith("pytorch_lightning"):
12
+ return Empty
13
+ return super().find_class(module, name)
comfy/cldm/cldm.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..ldm.modules.diffusionmodules.util import (
8
+ timestep_embedding,
9
+ )
10
+
11
+ from ..ldm.modules.attention import SpatialTransformer
12
+ from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
13
+ from ..ldm.util import exists
14
+ from .control_types import UNION_CONTROLNET_TYPES
15
+ from collections import OrderedDict
16
+ import comfy.ops
17
+ from comfy.ldm.modules.attention import optimized_attention
18
+
19
+ class OptimizedAttention(nn.Module):
20
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
21
+ super().__init__()
22
+ self.heads = nhead
23
+ self.c = c
24
+
25
+ self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
26
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
27
+
28
+ def forward(self, x):
29
+ x = self.in_proj(x)
30
+ q, k, v = x.split(self.c, dim=2)
31
+ out = optimized_attention(q, k, v, self.heads)
32
+ return self.out_proj(out)
33
+
34
+ class QuickGELU(nn.Module):
35
+ def forward(self, x: torch.Tensor):
36
+ return x * torch.sigmoid(1.702 * x)
37
+
38
+ class ResBlockUnionControlnet(nn.Module):
39
+ def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
40
+ super().__init__()
41
+ self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
42
+ self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
43
+ self.mlp = nn.Sequential(
44
+ OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
45
+ ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
46
+ self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
47
+
48
+ def attention(self, x: torch.Tensor):
49
+ return self.attn(x)
50
+
51
+ def forward(self, x: torch.Tensor):
52
+ x = x + self.attention(self.ln_1(x))
53
+ x = x + self.mlp(self.ln_2(x))
54
+ return x
55
+
56
+ class ControlledUnetModel(UNetModel):
57
+ #implemented in the ldm unet
58
+ pass
59
+
60
+ class ControlNet(nn.Module):
61
+ def __init__(
62
+ self,
63
+ image_size,
64
+ in_channels,
65
+ model_channels,
66
+ hint_channels,
67
+ num_res_blocks,
68
+ dropout=0,
69
+ channel_mult=(1, 2, 4, 8),
70
+ conv_resample=True,
71
+ dims=2,
72
+ num_classes=None,
73
+ use_checkpoint=False,
74
+ dtype=torch.float32,
75
+ num_heads=-1,
76
+ num_head_channels=-1,
77
+ num_heads_upsample=-1,
78
+ use_scale_shift_norm=False,
79
+ resblock_updown=False,
80
+ use_new_attention_order=False,
81
+ use_spatial_transformer=False, # custom transformer support
82
+ transformer_depth=1, # custom transformer support
83
+ context_dim=None, # custom transformer support
84
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
85
+ legacy=True,
86
+ disable_self_attentions=None,
87
+ num_attention_blocks=None,
88
+ disable_middle_self_attn=False,
89
+ use_linear_in_transformer=False,
90
+ adm_in_channels=None,
91
+ transformer_depth_middle=None,
92
+ transformer_depth_output=None,
93
+ attn_precision=None,
94
+ union_controlnet_num_control_type=None,
95
+ device=None,
96
+ operations=comfy.ops.disable_weight_init,
97
+ **kwargs,
98
+ ):
99
+ super().__init__()
100
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
101
+ if use_spatial_transformer:
102
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
103
+
104
+ if context_dim is not None:
105
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
106
+ # from omegaconf.listconfig import ListConfig
107
+ # if type(context_dim) == ListConfig:
108
+ # context_dim = list(context_dim)
109
+
110
+ if num_heads_upsample == -1:
111
+ num_heads_upsample = num_heads
112
+
113
+ if num_heads == -1:
114
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
115
+
116
+ if num_head_channels == -1:
117
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
118
+
119
+ self.dims = dims
120
+ self.image_size = image_size
121
+ self.in_channels = in_channels
122
+ self.model_channels = model_channels
123
+
124
+ if isinstance(num_res_blocks, int):
125
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
126
+ else:
127
+ if len(num_res_blocks) != len(channel_mult):
128
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
129
+ "as a list/tuple (per-level) with the same length as channel_mult")
130
+ self.num_res_blocks = num_res_blocks
131
+
132
+ if disable_self_attentions is not None:
133
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
134
+ assert len(disable_self_attentions) == len(channel_mult)
135
+ if num_attention_blocks is not None:
136
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
137
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
138
+
139
+ transformer_depth = transformer_depth[:]
140
+
141
+ self.dropout = dropout
142
+ self.channel_mult = channel_mult
143
+ self.conv_resample = conv_resample
144
+ self.num_classes = num_classes
145
+ self.use_checkpoint = use_checkpoint
146
+ self.dtype = dtype
147
+ self.num_heads = num_heads
148
+ self.num_head_channels = num_head_channels
149
+ self.num_heads_upsample = num_heads_upsample
150
+ self.predict_codebook_ids = n_embed is not None
151
+
152
+ time_embed_dim = model_channels * 4
153
+ self.time_embed = nn.Sequential(
154
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
155
+ nn.SiLU(),
156
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
157
+ )
158
+
159
+ if self.num_classes is not None:
160
+ if isinstance(self.num_classes, int):
161
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
162
+ elif self.num_classes == "continuous":
163
+ self.label_emb = nn.Linear(1, time_embed_dim)
164
+ elif self.num_classes == "sequential":
165
+ assert adm_in_channels is not None
166
+ self.label_emb = nn.Sequential(
167
+ nn.Sequential(
168
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
169
+ nn.SiLU(),
170
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
171
+ )
172
+ )
173
+ else:
174
+ raise ValueError()
175
+
176
+ self.input_blocks = nn.ModuleList(
177
+ [
178
+ TimestepEmbedSequential(
179
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
180
+ )
181
+ ]
182
+ )
183
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
184
+
185
+ self.input_hint_block = TimestepEmbedSequential(
186
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
187
+ nn.SiLU(),
188
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
189
+ nn.SiLU(),
190
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
191
+ nn.SiLU(),
192
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
193
+ nn.SiLU(),
194
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
195
+ nn.SiLU(),
196
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
197
+ nn.SiLU(),
198
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
199
+ nn.SiLU(),
200
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
201
+ )
202
+
203
+ self._feature_size = model_channels
204
+ input_block_chans = [model_channels]
205
+ ch = model_channels
206
+ ds = 1
207
+ for level, mult in enumerate(channel_mult):
208
+ for nr in range(self.num_res_blocks[level]):
209
+ layers = [
210
+ ResBlock(
211
+ ch,
212
+ time_embed_dim,
213
+ dropout,
214
+ out_channels=mult * model_channels,
215
+ dims=dims,
216
+ use_checkpoint=use_checkpoint,
217
+ use_scale_shift_norm=use_scale_shift_norm,
218
+ dtype=self.dtype,
219
+ device=device,
220
+ operations=operations,
221
+ )
222
+ ]
223
+ ch = mult * model_channels
224
+ num_transformers = transformer_depth.pop(0)
225
+ if num_transformers > 0:
226
+ if num_head_channels == -1:
227
+ dim_head = ch // num_heads
228
+ else:
229
+ num_heads = ch // num_head_channels
230
+ dim_head = num_head_channels
231
+ if legacy:
232
+ #num_heads = 1
233
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
234
+ if exists(disable_self_attentions):
235
+ disabled_sa = disable_self_attentions[level]
236
+ else:
237
+ disabled_sa = False
238
+
239
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
240
+ layers.append(
241
+ SpatialTransformer(
242
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
243
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
244
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
245
+ )
246
+ )
247
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
248
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
249
+ self._feature_size += ch
250
+ input_block_chans.append(ch)
251
+ if level != len(channel_mult) - 1:
252
+ out_ch = ch
253
+ self.input_blocks.append(
254
+ TimestepEmbedSequential(
255
+ ResBlock(
256
+ ch,
257
+ time_embed_dim,
258
+ dropout,
259
+ out_channels=out_ch,
260
+ dims=dims,
261
+ use_checkpoint=use_checkpoint,
262
+ use_scale_shift_norm=use_scale_shift_norm,
263
+ down=True,
264
+ dtype=self.dtype,
265
+ device=device,
266
+ operations=operations
267
+ )
268
+ if resblock_updown
269
+ else Downsample(
270
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
271
+ )
272
+ )
273
+ )
274
+ ch = out_ch
275
+ input_block_chans.append(ch)
276
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
277
+ ds *= 2
278
+ self._feature_size += ch
279
+
280
+ if num_head_channels == -1:
281
+ dim_head = ch // num_heads
282
+ else:
283
+ num_heads = ch // num_head_channels
284
+ dim_head = num_head_channels
285
+ if legacy:
286
+ #num_heads = 1
287
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
288
+ mid_block = [
289
+ ResBlock(
290
+ ch,
291
+ time_embed_dim,
292
+ dropout,
293
+ dims=dims,
294
+ use_checkpoint=use_checkpoint,
295
+ use_scale_shift_norm=use_scale_shift_norm,
296
+ dtype=self.dtype,
297
+ device=device,
298
+ operations=operations
299
+ )]
300
+ if transformer_depth_middle >= 0:
301
+ mid_block += [SpatialTransformer( # always uses a self-attn
302
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
303
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
304
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
305
+ ),
306
+ ResBlock(
307
+ ch,
308
+ time_embed_dim,
309
+ dropout,
310
+ dims=dims,
311
+ use_checkpoint=use_checkpoint,
312
+ use_scale_shift_norm=use_scale_shift_norm,
313
+ dtype=self.dtype,
314
+ device=device,
315
+ operations=operations
316
+ )]
317
+ self.middle_block = TimestepEmbedSequential(*mid_block)
318
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
319
+ self._feature_size += ch
320
+
321
+ if union_controlnet_num_control_type is not None:
322
+ self.num_control_type = union_controlnet_num_control_type
323
+ num_trans_channel = 320
324
+ num_trans_head = 8
325
+ num_trans_layer = 1
326
+ num_proj_channel = 320
327
+ # task_scale_factor = num_trans_channel ** 0.5
328
+ self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
329
+
330
+ self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
331
+ self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
332
+ #-----------------------------------------------------------------------------------------------------
333
+
334
+ control_add_embed_dim = 256
335
+ class ControlAddEmbedding(nn.Module):
336
+ def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
337
+ super().__init__()
338
+ self.num_control_type = num_control_type
339
+ self.in_dim = in_dim
340
+ self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
341
+ self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
342
+ def forward(self, control_type, dtype, device):
343
+ c_type = torch.zeros((self.num_control_type,), device=device)
344
+ c_type[control_type] = 1.0
345
+ c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
346
+ return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
347
+
348
+ self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
349
+ else:
350
+ self.task_embedding = None
351
+ self.control_add_embedding = None
352
+
353
+ def union_controlnet_merge(self, hint, control_type, emb, context):
354
+ # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
355
+ inputs = []
356
+ condition_list = []
357
+
358
+ for idx in range(min(1, len(control_type))):
359
+ controlnet_cond = self.input_hint_block(hint[idx], emb, context)
360
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
361
+ if idx < len(control_type):
362
+ feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
363
+
364
+ inputs.append(feat_seq.unsqueeze(1))
365
+ condition_list.append(controlnet_cond)
366
+
367
+ x = torch.cat(inputs, dim=1)
368
+ x = self.transformer_layes(x)
369
+ controlnet_cond_fuser = None
370
+ for idx in range(len(control_type)):
371
+ alpha = self.spatial_ch_projs(x[:, idx])
372
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
373
+ o = condition_list[idx] + alpha
374
+ if controlnet_cond_fuser is None:
375
+ controlnet_cond_fuser = o
376
+ else:
377
+ controlnet_cond_fuser += o
378
+ return controlnet_cond_fuser
379
+
380
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
381
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
382
+
383
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
384
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
385
+ emb = self.time_embed(t_emb)
386
+
387
+ guided_hint = None
388
+ if self.control_add_embedding is not None: #Union Controlnet
389
+ control_type = kwargs.get("control_type", [])
390
+
391
+ if any([c >= self.num_control_type for c in control_type]):
392
+ max_type = max(control_type)
393
+ max_type_name = {
394
+ v: k for k, v in UNION_CONTROLNET_TYPES.items()
395
+ }[max_type]
396
+ raise ValueError(
397
+ f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
398
+ f"({self.num_control_type}) supported.\n" +
399
+ "Please consider using the ProMax ControlNet Union model.\n" +
400
+ "https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
401
+ )
402
+
403
+ emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
404
+ if len(control_type) > 0:
405
+ if len(hint.shape) < 5:
406
+ hint = hint.unsqueeze(dim=0)
407
+ guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
408
+
409
+ if guided_hint is None:
410
+ guided_hint = self.input_hint_block(hint, emb, context)
411
+
412
+ out_output = []
413
+ out_middle = []
414
+
415
+ if self.num_classes is not None:
416
+ assert y.shape[0] == x.shape[0]
417
+ emb = emb + self.label_emb(y)
418
+
419
+ h = x
420
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
421
+ if guided_hint is not None:
422
+ h = module(h, emb, context)
423
+ h += guided_hint
424
+ guided_hint = None
425
+ else:
426
+ h = module(h, emb, context)
427
+ out_output.append(zero_conv(h, emb, context))
428
+
429
+ h = self.middle_block(h, emb, context)
430
+ out_middle.append(self.middle_block_out(h, emb, context))
431
+
432
+ return {"middle": out_middle, "output": out_output}
433
+
comfy/cldm/control_types.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ UNION_CONTROLNET_TYPES = {
2
+ "openpose": 0,
3
+ "depth": 1,
4
+ "hed/pidi/scribble/ted": 2,
5
+ "canny/lineart/anime_lineart/mlsd": 3,
6
+ "normal": 4,
7
+ "segment": 5,
8
+ "tile": 6,
9
+ "repaint": 7,
10
+ }
comfy/cldm/dit_embedder.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
9
+
10
+
11
+ class ControlNetEmbedder(nn.Module):
12
+
13
+ def __init__(
14
+ self,
15
+ img_size: int,
16
+ patch_size: int,
17
+ in_chans: int,
18
+ attention_head_dim: int,
19
+ num_attention_heads: int,
20
+ adm_in_channels: int,
21
+ num_layers: int,
22
+ main_model_double: int,
23
+ double_y_emb: bool,
24
+ device: torch.device,
25
+ dtype: torch.dtype,
26
+ pos_embed_max_size: Optional[int] = None,
27
+ operations = None,
28
+ ):
29
+ super().__init__()
30
+ self.main_model_double = main_model_double
31
+ self.dtype = dtype
32
+ self.hidden_size = num_attention_heads * attention_head_dim
33
+ self.patch_size = patch_size
34
+ self.x_embedder = PatchEmbed(
35
+ img_size=img_size,
36
+ patch_size=patch_size,
37
+ in_chans=in_chans,
38
+ embed_dim=self.hidden_size,
39
+ strict_img_size=pos_embed_max_size is None,
40
+ device=device,
41
+ dtype=dtype,
42
+ operations=operations,
43
+ )
44
+
45
+ self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
46
+
47
+ self.double_y_emb = double_y_emb
48
+ if self.double_y_emb:
49
+ self.orig_y_embedder = VectorEmbedder(
50
+ adm_in_channels, self.hidden_size, dtype, device, operations=operations
51
+ )
52
+ self.y_embedder = VectorEmbedder(
53
+ self.hidden_size, self.hidden_size, dtype, device, operations=operations
54
+ )
55
+ else:
56
+ self.y_embedder = VectorEmbedder(
57
+ adm_in_channels, self.hidden_size, dtype, device, operations=operations
58
+ )
59
+
60
+ self.transformer_blocks = nn.ModuleList(
61
+ DismantledBlock(
62
+ hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
63
+ dtype=dtype, device=device, operations=operations
64
+ )
65
+ for _ in range(num_layers)
66
+ )
67
+
68
+ # self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
69
+ # TODO double check this logic when 8b
70
+ self.use_y_embedder = True
71
+
72
+ self.controlnet_blocks = nn.ModuleList([])
73
+ for _ in range(len(self.transformer_blocks)):
74
+ controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
75
+ self.controlnet_blocks.append(controlnet_block)
76
+
77
+ self.pos_embed_input = PatchEmbed(
78
+ img_size=img_size,
79
+ patch_size=patch_size,
80
+ in_chans=in_chans,
81
+ embed_dim=self.hidden_size,
82
+ strict_img_size=False,
83
+ device=device,
84
+ dtype=dtype,
85
+ operations=operations,
86
+ )
87
+
88
+ def forward(
89
+ self,
90
+ x: torch.Tensor,
91
+ timesteps: torch.Tensor,
92
+ y: Optional[torch.Tensor] = None,
93
+ context: Optional[torch.Tensor] = None,
94
+ hint = None,
95
+ ) -> Tuple[Tensor, List[Tensor]]:
96
+ x_shape = list(x.shape)
97
+ x = self.x_embedder(x)
98
+ if not self.double_y_emb:
99
+ h = (x_shape[-2] + 1) // self.patch_size
100
+ w = (x_shape[-1] + 1) // self.patch_size
101
+ x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
102
+ c = self.t_embedder(timesteps, dtype=x.dtype)
103
+ if y is not None and self.y_embedder is not None:
104
+ if self.double_y_emb:
105
+ y = self.orig_y_embedder(y)
106
+ y = self.y_embedder(y)
107
+ c = c + y
108
+
109
+ x = x + self.pos_embed_input(hint)
110
+
111
+ block_out = ()
112
+
113
+ repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
114
+ for i in range(len(self.transformer_blocks)):
115
+ out = self.transformer_blocks[i](x, c)
116
+ if not self.double_y_emb:
117
+ x = out
118
+ block_out += (self.controlnet_blocks[i](out),) * repeat
119
+
120
+ return {"output": block_out}
comfy/cldm/mmdit.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ import comfy.ldm.modules.diffusionmodules.mmdit
4
+
5
+ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
6
+ def __init__(
7
+ self,
8
+ num_blocks = None,
9
+ control_latent_channels = None,
10
+ dtype = None,
11
+ device = None,
12
+ operations = None,
13
+ **kwargs,
14
+ ):
15
+ super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
16
+ # controlnet_blocks
17
+ self.controlnet_blocks = torch.nn.ModuleList([])
18
+ for _ in range(len(self.joint_blocks)):
19
+ self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
20
+
21
+ if control_latent_channels is None:
22
+ control_latent_channels = self.in_channels
23
+
24
+ self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
25
+ None,
26
+ self.patch_size,
27
+ control_latent_channels,
28
+ self.hidden_size,
29
+ bias=True,
30
+ strict_img_size=False,
31
+ dtype=dtype,
32
+ device=device,
33
+ operations=operations
34
+ )
35
+
36
+ def forward(
37
+ self,
38
+ x: torch.Tensor,
39
+ timesteps: torch.Tensor,
40
+ y: Optional[torch.Tensor] = None,
41
+ context: Optional[torch.Tensor] = None,
42
+ hint = None,
43
+ ) -> torch.Tensor:
44
+
45
+ #weird sd3 controlnet specific stuff
46
+ y = torch.zeros_like(y)
47
+
48
+ if self.context_processor is not None:
49
+ context = self.context_processor(context)
50
+
51
+ hw = x.shape[-2:]
52
+ x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
53
+ x += self.pos_embed_input(hint)
54
+
55
+ c = self.t_embedder(timesteps, dtype=x.dtype)
56
+ if y is not None and self.y_embedder is not None:
57
+ y = self.y_embedder(y)
58
+ c = c + y
59
+
60
+ if context is not None:
61
+ context = self.context_embedder(context)
62
+
63
+ output = []
64
+
65
+ blocks = len(self.joint_blocks)
66
+ for i in range(blocks):
67
+ context, x = self.joint_blocks[i](
68
+ context,
69
+ x,
70
+ c=c,
71
+ use_checkpoint=self.use_checkpoint,
72
+ )
73
+
74
+ out = self.controlnet_blocks[i](x)
75
+ count = self.depth // blocks
76
+ if i == blocks - 1:
77
+ count -= 1
78
+ for j in range(count):
79
+ output.append(out)
80
+
81
+ return {"output": output}
comfy/cli_args.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import enum
3
+ import os
4
+ import comfy.options
5
+
6
+
7
+ class EnumAction(argparse.Action):
8
+ """
9
+ Argparse action for handling Enums
10
+ """
11
+ def __init__(self, **kwargs):
12
+ # Pop off the type value
13
+ enum_type = kwargs.pop("type", None)
14
+
15
+ # Ensure an Enum subclass is provided
16
+ if enum_type is None:
17
+ raise ValueError("type must be assigned an Enum when using EnumAction")
18
+ if not issubclass(enum_type, enum.Enum):
19
+ raise TypeError("type must be an Enum when using EnumAction")
20
+
21
+ # Generate choices from the Enum
22
+ choices = tuple(e.value for e in enum_type)
23
+ kwargs.setdefault("choices", choices)
24
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
25
+
26
+ super(EnumAction, self).__init__(**kwargs)
27
+
28
+ self._enum = enum_type
29
+
30
+ def __call__(self, parser, namespace, values, option_string=None):
31
+ # Convert value back into an Enum
32
+ value = self._enum(values)
33
+ setattr(namespace, self.dest, value)
34
+
35
+
36
+ parser = argparse.ArgumentParser()
37
+
38
+ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
39
+ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
40
+ parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
41
+ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
42
+ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
43
+ parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
44
+
45
+ parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
46
+ parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
47
+ parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
48
+ parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
49
+ parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
50
+ parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
51
+ parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
52
+ parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
53
+ cm_group = parser.add_mutually_exclusive_group()
54
+ cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
55
+ cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
56
+
57
+
58
+ fp_group = parser.add_mutually_exclusive_group()
59
+ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
60
+ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
61
+
62
+ fpunet_group = parser.add_mutually_exclusive_group()
63
+ fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
64
+ fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
65
+ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
66
+ fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
67
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
68
+ fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
69
+
70
+ fpvae_group = parser.add_mutually_exclusive_group()
71
+ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
72
+ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
73
+ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
74
+
75
+ parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
76
+
77
+ fpte_group = parser.add_mutually_exclusive_group()
78
+ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
79
+ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
80
+ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
81
+ fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
82
+
83
+ parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
84
+
85
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
86
+
87
+ parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
88
+ parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
89
+
90
+ class LatentPreviewMethod(enum.Enum):
91
+ NoPreviews = "none"
92
+ Auto = "auto"
93
+ Latent2RGB = "latent2rgb"
94
+ TAESD = "taesd"
95
+
96
+ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
97
+
98
+ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
99
+
100
+ cache_group = parser.add_mutually_exclusive_group()
101
+ cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
102
+ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
103
+
104
+ attn_group = parser.add_mutually_exclusive_group()
105
+ attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
106
+ attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
107
+ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
108
+ attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
109
+ attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
110
+
111
+ parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
112
+
113
+ upcast = parser.add_mutually_exclusive_group()
114
+ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
115
+ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
116
+
117
+
118
+ vram_group = parser.add_mutually_exclusive_group()
119
+ vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
120
+ vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
121
+ vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
122
+ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
123
+ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
124
+ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
125
+
126
+ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
127
+
128
+
129
+ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
130
+
131
+ parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
132
+ parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
133
+
134
+ class PerformanceFeature(enum.Enum):
135
+ Fp16Accumulation = "fp16_accumulation"
136
+ Fp8MatrixMultiplication = "fp8_matrix_mult"
137
+
138
+ parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
139
+
140
+ parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
141
+ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
142
+ parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
143
+
144
+ parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
145
+ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
146
+
147
+ parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
148
+
149
+ parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
150
+ parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
151
+
152
+ # The default built-in provider hosted under web/
153
+ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
154
+
155
+ parser.add_argument(
156
+ "--front-end-version",
157
+ type=str,
158
+ default=DEFAULT_VERSION_STRING,
159
+ help="""
160
+ Specifies the version of the frontend to be used. This command needs internet connectivity to query and
161
+ download available frontend implementations from GitHub releases.
162
+
163
+ The version string should be in the format of:
164
+ [repoOwner]/[repoName]@[version]
165
+ where version is one of: "latest" or a valid version number (e.g. "1.0.0")
166
+ """,
167
+ )
168
+
169
+ def is_valid_directory(path: str) -> str:
170
+ """Validate if the given path is a directory, and check permissions."""
171
+ if not os.path.exists(path):
172
+ raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
173
+ if not os.path.isdir(path):
174
+ raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
175
+ if not os.access(path, os.R_OK):
176
+ raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
177
+ return path
178
+
179
+ parser.add_argument(
180
+ "--front-end-root",
181
+ type=is_valid_directory,
182
+ default=None,
183
+ help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
184
+ )
185
+
186
+ parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
187
+
188
+ parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
189
+
190
+ if comfy.options.args_parsing:
191
+ args = parser.parse_args()
192
+ else:
193
+ args = parser.parse_args([])
194
+
195
+ if args.windows_standalone_build:
196
+ args.auto_launch = True
197
+
198
+ if args.disable_auto_launch:
199
+ args.auto_launch = False
200
+
201
+ if args.force_fp16:
202
+ args.fp16_unet = True
203
+
204
+
205
+ # '--fast' is not provided, use an empty set
206
+ if args.fast is None:
207
+ args.fast = set()
208
+ # '--fast' is provided with an empty list, enable all optimizations
209
+ elif args.fast == []:
210
+ args.fast = set(PerformanceFeature)
211
+ # '--fast' is provided with a list of performance features, use that list
212
+ else:
213
+ args.fast = set(args.fast)
comfy/clip_config_bigg.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 49407,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }
comfy/clip_model.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.ldm.modules.attention import optimized_attention_for_device
3
+ import comfy.ops
4
+
5
+ class CLIPAttention(torch.nn.Module):
6
+ def __init__(self, embed_dim, heads, dtype, device, operations):
7
+ super().__init__()
8
+
9
+ self.heads = heads
10
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
13
+
14
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
15
+
16
+ def forward(self, x, mask=None, optimized_attention=None):
17
+ q = self.q_proj(x)
18
+ k = self.k_proj(x)
19
+ v = self.v_proj(x)
20
+
21
+ out = optimized_attention(q, k, v, self.heads, mask)
22
+ return self.out_proj(out)
23
+
24
+ ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
25
+ "gelu": torch.nn.functional.gelu,
26
+ "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
27
+ }
28
+
29
+ class CLIPMLP(torch.nn.Module):
30
+ def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
31
+ super().__init__()
32
+ self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
33
+ self.activation = ACTIVATIONS[activation]
34
+ self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
35
+
36
+ def forward(self, x):
37
+ x = self.fc1(x)
38
+ x = self.activation(x)
39
+ x = self.fc2(x)
40
+ return x
41
+
42
+ class CLIPLayer(torch.nn.Module):
43
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
44
+ super().__init__()
45
+ self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
46
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
47
+ self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
48
+ self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
49
+
50
+ def forward(self, x, mask=None, optimized_attention=None):
51
+ x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
52
+ x += self.mlp(self.layer_norm2(x))
53
+ return x
54
+
55
+
56
+ class CLIPEncoder(torch.nn.Module):
57
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
58
+ super().__init__()
59
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
60
+
61
+ def forward(self, x, mask=None, intermediate_output=None):
62
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
63
+
64
+ if intermediate_output is not None:
65
+ if intermediate_output < 0:
66
+ intermediate_output = len(self.layers) + intermediate_output
67
+
68
+ intermediate = None
69
+ for i, l in enumerate(self.layers):
70
+ x = l(x, mask, optimized_attention)
71
+ if i == intermediate_output:
72
+ intermediate = x.clone()
73
+ return x, intermediate
74
+
75
+ class CLIPEmbeddings(torch.nn.Module):
76
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
77
+ super().__init__()
78
+ self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
79
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
80
+
81
+ def forward(self, input_tokens, dtype=torch.float32):
82
+ return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
83
+
84
+
85
+ class CLIPTextModel_(torch.nn.Module):
86
+ def __init__(self, config_dict, dtype, device, operations):
87
+ num_layers = config_dict["num_hidden_layers"]
88
+ embed_dim = config_dict["hidden_size"]
89
+ heads = config_dict["num_attention_heads"]
90
+ intermediate_size = config_dict["intermediate_size"]
91
+ intermediate_activation = config_dict["hidden_act"]
92
+ num_positions = config_dict["max_position_embeddings"]
93
+ self.eos_token_id = config_dict["eos_token_id"]
94
+
95
+ super().__init__()
96
+ self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
97
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
98
+ self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
99
+
100
+ def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
101
+ if embeds is not None:
102
+ x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
103
+ else:
104
+ x = self.embeddings(input_tokens, dtype=dtype)
105
+
106
+ mask = None
107
+ if attention_mask is not None:
108
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
109
+ mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
110
+
111
+ causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
112
+
113
+ if mask is not None:
114
+ mask += causal_mask
115
+ else:
116
+ mask = causal_mask
117
+
118
+ x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
119
+ x = self.final_layer_norm(x)
120
+ if i is not None and final_layer_norm_intermediate:
121
+ i = self.final_layer_norm(i)
122
+
123
+ if num_tokens is not None:
124
+ pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
125
+ else:
126
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
127
+ return x, i, pooled_output
128
+
129
+ class CLIPTextModel(torch.nn.Module):
130
+ def __init__(self, config_dict, dtype, device, operations):
131
+ super().__init__()
132
+ self.num_layers = config_dict["num_hidden_layers"]
133
+ self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
134
+ embed_dim = config_dict["hidden_size"]
135
+ self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
136
+ self.dtype = dtype
137
+
138
+ def get_input_embeddings(self):
139
+ return self.text_model.embeddings.token_embedding
140
+
141
+ def set_input_embeddings(self, embeddings):
142
+ self.text_model.embeddings.token_embedding = embeddings
143
+
144
+ def forward(self, *args, **kwargs):
145
+ x = self.text_model(*args, **kwargs)
146
+ out = self.text_projection(x[2])
147
+ return (x[0], x[1], out, x[2])
148
+
149
+
150
+ class CLIPVisionEmbeddings(torch.nn.Module):
151
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
152
+ super().__init__()
153
+
154
+ num_patches = (image_size // patch_size) ** 2
155
+ if model_type == "siglip_vision_model":
156
+ self.class_embedding = None
157
+ patch_bias = True
158
+ else:
159
+ num_patches = num_patches + 1
160
+ self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
161
+ patch_bias = False
162
+
163
+ self.patch_embedding = operations.Conv2d(
164
+ in_channels=num_channels,
165
+ out_channels=embed_dim,
166
+ kernel_size=patch_size,
167
+ stride=patch_size,
168
+ bias=patch_bias,
169
+ dtype=dtype,
170
+ device=device
171
+ )
172
+
173
+ self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
174
+
175
+ def forward(self, pixel_values):
176
+ embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
177
+ if self.class_embedding is not None:
178
+ embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
179
+ return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
180
+
181
+
182
+ class CLIPVision(torch.nn.Module):
183
+ def __init__(self, config_dict, dtype, device, operations):
184
+ super().__init__()
185
+ num_layers = config_dict["num_hidden_layers"]
186
+ embed_dim = config_dict["hidden_size"]
187
+ heads = config_dict["num_attention_heads"]
188
+ intermediate_size = config_dict["intermediate_size"]
189
+ intermediate_activation = config_dict["hidden_act"]
190
+ model_type = config_dict["model_type"]
191
+
192
+ self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
193
+ if model_type == "siglip_vision_model":
194
+ self.pre_layrnorm = lambda a: a
195
+ self.output_layernorm = True
196
+ else:
197
+ self.pre_layrnorm = operations.LayerNorm(embed_dim)
198
+ self.output_layernorm = False
199
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
200
+ self.post_layernorm = operations.LayerNorm(embed_dim)
201
+
202
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
203
+ x = self.embeddings(pixel_values)
204
+ x = self.pre_layrnorm(x)
205
+ #TODO: attention_mask?
206
+ x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
207
+ if self.output_layernorm:
208
+ x = self.post_layernorm(x)
209
+ pooled_output = x
210
+ else:
211
+ pooled_output = self.post_layernorm(x[:, 0, :])
212
+ return x, i, pooled_output
213
+
214
+ class LlavaProjector(torch.nn.Module):
215
+ def __init__(self, in_dim, out_dim, dtype, device, operations):
216
+ super().__init__()
217
+ self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
218
+ self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
219
+
220
+ def forward(self, x):
221
+ return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
222
+
223
+ class CLIPVisionModelProjection(torch.nn.Module):
224
+ def __init__(self, config_dict, dtype, device, operations):
225
+ super().__init__()
226
+ self.vision_model = CLIPVision(config_dict, dtype, device, operations)
227
+ if "projection_dim" in config_dict:
228
+ self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
229
+ else:
230
+ self.visual_projection = lambda a: a
231
+
232
+ if "llava3" == config_dict.get("projector_type", None):
233
+ self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
234
+ else:
235
+ self.multi_modal_projector = None
236
+
237
+ def forward(self, *args, **kwargs):
238
+ x = self.vision_model(*args, **kwargs)
239
+ out = self.visual_projection(x[2])
240
+ projected = None
241
+ if self.multi_modal_projector is not None:
242
+ projected = self.multi_modal_projector(x[1])
243
+
244
+ return (x[0], x[1], out, projected)