Spaces:
Runtime error
Runtime error
Commit
·
45e7a3f
0
Parent(s):
Initial commit.
Browse files- .gitignore +288 -0
- LICENSE +77 -0
- NOTICE +173 -0
- README.md +12 -0
- app.py +632 -0
- ckpts/checkpoints-download.md +74 -0
- hyimage/common/config/__init__.py +4 -0
- hyimage/common/config/base_config.py +36 -0
- hyimage/common/config/lazy.py +69 -0
- hyimage/common/constants.py +7 -0
- hyimage/common/format_prompt.py +70 -0
- hyimage/diffusion/cfg_utils.py +140 -0
- hyimage/diffusion/pipelines/__init__.py +0 -0
- hyimage/diffusion/pipelines/hunyuanimage_pipeline.py +892 -0
- hyimage/diffusion/pipelines/hunyuanimage_refiner_pipeline.py +272 -0
- hyimage/models/hunyuan/__init__.py +0 -0
- hyimage/models/hunyuan/configs/hunyuanimage_config.py +51 -0
- hyimage/models/hunyuan/modules/__init__.py +0 -0
- hyimage/models/hunyuan/modules/activation_layers.py +23 -0
- hyimage/models/hunyuan/modules/embed_layers.py +189 -0
- hyimage/models/hunyuan/modules/flash_attn_no_pad.py +125 -0
- hyimage/models/hunyuan/modules/hunyuanimage_dit.py +556 -0
- hyimage/models/hunyuan/modules/mlp_layers.py +121 -0
- hyimage/models/hunyuan/modules/models.py +367 -0
- hyimage/models/hunyuan/modules/modulate_layers.py +154 -0
- hyimage/models/hunyuan/modules/norm_layers.py +81 -0
- hyimage/models/hunyuan/modules/posemb_layers.py +286 -0
- hyimage/models/hunyuan/modules/token_refiner.py +297 -0
- hyimage/models/hunyuan/utils/__init__.py +0 -0
- hyimage/models/hunyuan/utils/helpers.py +23 -0
- hyimage/models/model_zoo.py +143 -0
- hyimage/models/reprompt/__init__.py +1 -0
- hyimage/models/reprompt/reprompt.py +108 -0
- hyimage/models/text_encoder/__init__.py +469 -0
- hyimage/models/text_encoder/byT5/__init__.py +213 -0
- hyimage/models/vae/__init__.py +29 -0
- hyimage/models/vae/hunyuanimage_vae.py +779 -0
- requirements.txt +17 -0
.gitignore
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be added to the global gitignore or merged into this project gitignore. For a PyCharm
|
158 |
+
# project, it is recommended to include the following files:
|
159 |
+
# .idea/
|
160 |
+
# *.iml
|
161 |
+
# *.ipr
|
162 |
+
# *.iws
|
163 |
+
.idea/
|
164 |
+
*.iml
|
165 |
+
*.ipr
|
166 |
+
*.iws
|
167 |
+
|
168 |
+
# VS Code
|
169 |
+
.vscode/
|
170 |
+
*.code-workspace
|
171 |
+
|
172 |
+
# Local History for Visual Studio Code
|
173 |
+
.history/
|
174 |
+
|
175 |
+
# Built Visual Studio Code Extensions
|
176 |
+
*.vsix
|
177 |
+
|
178 |
+
# macOS
|
179 |
+
.DS_Store
|
180 |
+
.AppleDouble
|
181 |
+
.LSOverride
|
182 |
+
|
183 |
+
# Icon must end with two \r
|
184 |
+
Icon
|
185 |
+
|
186 |
+
# Thumbnails
|
187 |
+
._*
|
188 |
+
|
189 |
+
# Files that might appear in the root of a volume
|
190 |
+
.DocumentRevisions-V100
|
191 |
+
.fseventsd
|
192 |
+
.Spotlight-V100
|
193 |
+
.TemporaryItems
|
194 |
+
.Trashes
|
195 |
+
.VolumeIcon.icns
|
196 |
+
.com.apple.timemachine.donotpresent
|
197 |
+
|
198 |
+
# Directories potentially created on remote AFP share
|
199 |
+
.AppleDB
|
200 |
+
.AppleDesktop
|
201 |
+
Network Trash Folder
|
202 |
+
Temporary Items
|
203 |
+
.apdisk
|
204 |
+
|
205 |
+
# Windows
|
206 |
+
# Windows thumbnail cache files
|
207 |
+
Thumbs.db
|
208 |
+
Thumbs.db:encryptable
|
209 |
+
ehthumbs.db
|
210 |
+
ehthumbs_vista.db
|
211 |
+
|
212 |
+
# Dump file
|
213 |
+
*.stackdump
|
214 |
+
|
215 |
+
# Folder config file
|
216 |
+
[Dd]esktop.ini
|
217 |
+
|
218 |
+
# Recycle Bin used on file shares
|
219 |
+
$RECYCLE.BIN/
|
220 |
+
|
221 |
+
# Windows Installer files
|
222 |
+
*.cab
|
223 |
+
*.msi
|
224 |
+
*.msix
|
225 |
+
*.msm
|
226 |
+
*.msp
|
227 |
+
|
228 |
+
# Windows shortcuts
|
229 |
+
*.lnk
|
230 |
+
|
231 |
+
# Linux
|
232 |
+
*~
|
233 |
+
|
234 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
235 |
+
.fuse_hidden*
|
236 |
+
|
237 |
+
# KDE directory preferences
|
238 |
+
.directory
|
239 |
+
|
240 |
+
# Linux trash folder which might appear on any partition or disk
|
241 |
+
.Trash-*
|
242 |
+
|
243 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
244 |
+
.nfs*
|
245 |
+
|
246 |
+
# Project specific
|
247 |
+
# Output directories
|
248 |
+
outputs/
|
249 |
+
outputs_video/
|
250 |
+
states/
|
251 |
+
exp_logs/
|
252 |
+
my_exps/
|
253 |
+
vis/
|
254 |
+
|
255 |
+
# Data and model files
|
256 |
+
data_tools/
|
257 |
+
*.pkl
|
258 |
+
*.safetensors
|
259 |
+
*.pt
|
260 |
+
*.bin
|
261 |
+
*.h5
|
262 |
+
*.hdf5
|
263 |
+
|
264 |
+
# Environment files
|
265 |
+
scripts/env.sh
|
266 |
+
|
267 |
+
# Keep specific files
|
268 |
+
!assets/*.png
|
269 |
+
|
270 |
+
# Linting and formatting
|
271 |
+
.ruff_cache/
|
272 |
+
.black/
|
273 |
+
.isort.cfg
|
274 |
+
|
275 |
+
# Temporary files
|
276 |
+
*.tmp
|
277 |
+
*.temp
|
278 |
+
*.swp
|
279 |
+
*.swo
|
280 |
+
*~
|
281 |
+
|
282 |
+
# Logs
|
283 |
+
*.log
|
284 |
+
logs/
|
285 |
+
|
286 |
+
# Cache directories
|
287 |
+
.cache/
|
288 |
+
cache/
|
LICENSE
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
2 |
+
Tencent HunyuanImage 2.1 Release Date: September 8, 2025
|
3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
4 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
5 |
+
1. DEFINITIONS.
|
6 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
7 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
8 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
9 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
10 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
11 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
12 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
13 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
14 |
+
i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
|
15 |
+
j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanImage 2.1 released at [https://github.com/Tencent-Hunyuan/HunyuanImage-2.1/blob/master/LICENSE;https://huggingface.co/tencent/HunyuanImage-2.1/blob/main/LICENSE].
|
16 |
+
k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
17 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
18 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
19 |
+
n. “including” shall mean including but not limited to.
|
20 |
+
2. GRANT OF RIGHTS.
|
21 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
22 |
+
3. DISTRIBUTION.
|
23 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
24 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
25 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
26 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
27 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
28 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
29 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
30 |
+
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
31 |
+
5. RULES OF USE.
|
32 |
+
a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
33 |
+
b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
|
34 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
35 |
+
6. INTELLECTUAL PROPERTY.
|
36 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
37 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
38 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
39 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
40 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
41 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
42 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
43 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
44 |
+
8. SURVIVAL AND TERMINATION.
|
45 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
46 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
47 |
+
9. GOVERNING LAW AND JURISDICTION.
|
48 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
49 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
50 |
+
|
51 |
+
EXHIBIT A
|
52 |
+
ACCEPTABLE USE POLICY
|
53 |
+
|
54 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
55 |
+
Last modified: November 5, 2024
|
56 |
+
|
57 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
58 |
+
1. Outside the Territory;
|
59 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
60 |
+
3. To harm Yourself or others;
|
61 |
+
4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
62 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
63 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
64 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
65 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
66 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
67 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
68 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
69 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
70 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
71 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
72 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
73 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
74 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
75 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
76 |
+
19. For military purposes;
|
77 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
NOTICE
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Usage and Legal Notices:
|
2 |
+
|
3 |
+
Tencent is pleased to support the open source community by making Tencent HunyuanImage 2.1 available.
|
4 |
+
|
5 |
+
Copyright (C) 2025 Tencent. All rights reserved. The below model in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
|
6 |
+
|
7 |
+
Tencent HunyuanImage 2.1 is licensed under Tencent Hunyuan Community License Agreement, which can be found in this repository called "LICENSE", except for the third-party components listed below. Tencent HunyuanImage 2.1 does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
8 |
+
|
9 |
+
For avoidance of doubts, Tencent HunyuanImage 2.1 means the large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Tencent in accordance with the Tencent Hunyuan Community License Agreement.
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
Other dependencies and licenses:
|
14 |
+
|
15 |
+
Open Source Software Licensed under the Apache License Version 2.0:
|
16 |
+
The below software in this distribution may have been modified by Tencent.
|
17 |
+
--------------------------------------------------------------------
|
18 |
+
1. Glyph-ByT5
|
19 |
+
Copyright (c) Glyph-ByT5 and its authors.
|
20 |
+
Please find the original component at following site: https://github.com/AIGText/Glyph-ByT5
|
21 |
+
|
22 |
+
|
23 |
+
Terms of the Apache License Version 2.0:
|
24 |
+
--------------------------------------------------------------------
|
25 |
+
Apache License
|
26 |
+
|
27 |
+
Version 2.0, January 2004
|
28 |
+
|
29 |
+
http://www.apache.org/licenses/
|
30 |
+
|
31 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
32 |
+
1. Definitions.
|
33 |
+
|
34 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
35 |
+
|
36 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
37 |
+
|
38 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
39 |
+
|
40 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
41 |
+
|
42 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
43 |
+
|
44 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
45 |
+
|
46 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
47 |
+
|
48 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
51 |
+
|
52 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
53 |
+
|
54 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
55 |
+
|
56 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
57 |
+
|
58 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
59 |
+
|
60 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
61 |
+
|
62 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
63 |
+
|
64 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
65 |
+
|
66 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
67 |
+
|
68 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
69 |
+
|
70 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
71 |
+
|
72 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
73 |
+
|
74 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
75 |
+
|
76 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
77 |
+
|
78 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
79 |
+
|
80 |
+
END OF TERMS AND CONDITIONS
|
81 |
+
|
82 |
+
--------------------------------------------------------------------
|
83 |
+
|
84 |
+
Open Source Software Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT and Other Licenses of the Third-Party Components therein:
|
85 |
+
The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
|
86 |
+
|
87 |
+
--------------------------------------------------------------------
|
88 |
+
1. HunyuanVideo
|
89 |
+
Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
90 |
+
|
91 |
+
|
92 |
+
Terms of the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT:
|
93 |
+
--------------------------------------------------------------------
|
94 |
+
TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
95 |
+
Tencent HunyuanVideo Release Date: December 3, 2024
|
96 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
97 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
98 |
+
1. DEFINITIONS.
|
99 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
100 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
101 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
102 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
103 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
104 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
105 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
106 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
107 |
+
i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
|
108 |
+
j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo released at [https://github.com/Tencent/HunyuanVideo].
|
109 |
+
k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
110 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
111 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
112 |
+
n. “including” shall mean including but not limited to.
|
113 |
+
2. GRANT OF RIGHTS.
|
114 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
115 |
+
3. DISTRIBUTION.
|
116 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
117 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
118 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
119 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
120 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
121 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
122 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
123 |
+
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
124 |
+
5. RULES OF USE.
|
125 |
+
a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
126 |
+
b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
|
127 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
128 |
+
6. INTELLECTUAL PROPERTY.
|
129 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
130 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
131 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
132 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
133 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
134 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
135 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
136 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
137 |
+
8. SURVIVAL AND TERMINATION.
|
138 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
139 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
140 |
+
9. GOVERNING LAW AND JURISDICTION.
|
141 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
142 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
143 |
+
|
144 |
+
EXHIBIT A
|
145 |
+
ACCEPTABLE USE POLICY
|
146 |
+
|
147 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
148 |
+
Last modified: November 5, 2024
|
149 |
+
|
150 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
151 |
+
1. Outside the Territory;
|
152 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
153 |
+
3. To harm Yourself or others;
|
154 |
+
4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
155 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
156 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
157 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
158 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
159 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
160 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
161 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
162 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
163 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
164 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
165 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
166 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
167 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
168 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
169 |
+
19. For military purposes;
|
170 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
171 |
+
|
172 |
+
For the license of other third party components, please refer to the following URL:
|
173 |
+
https://github.com/Tencent-Hunyuan/HunyuanVideo/blob/ff2dd59277b3177785d8279d4170968afa3b1d55/Notice
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: HunyuanImage 2.1
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.44.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
IS_SPACE = True
|
3 |
+
|
4 |
+
if IS_SPACE:
|
5 |
+
import spaces
|
6 |
+
|
7 |
+
|
8 |
+
import sys
|
9 |
+
import warnings
|
10 |
+
import subprocess
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Optional, Tuple, Dict
|
13 |
+
import torch
|
14 |
+
|
15 |
+
def space_context(duration: int):
|
16 |
+
if IS_SPACE:
|
17 |
+
return spaces.GPU(duration=duration)
|
18 |
+
return lambda x: x
|
19 |
+
|
20 |
+
@space_context(duration=120)
|
21 |
+
def test_env():
|
22 |
+
assert torch.cuda.is_available()
|
23 |
+
|
24 |
+
try:
|
25 |
+
import flash_attn
|
26 |
+
except ImportError:
|
27 |
+
print("Flash-attn not found, installing...")
|
28 |
+
os.system("pip install flash-attn==2.7.3 --no-build-isolation")
|
29 |
+
|
30 |
+
else:
|
31 |
+
print("Flash-attn found, skipping installation...")
|
32 |
+
test_env()
|
33 |
+
|
34 |
+
warnings.filterwarnings("ignore")
|
35 |
+
|
36 |
+
# Add the current directory to Python path
|
37 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
38 |
+
|
39 |
+
try:
|
40 |
+
import gradio as gr
|
41 |
+
from PIL import Image
|
42 |
+
from hyimage.diffusion.pipelines.hunyuanimage_pipeline import HunyuanImagePipeline
|
43 |
+
from huggingface_hub import snapshot_download
|
44 |
+
import modelscope
|
45 |
+
except ImportError as e:
|
46 |
+
print(f"Missing required dependencies: {e}")
|
47 |
+
print("Please install with: pip install -r requirements_gradio.txt")
|
48 |
+
print("For checkpoint downloads, also install: pip install -U 'huggingface_hub[cli]' modelscope")
|
49 |
+
sys.exit(1)
|
50 |
+
|
51 |
+
class CheckpointDownloader:
|
52 |
+
"""Handles downloading of all required checkpoints for HunyuanImage."""
|
53 |
+
|
54 |
+
def __init__(self, base_dir: str = "./ckpts"):
|
55 |
+
self.base_dir = Path(base_dir)
|
56 |
+
self.base_dir.mkdir(exist_ok=True)
|
57 |
+
|
58 |
+
# Define all required checkpoints
|
59 |
+
self.checkpoints = {
|
60 |
+
"main_model": {
|
61 |
+
"repo_id": "tencent/HunyuanImage-2.1",
|
62 |
+
"local_dir": self.base_dir,
|
63 |
+
},
|
64 |
+
"mllm_encoder": {
|
65 |
+
"repo_id": "Qwen/Qwen2.5-VL-7B-Instruct",
|
66 |
+
"local_dir": self.base_dir / "text_encoder" / "llm",
|
67 |
+
},
|
68 |
+
"byt5_encoder": {
|
69 |
+
"repo_id": "google/byt5-small",
|
70 |
+
"local_dir": self.base_dir / "text_encoder" / "byt5-small",
|
71 |
+
},
|
72 |
+
"glyph_encoder": {
|
73 |
+
"repo_id": "AI-ModelScope/Glyph-SDXL-v2",
|
74 |
+
"local_dir": self.base_dir / "text_encoder" / "Glyph-SDXL-v2",
|
75 |
+
"use_modelscope": True
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
+
def download_checkpoint(self, checkpoint_name: str, progress_callback=None) -> Tuple[bool, str]:
|
80 |
+
"""Download a specific checkpoint."""
|
81 |
+
if checkpoint_name not in self.checkpoints:
|
82 |
+
return False, f"Unknown checkpoint: {checkpoint_name}"
|
83 |
+
|
84 |
+
config = self.checkpoints[checkpoint_name]
|
85 |
+
local_dir = config["local_dir"]
|
86 |
+
local_dir.mkdir(parents=True, exist_ok=True)
|
87 |
+
|
88 |
+
try:
|
89 |
+
if config.get("use_modelscope", False):
|
90 |
+
# Use modelscope for Chinese models
|
91 |
+
return self._download_with_modelscope(config, progress_callback)
|
92 |
+
else:
|
93 |
+
# Use huggingface_hub for other models
|
94 |
+
return self._download_with_hf(config, progress_callback)
|
95 |
+
except Exception as e:
|
96 |
+
return False, f"Download failed: {str(e)}"
|
97 |
+
|
98 |
+
def _download_with_hf(self, config: Dict, progress_callback=None) -> Tuple[bool, str]:
|
99 |
+
"""Download using huggingface_hub."""
|
100 |
+
repo_id = config["repo_id"]
|
101 |
+
local_dir = config["local_dir"]
|
102 |
+
|
103 |
+
if progress_callback:
|
104 |
+
progress_callback(f"Downloading {repo_id}...")
|
105 |
+
|
106 |
+
try:
|
107 |
+
snapshot_download(
|
108 |
+
repo_id=repo_id,
|
109 |
+
local_dir=str(local_dir),
|
110 |
+
local_dir_use_symlinks=False,
|
111 |
+
resume_download=True
|
112 |
+
)
|
113 |
+
return True, f"Successfully downloaded {repo_id}"
|
114 |
+
except Exception as e:
|
115 |
+
return False, f"HF download failed: {str(e)}"
|
116 |
+
|
117 |
+
def _download_with_modelscope(self, config: Dict, progress_callback=None) -> Tuple[bool, str]:
|
118 |
+
"""Download using modelscope."""
|
119 |
+
repo_id = config["repo_id"]
|
120 |
+
local_dir = config["local_dir"]
|
121 |
+
|
122 |
+
if progress_callback:
|
123 |
+
progress_callback(f"Downloading {repo_id} via ModelScope...")
|
124 |
+
print(f"Downloading {repo_id} via ModelScope...")
|
125 |
+
|
126 |
+
try:
|
127 |
+
# Use subprocess to call modelscope CLI
|
128 |
+
cmd = [
|
129 |
+
"modelscope", "download",
|
130 |
+
"--model", repo_id,
|
131 |
+
"--local_dir", str(local_dir)
|
132 |
+
]
|
133 |
+
|
134 |
+
subprocess.run(cmd, capture_output=True, text=True, check=True)
|
135 |
+
return True, f"Successfully downloaded {repo_id} via ModelScope"
|
136 |
+
except subprocess.CalledProcessError as e:
|
137 |
+
return False, f"ModelScope download failed: {e.stderr}"
|
138 |
+
except FileNotFoundError:
|
139 |
+
return False, "ModelScope CLI not found. Install with: pip install modelscope"
|
140 |
+
|
141 |
+
def download_all_checkpoints(self, progress_callback=None) -> Tuple[bool, str, Dict[str, any]]:
|
142 |
+
"""Download all checkpoints."""
|
143 |
+
results = {}
|
144 |
+
for name, _ in self.checkpoints.items():
|
145 |
+
if progress_callback:
|
146 |
+
progress_callback(f"Starting download of {name}...")
|
147 |
+
|
148 |
+
success, message = self.download_checkpoint(name, progress_callback)
|
149 |
+
results[name] = {"success": success, "message": message}
|
150 |
+
|
151 |
+
if not success:
|
152 |
+
return False, f"Failed to download {name}: {message}", results
|
153 |
+
return True, "All checkpoints downloaded successfully", results
|
154 |
+
|
155 |
+
|
156 |
+
@space_context(duration=200)
|
157 |
+
def load_pipeline(use_distilled: bool = False, device: str = "cuda"):
|
158 |
+
"""Load the HunyuanImage pipeline (only load once, refiner and reprompt are accessed from it)."""
|
159 |
+
try:
|
160 |
+
assert not use_distilled # use_distilled is a placeholder for the future
|
161 |
+
|
162 |
+
print(f"Loading HunyuanImage pipeline (distilled={use_distilled})...")
|
163 |
+
model_name = "hunyuanimage-v2.1-distilled" if use_distilled else "hunyuanimage-v2.1"
|
164 |
+
pipeline = HunyuanImagePipeline.from_pretrained(
|
165 |
+
model_name=model_name,
|
166 |
+
device=device,
|
167 |
+
enable_dit_offloading=True,
|
168 |
+
enable_reprompt_model_offloading=True,
|
169 |
+
enable_refiner_offloading=True
|
170 |
+
)
|
171 |
+
print("✓ Pipeline loaded successfully")
|
172 |
+
return pipeline
|
173 |
+
except Exception as e:
|
174 |
+
error_msg = f"Error loading pipeline: {str(e)}"
|
175 |
+
print(f"✗ {error_msg}")
|
176 |
+
raise
|
177 |
+
|
178 |
+
|
179 |
+
if IS_SPACE:
|
180 |
+
downloader = CheckpointDownloader()
|
181 |
+
downloader.download_all_checkpoints()
|
182 |
+
|
183 |
+
pipeline = load_pipeline(use_distilled=False, device="cuda")
|
184 |
+
class HunyuanImageApp:
|
185 |
+
|
186 |
+
@space_context(duration=290)
|
187 |
+
def __init__(self, auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"):
|
188 |
+
"""Initialize the HunyuanImage Gradio app."""
|
189 |
+
global pipeline
|
190 |
+
|
191 |
+
self.pipeline = pipeline
|
192 |
+
self.current_use_distilled = None
|
193 |
+
|
194 |
+
|
195 |
+
def print_peak_memory(self):
|
196 |
+
import torch
|
197 |
+
stats = torch.cuda.memory_stats()
|
198 |
+
peak_bytes_requirement = stats["allocated_bytes.all.peak"]
|
199 |
+
print(f"Before refiner Peak memory requirement: {peak_bytes_requirement / 1024 ** 3:.2f} GB")
|
200 |
+
|
201 |
+
@space_context(duration=300)
|
202 |
+
def generate_image(self,
|
203 |
+
prompt: str,
|
204 |
+
negative_prompt: str,
|
205 |
+
width: int,
|
206 |
+
height: int,
|
207 |
+
num_inference_steps: int,
|
208 |
+
guidance_scale: float,
|
209 |
+
seed: int,
|
210 |
+
use_reprompt: bool,
|
211 |
+
use_refiner: bool,
|
212 |
+
# use_distilled: bool
|
213 |
+
) -> Tuple[Optional[Image.Image], str]:
|
214 |
+
"""Generate an image using the HunyuanImage pipeline."""
|
215 |
+
try:
|
216 |
+
|
217 |
+
if self.pipeline is None:
|
218 |
+
return None, "Pipeline not loaded. Please try again."
|
219 |
+
|
220 |
+
|
221 |
+
if hasattr(self.pipeline, '_refiner_pipeline'):
|
222 |
+
self.pipeline.refiner_pipeline.to('cpu')
|
223 |
+
self.pipeline.to('cuda')
|
224 |
+
|
225 |
+
# Generate image
|
226 |
+
image = self.pipeline(
|
227 |
+
prompt=prompt,
|
228 |
+
negative_prompt=negative_prompt,
|
229 |
+
width=width,
|
230 |
+
height=height,
|
231 |
+
num_inference_steps=num_inference_steps,
|
232 |
+
guidance_scale=guidance_scale,
|
233 |
+
seed=seed,
|
234 |
+
use_reprompt=use_reprompt,
|
235 |
+
use_refiner=use_refiner
|
236 |
+
)
|
237 |
+
self.print_peak_memory()
|
238 |
+
return image, "Image generated successfully!"
|
239 |
+
|
240 |
+
except Exception as e:
|
241 |
+
error_msg = f"Error generating image: {str(e)}"
|
242 |
+
print(f"✗ {error_msg}")
|
243 |
+
return None, error_msg
|
244 |
+
|
245 |
+
@space_context(duration=300)
|
246 |
+
def enhance_prompt(self, prompt: str, # use_distilled: bool
|
247 |
+
) -> Tuple[str, str]:
|
248 |
+
"""Enhance a prompt using the reprompt model."""
|
249 |
+
try:
|
250 |
+
# Load pipeline if needed
|
251 |
+
if self.pipeline is None:
|
252 |
+
return prompt, "Pipeline not loaded. Please try again."
|
253 |
+
|
254 |
+
self.pipeline.to('cpu')
|
255 |
+
if hasattr(self.pipeline, '_refiner_pipeline'):
|
256 |
+
self.pipeline.refiner_pipeline.to('cpu')
|
257 |
+
|
258 |
+
# Use reprompt model from the main pipeline
|
259 |
+
enhanced_prompt = self.pipeline.reprompt_model.predict(prompt)
|
260 |
+
self.print_peak_memory()
|
261 |
+
return enhanced_prompt, "Prompt enhanced successfully!"
|
262 |
+
|
263 |
+
except Exception as e:
|
264 |
+
error_msg = f"Error enhancing prompt: {str(e)}"
|
265 |
+
print(f"✗ {error_msg}")
|
266 |
+
return prompt, error_msg
|
267 |
+
|
268 |
+
@space_context(duration=300)
|
269 |
+
def refine_image(self,
|
270 |
+
image: Image.Image,
|
271 |
+
prompt: str,
|
272 |
+
negative_prompt: str,
|
273 |
+
width: int,
|
274 |
+
height: int,
|
275 |
+
num_inference_steps: int,
|
276 |
+
guidance_scale: float,
|
277 |
+
seed: int) -> Tuple[Optional[Image.Image], str]:
|
278 |
+
"""Refine an image using the refiner pipeline."""
|
279 |
+
try:
|
280 |
+
if image is None:
|
281 |
+
return None, "Please upload an image to refine."
|
282 |
+
|
283 |
+
# Resize image to target dimensions if needed
|
284 |
+
if image.size != (width, height):
|
285 |
+
image = image.resize((width, height), Image.Resampling.LANCZOS)
|
286 |
+
|
287 |
+
self.pipeline.to('cpu')
|
288 |
+
self.pipeline.refiner_pipeline.to('cuda')
|
289 |
+
|
290 |
+
# Use refiner from the main pipeline
|
291 |
+
refined_image = self.pipeline.refiner_pipeline(
|
292 |
+
image=image,
|
293 |
+
prompt=prompt,
|
294 |
+
negative_prompt=negative_prompt,
|
295 |
+
width=width,
|
296 |
+
height=height,
|
297 |
+
num_inference_steps=num_inference_steps,
|
298 |
+
guidance_scale=guidance_scale,
|
299 |
+
seed=seed
|
300 |
+
)
|
301 |
+
self.print_peak_memory()
|
302 |
+
return refined_image, "Image refined successfully!"
|
303 |
+
|
304 |
+
except Exception as e:
|
305 |
+
error_msg = f"Error refining image: {str(e)}"
|
306 |
+
print(f"✗ {error_msg}")
|
307 |
+
return None, error_msg
|
308 |
+
|
309 |
+
|
310 |
+
def download_single_checkpoint(self, checkpoint_name: str) -> Tuple[bool, str]:
|
311 |
+
"""Download a single checkpoint."""
|
312 |
+
try:
|
313 |
+
success, message = self.downloader.download_checkpoint(checkpoint_name)
|
314 |
+
return success, message
|
315 |
+
except Exception as e:
|
316 |
+
return False, f"Download error: {str(e)}"
|
317 |
+
|
318 |
+
def download_all_checkpoints(self) -> Tuple[bool, str, Dict[str, any]]:
|
319 |
+
"""Download all missing checkpoints."""
|
320 |
+
try:
|
321 |
+
success, message, results = self.downloader.download_all_checkpoints()
|
322 |
+
return success, message, results
|
323 |
+
except Exception as e:
|
324 |
+
return False, f"Download error: {str(e)}", {}
|
325 |
+
|
326 |
+
def create_interface(auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"):
|
327 |
+
"""Create the Gradio interface."""
|
328 |
+
app = HunyuanImageApp(auto_load=auto_load, use_distilled=use_distilled, device=device)
|
329 |
+
|
330 |
+
# Custom CSS for better styling
|
331 |
+
css = """
|
332 |
+
.gradio-container {
|
333 |
+
max-width: 1200px !important;
|
334 |
+
margin: auto !important;
|
335 |
+
}
|
336 |
+
.tab-nav {
|
337 |
+
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
|
338 |
+
border-radius: 10px;
|
339 |
+
padding: 10px;
|
340 |
+
margin-bottom: 20px;
|
341 |
+
}
|
342 |
+
.model-info {
|
343 |
+
background: #f8f9fa;
|
344 |
+
border: 1px solid #dee2e6;
|
345 |
+
border-radius: 8px;
|
346 |
+
padding: 15px;
|
347 |
+
margin-bottom: 20px;
|
348 |
+
}
|
349 |
+
"""
|
350 |
+
|
351 |
+
with gr.Blocks(css=css, title="HunyuanImage Pipeline", theme=gr.themes.Soft()) as demo:
|
352 |
+
gr.Markdown(
|
353 |
+
"""
|
354 |
+
# 🎨 HunyuanImage 2.1 Pipeline
|
355 |
+
**HunyuanImage-2.1: An Efficient Diffusion Model for High-Resolution (2K) Text-to-Image Generation**
|
356 |
+
|
357 |
+
This app provides three main functionalities:
|
358 |
+
1. **Text-to-Image Generation**: Generate high-quality images from text prompts
|
359 |
+
2. **Prompt Enhancement**: Improve your prompts using MLLM reprompting
|
360 |
+
3. **Image Refinement**: Enhance existing images with the refiner model (Refiner is not supported yet; coming soon.)
|
361 |
+
""",
|
362 |
+
elem_classes="model-info"
|
363 |
+
)
|
364 |
+
|
365 |
+
with gr.Tabs():
|
366 |
+
# Tab 1: Text-to-Image Generation
|
367 |
+
with gr.Tab("🖼️ Text-to-Image Generation"):
|
368 |
+
with gr.Row():
|
369 |
+
with gr.Column(scale=1):
|
370 |
+
gr.Markdown("### Generation Settings")
|
371 |
+
gr.Markdown("**Model**: HunyuanImage v2.1 (Non-distilled)")
|
372 |
+
|
373 |
+
# use_distilled = gr.Checkbox(
|
374 |
+
# label="Use Distilled Model",
|
375 |
+
# value=False,
|
376 |
+
# info="Faster generation with slightly lower quality"
|
377 |
+
# )
|
378 |
+
use_distilled = False
|
379 |
+
|
380 |
+
prompt = gr.Textbox(
|
381 |
+
label="Prompt",
|
382 |
+
placeholder="",
|
383 |
+
lines=3,
|
384 |
+
value="A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, wearing a red knitted scarf and a red beret with the word “Tencent” on it, holding a paintbrush with a focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
|
385 |
+
)
|
386 |
+
|
387 |
+
negative_prompt = gr.Textbox(
|
388 |
+
label="Negative Prompt",
|
389 |
+
placeholder="",
|
390 |
+
lines=2,
|
391 |
+
value=""
|
392 |
+
)
|
393 |
+
|
394 |
+
with gr.Row():
|
395 |
+
width = gr.Slider(
|
396 |
+
minimum=512, maximum=2048, step=64, value=2048,
|
397 |
+
label="Width", info="Image width in pixels"
|
398 |
+
)
|
399 |
+
height = gr.Slider(
|
400 |
+
minimum=512, maximum=2048, step=64, value=2048,
|
401 |
+
label="Height", info="Image height in pixels"
|
402 |
+
)
|
403 |
+
|
404 |
+
with gr.Row():
|
405 |
+
num_inference_steps = gr.Slider(
|
406 |
+
minimum=10, maximum=100, step=5, value=50,
|
407 |
+
label="Inference Steps", info="More steps = better quality, slower generation"
|
408 |
+
)
|
409 |
+
guidance_scale = gr.Slider(
|
410 |
+
minimum=1.0, maximum=10.0, step=0.1, value=3.5,
|
411 |
+
label="Guidance Scale", info="How closely to follow the prompt"
|
412 |
+
)
|
413 |
+
|
414 |
+
with gr.Row():
|
415 |
+
seed = gr.Number(
|
416 |
+
label="Seed", value=649151, precision=0,
|
417 |
+
info="Random seed for reproducibility"
|
418 |
+
)
|
419 |
+
use_reprompt = gr.Checkbox(
|
420 |
+
label="Use Reprompt", value=False,
|
421 |
+
info="Enhance prompt automatically"
|
422 |
+
)
|
423 |
+
use_refiner = gr.Checkbox(
|
424 |
+
label="Use Refiner", value=False,
|
425 |
+
info="Apply refiner after generation (Refiner is not supported yet; coming soon.)",
|
426 |
+
interactive=False
|
427 |
+
)
|
428 |
+
|
429 |
+
generate_btn = gr.Button("🎨 Generate Image", variant="primary", size="lg")
|
430 |
+
|
431 |
+
with gr.Column(scale=1):
|
432 |
+
gr.Markdown("### Generated Image")
|
433 |
+
generated_image = gr.Image(
|
434 |
+
label="Generated Image",
|
435 |
+
type="pil",
|
436 |
+
height=600
|
437 |
+
)
|
438 |
+
generation_status = gr.Textbox(
|
439 |
+
label="Status",
|
440 |
+
interactive=False,
|
441 |
+
value="Ready to generate"
|
442 |
+
)
|
443 |
+
|
444 |
+
# Tab 2: Prompt Enhancement
|
445 |
+
with gr.Tab("✨ Prompt Enhancement"):
|
446 |
+
with gr.Row():
|
447 |
+
with gr.Column(scale=1):
|
448 |
+
gr.Markdown("### Prompt Enhancement Settings")
|
449 |
+
gr.Markdown("**Model**: HunyuanImage v2.1 Reprompt Model")
|
450 |
+
|
451 |
+
# enhance_use_distilled = gr.Checkbox(
|
452 |
+
# label="Use Distilled Model",
|
453 |
+
# value=False,
|
454 |
+
# info="For loading the reprompt model"
|
455 |
+
# )
|
456 |
+
enhance_use_distilled = False
|
457 |
+
|
458 |
+
original_prompt = gr.Textbox(
|
459 |
+
label="Original Prompt",
|
460 |
+
placeholder="A cat sitting on a table",
|
461 |
+
lines=4,
|
462 |
+
value="A cat sitting on a table"
|
463 |
+
)
|
464 |
+
|
465 |
+
enhance_btn = gr.Button("✨ Enhance Prompt", variant="primary", size="lg")
|
466 |
+
|
467 |
+
with gr.Column(scale=1):
|
468 |
+
gr.Markdown("### Enhanced Prompt")
|
469 |
+
enhanced_prompt = gr.Textbox(
|
470 |
+
label="Enhanced Prompt",
|
471 |
+
lines=6,
|
472 |
+
interactive=False
|
473 |
+
)
|
474 |
+
enhancement_status = gr.Textbox(
|
475 |
+
label="Status",
|
476 |
+
interactive=False,
|
477 |
+
value="Ready to enhance"
|
478 |
+
)
|
479 |
+
|
480 |
+
# # Tab 3: Image Refinement
|
481 |
+
# with gr.Tab("🔧 Image Refinement"):
|
482 |
+
# with gr.Row():
|
483 |
+
# with gr.Column(scale=1):
|
484 |
+
# gr.Markdown("### Refinement Settings")
|
485 |
+
# gr.Markdown("**Model**: HunyuanImage v2.1 Refiner")
|
486 |
+
|
487 |
+
# input_image = gr.Image(
|
488 |
+
# label="Input Image",
|
489 |
+
# type="pil",
|
490 |
+
# height=300
|
491 |
+
# )
|
492 |
+
|
493 |
+
# refine_prompt = gr.Textbox(
|
494 |
+
# label="Refinement Prompt",
|
495 |
+
# placeholder="Make the image more detailed and high quality",
|
496 |
+
# lines=2,
|
497 |
+
# value="Make the image more detailed and high quality"
|
498 |
+
# )
|
499 |
+
|
500 |
+
# refine_negative_prompt = gr.Textbox(
|
501 |
+
# label="Negative Prompt",
|
502 |
+
# placeholder="",
|
503 |
+
# lines=2,
|
504 |
+
# value=""
|
505 |
+
# )
|
506 |
+
|
507 |
+
# with gr.Row():
|
508 |
+
# refine_width = gr.Slider(
|
509 |
+
# minimum=512, maximum=2048, step=64, value=2048,
|
510 |
+
# label="Width", info="Output width"
|
511 |
+
# )
|
512 |
+
# refine_height = gr.Slider(
|
513 |
+
# minimum=512, maximum=2048, step=64, value=2048,
|
514 |
+
# label="Height", info="Output height"
|
515 |
+
# )
|
516 |
+
|
517 |
+
# with gr.Row():
|
518 |
+
# refine_steps = gr.Slider(
|
519 |
+
# minimum=1, maximum=20, step=1, value=4,
|
520 |
+
# label="Refinement Steps", info="More steps = more refinement"
|
521 |
+
# )
|
522 |
+
# refine_guidance = gr.Slider(
|
523 |
+
# minimum=1.0, maximum=10.0, step=0.1, value=3.5,
|
524 |
+
# label="Guidance Scale", info="How strongly to follow the prompt"
|
525 |
+
# )
|
526 |
+
|
527 |
+
# refine_seed = gr.Number(
|
528 |
+
# label="Seed", value=649151, precision=0,
|
529 |
+
# info="Random seed for reproducibility"
|
530 |
+
# )
|
531 |
+
|
532 |
+
# refine_btn = gr.Button("🔧 Refine Image", variant="primary", size="lg")
|
533 |
+
|
534 |
+
# with gr.Column(scale=1):
|
535 |
+
# gr.Markdown("### Refined Image")
|
536 |
+
# refined_image = gr.Image(
|
537 |
+
# label="Refined Image",
|
538 |
+
# type="pil",
|
539 |
+
# height=600
|
540 |
+
# )
|
541 |
+
# refinement_status = gr.Textbox(
|
542 |
+
# label="Status",
|
543 |
+
# interactive=False,
|
544 |
+
# value="Ready to refine"
|
545 |
+
# )
|
546 |
+
|
547 |
+
# Event handlers
|
548 |
+
generate_btn.click(
|
549 |
+
fn=app.generate_image,
|
550 |
+
inputs=[
|
551 |
+
prompt, negative_prompt, width, height, num_inference_steps,
|
552 |
+
guidance_scale, seed, use_reprompt, use_refiner # , use_distilled
|
553 |
+
],
|
554 |
+
outputs=[generated_image, generation_status]
|
555 |
+
)
|
556 |
+
|
557 |
+
enhance_btn.click(
|
558 |
+
fn=app.enhance_prompt,
|
559 |
+
inputs=[original_prompt],
|
560 |
+
outputs=[enhanced_prompt, enhancement_status]
|
561 |
+
)
|
562 |
+
|
563 |
+
#refine_btn.click(
|
564 |
+
# fn=app.refine_image,
|
565 |
+
# inputs=[
|
566 |
+
# input_image, refine_prompt, refine_negative_prompt,
|
567 |
+
# refine_width, refine_height, refine_steps, refine_guidance, refine_seed
|
568 |
+
# ],
|
569 |
+
# outputs=[refined_image, refinement_status]
|
570 |
+
#)
|
571 |
+
|
572 |
+
# Additional info
|
573 |
+
gr.Markdown(
|
574 |
+
"""
|
575 |
+
### 📝 Usage Tips
|
576 |
+
|
577 |
+
**Text-to-Image Generation:**
|
578 |
+
- Use descriptive prompts with specific details
|
579 |
+
- Adjust guidance scale: higher values follow prompts more closely
|
580 |
+
- More inference steps generally produce better quality
|
581 |
+
- Enable reprompt for automatic prompt enhancement
|
582 |
+
- Enable refiner for additional quality improvement
|
583 |
+
|
584 |
+
**Prompt Enhancement:**
|
585 |
+
- Enter your basic prompt idea
|
586 |
+
- The AI will enhance it with better structure and details
|
587 |
+
- Enhanced prompts often produce better results
|
588 |
+
|
589 |
+
**Image Refinement:**
|
590 |
+
- Upload any image you want to improve
|
591 |
+
- Describe what improvements you want in the refinement prompt
|
592 |
+
- The refiner will enhance details and quality
|
593 |
+
- Works best with images generated by HunyuanImage
|
594 |
+
""",
|
595 |
+
elem_classes="model-info"
|
596 |
+
)
|
597 |
+
|
598 |
+
return demo
|
599 |
+
|
600 |
+
if __name__ == "__main__":
|
601 |
+
import argparse
|
602 |
+
|
603 |
+
# Parse command line arguments
|
604 |
+
parser = argparse.ArgumentParser(description="Launch HunyuanImage Gradio App")
|
605 |
+
parser.add_argument("--no-auto-load", action="store_true", help="Disable auto-loading pipeline on startup")
|
606 |
+
parser.add_argument("--use-distilled", action="store_true", help="Use distilled model")
|
607 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda/cpu)")
|
608 |
+
parser.add_argument("--port", type=int, default=8081, help="Port to run the app on")
|
609 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
610 |
+
|
611 |
+
args = parser.parse_args()
|
612 |
+
|
613 |
+
# Create and launch the interface
|
614 |
+
auto_load = not args.no_auto_load
|
615 |
+
demo = create_interface(auto_load=auto_load, use_distilled=args.use_distilled, device=args.device)
|
616 |
+
|
617 |
+
print("🚀 Starting HunyuanImage Gradio App...")
|
618 |
+
print(f"📱 The app will be available at: http://{args.host}:{args.port}")
|
619 |
+
print(f"🔧 Auto-load pipeline: {'Yes' if auto_load else 'No'}")
|
620 |
+
print(f"🎯 Model type: {'Distilled' if args.use_distilled else 'Non-distilled'}")
|
621 |
+
print(f"💻 Device: {args.device}")
|
622 |
+
print("⚠️ Make sure you have the required model checkpoints downloaded!")
|
623 |
+
|
624 |
+
demo.launch(
|
625 |
+
server_name=args.host,
|
626 |
+
# server_port=args.port,
|
627 |
+
share=False,
|
628 |
+
show_error=True,
|
629 |
+
quiet=False,
|
630 |
+
# max_threads=1, # Default: sequential processing (recommended for GPU apps)
|
631 |
+
# max_threads=4, # Enable parallel processing (requires more GPU memory)
|
632 |
+
)
|
ckpts/checkpoints-download.md
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Download the pretrained checkpoints:
|
3 |
+
|
4 |
+
First, make sure you have installed the huggingface CLI and modelscope CLI.
|
5 |
+
|
6 |
+
```bash
|
7 |
+
pip install -U "huggingface_hub[cli]"
|
8 |
+
pip install modelscope
|
9 |
+
```
|
10 |
+
|
11 |
+
|
12 |
+
### Download the pretrained DiT and VAE checkpoints:
|
13 |
+
```bash
|
14 |
+
hf download tencent/HunyuanImage-2.1 --local-dir ./ckpts
|
15 |
+
```
|
16 |
+
|
17 |
+
### Downloading TextEncoders
|
18 |
+
|
19 |
+
HunyuanImage uses an MLLM and a byT5 as text encoders.
|
20 |
+
|
21 |
+
* **MLLM**
|
22 |
+
|
23 |
+
HunyuanImage can be integrated with different MLLMs (including HunyuanMLLM and other open-source MLLM models).
|
24 |
+
|
25 |
+
At this stage, we have not yet released the latest HunyuanMLLM. We recommend the users in community to use an open-source alternative, such as Qwen2.5-VL-7B-Instruct provided by Qwen Team, which can be downloaded by the following command:
|
26 |
+
```bash
|
27 |
+
hf download Qwen/Qwen2.5-VL-7B-Instruct --local-dir ./ckpts/text_encoder/llm
|
28 |
+
```
|
29 |
+
|
30 |
+
* **ByT5 encoder**
|
31 |
+
|
32 |
+
We use [Glyph-SDXL-v2](https://modelscope.cn/models/AI-ModelScope/Glyph-SDXL-v2) as our [byT5](https://github.com/google-research/byt5) encoder, which can be downloaded by the following command:
|
33 |
+
|
34 |
+
```bash
|
35 |
+
hf download google/byt5-small --local-dir ./ckpts/text_encoder/byt5-small
|
36 |
+
modelscope download --model AI-ModelScope/Glyph-SDXL-v2 --local_dir ./ckpts/text_encoder/Glyph-SDXL-v2
|
37 |
+
```
|
38 |
+
You can also manually download the checkpoints from [here](https://modelscope.cn/models/AI-ModelScope/Glyph-SDXL-v2/files) and place them in the text_encoder folder like:
|
39 |
+
```
|
40 |
+
ckpts
|
41 |
+
├── text_encoder
|
42 |
+
│ ├── Glyph-SDXL-v2
|
43 |
+
│ │ ├── assets
|
44 |
+
│ │ │ ├── color_idx.json
|
45 |
+
│ │ │ ├── multilingual_10-lang_idx.json
|
46 |
+
│ │ │ └── ...
|
47 |
+
│ │ └── checkpoints
|
48 |
+
│ │ ├── byt5_model.pt
|
49 |
+
│ │ └── ...
|
50 |
+
│ └─ ...
|
51 |
+
└─ ...
|
52 |
+
```
|
53 |
+
|
54 |
+
<details>
|
55 |
+
|
56 |
+
<summary>💡Tips for using hf/huggingface-cli (network problem)</summary>
|
57 |
+
|
58 |
+
##### 1. Using HF-Mirror
|
59 |
+
|
60 |
+
If you encounter slow download speeds in China, you can try a mirror to speed up the download process:
|
61 |
+
|
62 |
+
```shell
|
63 |
+
HF_ENDPOINT=https://hf-mirror.com hf download tencent/HunyuanImage-2.1 --local-dir ./ckpts
|
64 |
+
```
|
65 |
+
|
66 |
+
##### 2. Resume Download
|
67 |
+
|
68 |
+
`huggingface-cli` supports resuming downloads. If the download is interrupted, you can just rerun the download
|
69 |
+
command to resume the download process.
|
70 |
+
|
71 |
+
Note: If an `No such file or directory: 'ckpts/.huggingface/.gitignore.lock'` like error occurs during the download
|
72 |
+
process, you can ignore the error and rerun the download command.
|
73 |
+
|
74 |
+
</details>
|
hyimage/common/config/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .lazy import LazyCall, instantiate, locate
|
2 |
+
from .base_config import DiTConfig, VAEConfig, TextEncoderConfig, RepromptConfig
|
3 |
+
|
4 |
+
__all__ = ["LazyCall", "instantiate", "locate", "DiTConfig", "VAEConfig", "TextEncoderConfig", "RepromptConfig"]
|
hyimage/common/config/base_config.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from hyimage.common.config.lazy import DictConfig
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class DiTConfig:
|
9 |
+
model: DictConfig
|
10 |
+
use_lora: bool = False
|
11 |
+
use_cpu_offload: bool = False
|
12 |
+
gradient_checkpointing: bool = False
|
13 |
+
load_from: Optional[str] = None
|
14 |
+
use_compile: bool = False
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class VAEConfig:
|
19 |
+
model: DictConfig
|
20 |
+
load_from: str
|
21 |
+
cpu_offload: bool = False
|
22 |
+
enable_tiling: bool = False
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class TextEncoderConfig:
|
27 |
+
model: DictConfig
|
28 |
+
load_from: str
|
29 |
+
prompt_template: Optional[str] = None
|
30 |
+
text_len: Optional[int] = None
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class RepromptConfig:
|
35 |
+
model: DictConfig
|
36 |
+
load_from: str
|
hyimage/common/config/lazy.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc as abc
|
2 |
+
import copy
|
3 |
+
import pydoc
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
|
7 |
+
class DictConfig(dict):
|
8 |
+
|
9 |
+
def __getattr__(self, item):
|
10 |
+
try:
|
11 |
+
return self[item]
|
12 |
+
except KeyError:
|
13 |
+
raise AttributeError(f"'AttrDict' object has no attribute '{item}'")
|
14 |
+
|
15 |
+
def __setattr__(self, key, value):
|
16 |
+
self[key] = value
|
17 |
+
|
18 |
+
def __delattr__(self, item):
|
19 |
+
try:
|
20 |
+
del self[item]
|
21 |
+
except KeyError:
|
22 |
+
raise AttributeError(f"'DictConfig' object has no attribute '{item}'")
|
23 |
+
|
24 |
+
|
25 |
+
def locate(name: str) -> Any:
|
26 |
+
"""
|
27 |
+
Locate and return an object using a string like {x.__module__}.{x.__qualname__}.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
name:Dotted path to the object
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
The located object
|
34 |
+
|
35 |
+
Raises:
|
36 |
+
ImportError if the object cannot be found
|
37 |
+
"""
|
38 |
+
return pydoc.locate(name)
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
class LazyObject:
|
43 |
+
|
44 |
+
def __init__(self, target, **kwargs):
|
45 |
+
self._target = target
|
46 |
+
self._kwargs = kwargs
|
47 |
+
|
48 |
+
def instantiate(self, **kwargs):
|
49 |
+
new_kwargs = copy.deepcopy(self._kwargs)
|
50 |
+
new_kwargs.update(kwargs)
|
51 |
+
return self._target(**new_kwargs)
|
52 |
+
|
53 |
+
|
54 |
+
class LazyCall:
|
55 |
+
|
56 |
+
def __init__(self, target):
|
57 |
+
if not callable(target):
|
58 |
+
raise ValueError(f"`target` of LazyCall must be a callable, got {target}")
|
59 |
+
self._target = target
|
60 |
+
|
61 |
+
def __call__(self, **kwargs):
|
62 |
+
return LazyObject(self._target, **kwargs)
|
63 |
+
|
64 |
+
|
65 |
+
def instantiate(config: LazyObject, **kwargs):
|
66 |
+
if config is None:
|
67 |
+
return None
|
68 |
+
return config.instantiate(**kwargs)
|
69 |
+
|
hyimage/common/constants.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
PRECISION_TO_TYPE = {
|
4 |
+
"fp32": torch.float32,
|
5 |
+
"fp16": torch.float16,
|
6 |
+
"bf16": torch.bfloat16,
|
7 |
+
}
|
hyimage/common/format_prompt.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
def closest_color(requested_color):
|
5 |
+
import webcolors
|
6 |
+
|
7 |
+
min_colors = {}
|
8 |
+
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
|
9 |
+
|
10 |
+
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
|
11 |
+
rd = (r_c - requested_color[0]) ** 2
|
12 |
+
gd = (g_c - requested_color[1]) ** 2
|
13 |
+
bd = (b_c - requested_color[2]) ** 2
|
14 |
+
min_colors[(rd + gd + bd)] = name
|
15 |
+
return min_colors[min(min_colors.keys())]
|
16 |
+
|
17 |
+
|
18 |
+
def convert_rgb_to_names(rgb_tuple):
|
19 |
+
try:
|
20 |
+
import webcolors
|
21 |
+
|
22 |
+
color_name = webcolors.rgb_to_name(rgb_tuple)
|
23 |
+
except ValueError:
|
24 |
+
color_name = closest_color(rgb_tuple)
|
25 |
+
return color_name
|
26 |
+
|
27 |
+
|
28 |
+
class MultilingualPromptFormat:
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
font_path: str = 'assets/glyph_sdxl_assets/multilingual_10-lang_idx.json',
|
33 |
+
color_path: str = 'assets/glyph_sdxl_assets/color_idx.json',
|
34 |
+
):
|
35 |
+
with open(font_path, 'r') as f:
|
36 |
+
self.font_dict = json.load(f)
|
37 |
+
with open(color_path, 'r') as f:
|
38 |
+
self.color_dict = json.load(f)
|
39 |
+
|
40 |
+
def format_prompt(self, texts, styles):
|
41 |
+
'''
|
42 |
+
Text "{text}" in {color}, {type}.
|
43 |
+
'''
|
44 |
+
|
45 |
+
prompt = ""
|
46 |
+
for text, style in zip(texts, styles):
|
47 |
+
text_prompt = f'Text "{text}"'
|
48 |
+
|
49 |
+
attr_list = []
|
50 |
+
|
51 |
+
# format color
|
52 |
+
if style["color"] is not None:
|
53 |
+
import webcolors
|
54 |
+
|
55 |
+
hex_color = style["color"]
|
56 |
+
rgb_color = webcolors.hex_to_rgb(hex_color)
|
57 |
+
color_name = convert_rgb_to_names(rgb_color)
|
58 |
+
attr_list.append(f"<color-{self.color_dict[color_name]}>")
|
59 |
+
|
60 |
+
# format font
|
61 |
+
if style["font-family"] is not None:
|
62 |
+
attr_list.append(f"<{style['font-family'][:2]}-font-{self.font_dict[style['font-family']]}>")
|
63 |
+
attr_suffix = ", ".join(attr_list)
|
64 |
+
text_prompt += " in " + attr_suffix
|
65 |
+
text_prompt += ". "
|
66 |
+
else:
|
67 |
+
text_prompt += ". "
|
68 |
+
|
69 |
+
prompt = prompt + text_prompt
|
70 |
+
return prompt
|
hyimage/diffusion/cfg_utils.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from typing import Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
6 |
+
r"""
|
7 |
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
8 |
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
9 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
10 |
+
|
11 |
+
Args:
|
12 |
+
noise_cfg (`torch.Tensor`):
|
13 |
+
The predicted noise tensor for the guided diffusion process.
|
14 |
+
noise_pred_text (`torch.Tensor`):
|
15 |
+
The predicted noise tensor for the text-guided diffusion process.
|
16 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
17 |
+
A rescale factor applied to the noise predictions.
|
18 |
+
Returns:
|
19 |
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
20 |
+
"""
|
21 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
22 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
23 |
+
# rescale the results from guidance (fixes overexposure)
|
24 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
25 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
26 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
27 |
+
return noise_cfg
|
28 |
+
|
29 |
+
class ClassifierFreeGuidance:
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
guidance_scale: float = 7.5,
|
33 |
+
guidance_rescale: float = 0.0,
|
34 |
+
use_original_formulation: bool = False,
|
35 |
+
start: float = 0.0,
|
36 |
+
stop: float = 1.0,
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.guidance_scale = guidance_scale
|
41 |
+
self.guidance_rescale = guidance_rescale
|
42 |
+
self.use_original_formulation = use_original_formulation
|
43 |
+
|
44 |
+
def __call__(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
45 |
+
|
46 |
+
shift = pred_cond - pred_uncond
|
47 |
+
pred = pred_cond if self.use_original_formulation else pred_uncond
|
48 |
+
pred = pred + self.guidance_scale * shift
|
49 |
+
|
50 |
+
if self.guidance_rescale > 0.0:
|
51 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
52 |
+
|
53 |
+
return pred
|
54 |
+
|
55 |
+
|
56 |
+
class MomentumBuffer:
|
57 |
+
def __init__(self, momentum: float):
|
58 |
+
self.momentum = momentum
|
59 |
+
self.running_average = 0
|
60 |
+
|
61 |
+
def update(self, update_value: torch.Tensor):
|
62 |
+
new_average = self.momentum * self.running_average
|
63 |
+
self.running_average = update_value + new_average
|
64 |
+
|
65 |
+
def normalized_guidance_apg(
|
66 |
+
pred_cond: torch.Tensor,
|
67 |
+
pred_uncond: torch.Tensor,
|
68 |
+
guidance_scale: float,
|
69 |
+
momentum_buffer: Optional[MomentumBuffer] = None,
|
70 |
+
eta: float = 1.0,
|
71 |
+
norm_threshold: float = 0.0,
|
72 |
+
use_original_formulation: bool = False,
|
73 |
+
):
|
74 |
+
diff = pred_cond - pred_uncond
|
75 |
+
dim = [-i for i in range(1, len(diff.shape))]
|
76 |
+
|
77 |
+
if momentum_buffer is not None:
|
78 |
+
momentum_buffer.update(diff)
|
79 |
+
diff = momentum_buffer.running_average
|
80 |
+
|
81 |
+
if norm_threshold > 0:
|
82 |
+
ones = torch.ones_like(diff)
|
83 |
+
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
84 |
+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
85 |
+
diff = diff * scale_factor
|
86 |
+
|
87 |
+
v0, v1 = diff.double(), pred_cond.double()
|
88 |
+
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
89 |
+
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
90 |
+
v0_orthogonal = v0 - v0_parallel
|
91 |
+
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
92 |
+
|
93 |
+
normalized_update = diff_orthogonal + eta * diff_parallel
|
94 |
+
pred = pred_cond if use_original_formulation else pred_uncond
|
95 |
+
pred = pred + guidance_scale * normalized_update
|
96 |
+
|
97 |
+
return pred
|
98 |
+
|
99 |
+
class AdaptiveProjectedGuidance:
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
guidance_scale: float = 7.5,
|
103 |
+
adaptive_projected_guidance_momentum: Optional[float] = None,
|
104 |
+
adaptive_projected_guidance_rescale: float = 15.0,
|
105 |
+
# eta: float = 1.0,
|
106 |
+
eta: float = 0.0,
|
107 |
+
guidance_rescale: float = 0.0,
|
108 |
+
use_original_formulation: bool = False,
|
109 |
+
start: float = 0.0,
|
110 |
+
stop: float = 1.0,
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
self.guidance_scale = guidance_scale
|
115 |
+
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
116 |
+
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
117 |
+
self.eta = eta
|
118 |
+
self.guidance_rescale = guidance_rescale
|
119 |
+
self.use_original_formulation = use_original_formulation
|
120 |
+
self.momentum_buffer = None
|
121 |
+
|
122 |
+
def __call__(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, step=None) -> torch.Tensor:
|
123 |
+
|
124 |
+
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
|
125 |
+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
126 |
+
|
127 |
+
pred = normalized_guidance_apg(
|
128 |
+
pred_cond,
|
129 |
+
pred_uncond,
|
130 |
+
self.guidance_scale,
|
131 |
+
self.momentum_buffer,
|
132 |
+
self.eta,
|
133 |
+
self.adaptive_projected_guidance_rescale,
|
134 |
+
self.use_original_formulation,
|
135 |
+
)
|
136 |
+
|
137 |
+
if self.guidance_rescale > 0.0:
|
138 |
+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
139 |
+
|
140 |
+
return pred
|
hyimage/diffusion/pipelines/__init__.py
ADDED
File without changes
|
hyimage/diffusion/pipelines/hunyuanimage_pipeline.py
ADDED
@@ -0,0 +1,892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
from sympy import N
|
7 |
+
from tqdm import tqdm
|
8 |
+
import loguru
|
9 |
+
import torch
|
10 |
+
from hyimage.common.config.lazy import DictConfig
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
from hyimage.common.config import instantiate
|
14 |
+
from hyimage.common.constants import PRECISION_TO_TYPE
|
15 |
+
from hyimage.common.format_prompt import MultilingualPromptFormat
|
16 |
+
from hyimage.models.text_encoder import PROMPT_TEMPLATE
|
17 |
+
from hyimage.models.model_zoo import HUNYUANIMAGE_REPROMPT
|
18 |
+
from hyimage.models.text_encoder.byT5 import load_glyph_byT5_v2
|
19 |
+
from hyimage.models.hunyuan.modules.hunyuanimage_dit import load_hunyuan_dit_state_dict
|
20 |
+
from hyimage.diffusion.cfg_utils import AdaptiveProjectedGuidance, rescale_noise_cfg
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class HunyuanImagePipelineConfig:
|
25 |
+
"""
|
26 |
+
Configuration class for HunyuanImage diffusion pipeline.
|
27 |
+
|
28 |
+
This dataclass consolidates all configuration parameters for the pipeline,
|
29 |
+
including model configurations (DiT, VAE, text encoder) and pipeline
|
30 |
+
parameters (sampling steps, guidance scale, etc.).
|
31 |
+
"""
|
32 |
+
|
33 |
+
# Model configurations
|
34 |
+
dit_config: DictConfig
|
35 |
+
vae_config: DictConfig
|
36 |
+
text_encoder_config: DictConfig
|
37 |
+
reprompt_config: DictConfig
|
38 |
+
refiner_model_name: str = "hunyuanimage-refiner"
|
39 |
+
|
40 |
+
enable_dit_offloading: bool = True
|
41 |
+
enable_reprompt_model_offloading: bool = True
|
42 |
+
enable_refiner_offloading: bool = True
|
43 |
+
|
44 |
+
cfg_mode: str = "MIX_mode_0"
|
45 |
+
guidance_rescale: float = 0.0
|
46 |
+
|
47 |
+
# Pipeline parameters
|
48 |
+
default_sampling_steps: int = 50
|
49 |
+
# Default guidance scale, will be overridden by the guidance_scale parameter in __call__
|
50 |
+
default_guidance_scale: float = 3.5
|
51 |
+
# Inference shift
|
52 |
+
shift: int = 4
|
53 |
+
torch_dtype: str = "bf16"
|
54 |
+
device: str = "cuda"
|
55 |
+
version: str = ""
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def create_default(cls, version: str = "v2.1", use_distilled: bool = False, **kwargs):
|
59 |
+
"""
|
60 |
+
Create a default configuration for specified HunyuanImage version.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
version: HunyuanImage version, only "v2.1" is supported
|
64 |
+
use_distilled: Whether to use distilled model
|
65 |
+
**kwargs: Additional configuration options
|
66 |
+
"""
|
67 |
+
if version == "v2.1":
|
68 |
+
from hyimage.models.model_zoo import (
|
69 |
+
HUNYUANIMAGE_V2_1_DIT,
|
70 |
+
HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL,
|
71 |
+
HUNYUANIMAGE_V2_1_VAE_32x,
|
72 |
+
HUNYUANIMAGE_V2_1_TEXT_ENCODER,
|
73 |
+
)
|
74 |
+
dit_config = HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL() if use_distilled else HUNYUANIMAGE_V2_1_DIT()
|
75 |
+
return cls(
|
76 |
+
dit_config=dit_config,
|
77 |
+
vae_config=HUNYUANIMAGE_V2_1_VAE_32x(),
|
78 |
+
text_encoder_config=HUNYUANIMAGE_V2_1_TEXT_ENCODER(),
|
79 |
+
reprompt_config=HUNYUANIMAGE_REPROMPT(),
|
80 |
+
version=version,
|
81 |
+
**kwargs
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
raise ValueError(f"Unsupported HunyuanImage version: {version}. Only 'v2.1' is supported")
|
85 |
+
|
86 |
+
|
87 |
+
class HunyuanImagePipeline:
|
88 |
+
"""
|
89 |
+
User-friendly pipeline for HunyuanImage text-to-image generation.
|
90 |
+
|
91 |
+
This pipeline provides a simple interface similar to diffusers library
|
92 |
+
for generating high-quality images from text prompts.
|
93 |
+
|
94 |
+
Supports HunyuanImage 2.1 version with automatic configuration.
|
95 |
+
Both default and distilled (CFG distillation) models are supported.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(
|
99 |
+
self,
|
100 |
+
config: HunyuanImagePipelineConfig,
|
101 |
+
**kwargs
|
102 |
+
):
|
103 |
+
"""
|
104 |
+
Initialize the HunyuanImage diffusion pipeline.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
config: Configuration object containing all model and pipeline settings
|
108 |
+
**kwargs: Additional configuration options
|
109 |
+
"""
|
110 |
+
self.config = config
|
111 |
+
self.default_sampling_steps = config.default_sampling_steps
|
112 |
+
self.default_guidance_scale = config.default_guidance_scale
|
113 |
+
self.shift = config.shift
|
114 |
+
self.torch_dtype = PRECISION_TO_TYPE[config.torch_dtype]
|
115 |
+
self.device = config.device
|
116 |
+
self.execution_device = config.device
|
117 |
+
|
118 |
+
self.dit = None
|
119 |
+
self.text_encoder = None
|
120 |
+
self.vae = None
|
121 |
+
self.byt5_kwargs = None
|
122 |
+
self.prompt_format = None
|
123 |
+
|
124 |
+
self.enable_dit_offloading = config.enable_dit_offloading
|
125 |
+
self.enable_reprompt_model_offloading = config.enable_reprompt_model_offloading
|
126 |
+
self.enable_refiner_offloading = config.enable_refiner_offloading
|
127 |
+
|
128 |
+
|
129 |
+
self.cfg_mode = config.cfg_mode
|
130 |
+
self.guidance_rescale = config.guidance_rescale
|
131 |
+
|
132 |
+
if self.cfg_mode == "APG_mode_0":
|
133 |
+
self.cfg_guider = AdaptiveProjectedGuidance(guidance_scale=10.0, eta=0.0,
|
134 |
+
adaptive_projected_guidance_rescale=10.0,
|
135 |
+
adaptive_projected_guidance_momentum=-0.5)
|
136 |
+
self.apg_start_step = 10
|
137 |
+
elif self.cfg_mode == "MIX_mode_0":
|
138 |
+
self.cfg_guider_ocr = AdaptiveProjectedGuidance(guidance_scale=10.0, eta=0.0,
|
139 |
+
adaptive_projected_guidance_rescale=10.0,
|
140 |
+
adaptive_projected_guidance_momentum=-0.5)
|
141 |
+
self.apg_start_step_ocr = 75
|
142 |
+
|
143 |
+
self.cfg_guider_general = AdaptiveProjectedGuidance(guidance_scale=10.0, eta=0.0,
|
144 |
+
adaptive_projected_guidance_rescale=10.0,
|
145 |
+
adaptive_projected_guidance_momentum=-0.5)
|
146 |
+
self.apg_start_step_general = 10
|
147 |
+
|
148 |
+
self.ocr_mask = []
|
149 |
+
|
150 |
+
|
151 |
+
self._load_models()
|
152 |
+
|
153 |
+
def _load_dit(self):
|
154 |
+
try:
|
155 |
+
dit_config = self.config.dit_config
|
156 |
+
self.dit = instantiate(dit_config.model)
|
157 |
+
if dit_config.load_from:
|
158 |
+
load_hunyuan_dit_state_dict(self.dit, dit_config.load_from, strict=True)
|
159 |
+
else:
|
160 |
+
raise ValueError("Must provide checkpoint path for DiT model")
|
161 |
+
self.dit = self.dit.to(self.device, dtype=self.torch_dtype)
|
162 |
+
self.dit.eval()
|
163 |
+
if getattr(dit_config, "use_compile", False):
|
164 |
+
self.dit = torch.compile(self.dit)
|
165 |
+
loguru.logger.info("✓ DiT model loaded")
|
166 |
+
except Exception as e:
|
167 |
+
raise RuntimeError(f"Error loading DiT model: {e}") from e
|
168 |
+
|
169 |
+
def _load_text_encoder(self):
|
170 |
+
try:
|
171 |
+
text_encoder_config = self.config.text_encoder_config
|
172 |
+
if not text_encoder_config.load_from:
|
173 |
+
raise ValueError("Must provide checkpoint path for text encoder")
|
174 |
+
|
175 |
+
if text_encoder_config.prompt_template is not None:
|
176 |
+
prompt_template = PROMPT_TEMPLATE[text_encoder_config.prompt_template]
|
177 |
+
crop_start = prompt_template.get("crop_start", 0)
|
178 |
+
else:
|
179 |
+
crop_start = 0
|
180 |
+
prompt_template = None
|
181 |
+
max_length = text_encoder_config.text_len + crop_start
|
182 |
+
|
183 |
+
self.text_encoder = instantiate(
|
184 |
+
text_encoder_config.model,
|
185 |
+
max_length=max_length,
|
186 |
+
text_encoder_path=os.path.join(text_encoder_config.load_from, "llm"),
|
187 |
+
prompt_template=prompt_template,
|
188 |
+
logger=None,
|
189 |
+
device=self.device,
|
190 |
+
)
|
191 |
+
loguru.logger.info("✓ HunyuanImage text encoder loaded")
|
192 |
+
except Exception as e:
|
193 |
+
raise RuntimeError(f"Error loading text encoder: {e}") from e
|
194 |
+
|
195 |
+
def _load_vae(self):
|
196 |
+
try:
|
197 |
+
vae_config = self.config.vae_config
|
198 |
+
self.vae = instantiate(
|
199 |
+
vae_config.model,
|
200 |
+
vae_path=vae_config.load_from,
|
201 |
+
)
|
202 |
+
self.vae = self.vae.to(self.device)
|
203 |
+
loguru.logger.info("✓ VAE loaded")
|
204 |
+
except Exception as e:
|
205 |
+
raise RuntimeError(f"Error loading VAE: {e}") from e
|
206 |
+
|
207 |
+
def _load_reprompt_model(self):
|
208 |
+
try:
|
209 |
+
reprompt_config = self.config.reprompt_config
|
210 |
+
self._reprompt_model = instantiate(reprompt_config.model, models_root_path=reprompt_config.load_from, enable_offloading=self.enable_reprompt_model_offloading)
|
211 |
+
loguru.logger.info("✓ Reprompt model loaded")
|
212 |
+
except Exception as e:
|
213 |
+
raise RuntimeError(f"Error loading reprompt model: {e}") from e
|
214 |
+
|
215 |
+
@property
|
216 |
+
def refiner_pipeline(self):
|
217 |
+
"""
|
218 |
+
As the refiner model is an optional component, we load it on demand.
|
219 |
+
"""
|
220 |
+
if hasattr(self, '_refiner_pipeline') and self._refiner_pipeline is not None:
|
221 |
+
return self._refiner_pipeline
|
222 |
+
from hyimage.diffusion.pipelines.hunyuanimage_refiner_pipeline import HunYuanImageRefinerPipeline
|
223 |
+
self._refiner_pipeline = HunYuanImageRefinerPipeline.from_pretrained(self.config.refiner_model_name)
|
224 |
+
return self._refiner_pipeline
|
225 |
+
|
226 |
+
@property
|
227 |
+
def reprompt_model(self):
|
228 |
+
"""
|
229 |
+
As the reprompt model is an optional component, we load it on demand.
|
230 |
+
"""
|
231 |
+
if hasattr(self, '_reprompt_model') and self._reprompt_model is not None:
|
232 |
+
return self._reprompt_model
|
233 |
+
self._load_reprompt_model()
|
234 |
+
return self._reprompt_model
|
235 |
+
|
236 |
+
def _load_byt5(self):
|
237 |
+
|
238 |
+
assert self.dit is not None, "DiT model must be loaded before byT5"
|
239 |
+
|
240 |
+
if not self.use_byt5:
|
241 |
+
self.byt5_kwargs = None
|
242 |
+
self.prompt_format = None
|
243 |
+
return
|
244 |
+
|
245 |
+
try:
|
246 |
+
|
247 |
+
text_encoder_config = self.config.text_encoder_config
|
248 |
+
|
249 |
+
glyph_root = os.path.join(self.config.text_encoder_config.load_from, "Glyph-SDXL-v2")
|
250 |
+
if not os.path.exists(glyph_root):
|
251 |
+
raise RuntimeError(
|
252 |
+
f"Glyph checkpoint not found from '{glyph_root}'. \n"
|
253 |
+
"Please download from https://modelscope.cn/models/AI-ModelScope/Glyph-SDXL-v2/files.\n\n"
|
254 |
+
"- Required files:\n"
|
255 |
+
" Glyph-SDXL-v2\n"
|
256 |
+
" ├── assets\n"
|
257 |
+
" │ ├── color_idx.json\n"
|
258 |
+
" │ └── multilingual_10-lang_idx.json\n"
|
259 |
+
" └── checkpoints\n"
|
260 |
+
" └── byt5_model.pt\n"
|
261 |
+
)
|
262 |
+
|
263 |
+
|
264 |
+
byT5_google_path = os.path.join(text_encoder_config.load_from, "byt5-small")
|
265 |
+
if not os.path.exists(byT5_google_path):
|
266 |
+
loguru.logger.warning(f"ByT5 google path not found from: {byT5_google_path}. Try downloading from https://huggingface.co/google/byt5-small.")
|
267 |
+
byT5_google_path = "google/byt5-small"
|
268 |
+
|
269 |
+
|
270 |
+
multilingual_prompt_format_color_path = os.path.join(glyph_root, "assets/color_idx.json")
|
271 |
+
multilingual_prompt_format_font_path = os.path.join(glyph_root, "assets/multilingual_10-lang_idx.json")
|
272 |
+
|
273 |
+
byt5_args = dict(
|
274 |
+
byT5_google_path=byT5_google_path,
|
275 |
+
byT5_ckpt_path=os.path.join(glyph_root, "checkpoints/byt5_model.pt"),
|
276 |
+
multilingual_prompt_format_color_path=multilingual_prompt_format_color_path,
|
277 |
+
multilingual_prompt_format_font_path=multilingual_prompt_format_font_path,
|
278 |
+
byt5_max_length=128
|
279 |
+
)
|
280 |
+
|
281 |
+
self.byt5_kwargs = load_glyph_byT5_v2(byt5_args, device=self.device)
|
282 |
+
self.prompt_format = MultilingualPromptFormat(
|
283 |
+
font_path=multilingual_prompt_format_font_path,
|
284 |
+
color_path=multilingual_prompt_format_color_path
|
285 |
+
)
|
286 |
+
loguru.logger.info("✓ byT5 glyph processor loaded")
|
287 |
+
except Exception as e:
|
288 |
+
raise RuntimeError("Error loading byT5 glyph processor") from e
|
289 |
+
|
290 |
+
def _load_models(self):
|
291 |
+
"""
|
292 |
+
Load all model components.
|
293 |
+
"""
|
294 |
+
loguru.logger.info("Loading HunyuanImage models...")
|
295 |
+
self._load_vae()
|
296 |
+
self._load_dit()
|
297 |
+
self._load_byt5()
|
298 |
+
self._load_text_encoder()
|
299 |
+
|
300 |
+
|
301 |
+
def _encode_text(self, prompt: str, data_type: str = "image"):
|
302 |
+
"""
|
303 |
+
Encode text prompt to embeddings.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
prompt: The text prompt
|
307 |
+
data_type: The type of data ("image" by default)
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
Tuple of (text_emb, text_mask)
|
311 |
+
"""
|
312 |
+
text_inputs = self.text_encoder.text2tokens(prompt)
|
313 |
+
with torch.no_grad():
|
314 |
+
text_outputs = self.text_encoder.encode(
|
315 |
+
text_inputs,
|
316 |
+
data_type=data_type,
|
317 |
+
)
|
318 |
+
text_emb = text_outputs.hidden_state
|
319 |
+
text_mask = text_outputs.attention_mask
|
320 |
+
return text_emb, text_mask
|
321 |
+
|
322 |
+
def _encode_glyph(self, prompt: str):
|
323 |
+
"""
|
324 |
+
Encode glyph information using byT5.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
prompt: The text prompt
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
Tuple of (byt5_emb, byt5_mask)
|
331 |
+
"""
|
332 |
+
if not self.use_byt5:
|
333 |
+
return None, None
|
334 |
+
|
335 |
+
if not prompt:
|
336 |
+
return (
|
337 |
+
torch.zeros((1, self.byt5_kwargs["byt5_max_length"], 1472), device=self.device),
|
338 |
+
torch.zeros((1, self.byt5_kwargs["byt5_max_length"]), device=self.device, dtype=torch.int64)
|
339 |
+
)
|
340 |
+
|
341 |
+
try:
|
342 |
+
text_prompt_texts = []
|
343 |
+
pattern_quote_single = r'\'(.*?)\''
|
344 |
+
pattern_quote_double = r'\"(.*?)\"'
|
345 |
+
pattern_quote_chinese_single = r'‘(.*?)’'
|
346 |
+
pattern_quote_chinese_double = r'“(.*?)”'
|
347 |
+
|
348 |
+
matches_quote_single = re.findall(pattern_quote_single, prompt)
|
349 |
+
matches_quote_double = re.findall(pattern_quote_double, prompt)
|
350 |
+
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
|
351 |
+
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
|
352 |
+
|
353 |
+
text_prompt_texts.extend(matches_quote_single)
|
354 |
+
text_prompt_texts.extend(matches_quote_double)
|
355 |
+
text_prompt_texts.extend(matches_quote_chinese_single)
|
356 |
+
text_prompt_texts.extend(matches_quote_chinese_double)
|
357 |
+
|
358 |
+
if not text_prompt_texts:
|
359 |
+
self.ocr_mask = [False]
|
360 |
+
return (
|
361 |
+
torch.zeros((1, self.byt5_kwargs["byt5_max_length"], 1472), device=self.device),
|
362 |
+
torch.zeros((1, self.byt5_kwargs["byt5_max_length"]), device=self.device, dtype=torch.int64)
|
363 |
+
)
|
364 |
+
self.ocr_mask = [True]
|
365 |
+
|
366 |
+
text_prompt_style_list = [{'color': None, 'font-family': None} for _ in range(len(text_prompt_texts))]
|
367 |
+
glyph_text_formatted = self.prompt_format.format_prompt(text_prompt_texts, text_prompt_style_list)
|
368 |
+
|
369 |
+
byt5_text_ids, byt5_text_mask = self._get_byt5_text_tokens(
|
370 |
+
self.byt5_kwargs["byt5_tokenizer"],
|
371 |
+
self.byt5_kwargs["byt5_max_length"],
|
372 |
+
glyph_text_formatted
|
373 |
+
)
|
374 |
+
|
375 |
+
byt5_text_ids = byt5_text_ids.to(device=self.device)
|
376 |
+
byt5_text_mask = byt5_text_mask.to(device=self.device)
|
377 |
+
|
378 |
+
byt5_prompt_embeds = self.byt5_kwargs["byt5_model"](
|
379 |
+
byt5_text_ids, attention_mask=byt5_text_mask.float()
|
380 |
+
)
|
381 |
+
byt5_emb = byt5_prompt_embeds[0]
|
382 |
+
|
383 |
+
return byt5_emb, byt5_text_mask
|
384 |
+
except Exception as e:
|
385 |
+
loguru.logger.warning(f"Warning: Error in glyph encoding, using fallback: {e}")
|
386 |
+
return (
|
387 |
+
torch.zeros((1, self.byt5_kwargs["byt5_max_length"], 1472), device=self.device),
|
388 |
+
torch.zeros((1, self.byt5_kwargs["byt5_max_length"]), device=self.device, dtype=torch.int64)
|
389 |
+
)
|
390 |
+
|
391 |
+
def _get_byt5_text_tokens(self, tokenizer, max_length, text_list):
|
392 |
+
"""
|
393 |
+
Get byT5 text tokens.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
tokenizer: The tokenizer object
|
397 |
+
max_length: Maximum token length
|
398 |
+
text_list: List or string of text
|
399 |
+
|
400 |
+
Returns:
|
401 |
+
Tuple of (byt5_text_ids, byt5_text_mask)
|
402 |
+
"""
|
403 |
+
if isinstance(text_list, list):
|
404 |
+
text_prompt = " ".join(text_list)
|
405 |
+
else:
|
406 |
+
text_prompt = text_list
|
407 |
+
|
408 |
+
byt5_text_inputs = tokenizer(
|
409 |
+
text_prompt,
|
410 |
+
padding="max_length",
|
411 |
+
max_length=max_length,
|
412 |
+
truncation=True,
|
413 |
+
add_special_tokens=True,
|
414 |
+
return_tensors="pt",
|
415 |
+
)
|
416 |
+
|
417 |
+
byt5_text_ids = byt5_text_inputs.input_ids
|
418 |
+
byt5_text_mask = byt5_text_inputs.attention_mask
|
419 |
+
|
420 |
+
return byt5_text_ids, byt5_text_mask
|
421 |
+
|
422 |
+
def _prepare_latents(self, width: int, height: int, generator: torch.Generator, batch_size: int = 1):
|
423 |
+
"""
|
424 |
+
Prepare initial noise latents.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
width: Image width
|
428 |
+
height: Image height
|
429 |
+
generator: Torch random generator
|
430 |
+
batch_size: Batch size
|
431 |
+
|
432 |
+
Returns:
|
433 |
+
Latent tensor
|
434 |
+
"""
|
435 |
+
vae_downsampling_factor = 32
|
436 |
+
assert width % vae_downsampling_factor == 0 and height % vae_downsampling_factor == 0, (
|
437 |
+
f"width and height must be divisible by {vae_downsampling_factor}, but got {width} and {height}"
|
438 |
+
)
|
439 |
+
latent_width = width // vae_downsampling_factor
|
440 |
+
latent_height = height // vae_downsampling_factor
|
441 |
+
latent_channels = 64
|
442 |
+
|
443 |
+
if len(self.dit.patch_size) == 3:
|
444 |
+
latent_shape = (batch_size, latent_channels, 1, latent_height, latent_width)
|
445 |
+
elif len(self.dit.patch_size) == 2:
|
446 |
+
latent_shape = (batch_size, latent_channels, latent_height, latent_width)
|
447 |
+
else:
|
448 |
+
raise ValueError(f"Unsupported patch_size: {self.dit.patch_size}")
|
449 |
+
|
450 |
+
|
451 |
+
# Generate random noise with shape latent_shape
|
452 |
+
latents = torch.randn(
|
453 |
+
latent_shape,
|
454 |
+
device=generator.device,
|
455 |
+
dtype=self.torch_dtype,
|
456 |
+
generator=generator,
|
457 |
+
).to(device=self.device)
|
458 |
+
|
459 |
+
return latents
|
460 |
+
|
461 |
+
def _denoise_step(self, latents, timesteps, text_emb, text_mask, byt5_emb, byt5_mask, guidance_scale: float = 1.0, timesteps_r=None):
|
462 |
+
"""
|
463 |
+
Perform one denoising step.
|
464 |
+
|
465 |
+
Args:
|
466 |
+
latents: Latent tensor
|
467 |
+
timesteps: Timesteps tensor
|
468 |
+
text_emb: Text embedding
|
469 |
+
text_mask: Text mask
|
470 |
+
byt5_emb: byT5 embedding
|
471 |
+
byt5_mask: byT5 mask
|
472 |
+
guidance_scale: Guidance scale
|
473 |
+
timesteps_r: Optional next timestep
|
474 |
+
|
475 |
+
Returns:
|
476 |
+
Noise prediction tensor
|
477 |
+
"""
|
478 |
+
if byt5_emb is not None and byt5_mask is not None:
|
479 |
+
extra_kwargs = {
|
480 |
+
"byt5_text_states": byt5_emb,
|
481 |
+
"byt5_text_mask": byt5_mask,
|
482 |
+
}
|
483 |
+
else:
|
484 |
+
if self.use_byt5:
|
485 |
+
raise ValueError("Must provide byt5_emb and byt5_mask for HunyuanImage 2.1")
|
486 |
+
extra_kwargs = {}
|
487 |
+
|
488 |
+
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
489 |
+
if hasattr(self.dit, 'guidance_embed') and self.dit.guidance_embed:
|
490 |
+
guidance_expand = torch.tensor(
|
491 |
+
[guidance_scale] * latents.shape[0],
|
492 |
+
dtype=torch.float32,
|
493 |
+
device=latents.device
|
494 |
+
).to(latents.dtype) * 1000
|
495 |
+
else:
|
496 |
+
guidance_expand = None
|
497 |
+
|
498 |
+
noise_pred = self.dit(
|
499 |
+
latents,
|
500 |
+
timesteps,
|
501 |
+
text_states=text_emb,
|
502 |
+
encoder_attention_mask=text_mask,
|
503 |
+
guidance=guidance_expand,
|
504 |
+
return_dict=False,
|
505 |
+
extra_kwargs=extra_kwargs,
|
506 |
+
timesteps_r=timesteps_r,
|
507 |
+
)[0]
|
508 |
+
|
509 |
+
return noise_pred
|
510 |
+
|
511 |
+
def _apply_classifier_free_guidance(self, noise_pred, guidance_scale: float, i: int):
|
512 |
+
"""
|
513 |
+
Apply classifier-free guidance.
|
514 |
+
|
515 |
+
Args:
|
516 |
+
noise_pred: Noise prediction tensor
|
517 |
+
guidance_scale: Guidance scale
|
518 |
+
|
519 |
+
Returns:
|
520 |
+
Guided noise prediction tensor
|
521 |
+
"""
|
522 |
+
if guidance_scale == 1.0:
|
523 |
+
return noise_pred
|
524 |
+
|
525 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
526 |
+
|
527 |
+
|
528 |
+
if self.cfg_mode.startswith("APG_mode_"):
|
529 |
+
if i <= self.apg_start_step:
|
530 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
531 |
+
noise_pred_text - noise_pred_uncond
|
532 |
+
)
|
533 |
+
_ = self.cfg_guider(noise_pred_text, noise_pred_uncond, step=i)
|
534 |
+
else:
|
535 |
+
noise_pred = self.cfg_guider(noise_pred_text, noise_pred_uncond, step=i)
|
536 |
+
elif self.cfg_mode.startswith("MIX_mode_"):
|
537 |
+
|
538 |
+
ocr_mask_bool = torch.tensor(self.ocr_mask, dtype=torch.bool)
|
539 |
+
|
540 |
+
true_idx = torch.where(ocr_mask_bool)[0]
|
541 |
+
false_idx = torch.where(~ocr_mask_bool)[0]
|
542 |
+
|
543 |
+
noise_pred_text_true = noise_pred_text[true_idx] if len(true_idx) > 0 else \
|
544 |
+
torch.empty((0, noise_pred_text.size(1)), dtype=noise_pred_text.dtype, device=noise_pred_text.device)
|
545 |
+
noise_pred_text_false = noise_pred_text[false_idx] if len(false_idx) > 0 else \
|
546 |
+
torch.empty((0, noise_pred_text.size(1)), dtype=noise_pred_text.dtype, device=noise_pred_text.device)
|
547 |
+
|
548 |
+
noise_pred_uncond_true = noise_pred_uncond[true_idx] if len(true_idx) > 0 else \
|
549 |
+
torch.empty((0, noise_pred_uncond.size(1)), dtype=noise_pred_uncond.dtype, device=noise_pred_uncond.device)
|
550 |
+
noise_pred_uncond_false = noise_pred_uncond[false_idx] if len(false_idx) > 0 else \
|
551 |
+
torch.empty((0, noise_pred_uncond.size(1)), dtype=noise_pred_uncond.dtype, device=noise_pred_uncond.device)
|
552 |
+
|
553 |
+
if len(noise_pred_text_true) > 0:
|
554 |
+
if i <= self.apg_start_step_ocr:
|
555 |
+
noise_pred_true = noise_pred_uncond_true + guidance_scale * (
|
556 |
+
noise_pred_text_true - noise_pred_uncond_true
|
557 |
+
)
|
558 |
+
_ = self.cfg_guider_ocr(noise_pred_text_true, noise_pred_uncond_true, step=i)
|
559 |
+
else:
|
560 |
+
noise_pred_true = self.cfg_guider_ocr(noise_pred_text_true, noise_pred_uncond_true, step=i)
|
561 |
+
else:
|
562 |
+
noise_pred_true = noise_pred_text_true
|
563 |
+
|
564 |
+
if len(noise_pred_text_false) > 0:
|
565 |
+
if i <= self.apg_start_step_general:
|
566 |
+
noise_pred_false = noise_pred_uncond_false + guidance_scale * (
|
567 |
+
noise_pred_text_false - noise_pred_uncond_false
|
568 |
+
)
|
569 |
+
_ = self.cfg_guider_general(noise_pred_text_false, noise_pred_uncond_false, step=i)
|
570 |
+
else:
|
571 |
+
noise_pred_false = self.cfg_guider_general(noise_pred_text_false, noise_pred_uncond_false, step=i)
|
572 |
+
else:
|
573 |
+
noise_pred_false = noise_pred_text_false
|
574 |
+
|
575 |
+
noise_pred = torch.empty_like(noise_pred_text)
|
576 |
+
if len(true_idx) > 0:
|
577 |
+
noise_pred[true_idx] = noise_pred_true
|
578 |
+
if len(false_idx) > 0:
|
579 |
+
noise_pred[false_idx] = noise_pred_false
|
580 |
+
|
581 |
+
else:
|
582 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
583 |
+
noise_pred_text - noise_pred_uncond
|
584 |
+
)
|
585 |
+
|
586 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
587 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
588 |
+
noise_pred = rescale_noise_cfg(
|
589 |
+
noise_pred,
|
590 |
+
noise_pred_text,
|
591 |
+
guidance_rescale=self.guidance_rescale,
|
592 |
+
)
|
593 |
+
|
594 |
+
|
595 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
596 |
+
return noise_pred
|
597 |
+
|
598 |
+
def _decode_latents(self, latents):
|
599 |
+
"""
|
600 |
+
Decode latents to images using VAE.
|
601 |
+
|
602 |
+
Args:
|
603 |
+
latents: Latent tensor
|
604 |
+
|
605 |
+
Returns:
|
606 |
+
Image tensor
|
607 |
+
"""
|
608 |
+
if hasattr(self.vae.config, "shift_factor") and self.vae.config.shift_factor:
|
609 |
+
latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
610 |
+
else:
|
611 |
+
latents = latents / self.vae.config.scaling_factor
|
612 |
+
|
613 |
+
if latents.ndim == 5:
|
614 |
+
latents = latents.squeeze(2)
|
615 |
+
if latents.ndim == 4:
|
616 |
+
latents = latents.unsqueeze(2)
|
617 |
+
|
618 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
|
619 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
620 |
+
|
621 |
+
# Post-process image - remove frame dimension and normalize
|
622 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
623 |
+
image = image[:, :, 0] # Remove frame dimension for images
|
624 |
+
image = image.cpu().float()
|
625 |
+
|
626 |
+
return image
|
627 |
+
|
628 |
+
def get_timesteps_sigmas(self, sampling_steps: int, shift):
|
629 |
+
sigmas = torch.linspace(1, 0, sampling_steps + 1)
|
630 |
+
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
|
631 |
+
sigmas = sigmas.to(torch.float32)
|
632 |
+
timesteps = (sigmas[:-1] * 1000).to(dtype=torch.float32, device=self.device)
|
633 |
+
return timesteps, sigmas
|
634 |
+
|
635 |
+
def step(self, latents, noise_pred, sigmas, step_i):
|
636 |
+
return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float()
|
637 |
+
|
638 |
+
@torch.no_grad()
|
639 |
+
def __call__(
|
640 |
+
self,
|
641 |
+
prompt: str,
|
642 |
+
shift: int = 4,
|
643 |
+
negative_prompt: str = "",
|
644 |
+
width: int = 2048,
|
645 |
+
height: int = 2048,
|
646 |
+
use_reprompt: bool = False,
|
647 |
+
use_refiner: bool = False,
|
648 |
+
num_inference_steps: Optional[int] = None,
|
649 |
+
guidance_scale: Optional[float] = None,
|
650 |
+
seed: Optional[int] = 42,
|
651 |
+
**kwargs
|
652 |
+
) -> Image.Image:
|
653 |
+
"""
|
654 |
+
Generate an image from a text prompt.
|
655 |
+
|
656 |
+
Args:
|
657 |
+
prompt: Text prompt describing the image
|
658 |
+
negative_prompt: Negative prompt for guidance
|
659 |
+
width: Image width
|
660 |
+
height: Image height
|
661 |
+
use_reprompt: Whether to use reprompt model
|
662 |
+
use_refiner: Whether to use refiner pipeline
|
663 |
+
num_inference_steps: Number of denoising steps (overrides config if provided)
|
664 |
+
guidance_scale: Strength of classifier-free guidance (overrides config if provided)
|
665 |
+
seed: Random seed for reproducibility
|
666 |
+
**kwargs: Additional arguments
|
667 |
+
|
668 |
+
Returns:
|
669 |
+
Generated PIL Image
|
670 |
+
"""
|
671 |
+
if seed is not None:
|
672 |
+
generator = torch.Generator(device='cpu').manual_seed(seed)
|
673 |
+
else:
|
674 |
+
generator = None
|
675 |
+
|
676 |
+
sampling_steps = num_inference_steps if num_inference_steps is not None else self.default_sampling_steps
|
677 |
+
guidance_scale = guidance_scale if guidance_scale is not None else self.default_guidance_scale
|
678 |
+
shift = shift if shift is not None else self.shift
|
679 |
+
|
680 |
+
user_prompt = prompt
|
681 |
+
if use_reprompt:
|
682 |
+
if self.enable_dit_offloading:
|
683 |
+
self.to('cpu')
|
684 |
+
prompt = self.reprompt_model.predict(prompt)
|
685 |
+
if self.enable_dit_offloading:
|
686 |
+
self.to(self.execution_device)
|
687 |
+
|
688 |
+
print("=" * 60)
|
689 |
+
print("🖼️ HunyuanImage Generation Task")
|
690 |
+
print("-" * 60)
|
691 |
+
print(f"Prompt: {user_prompt}")
|
692 |
+
if use_reprompt:
|
693 |
+
print(f"Reprompt: {prompt}")
|
694 |
+
if not self.cfg_distilled:
|
695 |
+
print(f"Negative Prompt: {negative_prompt if negative_prompt else '(none)'}")
|
696 |
+
print(f"Guidance Scale: {guidance_scale}")
|
697 |
+
print(f"CFG Mode: {self.cfg_mode}")
|
698 |
+
print(f"Guidance Rescale: {self.guidance_rescale}")
|
699 |
+
print(f"Shift: {self.shift}")
|
700 |
+
print(f"Seed: {seed}")
|
701 |
+
print(f"Use MeanFlow: {self.use_meanflow}")
|
702 |
+
print(f"Use byT5: {self.use_byt5}")
|
703 |
+
print(f"Image Size: {width} x {height}")
|
704 |
+
print(f"Sampling Steps: {sampling_steps}")
|
705 |
+
print("=" * 60)
|
706 |
+
|
707 |
+
pos_text_emb, pos_text_mask = self._encode_text(prompt)
|
708 |
+
neg_text_emb, neg_text_mask = self._encode_text(negative_prompt)
|
709 |
+
|
710 |
+
pos_byt5_emb, pos_byt5_mask = self._encode_glyph(prompt)
|
711 |
+
neg_byt5_emb, neg_byt5_mask = self._encode_glyph(negative_prompt)
|
712 |
+
|
713 |
+
latents = self._prepare_latents(width, height, generator=generator)
|
714 |
+
|
715 |
+
do_classifier_free_guidance = (not self.cfg_distilled) and guidance_scale > 1
|
716 |
+
if do_classifier_free_guidance:
|
717 |
+
text_emb = torch.cat([neg_text_emb, pos_text_emb])
|
718 |
+
text_mask = torch.cat([neg_text_mask, pos_text_mask])
|
719 |
+
|
720 |
+
if self.use_byt5 and pos_byt5_emb is not None and neg_byt5_emb is not None:
|
721 |
+
byt5_emb = torch.cat([neg_byt5_emb, pos_byt5_emb])
|
722 |
+
byt5_mask = torch.cat([neg_byt5_mask, pos_byt5_mask])
|
723 |
+
else:
|
724 |
+
byt5_emb = pos_byt5_emb
|
725 |
+
byt5_mask = pos_byt5_mask
|
726 |
+
else:
|
727 |
+
text_emb = pos_text_emb
|
728 |
+
text_mask = pos_text_mask
|
729 |
+
byt5_emb = pos_byt5_emb
|
730 |
+
byt5_mask = pos_byt5_mask
|
731 |
+
|
732 |
+
timesteps, sigmas = self.get_timesteps_sigmas(sampling_steps, shift)
|
733 |
+
|
734 |
+
for i, t in enumerate(tqdm(timesteps, desc="Denoising", total=len(timesteps))):
|
735 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
736 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
737 |
+
if self.use_meanflow:
|
738 |
+
if i == len(timesteps) - 1:
|
739 |
+
timesteps_r = torch.tensor([0.0], device=self.device)
|
740 |
+
else:
|
741 |
+
timesteps_r = timesteps[i + 1]
|
742 |
+
timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
|
743 |
+
else:
|
744 |
+
timesteps_r = None
|
745 |
+
|
746 |
+
if self.cfg_distilled:
|
747 |
+
noise_pred = self._denoise_step(
|
748 |
+
latent_model_input, t_expand, text_emb, text_mask, byt5_emb, byt5_mask, guidance_scale, timesteps_r=timesteps_r,
|
749 |
+
)
|
750 |
+
else:
|
751 |
+
noise_pred = self._denoise_step(
|
752 |
+
latent_model_input, t_expand, text_emb, text_mask, byt5_emb, byt5_mask, timesteps_r=timesteps_r,
|
753 |
+
)
|
754 |
+
|
755 |
+
if do_classifier_free_guidance:
|
756 |
+
noise_pred = self._apply_classifier_free_guidance(noise_pred, guidance_scale, i)
|
757 |
+
|
758 |
+
latents = self.step(latents, noise_pred, sigmas, i)
|
759 |
+
|
760 |
+
|
761 |
+
image = self._decode_latents(latents)
|
762 |
+
image = (image.squeeze(0).permute(1, 2, 0) * 255).byte().numpy()
|
763 |
+
pil_image = Image.fromarray(image)
|
764 |
+
|
765 |
+
if use_refiner:
|
766 |
+
if self.enable_dit_offloading:
|
767 |
+
self.to('cpu')
|
768 |
+
if self.enable_refiner_offloading:
|
769 |
+
self.refiner_pipeline.to(self.execution_device)
|
770 |
+
pil_image = self.refiner_pipeline(
|
771 |
+
image=pil_image,
|
772 |
+
prompt=prompt,
|
773 |
+
negative_prompt=negative_prompt,
|
774 |
+
width=width,
|
775 |
+
height=height,
|
776 |
+
use_reprompt=False,
|
777 |
+
use_refiner=False,
|
778 |
+
num_inference_steps=4,
|
779 |
+
guidance_scale=guidance_scale,
|
780 |
+
generator=generator,
|
781 |
+
)
|
782 |
+
if self.enable_refiner_offloading:
|
783 |
+
self.refiner_pipeline.to('cpu')
|
784 |
+
if self.enable_dit_offloading:
|
785 |
+
self.to(self.execution_device)
|
786 |
+
|
787 |
+
return pil_image
|
788 |
+
|
789 |
+
@property
|
790 |
+
def use_meanflow(self):
|
791 |
+
return getattr(self.dit, 'use_meanflow', False)
|
792 |
+
|
793 |
+
@property
|
794 |
+
def use_byt5(self):
|
795 |
+
return getattr(self.dit, 'glyph_byT5_v2', False)
|
796 |
+
|
797 |
+
@property
|
798 |
+
def cfg_distilled(self):
|
799 |
+
return getattr(self.dit, 'guidance_embed', False)
|
800 |
+
|
801 |
+
def to(self, device: str | torch.device):
|
802 |
+
"""
|
803 |
+
Move pipeline to specified device.
|
804 |
+
|
805 |
+
Args:
|
806 |
+
device: Target device string
|
807 |
+
|
808 |
+
Returns:
|
809 |
+
Self
|
810 |
+
"""
|
811 |
+
self.device = device
|
812 |
+
if self.dit is not None:
|
813 |
+
self.dit = self.dit.to(device, non_blocking=True)
|
814 |
+
if self.text_encoder is not None:
|
815 |
+
self.text_encoder = self.text_encoder.to(device, non_blocking=True)
|
816 |
+
if self.vae is not None:
|
817 |
+
self.vae = self.vae.to(device, non_blocking=True)
|
818 |
+
return self
|
819 |
+
|
820 |
+
def update_config(self, **kwargs):
|
821 |
+
"""
|
822 |
+
Update configuration parameters.
|
823 |
+
|
824 |
+
Args:
|
825 |
+
**kwargs: Key-value pairs to update
|
826 |
+
|
827 |
+
Returns:
|
828 |
+
Self
|
829 |
+
"""
|
830 |
+
for key, value in kwargs.items():
|
831 |
+
if hasattr(self.config, key):
|
832 |
+
setattr(self.config, key, value)
|
833 |
+
if hasattr(self, key):
|
834 |
+
setattr(self, key, value)
|
835 |
+
return self
|
836 |
+
|
837 |
+
@classmethod
|
838 |
+
def from_pretrained(cls, model_name: str = "hunyuanimage-v2.1", use_distilled: bool = False, **kwargs):
|
839 |
+
"""
|
840 |
+
Create pipeline from pretrained model.
|
841 |
+
|
842 |
+
Args:
|
843 |
+
model_name: Model name, supports "hunyuanimage-v2.1", "hunyuanimage-v2.1-distilled"
|
844 |
+
use_distilled: Whether to use distilled model (overrides model_name if specified)
|
845 |
+
**kwargs: Additional configuration options
|
846 |
+
|
847 |
+
Returns:
|
848 |
+
HunyuanImagePipeline instance
|
849 |
+
"""
|
850 |
+
if model_name == "hunyuanimage-v2.1":
|
851 |
+
version = "v2.1"
|
852 |
+
use_distilled = False
|
853 |
+
elif model_name == "hunyuanimage-v2.1-distilled":
|
854 |
+
version = "v2.1"
|
855 |
+
use_distilled = True
|
856 |
+
else:
|
857 |
+
raise ValueError(
|
858 |
+
f"Unsupported model name: {model_name}. Supported names: 'hunyuanimage-v2.1', 'hunyuanimage-v2.1-distilled'"
|
859 |
+
)
|
860 |
+
|
861 |
+
config = HunyuanImagePipelineConfig.create_default(
|
862 |
+
version=version, use_distilled=use_distilled, **kwargs
|
863 |
+
)
|
864 |
+
return cls(config=config)
|
865 |
+
|
866 |
+
@classmethod
|
867 |
+
def from_config(cls, config: HunyuanImagePipelineConfig):
|
868 |
+
"""
|
869 |
+
Create pipeline from configuration object.
|
870 |
+
|
871 |
+
Args:
|
872 |
+
config: HunyuanImagePipelineConfig instance
|
873 |
+
|
874 |
+
Returns:
|
875 |
+
HunyuanImagePipeline instance
|
876 |
+
"""
|
877 |
+
return cls(config=config)
|
878 |
+
|
879 |
+
|
880 |
+
def DiffusionPipeline(model_name: str = "hunyuanimage-v2.1", use_distilled: bool = False, **kwargs):
|
881 |
+
"""
|
882 |
+
Factory function to create HunyuanImagePipeline.
|
883 |
+
|
884 |
+
Args:
|
885 |
+
model_name: Model name, supports "hunyuanimage-v2.1", "hunyuanimage-v2.1-distilled"
|
886 |
+
use_distilled: Whether to use distilled model (overrides model_name if specified)
|
887 |
+
**kwargs: Additional configuration options
|
888 |
+
|
889 |
+
Returns:
|
890 |
+
HunyuanImagePipeline instance
|
891 |
+
"""
|
892 |
+
return HunyuanImagePipeline.from_pretrained(model_name, use_distilled=use_distilled, **kwargs)
|
hyimage/diffusion/pipelines/hunyuanimage_refiner_pipeline.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
import torchvision.transforms as T
|
8 |
+
|
9 |
+
from .hunyuanimage_pipeline import HunyuanImagePipeline, HunyuanImagePipelineConfig
|
10 |
+
|
11 |
+
from hyimage.models.model_zoo import (
|
12 |
+
HUNYUANIMAGE_REFINER_DIT,
|
13 |
+
HUNYUANIMAGE_REFINER_VAE_32x,
|
14 |
+
HUNYUANIMAGE_REFINER_TEXT_ENCODER,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class HunYuanImageRefinerPipelineConfig(HunyuanImagePipelineConfig):
|
20 |
+
"""
|
21 |
+
Configuration class for HunyuanImage refiner pipeline.
|
22 |
+
|
23 |
+
Inherits from HunyuanImagePipelineConfig and overrides specific parameters
|
24 |
+
for the refiner functionality.
|
25 |
+
"""
|
26 |
+
|
27 |
+
default_sampling_steps: int = 4
|
28 |
+
shift: int = 1
|
29 |
+
version: str = "v1.0"
|
30 |
+
cfg_mode: str = ""
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def create_default(
|
34 |
+
cls,
|
35 |
+
version: str = "v1.0",
|
36 |
+
use_distilled: bool = False,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
dit_config = HUNYUANIMAGE_REFINER_DIT()
|
40 |
+
vae_config = HUNYUANIMAGE_REFINER_VAE_32x()
|
41 |
+
text_encoder_config = HUNYUANIMAGE_REFINER_TEXT_ENCODER()
|
42 |
+
|
43 |
+
return cls(
|
44 |
+
dit_config=dit_config,
|
45 |
+
vae_config=vae_config,
|
46 |
+
text_encoder_config=text_encoder_config,
|
47 |
+
reprompt_config=None,
|
48 |
+
version=version,
|
49 |
+
**kwargs,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
class HunYuanImageRefinerPipeline(HunyuanImagePipeline):
|
54 |
+
"""A refiner pipeline for HunyuanImage that inherits from the main pipeline.
|
55 |
+
|
56 |
+
This pipeline refines existing images using the same model architecture
|
57 |
+
but with different default parameters and an image input.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, config: HunYuanImageRefinerPipelineConfig, **kwargs):
|
61 |
+
"""Initialize the refiner pipeline.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
config: Refiner-specific configuration
|
65 |
+
**kwargs: Additional arguments passed to parent class
|
66 |
+
"""
|
67 |
+
assert isinstance(config, HunYuanImageRefinerPipelineConfig)
|
68 |
+
super().__init__(config, **kwargs)
|
69 |
+
assert self.cfg_distilled
|
70 |
+
|
71 |
+
def _condition_aug(self, latents, noise=None, strength=0.3):
|
72 |
+
"""Apply conditioning augmentation for refiner.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
latents: Input latents tensor
|
76 |
+
noise: Optional noise tensor, if None will be generated
|
77 |
+
strength: Augmentation strength factor
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
Augmented latents tensor
|
81 |
+
"""
|
82 |
+
if noise is None:
|
83 |
+
noise = torch.randn_like(latents)
|
84 |
+
return strength * noise + (1 - strength) * latents
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def __call__(
|
88 |
+
self,
|
89 |
+
prompt: str,
|
90 |
+
negative_prompt: str = "",
|
91 |
+
width: int = 2048,
|
92 |
+
height: int = 2048,
|
93 |
+
use_reprompt: bool = False,
|
94 |
+
num_inference_steps: Optional[int] = None,
|
95 |
+
guidance_scale: Optional[float] = None,
|
96 |
+
shift: int = 4,
|
97 |
+
seed: Optional[int] = 42,
|
98 |
+
image: Optional[Image.Image] = None,
|
99 |
+
**kwargs,
|
100 |
+
) -> Image.Image:
|
101 |
+
"""Refine an existing image using text guidance.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
prompt: Text prompt describing the desired refinement
|
105 |
+
negative_prompt: Negative prompt for guidance
|
106 |
+
width: Image width
|
107 |
+
height: Image height
|
108 |
+
use_reprompt: Whether to use reprompt (ignored for refiner)
|
109 |
+
num_inference_steps: Number of denoising steps (overrides config if provided)
|
110 |
+
guidance_scale: Strength of classifier-free guidance (overrides config if provided)
|
111 |
+
seed: Random seed for reproducibility
|
112 |
+
image: Image to be refined (required for refiner)
|
113 |
+
**kwargs: Additional arguments
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
Refined PIL Image
|
117 |
+
"""
|
118 |
+
if image is None:
|
119 |
+
raise ValueError("Image parameter is required for refiner pipeline")
|
120 |
+
|
121 |
+
if seed is not None:
|
122 |
+
generator = torch.Generator(device='cpu').manual_seed(seed)
|
123 |
+
else:
|
124 |
+
generator = None
|
125 |
+
|
126 |
+
sampling_steps = (
|
127 |
+
num_inference_steps
|
128 |
+
if num_inference_steps is not None
|
129 |
+
else self.default_sampling_steps
|
130 |
+
)
|
131 |
+
guidance_scale = (
|
132 |
+
guidance_scale if guidance_scale is not None else self.default_guidance_scale
|
133 |
+
)
|
134 |
+
shift = shift if shift is not None else self.shift
|
135 |
+
|
136 |
+
# Print log about current refinement task
|
137 |
+
print("=" * 60)
|
138 |
+
print("🔧 HunyuanImage Refinement Task")
|
139 |
+
print("-" * 60)
|
140 |
+
print(f"Prompt: {prompt}")
|
141 |
+
print(f"Guidance Scale: {guidance_scale}")
|
142 |
+
print(f"Shift: {self.shift}")
|
143 |
+
print(f"Seed: {seed}")
|
144 |
+
print(f"Image Size: {width} x {height}")
|
145 |
+
print(f"Sampling Steps: {sampling_steps}")
|
146 |
+
print("=" * 60)
|
147 |
+
|
148 |
+
# Encode prompts
|
149 |
+
pos_text_emb, pos_text_mask = self._encode_text(prompt)
|
150 |
+
|
151 |
+
latents = self._prepare_latents(width, height, generator=generator)
|
152 |
+
|
153 |
+
_pil_to_tensor = T.Compose(
|
154 |
+
[
|
155 |
+
T.ToTensor(), # convert to tensor and normalize to [0, 1]
|
156 |
+
T.Normalize([0.5], [0.5]), # transform to [-1, 1]
|
157 |
+
]
|
158 |
+
)
|
159 |
+
|
160 |
+
image_tensor = (
|
161 |
+
_pil_to_tensor(image).unsqueeze(0).to("cuda", dtype=self.vae.dtype)
|
162 |
+
)
|
163 |
+
|
164 |
+
cond_latents = self.vae.encode(
|
165 |
+
image_tensor.to(self.device, dtype=self.vae.dtype)
|
166 |
+
).latent_dist.sample()
|
167 |
+
|
168 |
+
if (
|
169 |
+
hasattr(self.vae.config, "shift_factor")
|
170 |
+
and self.vae.config.shift_factor
|
171 |
+
):
|
172 |
+
cond_latents.sub_(self.vae.config.shift_factor).mul_(
|
173 |
+
self.vae.config.scaling_factor
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
cond_latents.mul_(self.vae.config.scaling_factor)
|
177 |
+
|
178 |
+
# Add frame dimension for refiner model
|
179 |
+
cond_latents = cond_latents.unsqueeze(2) # (b c 1 h w)
|
180 |
+
|
181 |
+
# Apply conditioning augmentation
|
182 |
+
cond_latents = self._condition_aug(cond_latents)
|
183 |
+
|
184 |
+
timesteps, sigmas = self.get_timesteps_sigmas(sampling_steps, shift)
|
185 |
+
|
186 |
+
text_emb = pos_text_emb
|
187 |
+
text_mask = pos_text_mask
|
188 |
+
|
189 |
+
for i, t in enumerate(tqdm(timesteps, desc="Refining", total=len(timesteps))):
|
190 |
+
# Concatenate noise latents with condition latents for refiner input
|
191 |
+
latent_model_input = torch.cat([latents, cond_latents], dim=1)
|
192 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
193 |
+
|
194 |
+
# Predict noise with guidance
|
195 |
+
noise_pred = self._denoise_step(
|
196 |
+
latent_model_input,
|
197 |
+
t_expand,
|
198 |
+
text_emb,
|
199 |
+
text_mask,
|
200 |
+
None,
|
201 |
+
None,
|
202 |
+
guidance_scale,
|
203 |
+
timesteps_r=None,
|
204 |
+
)
|
205 |
+
|
206 |
+
latents = self.step(latents, noise_pred, sigmas, i)
|
207 |
+
|
208 |
+
refined_image = self._decode_latents(latents)
|
209 |
+
|
210 |
+
# Convert to PIL Image
|
211 |
+
refined_image = (refined_image.squeeze(0).permute(1, 2, 0) * 255).byte().numpy()
|
212 |
+
pil_image = Image.fromarray(refined_image)
|
213 |
+
|
214 |
+
return pil_image
|
215 |
+
|
216 |
+
@classmethod
|
217 |
+
def from_pretrained(
|
218 |
+
cls,
|
219 |
+
model_name: str = "hunyuanimage-refiner",
|
220 |
+
use_distilled: bool = False,
|
221 |
+
**kwargs,
|
222 |
+
):
|
223 |
+
"""Create refiner pipeline from pretrained model.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
model_name: Model name, currently only supports "hunyuanimage-refiner"
|
227 |
+
use_distilled: Whether to use distilled model (unused for refiner)
|
228 |
+
**kwargs: Additional configuration options
|
229 |
+
"""
|
230 |
+
if model_name == "hunyuanimage-refiner":
|
231 |
+
version = "v1.0"
|
232 |
+
else:
|
233 |
+
raise ValueError(
|
234 |
+
f"Unsupported refiner model name: {model_name}. Supported names: 'hunyuanimage-refiner'"
|
235 |
+
)
|
236 |
+
|
237 |
+
config = HunYuanImageRefinerPipelineConfig.create_default(
|
238 |
+
version=version, **kwargs
|
239 |
+
)
|
240 |
+
|
241 |
+
return cls(config=config)
|
242 |
+
|
243 |
+
@classmethod
|
244 |
+
def from_config(cls, config: Union[HunYuanImageRefinerPipelineConfig, HunyuanImagePipelineConfig]):
|
245 |
+
"""Create refiner pipeline from configuration object.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
config: Configuration object for the pipeline
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
Initialized refiner pipeline instance
|
252 |
+
"""
|
253 |
+
return cls(config=config)
|
254 |
+
|
255 |
+
|
256 |
+
# Convenience function for easy access
|
257 |
+
def RefinerPipeline(
|
258 |
+
model_name: str = "hunyuanimage-refiner",
|
259 |
+
**kwargs,
|
260 |
+
):
|
261 |
+
"""Factory function to create HunYuanImageRefinerPipeline.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
model_name: Model name, currently only supports "hunyuanimage-refiner"
|
265 |
+
**kwargs: Additional configuration options
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
Initialized refiner pipeline instance
|
269 |
+
"""
|
270 |
+
return HunYuanImageRefinerPipeline.from_pretrained(
|
271 |
+
model_name, **kwargs
|
272 |
+
)
|
hyimage/models/hunyuan/__init__.py
ADDED
File without changes
|
hyimage/models/hunyuan/configs/hunyuanimage_config.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from hyimage.common.config import LazyCall as L
|
2 |
+
from hyimage.models.hunyuan.modules.hunyuanimage_dit import HYImageDiffusionTransformer
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
hunyuanimage_refiner_cfg = L(HYImageDiffusionTransformer)(
|
8 |
+
in_channels=128,
|
9 |
+
out_channels=64,
|
10 |
+
mm_double_blocks_depth=20,
|
11 |
+
mm_single_blocks_depth=40,
|
12 |
+
rope_dim_list=[16, 56, 56],
|
13 |
+
hidden_size=3328,
|
14 |
+
heads_num=26,
|
15 |
+
mlp_width_ratio=4,
|
16 |
+
patch_size=[1, 1, 1],
|
17 |
+
text_states_dim=3584,
|
18 |
+
guidance_embed=True,
|
19 |
+
use_meanflow=True,
|
20 |
+
)
|
21 |
+
|
22 |
+
hunyuanimage_v2_1_cfg = L(HYImageDiffusionTransformer)(
|
23 |
+
in_channels=64,
|
24 |
+
out_channels=64,
|
25 |
+
mm_double_blocks_depth=20,
|
26 |
+
mm_single_blocks_depth=40,
|
27 |
+
rope_dim_list=[64, 64],
|
28 |
+
hidden_size=3584,
|
29 |
+
heads_num=28,
|
30 |
+
mlp_width_ratio=4,
|
31 |
+
patch_size=[1, 1],
|
32 |
+
text_states_dim=3584,
|
33 |
+
glyph_byT5_v2=True,
|
34 |
+
guidance_embed=False,
|
35 |
+
)
|
36 |
+
|
37 |
+
hunyuanimage_v2_1_distilled_cfg = L(HYImageDiffusionTransformer)(
|
38 |
+
in_channels=64,
|
39 |
+
out_channels=64,
|
40 |
+
mm_double_blocks_depth=20,
|
41 |
+
mm_single_blocks_depth=40,
|
42 |
+
rope_dim_list=[64, 64],
|
43 |
+
hidden_size=3584,
|
44 |
+
heads_num=28,
|
45 |
+
mlp_width_ratio=4,
|
46 |
+
patch_size=[1, 1],
|
47 |
+
text_states_dim=3584,
|
48 |
+
glyph_byT5_v2=True,
|
49 |
+
guidance_embed=True,
|
50 |
+
use_meanflow=True,
|
51 |
+
)
|
hyimage/models/hunyuan/modules/__init__.py
ADDED
File without changes
|
hyimage/models/hunyuan/modules/activation_layers.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_activation_layer(act_type):
|
5 |
+
"""get activation layer
|
6 |
+
|
7 |
+
Args:
|
8 |
+
act_type (str): the activation type
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
torch.nn.functional: the activation layer
|
12 |
+
"""
|
13 |
+
if act_type == "gelu":
|
14 |
+
return lambda: nn.GELU()
|
15 |
+
elif act_type == "gelu_tanh":
|
16 |
+
# Approximate `tanh` requires torch >= 1.13
|
17 |
+
return lambda: nn.GELU(approximate="tanh")
|
18 |
+
elif act_type == "relu":
|
19 |
+
return nn.ReLU
|
20 |
+
elif act_type == "silu":
|
21 |
+
return nn.SiLU
|
22 |
+
else:
|
23 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
hyimage/models/hunyuan/modules/embed_layers.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from ..utils.helpers import to_2tuple
|
7 |
+
|
8 |
+
|
9 |
+
class PatchEmbed2D(nn.Module):
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
patch_size=16,
|
14 |
+
in_chans=3,
|
15 |
+
embed_dim=768,
|
16 |
+
norm_layer=None,
|
17 |
+
flatten=True,
|
18 |
+
bias=True,
|
19 |
+
dtype=None,
|
20 |
+
device=None,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
patch_size = to_2tuple(patch_size)
|
24 |
+
self.patch_size = patch_size
|
25 |
+
self.flatten = flatten
|
26 |
+
|
27 |
+
self.proj = nn.Conv2d(
|
28 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype
|
29 |
+
)
|
30 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
31 |
+
if bias:
|
32 |
+
nn.init.zeros_(self.proj.bias)
|
33 |
+
|
34 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.proj(x)
|
38 |
+
if self.flatten:
|
39 |
+
x = x.flatten(2).transpose(1, 2)
|
40 |
+
x = self.norm(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class PatchEmbed(nn.Module):
|
45 |
+
"""2D Image to Patch Embedding
|
46 |
+
|
47 |
+
Image to Patch Embedding using Conv2d
|
48 |
+
|
49 |
+
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
50 |
+
|
51 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
52 |
+
|
53 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
54 |
+
|
55 |
+
Remove the _assert function in forward function to be compatible with multi-resolution images.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
patch_size=16,
|
61 |
+
in_chans=3,
|
62 |
+
embed_dim=768,
|
63 |
+
norm_layer=None,
|
64 |
+
flatten=True,
|
65 |
+
bias=True,
|
66 |
+
dtype=None,
|
67 |
+
device=None,
|
68 |
+
):
|
69 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
70 |
+
super().__init__()
|
71 |
+
patch_size = to_2tuple(patch_size)
|
72 |
+
self.patch_size = patch_size
|
73 |
+
self.flatten = flatten
|
74 |
+
|
75 |
+
self.proj = nn.Conv3d(
|
76 |
+
in_chans,
|
77 |
+
embed_dim,
|
78 |
+
kernel_size=patch_size,
|
79 |
+
stride=patch_size,
|
80 |
+
bias=bias,
|
81 |
+
**factory_kwargs
|
82 |
+
)
|
83 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
84 |
+
if bias:
|
85 |
+
nn.init.zeros_(self.proj.bias)
|
86 |
+
|
87 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
x = self.proj(x)
|
91 |
+
if self.flatten:
|
92 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
93 |
+
x = self.norm(x)
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
class TextProjection(nn.Module):
|
98 |
+
"""
|
99 |
+
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
100 |
+
|
101 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
105 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
106 |
+
super().__init__()
|
107 |
+
self.linear_1 = nn.Linear(
|
108 |
+
in_features=in_channels,
|
109 |
+
out_features=hidden_size,
|
110 |
+
bias=True,
|
111 |
+
**factory_kwargs
|
112 |
+
)
|
113 |
+
self.act_1 = act_layer()
|
114 |
+
self.linear_2 = nn.Linear(
|
115 |
+
in_features=hidden_size,
|
116 |
+
out_features=hidden_size,
|
117 |
+
bias=True,
|
118 |
+
**factory_kwargs
|
119 |
+
)
|
120 |
+
|
121 |
+
def forward(self, caption):
|
122 |
+
hidden_states = self.linear_1(caption)
|
123 |
+
hidden_states = self.act_1(hidden_states)
|
124 |
+
hidden_states = self.linear_2(hidden_states)
|
125 |
+
return hidden_states
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
def timestep_embedding(t, dim, max_period=10000):
|
130 |
+
"""
|
131 |
+
Create sinusoidal timestep embeddings.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
135 |
+
dim (int): the dimension of the output.
|
136 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
140 |
+
|
141 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
142 |
+
"""
|
143 |
+
half = dim // 2
|
144 |
+
freqs = torch.exp(
|
145 |
+
-math.log(max_period)
|
146 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
147 |
+
/ half
|
148 |
+
).to(device=t.device)
|
149 |
+
args = t[:, None].float() * freqs[None]
|
150 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
151 |
+
if dim % 2:
|
152 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
153 |
+
return embedding
|
154 |
+
|
155 |
+
|
156 |
+
class TimestepEmbedder(nn.Module):
|
157 |
+
"""
|
158 |
+
Embeds scalar timesteps into vector representations.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
hidden_size,
|
164 |
+
act_layer,
|
165 |
+
frequency_embedding_size=256,
|
166 |
+
max_period=10000,
|
167 |
+
out_size=None,
|
168 |
+
dtype=None,
|
169 |
+
device=None,
|
170 |
+
):
|
171 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
172 |
+
super().__init__()
|
173 |
+
self.frequency_embedding_size = frequency_embedding_size
|
174 |
+
self.max_period = max_period
|
175 |
+
if out_size is None:
|
176 |
+
out_size = hidden_size
|
177 |
+
|
178 |
+
self.mlp = nn.Sequential(
|
179 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
180 |
+
act_layer(),
|
181 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
182 |
+
)
|
183 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02)
|
184 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02)
|
185 |
+
|
186 |
+
def forward(self, t):
|
187 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
188 |
+
t_emb = self.mlp(t_freq)
|
189 |
+
return t_emb
|
hyimage/models/hunyuan/modules/flash_attn_no_pad.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
try:
|
5 |
+
from flash_attn_interface import flash_attn_varlen_func
|
6 |
+
|
7 |
+
print("Using FlashAttention v3.")
|
8 |
+
except ImportError:
|
9 |
+
print("FlashAttention v3 not found, falling back to v2.")
|
10 |
+
from flash_attn import flash_attn_varlen_func
|
11 |
+
|
12 |
+
from flash_attn import flash_attn_varlen_qkvpacked_func
|
13 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
14 |
+
|
15 |
+
|
16 |
+
def get_cu_seqlens(text_mask: torch.Tensor, img_len: int):
|
17 |
+
"""
|
18 |
+
Compute cumulative sequence lengths (cu_seqlens) for FlashAttention.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
text_mask (torch.Tensor): Boolean mask of shape (batch_size, text_seq_len).
|
22 |
+
img_len (int): Length of image sequence.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
cu_seqlens (torch.Tensor): 1D tensor of cumulative sequence lengths for each segment.
|
26 |
+
max_len (int): Maximum sequence length (text + image).
|
27 |
+
"""
|
28 |
+
batch_size = text_mask.shape[0]
|
29 |
+
text_len = text_mask.sum(dim=1)
|
30 |
+
max_len = text_mask.shape[1] + img_len
|
31 |
+
|
32 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device)
|
33 |
+
for i in range(batch_size):
|
34 |
+
s = text_len[i] + img_len
|
35 |
+
s1 = i * max_len + s
|
36 |
+
s2 = (i + 1) * max_len
|
37 |
+
cu_seqlens[2 * i + 1] = s1
|
38 |
+
cu_seqlens[2 * i + 2] = s2
|
39 |
+
|
40 |
+
return cu_seqlens, max_len
|
41 |
+
|
42 |
+
|
43 |
+
def flash_attn_v3(
|
44 |
+
q: torch.Tensor,
|
45 |
+
k: torch.Tensor,
|
46 |
+
v: torch.Tensor,
|
47 |
+
cu_seqlens: torch.Tensor,
|
48 |
+
max_s: int,
|
49 |
+
causal: bool = False,
|
50 |
+
deterministic: bool = False,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
FlashAttention v3 wrapper.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
q, k, v (torch.Tensor): Query, key, value tensors of shape (batch, seq, nheads, head_dim).
|
57 |
+
cu_seqlens (torch.Tensor): Cumulative sequence lengths.
|
58 |
+
max_s (int): Maximum sequence length.
|
59 |
+
causal (bool): Whether to apply causal masking.
|
60 |
+
deterministic (bool): Deterministic computation.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
torch.Tensor: Output tensor of shape (batch, seq, nheads, head_dim).
|
64 |
+
"""
|
65 |
+
batch_size, seqlen = q.shape[:2]
|
66 |
+
q = q.reshape(-1, *q.shape[2:])
|
67 |
+
k = k.reshape(-1, *k.shape[2:])
|
68 |
+
v = v.reshape(-1, *v.shape[2:])
|
69 |
+
output = flash_attn_varlen_func(
|
70 |
+
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s, causal=causal, deterministic=deterministic
|
71 |
+
)
|
72 |
+
output = output.view(batch_size, seqlen, *output.shape[-2:])
|
73 |
+
return output
|
74 |
+
|
75 |
+
|
76 |
+
def flash_attn_no_pad(
|
77 |
+
qkv: torch.Tensor,
|
78 |
+
key_padding_mask: torch.Tensor,
|
79 |
+
causal: bool = False,
|
80 |
+
dropout_p: float = 0.0,
|
81 |
+
softmax_scale=None,
|
82 |
+
deterministic: bool = False,
|
83 |
+
):
|
84 |
+
"""
|
85 |
+
FlashAttention for packed QKV input without padding.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
qkv (torch.Tensor): Input tensor of shape (batch, seq, 3, nheads, head_dim).
|
89 |
+
key_padding_mask (torch.Tensor): Boolean mask of shape (batch, seq).
|
90 |
+
causal (bool): Whether to apply causal masking.
|
91 |
+
dropout_p (float): Dropout probability.
|
92 |
+
softmax_scale (float, optional): Softmax scaling factor.
|
93 |
+
deterministic (bool): Deterministic computation.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: Output tensor of shape (batch, seq, nheads, head_dim).
|
97 |
+
"""
|
98 |
+
batch_size, seqlen, _, nheads, head_dim = qkv.shape
|
99 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
100 |
+
|
101 |
+
# Unpad input for FlashAttention, drop `used_seqlens_in_batch` for version compatibility
|
102 |
+
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)[:4]
|
103 |
+
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
|
104 |
+
|
105 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
106 |
+
x_unpad,
|
107 |
+
cu_seqlens,
|
108 |
+
max_s,
|
109 |
+
dropout_p,
|
110 |
+
softmax_scale=softmax_scale,
|
111 |
+
causal=causal,
|
112 |
+
deterministic=deterministic,
|
113 |
+
)
|
114 |
+
if isinstance(output_unpad, tuple):
|
115 |
+
output_unpad = output_unpad[0]
|
116 |
+
|
117 |
+
# Pad output back to original shape
|
118 |
+
output = pad_input(
|
119 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
120 |
+
indices,
|
121 |
+
batch_size,
|
122 |
+
seqlen,
|
123 |
+
)
|
124 |
+
output = rearrange(output, "b s (h d) -> b s h d", h=nheads)
|
125 |
+
return output
|
hyimage/models/hunyuan/modules/hunyuanimage_dit.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, List, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
7 |
+
from diffusers.models import ModelMixin
|
8 |
+
|
9 |
+
from hyimage.models.hunyuan.modules.posemb_layers import get_nd_rotary_pos_embed
|
10 |
+
from hyimage.models.hunyuan.modules.flash_attn_no_pad import get_cu_seqlens
|
11 |
+
|
12 |
+
from .activation_layers import get_activation_layer
|
13 |
+
from .embed_layers import PatchEmbed, PatchEmbed2D, TextProjection, TimestepEmbedder
|
14 |
+
from .mlp_layers import FinalLayer
|
15 |
+
from .models import MMDoubleStreamBlock, MMSingleStreamBlock
|
16 |
+
from .token_refiner import SingleTokenRefiner
|
17 |
+
|
18 |
+
from hyimage.models.text_encoder.byT5 import ByT5Mapper
|
19 |
+
|
20 |
+
|
21 |
+
def convert_hunyuan_dict_for_tensor_parallel(state_dict):
|
22 |
+
"""
|
23 |
+
Convert a Hunyuan model state dict to be compatible with tensor parallel architectures.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
state_dict: Original state dict
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
new_dict: Converted state dict
|
30 |
+
"""
|
31 |
+
new_dict = {}
|
32 |
+
for k, w in state_dict.items():
|
33 |
+
if k.startswith("double_blocks") and "attn_qkv.weight" in k:
|
34 |
+
hidden_size = w.shape[1]
|
35 |
+
k1 = k.replace("attn_qkv.weight", "attn_q.weight")
|
36 |
+
w1 = w[:hidden_size, :]
|
37 |
+
new_dict[k1] = w1
|
38 |
+
k2 = k.replace("attn_qkv.weight", "attn_k.weight")
|
39 |
+
w2 = w[hidden_size : 2 * hidden_size, :]
|
40 |
+
new_dict[k2] = w2
|
41 |
+
k3 = k.replace("attn_qkv.weight", "attn_v.weight")
|
42 |
+
w3 = w[-hidden_size:, :]
|
43 |
+
new_dict[k3] = w3
|
44 |
+
elif k.startswith("double_blocks") and "attn_qkv.bias" in k:
|
45 |
+
hidden_size = w.shape[0] // 3
|
46 |
+
k1 = k.replace("attn_qkv.bias", "attn_q.bias")
|
47 |
+
w1 = w[:hidden_size]
|
48 |
+
new_dict[k1] = w1
|
49 |
+
k2 = k.replace("attn_qkv.bias", "attn_k.bias")
|
50 |
+
w2 = w[hidden_size : 2 * hidden_size]
|
51 |
+
new_dict[k2] = w2
|
52 |
+
k3 = k.replace("attn_qkv.bias", "attn_v.bias")
|
53 |
+
w3 = w[-hidden_size:]
|
54 |
+
new_dict[k3] = w3
|
55 |
+
elif k.startswith("single_blocks") and "linear1" in k:
|
56 |
+
hidden_size = state_dict[k.replace("linear1", "linear2")].shape[0]
|
57 |
+
k1 = k.replace("linear1", "linear1_q")
|
58 |
+
w1 = w[:hidden_size]
|
59 |
+
new_dict[k1] = w1
|
60 |
+
k2 = k.replace("linear1", "linear1_k")
|
61 |
+
w2 = w[hidden_size : 2 * hidden_size]
|
62 |
+
new_dict[k2] = w2
|
63 |
+
k3 = k.replace("linear1", "linear1_v")
|
64 |
+
w3 = w[2 * hidden_size : 3 * hidden_size]
|
65 |
+
new_dict[k3] = w3
|
66 |
+
k4 = k.replace("linear1", "linear1_mlp")
|
67 |
+
w4 = w[3 * hidden_size :]
|
68 |
+
new_dict[k4] = w4
|
69 |
+
elif k.startswith("single_blocks") and "linear2" in k:
|
70 |
+
k1 = k.replace("linear2", "linear2.fc")
|
71 |
+
new_dict[k1] = w
|
72 |
+
else:
|
73 |
+
new_dict[k] = w
|
74 |
+
return new_dict
|
75 |
+
|
76 |
+
|
77 |
+
def load_hunyuan_dit_state_dict(model, dit_model_name_or_path, strict=True, assign=False):
|
78 |
+
"""
|
79 |
+
Load a state dict for a Hunyuan model, handling both safetensors and torch formats.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
model: Model instance to load weights into
|
83 |
+
dit_model_name_or_path: Path to the checkpoint file
|
84 |
+
strict: Whether to strictly enforce that the keys in state_dict match the model's keys
|
85 |
+
assign: If True, assign weights directly without copying
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
model: The model with loaded weights
|
89 |
+
"""
|
90 |
+
from safetensors.torch import load_file as safetensors_load_file
|
91 |
+
|
92 |
+
if not os.path.exists(dit_model_name_or_path):
|
93 |
+
return
|
94 |
+
|
95 |
+
if dit_model_name_or_path.endswith(".safetensors"):
|
96 |
+
state_dict = safetensors_load_file(dit_model_name_or_path)
|
97 |
+
else:
|
98 |
+
state_dict = torch.load(
|
99 |
+
dit_model_name_or_path,
|
100 |
+
map_location="cpu",
|
101 |
+
weights_only=True,
|
102 |
+
)
|
103 |
+
try:
|
104 |
+
state_dict = convert_hunyuan_dict_for_tensor_parallel(state_dict)
|
105 |
+
except Exception:
|
106 |
+
pass
|
107 |
+
model.load_state_dict(state_dict, strict=strict, assign=assign)
|
108 |
+
return model
|
109 |
+
|
110 |
+
|
111 |
+
class HYImageDiffusionTransformer(ModelMixin, ConfigMixin):
|
112 |
+
|
113 |
+
@register_to_config
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
patch_size: list = [1, 2, 2],
|
117 |
+
in_channels: int = 4,
|
118 |
+
out_channels: int = None,
|
119 |
+
hidden_size: int = 3072,
|
120 |
+
heads_num: int = 24,
|
121 |
+
mlp_width_ratio: float = 4.0,
|
122 |
+
mlp_act_type: str = "gelu_tanh",
|
123 |
+
mm_double_blocks_depth: int = 20,
|
124 |
+
mm_single_blocks_depth: int = 40,
|
125 |
+
rope_dim_list: List[int] = [16, 56, 56],
|
126 |
+
qkv_bias: bool = True,
|
127 |
+
qk_norm: bool = True,
|
128 |
+
qk_norm_type: str = "rms",
|
129 |
+
guidance_embed: bool = False,
|
130 |
+
text_projection: str = "single_refiner",
|
131 |
+
use_attention_mask: bool = True,
|
132 |
+
dtype: Optional[torch.dtype] = None,
|
133 |
+
device: Optional[torch.device] = None,
|
134 |
+
text_states_dim: int = 4096,
|
135 |
+
rope_theta: int = 256,
|
136 |
+
glyph_byT5_v2: bool = False,
|
137 |
+
use_meanflow: bool = False,
|
138 |
+
):
|
139 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
self.patch_size = patch_size
|
143 |
+
self.in_channels = in_channels
|
144 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
145 |
+
self.unpatchify_channels = self.out_channels
|
146 |
+
self.guidance_embed = guidance_embed
|
147 |
+
self.rope_dim_list = rope_dim_list
|
148 |
+
self.rope_theta = rope_theta
|
149 |
+
self.use_attention_mask = use_attention_mask
|
150 |
+
self.text_projection = text_projection
|
151 |
+
|
152 |
+
if hidden_size % heads_num != 0:
|
153 |
+
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
|
154 |
+
pe_dim = hidden_size // heads_num
|
155 |
+
if sum(rope_dim_list) != pe_dim:
|
156 |
+
raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
|
157 |
+
self.hidden_size = hidden_size
|
158 |
+
self.heads_num = heads_num
|
159 |
+
|
160 |
+
self.glyph_byT5_v2 = glyph_byT5_v2
|
161 |
+
if self.glyph_byT5_v2:
|
162 |
+
self.byt5_in = ByT5Mapper(
|
163 |
+
in_dim=1472,
|
164 |
+
out_dim=2048,
|
165 |
+
hidden_dim=2048,
|
166 |
+
out_dim1=hidden_size,
|
167 |
+
use_residual=False
|
168 |
+
)
|
169 |
+
|
170 |
+
# Image projection
|
171 |
+
if len(self.patch_size) == 3:
|
172 |
+
self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
|
173 |
+
elif len(self.patch_size) == 2:
|
174 |
+
self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
|
175 |
+
else:
|
176 |
+
raise ValueError(f"Unsupported patch_size: {self.patch_size}")
|
177 |
+
|
178 |
+
# Text projection
|
179 |
+
if self.text_projection == "linear":
|
180 |
+
self.txt_in = TextProjection(
|
181 |
+
text_states_dim,
|
182 |
+
self.hidden_size,
|
183 |
+
get_activation_layer("silu"),
|
184 |
+
**factory_kwargs,
|
185 |
+
)
|
186 |
+
elif self.text_projection == "single_refiner":
|
187 |
+
self.txt_in = SingleTokenRefiner(
|
188 |
+
text_states_dim,
|
189 |
+
hidden_size,
|
190 |
+
heads_num,
|
191 |
+
depth=2,
|
192 |
+
**factory_kwargs,
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
196 |
+
|
197 |
+
# Time modulation
|
198 |
+
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
|
199 |
+
|
200 |
+
# MeanFlow support: only create time_r_in when needed
|
201 |
+
self.time_r_in = (
|
202 |
+
TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
|
203 |
+
if use_meanflow
|
204 |
+
else None
|
205 |
+
)
|
206 |
+
self.use_meanflow = use_meanflow
|
207 |
+
|
208 |
+
# Guidance modulation
|
209 |
+
self.guidance_in = (
|
210 |
+
TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
|
211 |
+
if guidance_embed
|
212 |
+
else None
|
213 |
+
)
|
214 |
+
|
215 |
+
# Double blocks
|
216 |
+
self.double_blocks = nn.ModuleList(
|
217 |
+
[
|
218 |
+
MMDoubleStreamBlock(
|
219 |
+
self.hidden_size,
|
220 |
+
self.heads_num,
|
221 |
+
mlp_width_ratio=mlp_width_ratio,
|
222 |
+
mlp_act_type=mlp_act_type,
|
223 |
+
qk_norm=qk_norm,
|
224 |
+
qk_norm_type=qk_norm_type,
|
225 |
+
qkv_bias=qkv_bias,
|
226 |
+
**factory_kwargs,
|
227 |
+
)
|
228 |
+
for _ in range(mm_double_blocks_depth)
|
229 |
+
]
|
230 |
+
)
|
231 |
+
|
232 |
+
# Single blocks
|
233 |
+
self.single_blocks = nn.ModuleList(
|
234 |
+
[
|
235 |
+
MMSingleStreamBlock(
|
236 |
+
self.hidden_size,
|
237 |
+
self.heads_num,
|
238 |
+
mlp_width_ratio=mlp_width_ratio,
|
239 |
+
mlp_act_type=mlp_act_type,
|
240 |
+
qk_norm=qk_norm,
|
241 |
+
qk_norm_type=qk_norm_type,
|
242 |
+
**factory_kwargs,
|
243 |
+
)
|
244 |
+
for _ in range(mm_single_blocks_depth)
|
245 |
+
]
|
246 |
+
)
|
247 |
+
|
248 |
+
self.final_layer = FinalLayer(
|
249 |
+
self.hidden_size,
|
250 |
+
self.patch_size,
|
251 |
+
self.out_channels,
|
252 |
+
get_activation_layer("silu"),
|
253 |
+
**factory_kwargs,
|
254 |
+
)
|
255 |
+
|
256 |
+
def enable_deterministic(self):
|
257 |
+
"""Enable deterministic mode for all transformer blocks."""
|
258 |
+
for block in self.double_blocks:
|
259 |
+
block.enable_deterministic()
|
260 |
+
for block in self.single_blocks:
|
261 |
+
block.enable_deterministic()
|
262 |
+
|
263 |
+
def disable_deterministic(self):
|
264 |
+
"""Disable deterministic mode for all transformer blocks."""
|
265 |
+
for block in self.double_blocks:
|
266 |
+
block.disable_deterministic()
|
267 |
+
for block in self.single_blocks:
|
268 |
+
block.disable_deterministic()
|
269 |
+
|
270 |
+
def get_rotary_pos_embed(self, rope_sizes):
|
271 |
+
"""
|
272 |
+
Get rotary position embeddings for the given sizes.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
rope_sizes: Sizes for each rotary dimension.
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
freqs_cos, freqs_sin: Cosine and sine frequencies for rotary embedding.
|
279 |
+
"""
|
280 |
+
target_ndim = 3
|
281 |
+
head_dim = self.hidden_size // self.heads_num
|
282 |
+
rope_dim_list = self.rope_dim_list
|
283 |
+
if rope_dim_list is None:
|
284 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
285 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
286 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
287 |
+
rope_dim_list,
|
288 |
+
rope_sizes,
|
289 |
+
theta=self.rope_theta,
|
290 |
+
use_real=True,
|
291 |
+
theta_rescale_factor=1,
|
292 |
+
)
|
293 |
+
return freqs_cos, freqs_sin
|
294 |
+
|
295 |
+
def reorder_txt_token(self, byt5_txt, txt, byt5_text_mask, text_mask):
|
296 |
+
"""
|
297 |
+
Reorder text tokens for ByT5 integration.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
byt5_txt: ByT5 text embeddings.
|
301 |
+
txt: Text embeddings.
|
302 |
+
byt5_text_mask: Mask for ByT5 tokens.
|
303 |
+
text_mask: Mask for text tokens.
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
reorder_txt: Reordered text embeddings.
|
307 |
+
reorder_mask: Reordered mask.
|
308 |
+
"""
|
309 |
+
reorder_txt = []
|
310 |
+
reorder_mask = []
|
311 |
+
|
312 |
+
for i in range(text_mask.shape[0]):
|
313 |
+
byt5_text_mask_i = byt5_text_mask[i].bool()
|
314 |
+
text_mask_i = text_mask[i].bool()
|
315 |
+
byt5_txt_i = byt5_txt[i]
|
316 |
+
txt_i = txt[i]
|
317 |
+
reorder_txt_i = torch.cat([
|
318 |
+
byt5_txt_i[byt5_text_mask_i],
|
319 |
+
txt_i[text_mask_i],
|
320 |
+
byt5_txt_i[~byt5_text_mask_i],
|
321 |
+
txt_i[~text_mask_i]
|
322 |
+
], dim=0)
|
323 |
+
|
324 |
+
reorder_mask_i = torch.cat([
|
325 |
+
byt5_text_mask_i[byt5_text_mask_i],
|
326 |
+
text_mask_i[text_mask_i],
|
327 |
+
byt5_text_mask_i[~byt5_text_mask_i],
|
328 |
+
text_mask_i[~text_mask_i]
|
329 |
+
], dim=0)
|
330 |
+
|
331 |
+
reorder_txt.append(reorder_txt_i)
|
332 |
+
reorder_mask.append(reorder_mask_i)
|
333 |
+
|
334 |
+
reorder_txt = torch.stack(reorder_txt)
|
335 |
+
reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64)
|
336 |
+
|
337 |
+
return reorder_txt, reorder_mask
|
338 |
+
|
339 |
+
def forward(
|
340 |
+
self,
|
341 |
+
hidden_states: torch.Tensor,
|
342 |
+
timestep: torch.LongTensor,
|
343 |
+
text_states: torch.Tensor,
|
344 |
+
encoder_attention_mask: torch.Tensor,
|
345 |
+
output_features: bool = False,
|
346 |
+
output_features_stride: int = 8,
|
347 |
+
freqs_cos: Optional[torch.Tensor] = None,
|
348 |
+
freqs_sin: Optional[torch.Tensor] = None,
|
349 |
+
return_dict: bool = False,
|
350 |
+
guidance=None,
|
351 |
+
extra_kwargs=None,
|
352 |
+
*,
|
353 |
+
timesteps_r: Optional[torch.LongTensor] = None,
|
354 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
355 |
+
"""
|
356 |
+
Forward pass for the transformer.
|
357 |
+
|
358 |
+
Parameters
|
359 |
+
----------
|
360 |
+
hidden_states : torch.Tensor
|
361 |
+
Input image tensor.
|
362 |
+
timestep : torch.LongTensor
|
363 |
+
Timestep tensor.
|
364 |
+
text_states : torch.Tensor
|
365 |
+
Text embeddings.
|
366 |
+
encoder_attention_mask : torch.Tensor
|
367 |
+
Attention mask for text.
|
368 |
+
output_features : bool, optional
|
369 |
+
Whether to output intermediate features.
|
370 |
+
output_features_stride : int, optional
|
371 |
+
Stride for outputting features.
|
372 |
+
freqs_cos, freqs_sin : torch.Tensor, optional
|
373 |
+
Precomputed rotary embeddings.
|
374 |
+
return_dict : bool, optional
|
375 |
+
Not supported.
|
376 |
+
guidance : torch.Tensor, optional
|
377 |
+
Guidance vector for distillation.
|
378 |
+
extra_kwargs : dict, optional
|
379 |
+
Extra arguments for ByT5.
|
380 |
+
timesteps_r : torch.LongTensor, optional
|
381 |
+
Additional timestep for MeanFlow.
|
382 |
+
|
383 |
+
Returns
|
384 |
+
-------
|
385 |
+
tuple
|
386 |
+
(img, features_list, shape)
|
387 |
+
"""
|
388 |
+
if guidance is None:
|
389 |
+
guidance = torch.tensor([6016.0], device=hidden_states.device, dtype=torch.bfloat16)
|
390 |
+
img = x = hidden_states
|
391 |
+
text_mask = encoder_attention_mask
|
392 |
+
t = timestep
|
393 |
+
txt = text_states
|
394 |
+
input_shape = x.shape
|
395 |
+
|
396 |
+
# Calculate spatial dimensions and get rotary embeddings
|
397 |
+
if len(input_shape) == 5:
|
398 |
+
_, _, ot, oh, ow = x.shape
|
399 |
+
tt, th, tw = (
|
400 |
+
ot // self.patch_size[0],
|
401 |
+
oh // self.patch_size[1],
|
402 |
+
ow // self.patch_size[2],
|
403 |
+
)
|
404 |
+
if freqs_cos is None or freqs_sin is None:
|
405 |
+
freqs_cos, freqs_sin = self.get_rotary_pos_embed((tt, th, tw))
|
406 |
+
elif len(input_shape) == 4:
|
407 |
+
_, _, oh, ow = x.shape
|
408 |
+
th, tw = (
|
409 |
+
oh // self.patch_size[0],
|
410 |
+
ow // self.patch_size[1],
|
411 |
+
)
|
412 |
+
if freqs_cos is None or freqs_sin is None:
|
413 |
+
assert freqs_cos is None and freqs_sin is None, "freqs_cos and freqs_sin must be both None or both not None"
|
414 |
+
freqs_cos, freqs_sin = self.get_rotary_pos_embed((th, tw))
|
415 |
+
else:
|
416 |
+
raise ValueError(f"Unsupported hidden_states shape: {x.shape}")
|
417 |
+
|
418 |
+
img = self.img_in(img)
|
419 |
+
|
420 |
+
# Prepare modulation vectors
|
421 |
+
vec = self.time_in(t)
|
422 |
+
|
423 |
+
# MeanFlow support: merge timestep and timestep_r if available
|
424 |
+
if self.use_meanflow:
|
425 |
+
assert self.time_r_in is not None, "use_meanflow is True but time_r_in is None"
|
426 |
+
if timesteps_r is not None:
|
427 |
+
assert self.time_r_in is not None, "timesteps_r is not None but time_r_in is None"
|
428 |
+
vec_r = self.time_r_in(timesteps_r)
|
429 |
+
vec = (vec + vec_r) / 2
|
430 |
+
|
431 |
+
# Guidance modulation
|
432 |
+
if self.guidance_embed:
|
433 |
+
if guidance is None:
|
434 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
435 |
+
vec = vec + self.guidance_in(guidance)
|
436 |
+
|
437 |
+
# Embed image and text
|
438 |
+
if self.text_projection == "linear":
|
439 |
+
txt = self.txt_in(txt)
|
440 |
+
elif self.text_projection == "single_refiner":
|
441 |
+
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
|
442 |
+
else:
|
443 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
444 |
+
|
445 |
+
if self.glyph_byT5_v2:
|
446 |
+
byt5_text_states = extra_kwargs["byt5_text_states"]
|
447 |
+
byt5_text_mask = extra_kwargs["byt5_text_mask"]
|
448 |
+
byt5_txt = self.byt5_in(byt5_text_states)
|
449 |
+
txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
|
450 |
+
|
451 |
+
txt_seq_len = txt.shape[1]
|
452 |
+
img_seq_len = img.shape[1]
|
453 |
+
|
454 |
+
# Calculate cu_seqlens and max_s for flash attention
|
455 |
+
cu_seqlens, max_s = get_cu_seqlens(text_mask, img_seq_len)
|
456 |
+
|
457 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
458 |
+
|
459 |
+
# Pass through double stream blocks
|
460 |
+
for block in self.double_blocks:
|
461 |
+
double_block_args = [img, txt, vec, freqs_cis, text_mask, cu_seqlens, max_s]
|
462 |
+
img, txt = block(*double_block_args)
|
463 |
+
|
464 |
+
# Merge txt and img to pass through single stream blocks
|
465 |
+
x = torch.cat((img, txt), 1)
|
466 |
+
features_list = [] if output_features else None
|
467 |
+
|
468 |
+
if len(self.single_blocks) > 0:
|
469 |
+
for index, block in enumerate(self.single_blocks):
|
470 |
+
single_block_args = [
|
471 |
+
x,
|
472 |
+
vec,
|
473 |
+
txt_seq_len,
|
474 |
+
(freqs_cos, freqs_sin),
|
475 |
+
text_mask,
|
476 |
+
cu_seqlens,
|
477 |
+
max_s,
|
478 |
+
]
|
479 |
+
x = block(*single_block_args)
|
480 |
+
if output_features and index % output_features_stride == 0:
|
481 |
+
features_list.append(x[:, :img_seq_len, ...])
|
482 |
+
|
483 |
+
img = x[:, :img_seq_len, ...]
|
484 |
+
|
485 |
+
# Final layer
|
486 |
+
img = self.final_layer(img, vec)
|
487 |
+
|
488 |
+
# Unpatchify based on input shape
|
489 |
+
if len(input_shape) == 5:
|
490 |
+
img = self.unpatchify(img, tt, th, tw)
|
491 |
+
shape = (tt, th, tw)
|
492 |
+
elif len(input_shape) == 4:
|
493 |
+
img = self.unpatchify_2d(img, th, tw)
|
494 |
+
shape = (th, tw)
|
495 |
+
else:
|
496 |
+
raise ValueError(f"Unsupported input_shape: {input_shape}")
|
497 |
+
|
498 |
+
assert not return_dict, "return_dict is not supported."
|
499 |
+
|
500 |
+
if output_features:
|
501 |
+
features_list = torch.stack(features_list, dim=0)
|
502 |
+
else:
|
503 |
+
features_list = None
|
504 |
+
|
505 |
+
return (img, features_list, shape)
|
506 |
+
|
507 |
+
def unpatchify(self, x, t, h, w):
|
508 |
+
"""
|
509 |
+
Unpatchify 3D tensor.
|
510 |
+
|
511 |
+
Parameters
|
512 |
+
----------
|
513 |
+
x: torch.Tensor
|
514 |
+
Input tensor of shape (N, T, patch_size**2 * C)
|
515 |
+
t, h, w: int
|
516 |
+
Temporal and spatial dimensions
|
517 |
+
|
518 |
+
Returns
|
519 |
+
-------
|
520 |
+
torch.Tensor
|
521 |
+
Unpatchified tensor of shape (N, C, T*pt, H*ph, W*pw)
|
522 |
+
"""
|
523 |
+
c = self.unpatchify_channels
|
524 |
+
pt, ph, pw = self.patch_size
|
525 |
+
assert t * h * w == x.shape[1]
|
526 |
+
|
527 |
+
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
|
528 |
+
x = torch.einsum("nthwcopq->nctohpwq", x)
|
529 |
+
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
530 |
+
|
531 |
+
return imgs
|
532 |
+
|
533 |
+
def unpatchify_2d(self, x, h, w):
|
534 |
+
"""
|
535 |
+
Unpatchify 2D tensor.
|
536 |
+
|
537 |
+
Parameters
|
538 |
+
----------
|
539 |
+
x: torch.Tensor
|
540 |
+
Input tensor of shape (N, T, patch_size**2 * C)
|
541 |
+
h, w: int
|
542 |
+
Spatial dimensions
|
543 |
+
|
544 |
+
Returns
|
545 |
+
-------
|
546 |
+
torch.Tensor
|
547 |
+
Unpatchified tensor of shape (N, C, H*ph, W*pw)
|
548 |
+
"""
|
549 |
+
c = self.unpatchify_channels
|
550 |
+
ph, pw = self.patch_size
|
551 |
+
assert h * w == x.shape[1]
|
552 |
+
|
553 |
+
x = x.reshape(shape=(x.shape[0], h, w, c, ph, pw))
|
554 |
+
x = torch.einsum('nhwcpq->nchpwq', x)
|
555 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * ph, w * pw))
|
556 |
+
return imgs
|
hyimage/models/hunyuan/modules/mlp_layers.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from timm library:
|
2 |
+
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from ..utils.helpers import to_2tuple
|
10 |
+
from .modulate_layers import modulate
|
11 |
+
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
in_channels,
|
19 |
+
hidden_channels=None,
|
20 |
+
out_features=None,
|
21 |
+
act_layer=nn.GELU,
|
22 |
+
norm_layer=None,
|
23 |
+
bias=True,
|
24 |
+
drop=0.0,
|
25 |
+
use_conv=False,
|
26 |
+
device=None,
|
27 |
+
dtype=None,
|
28 |
+
):
|
29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
30 |
+
super().__init__()
|
31 |
+
out_features = out_features or in_channels
|
32 |
+
hidden_channels = hidden_channels or in_channels
|
33 |
+
bias = to_2tuple(bias)
|
34 |
+
drop_probs = to_2tuple(drop)
|
35 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
36 |
+
|
37 |
+
self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
|
38 |
+
self.act = act_layer()
|
39 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
40 |
+
self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
|
41 |
+
self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
|
42 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.fc1(x)
|
46 |
+
x = self.act(x)
|
47 |
+
x = self.drop1(x)
|
48 |
+
x = self.norm(x)
|
49 |
+
x = self.fc2(x)
|
50 |
+
x = self.drop2(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class LinearWarpforSingle(nn.Module):
|
55 |
+
def __init__(self, in_dim: int, out_dim: int, bias=False, device=None, dtype=None):
|
56 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
57 |
+
super().__init__()
|
58 |
+
self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs)
|
59 |
+
|
60 |
+
def forward(self, x, y):
|
61 |
+
input = torch.cat([x.contiguous(), y.contiguous()], dim=2).contiguous()
|
62 |
+
return self.fc(input)
|
63 |
+
|
64 |
+
|
65 |
+
#
|
66 |
+
class MLPEmbedder(nn.Module):
|
67 |
+
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
|
68 |
+
|
69 |
+
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
70 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
71 |
+
super().__init__()
|
72 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
73 |
+
self.silu = nn.SiLU()
|
74 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
78 |
+
|
79 |
+
|
80 |
+
class FinalLayer(nn.Module):
|
81 |
+
"""The final layer of DiT."""
|
82 |
+
|
83 |
+
def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
|
84 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
85 |
+
super().__init__()
|
86 |
+
|
87 |
+
# Just use LayerNorm for the final layer
|
88 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
89 |
+
if isinstance(patch_size, int):
|
90 |
+
self.linear = nn.Linear(
|
91 |
+
hidden_size,
|
92 |
+
patch_size * patch_size * out_channels,
|
93 |
+
bias=True,
|
94 |
+
**factory_kwargs,
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
out_size = (
|
98 |
+
patch_size[0] * patch_size[1] * patch_size[2] if len(patch_size) == 3 else patch_size[0] * patch_size[1]
|
99 |
+
) * out_channels
|
100 |
+
self.linear = nn.Linear(
|
101 |
+
hidden_size,
|
102 |
+
out_size,
|
103 |
+
bias=True,
|
104 |
+
)
|
105 |
+
nn.init.zeros_(self.linear.weight)
|
106 |
+
nn.init.zeros_(self.linear.bias)
|
107 |
+
|
108 |
+
# Here we don't distinguish between the modulate types. Just use the simple one.
|
109 |
+
self.adaLN_modulation = nn.Sequential(
|
110 |
+
act_layer(),
|
111 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
112 |
+
)
|
113 |
+
# Zero-initialize the modulation
|
114 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
115 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
116 |
+
|
117 |
+
def forward(self, x, c):
|
118 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
119 |
+
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
120 |
+
x = self.linear(x)
|
121 |
+
return x
|
hyimage/models/hunyuan/modules/models.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from hyimage.models.hunyuan.modules.flash_attn_no_pad import flash_attn_no_pad
|
9 |
+
|
10 |
+
from .activation_layers import get_activation_layer
|
11 |
+
from .mlp_layers import MLP, LinearWarpforSingle
|
12 |
+
from .modulate_layers import ModulateDiT, apply_gate, modulate
|
13 |
+
from .norm_layers import get_norm_layer
|
14 |
+
from .posemb_layers import apply_rotary_emb
|
15 |
+
|
16 |
+
|
17 |
+
@torch.compiler.disable
|
18 |
+
def attention(
|
19 |
+
q,
|
20 |
+
k,
|
21 |
+
v,
|
22 |
+
attn_mode="flash",
|
23 |
+
text_mask=None,
|
24 |
+
):
|
25 |
+
"""Multi-modal attention function that processes image and text sequences."""
|
26 |
+
query, encoder_query = q
|
27 |
+
key, encoder_key = k
|
28 |
+
value, encoder_value = v
|
29 |
+
|
30 |
+
assert attn_mode == "flash" # Only flash attention is implemented for now
|
31 |
+
sequence_length = query.size(1)
|
32 |
+
encoder_sequence_length = encoder_query.size(1)
|
33 |
+
|
34 |
+
query = torch.cat([query, encoder_query], dim=1)
|
35 |
+
key = torch.cat([key, encoder_key], dim=1)
|
36 |
+
value = torch.cat([value, encoder_value], dim=1)
|
37 |
+
|
38 |
+
# Stack query, key, value: B, S, 3, H, D
|
39 |
+
qkv = torch.stack([query, key, value], dim=2)
|
40 |
+
|
41 |
+
attn_mask = torch.nn.functional.pad(text_mask, (sequence_length, 0), value=True)
|
42 |
+
hidden_states = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None)
|
43 |
+
|
44 |
+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
45 |
+
(sequence_length, encoder_sequence_length), dim=1
|
46 |
+
)
|
47 |
+
|
48 |
+
hidden_states = hidden_states.to(query.dtype)
|
49 |
+
encoder_hidden_states = encoder_hidden_states.to(query.dtype)
|
50 |
+
|
51 |
+
attn = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
52 |
+
|
53 |
+
b, s, a, d = attn.shape
|
54 |
+
attn = attn.reshape(b, s, -1)
|
55 |
+
|
56 |
+
return attn
|
57 |
+
|
58 |
+
|
59 |
+
class MMDoubleStreamBlock(nn.Module):
|
60 |
+
"""
|
61 |
+
A multimodal DiT block with separate modulation for text and image/video.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
hidden_size: int,
|
67 |
+
heads_num: int,
|
68 |
+
mlp_width_ratio: float,
|
69 |
+
mlp_act_type: str = "gelu_tanh",
|
70 |
+
qk_norm: bool = True,
|
71 |
+
qk_norm_type: str = "rms",
|
72 |
+
qkv_bias: bool = False,
|
73 |
+
dtype: Optional[torch.dtype] = None,
|
74 |
+
device: Optional[torch.device] = None,
|
75 |
+
):
|
76 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
77 |
+
super().__init__()
|
78 |
+
|
79 |
+
self.deterministic = False
|
80 |
+
self.heads_num = heads_num
|
81 |
+
head_dim = hidden_size // heads_num
|
82 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
83 |
+
|
84 |
+
# Image stream components
|
85 |
+
self.img_mod = ModulateDiT(
|
86 |
+
hidden_size,
|
87 |
+
factor=6,
|
88 |
+
act_layer=get_activation_layer("silu"),
|
89 |
+
**factory_kwargs,
|
90 |
+
)
|
91 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
92 |
+
|
93 |
+
self.img_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
94 |
+
self.img_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
95 |
+
self.img_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
96 |
+
|
97 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
98 |
+
self.img_attn_q_norm = (
|
99 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
100 |
+
)
|
101 |
+
self.img_attn_k_norm = (
|
102 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
103 |
+
)
|
104 |
+
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
105 |
+
|
106 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
107 |
+
self.img_mlp = MLP(
|
108 |
+
hidden_size,
|
109 |
+
mlp_hidden_dim,
|
110 |
+
act_layer=get_activation_layer(mlp_act_type),
|
111 |
+
bias=True,
|
112 |
+
**factory_kwargs,
|
113 |
+
)
|
114 |
+
|
115 |
+
# Text stream components
|
116 |
+
self.txt_mod = ModulateDiT(
|
117 |
+
hidden_size,
|
118 |
+
factor=6,
|
119 |
+
act_layer=get_activation_layer("silu"),
|
120 |
+
**factory_kwargs,
|
121 |
+
)
|
122 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
123 |
+
|
124 |
+
self.txt_attn_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
125 |
+
self.txt_attn_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
126 |
+
self.txt_attn_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
127 |
+
self.txt_attn_q_norm = (
|
128 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
129 |
+
)
|
130 |
+
self.txt_attn_k_norm = (
|
131 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
132 |
+
)
|
133 |
+
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
134 |
+
|
135 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
136 |
+
self.txt_mlp = MLP(
|
137 |
+
hidden_size,
|
138 |
+
mlp_hidden_dim,
|
139 |
+
act_layer=get_activation_layer(mlp_act_type),
|
140 |
+
bias=True,
|
141 |
+
**factory_kwargs,
|
142 |
+
)
|
143 |
+
self.core_attn = attention
|
144 |
+
|
145 |
+
def enable_deterministic(self):
|
146 |
+
self.deterministic = True
|
147 |
+
|
148 |
+
def disable_deterministic(self):
|
149 |
+
self.deterministic = False
|
150 |
+
|
151 |
+
def forward(
|
152 |
+
self,
|
153 |
+
img: torch.Tensor,
|
154 |
+
txt: torch.Tensor,
|
155 |
+
vec: torch.Tensor,
|
156 |
+
freqs_cis: tuple = None,
|
157 |
+
text_mask: torch.Tensor = None,
|
158 |
+
cu_seqlens=None,
|
159 |
+
max_s=None,
|
160 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
161 |
+
# Extract modulation parameters for image and text streams
|
162 |
+
(
|
163 |
+
img_mod1_shift,
|
164 |
+
img_mod1_scale,
|
165 |
+
img_mod1_gate,
|
166 |
+
img_mod2_shift,
|
167 |
+
img_mod2_scale,
|
168 |
+
img_mod2_gate,
|
169 |
+
) = self.img_mod(vec).chunk(6, dim=-1)
|
170 |
+
(
|
171 |
+
txt_mod1_shift,
|
172 |
+
txt_mod1_scale,
|
173 |
+
txt_mod1_gate,
|
174 |
+
txt_mod2_shift,
|
175 |
+
txt_mod2_scale,
|
176 |
+
txt_mod2_gate,
|
177 |
+
) = self.txt_mod(vec).chunk(6, dim=-1)
|
178 |
+
|
179 |
+
# Process image stream for attention
|
180 |
+
img_modulated = self.img_norm1(img)
|
181 |
+
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
182 |
+
|
183 |
+
img_q = self.img_attn_q(img_modulated)
|
184 |
+
img_k = self.img_attn_k(img_modulated)
|
185 |
+
img_v = self.img_attn_v(img_modulated)
|
186 |
+
|
187 |
+
img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num)
|
188 |
+
img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num)
|
189 |
+
img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num)
|
190 |
+
|
191 |
+
# Apply QK-Norm if enabled
|
192 |
+
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
193 |
+
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
194 |
+
|
195 |
+
# Apply RoPE if provided
|
196 |
+
if freqs_cis is not None:
|
197 |
+
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
198 |
+
assert (
|
199 |
+
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
|
200 |
+
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
|
201 |
+
img_q, img_k = img_qq, img_kk
|
202 |
+
|
203 |
+
# Process text stream for attention
|
204 |
+
txt_modulated = self.txt_norm1(txt)
|
205 |
+
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
206 |
+
|
207 |
+
txt_q = self.txt_attn_q(txt_modulated)
|
208 |
+
txt_k = self.txt_attn_k(txt_modulated)
|
209 |
+
txt_v = self.txt_attn_v(txt_modulated)
|
210 |
+
|
211 |
+
txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num)
|
212 |
+
txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num)
|
213 |
+
txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num)
|
214 |
+
|
215 |
+
# Apply QK-Norm if enabled
|
216 |
+
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
217 |
+
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
218 |
+
|
219 |
+
# Compute cross-modal attention
|
220 |
+
attn = self.core_attn(
|
221 |
+
(img_q, txt_q),
|
222 |
+
(img_k, txt_k),
|
223 |
+
(img_v, txt_v),
|
224 |
+
text_mask=text_mask,
|
225 |
+
)
|
226 |
+
|
227 |
+
# Split attention outputs for image and text streams
|
228 |
+
img_attn, txt_attn = (
|
229 |
+
attn[:, : img_q.shape[1]].contiguous(),
|
230 |
+
attn[:, img_q.shape[1] :].contiguous(),
|
231 |
+
)
|
232 |
+
|
233 |
+
# Apply attention projection and residual connection for image stream
|
234 |
+
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
235 |
+
|
236 |
+
# Apply MLP and residual connection for image stream
|
237 |
+
img = img + apply_gate(
|
238 |
+
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
|
239 |
+
gate=img_mod2_gate,
|
240 |
+
)
|
241 |
+
|
242 |
+
# Apply attention projection and residual connection for text stream
|
243 |
+
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
244 |
+
|
245 |
+
# Apply MLP and residual connection for text stream
|
246 |
+
txt = txt + apply_gate(
|
247 |
+
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
|
248 |
+
gate=txt_mod2_gate,
|
249 |
+
)
|
250 |
+
|
251 |
+
return img, txt
|
252 |
+
|
253 |
+
|
254 |
+
class MMSingleStreamBlock(nn.Module):
|
255 |
+
"""
|
256 |
+
A DiT block with parallel linear layers for multimodal processing.
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
hidden_size: int,
|
262 |
+
heads_num: int,
|
263 |
+
mlp_width_ratio: float = 4.0,
|
264 |
+
mlp_act_type: str = "gelu_tanh",
|
265 |
+
qk_norm: bool = True,
|
266 |
+
qk_norm_type: str = "rms",
|
267 |
+
qk_scale: float = None,
|
268 |
+
dtype: Optional[torch.dtype] = None,
|
269 |
+
device: Optional[torch.device] = None,
|
270 |
+
):
|
271 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
272 |
+
super().__init__()
|
273 |
+
|
274 |
+
self.deterministic = False
|
275 |
+
self.hidden_size = hidden_size
|
276 |
+
self.heads_num = heads_num
|
277 |
+
head_dim = hidden_size // heads_num
|
278 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
279 |
+
self.mlp_hidden_dim = mlp_hidden_dim
|
280 |
+
self.scale = qk_scale or head_dim**-0.5
|
281 |
+
|
282 |
+
# Separate linear layers for Q, K, V, and MLP input
|
283 |
+
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
284 |
+
self.linear1_k = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
285 |
+
self.linear1_v = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
286 |
+
self.linear1_mlp = nn.Linear(hidden_size, mlp_hidden_dim, **factory_kwargs)
|
287 |
+
|
288 |
+
# Output projection layer
|
289 |
+
self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True, **factory_kwargs)
|
290 |
+
|
291 |
+
# QK normalization layers
|
292 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
293 |
+
self.q_norm = (
|
294 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
295 |
+
)
|
296 |
+
self.k_norm = (
|
297 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
298 |
+
)
|
299 |
+
|
300 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
301 |
+
|
302 |
+
self.mlp_act = get_activation_layer(mlp_act_type)()
|
303 |
+
self.modulation = ModulateDiT(
|
304 |
+
hidden_size,
|
305 |
+
factor=3,
|
306 |
+
act_layer=get_activation_layer("silu"),
|
307 |
+
**factory_kwargs,
|
308 |
+
)
|
309 |
+
self.core_attn = attention
|
310 |
+
|
311 |
+
def enable_deterministic(self):
|
312 |
+
self.deterministic = True
|
313 |
+
|
314 |
+
def disable_deterministic(self):
|
315 |
+
self.deterministic = False
|
316 |
+
|
317 |
+
def forward(
|
318 |
+
self,
|
319 |
+
x: torch.Tensor,
|
320 |
+
vec: torch.Tensor,
|
321 |
+
txt_len: int,
|
322 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
323 |
+
text_mask: torch.Tensor = None,
|
324 |
+
cu_seqlens=None,
|
325 |
+
max_s=None,
|
326 |
+
) -> torch.Tensor:
|
327 |
+
# Extract modulation parameters
|
328 |
+
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
329 |
+
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
330 |
+
|
331 |
+
# Compute Q, K, V, and MLP input
|
332 |
+
q = self.linear1_q(x_mod)
|
333 |
+
k = self.linear1_k(x_mod)
|
334 |
+
v = self.linear1_v(x_mod)
|
335 |
+
|
336 |
+
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
337 |
+
k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num)
|
338 |
+
v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num)
|
339 |
+
mlp = self.linear1_mlp(x_mod)
|
340 |
+
|
341 |
+
# Apply QK-Norm if enabled
|
342 |
+
q = self.q_norm(q).to(v)
|
343 |
+
k = self.k_norm(k).to(v)
|
344 |
+
|
345 |
+
# Split into image and text sequences
|
346 |
+
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
347 |
+
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
348 |
+
img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
|
349 |
+
|
350 |
+
# Apply RoPE to image sequence
|
351 |
+
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
352 |
+
assert (
|
353 |
+
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
|
354 |
+
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
|
355 |
+
img_q, img_k = img_qq, img_kk
|
356 |
+
|
357 |
+
# Compute cross-modal attention
|
358 |
+
attn = self.core_attn(
|
359 |
+
(img_q, txt_q),
|
360 |
+
(img_k, txt_k),
|
361 |
+
(img_v, txt_v),
|
362 |
+
text_mask=text_mask,
|
363 |
+
)
|
364 |
+
|
365 |
+
# Combine attention output with MLP activation and apply final projection
|
366 |
+
output = self.linear2(attn, self.mlp_act(mlp))
|
367 |
+
return x + apply_gate(output, gate=mod_gate)
|
hyimage/models/hunyuan/modules/modulate_layers.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class ModulateDiT(nn.Module):
|
8 |
+
"""Modulation layer for DiT."""
|
9 |
+
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
hidden_size: int,
|
13 |
+
factor: int,
|
14 |
+
act_layer: Callable,
|
15 |
+
dtype=None,
|
16 |
+
device=None,
|
17 |
+
):
|
18 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
19 |
+
super().__init__()
|
20 |
+
self.act = act_layer()
|
21 |
+
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
|
22 |
+
# Zero-initialize the modulation
|
23 |
+
nn.init.zeros_(self.linear.weight)
|
24 |
+
nn.init.zeros_(self.linear.bias)
|
25 |
+
|
26 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
27 |
+
return self.linear(self.act(x))
|
28 |
+
|
29 |
+
|
30 |
+
def modulate(x, shift=None, scale=None):
|
31 |
+
"""modulate by shift and scale
|
32 |
+
|
33 |
+
Args:
|
34 |
+
x (torch.Tensor): input tensor.
|
35 |
+
shift (torch.Tensor, optional): shift tensor. Defaults to None.
|
36 |
+
scale (torch.Tensor, optional): scale tensor. Defaults to None.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
torch.Tensor: the output tensor after modulate.
|
40 |
+
"""
|
41 |
+
if scale is None and shift is None:
|
42 |
+
return x
|
43 |
+
elif shift is None:
|
44 |
+
return x * (1 + scale.unsqueeze(1))
|
45 |
+
elif scale is None:
|
46 |
+
return x + shift.unsqueeze(1)
|
47 |
+
else:
|
48 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
49 |
+
|
50 |
+
|
51 |
+
def apply_gate(x, gate=None, tanh=False):
|
52 |
+
"""AI is creating summary for apply_gate
|
53 |
+
|
54 |
+
Args:
|
55 |
+
x (torch.Tensor): input tensor.
|
56 |
+
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
57 |
+
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
torch.Tensor: the output tensor after apply gate.
|
61 |
+
"""
|
62 |
+
if gate is None:
|
63 |
+
return x
|
64 |
+
if tanh:
|
65 |
+
return x * gate.unsqueeze(1).tanh()
|
66 |
+
else:
|
67 |
+
return x * gate.unsqueeze(1)
|
68 |
+
|
69 |
+
|
70 |
+
def ckpt_wrapper(module):
|
71 |
+
def ckpt_forward(*inputs):
|
72 |
+
outputs = module(*inputs)
|
73 |
+
return outputs
|
74 |
+
|
75 |
+
return ckpt_forward
|
76 |
+
|
77 |
+
|
78 |
+
import torch
|
79 |
+
import torch.nn as nn
|
80 |
+
|
81 |
+
|
82 |
+
class RMSNorm(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
dim: int,
|
86 |
+
elementwise_affine=True,
|
87 |
+
eps: float = 1e-6,
|
88 |
+
device=None,
|
89 |
+
dtype=None,
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Initialize the RMSNorm normalization layer.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
dim (int): The dimension of the input tensor.
|
96 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
97 |
+
|
98 |
+
Attributes:
|
99 |
+
eps (float): A small value added to the denominator for numerical stability.
|
100 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
101 |
+
|
102 |
+
"""
|
103 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
104 |
+
super().__init__()
|
105 |
+
self.eps = eps
|
106 |
+
if elementwise_affine:
|
107 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
108 |
+
|
109 |
+
def _norm(self, x):
|
110 |
+
"""
|
111 |
+
Apply the RMSNorm normalization to the input tensor.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
x (torch.Tensor): The input tensor.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
torch.Tensor: The normalized tensor.
|
118 |
+
|
119 |
+
"""
|
120 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
"""
|
124 |
+
Forward pass through the RMSNorm layer.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
x (torch.Tensor): The input tensor.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
131 |
+
|
132 |
+
"""
|
133 |
+
output = self._norm(x.float()).type_as(x)
|
134 |
+
if hasattr(self, "weight"):
|
135 |
+
output = output * self.weight
|
136 |
+
return output
|
137 |
+
|
138 |
+
|
139 |
+
def get_norm_layer(norm_layer):
|
140 |
+
"""
|
141 |
+
Get the normalization layer.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
norm_layer (str): The type of normalization layer.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
norm_layer (nn.Module): The normalization layer.
|
148 |
+
"""
|
149 |
+
if norm_layer == "layer":
|
150 |
+
return nn.LayerNorm
|
151 |
+
elif norm_layer == "rms":
|
152 |
+
return RMSNorm
|
153 |
+
else:
|
154 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
hyimage/models/hunyuan/modules/norm_layers.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class RMSNorm(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
dim: int,
|
9 |
+
elementwise_affine=True,
|
10 |
+
eps: float = 1e-6,
|
11 |
+
device=None,
|
12 |
+
dtype=None,
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
Initialize the RMSNorm normalization layer.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
dim (int): The dimension of the input tensor.
|
19 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
20 |
+
|
21 |
+
Attributes:
|
22 |
+
eps (float): A small value added to the denominator for numerical stability.
|
23 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
24 |
+
|
25 |
+
"""
|
26 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
27 |
+
super().__init__()
|
28 |
+
self.eps = eps
|
29 |
+
if elementwise_affine:
|
30 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
31 |
+
|
32 |
+
def _norm(self, x):
|
33 |
+
"""
|
34 |
+
Apply the RMSNorm normalization to the input tensor.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (torch.Tensor): The input tensor.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: The normalized tensor.
|
41 |
+
|
42 |
+
"""
|
43 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
44 |
+
|
45 |
+
def reset_parameters(self):
|
46 |
+
if hasattr(self, "weight"):
|
47 |
+
self.weight.fill_(1)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""
|
51 |
+
Forward pass through the RMSNorm layer.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
x (torch.Tensor): The input tensor.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
58 |
+
|
59 |
+
"""
|
60 |
+
output = self._norm(x.float()).type_as(x)
|
61 |
+
if hasattr(self, "weight"):
|
62 |
+
output = output * self.weight
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
def get_norm_layer(norm_layer):
|
67 |
+
"""
|
68 |
+
Get the normalization layer.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
norm_layer (str): The type of normalization layer.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
norm_layer (nn.Module): The normalization layer.
|
75 |
+
"""
|
76 |
+
if norm_layer == "layer":
|
77 |
+
return nn.LayerNorm
|
78 |
+
elif norm_layer == "rms":
|
79 |
+
return RMSNorm
|
80 |
+
else:
|
81 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
hyimage/models/hunyuan/modules/posemb_layers.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def _to_tuple(x, dim=2):
|
7 |
+
if isinstance(x, int):
|
8 |
+
return (x,) * dim
|
9 |
+
elif len(x) == dim:
|
10 |
+
return x
|
11 |
+
else:
|
12 |
+
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
13 |
+
|
14 |
+
|
15 |
+
def get_meshgrid_nd(start, *args, dim=2):
|
16 |
+
"""
|
17 |
+
Get n-D meshgrid with start, stop and num.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
21 |
+
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
22 |
+
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
23 |
+
n-tuples.
|
24 |
+
*args: See above.
|
25 |
+
dim (int): Dimension of the meshgrid. Defaults to 2.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
grid (np.ndarray): [dim, ...]
|
29 |
+
"""
|
30 |
+
if len(args) == 0:
|
31 |
+
# start is grid_size
|
32 |
+
num = _to_tuple(start, dim=dim)
|
33 |
+
start = (0,) * dim
|
34 |
+
stop = num
|
35 |
+
elif len(args) == 1:
|
36 |
+
# start is start, args[0] is stop, step is 1
|
37 |
+
start = _to_tuple(start, dim=dim)
|
38 |
+
stop = _to_tuple(args[0], dim=dim)
|
39 |
+
num = [stop[i] - start[i] for i in range(dim)]
|
40 |
+
elif len(args) == 2:
|
41 |
+
# start is start, args[0] is stop, args[1] is num
|
42 |
+
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
43 |
+
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
44 |
+
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
45 |
+
else:
|
46 |
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
47 |
+
|
48 |
+
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
49 |
+
axis_grid = []
|
50 |
+
for i in range(dim):
|
51 |
+
a, b, n = start[i], stop[i], num[i]
|
52 |
+
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
53 |
+
axis_grid.append(g)
|
54 |
+
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
55 |
+
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
56 |
+
|
57 |
+
return grid
|
58 |
+
|
59 |
+
|
60 |
+
#################################################################################
|
61 |
+
# Rotary Positional Embedding Functions #
|
62 |
+
#################################################################################
|
63 |
+
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
|
64 |
+
|
65 |
+
|
66 |
+
def reshape_for_broadcast(
|
67 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
68 |
+
x: torch.Tensor,
|
69 |
+
head_first=False,
|
70 |
+
):
|
71 |
+
"""
|
72 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
73 |
+
|
74 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
75 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
76 |
+
|
77 |
+
Notes:
|
78 |
+
When using FlashMHAModified, head_first should be False.
|
79 |
+
When using Attention, head_first should be True.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
83 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
84 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
torch.Tensor: Reshaped frequency tensor.
|
88 |
+
|
89 |
+
Raises:
|
90 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
91 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
92 |
+
"""
|
93 |
+
ndim = x.ndim
|
94 |
+
assert 0 <= 1 < ndim
|
95 |
+
|
96 |
+
if isinstance(freqs_cis, tuple):
|
97 |
+
# freqs_cis: (cos, sin) in real space
|
98 |
+
if head_first:
|
99 |
+
assert freqs_cis[0].shape == (
|
100 |
+
x.shape[-2],
|
101 |
+
x.shape[-1],
|
102 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
103 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
104 |
+
else:
|
105 |
+
assert freqs_cis[0].shape == (
|
106 |
+
x.shape[1],
|
107 |
+
x.shape[-1],
|
108 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
109 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
110 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
111 |
+
else:
|
112 |
+
# freqs_cis: values in complex space
|
113 |
+
if head_first:
|
114 |
+
assert freqs_cis.shape == (
|
115 |
+
x.shape[-2],
|
116 |
+
x.shape[-1],
|
117 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
118 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
119 |
+
else:
|
120 |
+
assert freqs_cis.shape == (
|
121 |
+
x.shape[1],
|
122 |
+
x.shape[-1],
|
123 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
124 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
125 |
+
return freqs_cis.view(*shape)
|
126 |
+
|
127 |
+
|
128 |
+
def rotate_half(x):
|
129 |
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
130 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
131 |
+
|
132 |
+
|
133 |
+
def apply_rotary_emb(
|
134 |
+
xq: torch.Tensor,
|
135 |
+
xk: torch.Tensor,
|
136 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
137 |
+
head_first: bool = False,
|
138 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
139 |
+
"""
|
140 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
141 |
+
|
142 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
143 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
144 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
145 |
+
returned as real tensors.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
149 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
150 |
+
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
151 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
155 |
+
|
156 |
+
"""
|
157 |
+
xk_out = None
|
158 |
+
if isinstance(freqs_cis, tuple):
|
159 |
+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
160 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
161 |
+
# real * cos - imag * sin
|
162 |
+
# imag * cos + real * sin
|
163 |
+
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
164 |
+
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
165 |
+
else:
|
166 |
+
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
167 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
168 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
169 |
+
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
170 |
+
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
171 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
172 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
173 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
174 |
+
|
175 |
+
return xq_out, xk_out
|
176 |
+
|
177 |
+
|
178 |
+
def get_nd_rotary_pos_embed(
|
179 |
+
rope_dim_list,
|
180 |
+
start,
|
181 |
+
*args,
|
182 |
+
theta=10000.0,
|
183 |
+
use_real=False,
|
184 |
+
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
185 |
+
interpolation_factor: Union[float, List[float]] = 1.0,
|
186 |
+
):
|
187 |
+
"""
|
188 |
+
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
192 |
+
sum(rope_dim_list) should equal to head_dim of attention layer.
|
193 |
+
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
194 |
+
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
195 |
+
*args: See above.
|
196 |
+
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
197 |
+
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
198 |
+
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
199 |
+
part and an imaginary part separately.
|
200 |
+
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
pos_embed (torch.Tensor): [HW, D/2]
|
204 |
+
"""
|
205 |
+
|
206 |
+
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
|
207 |
+
|
208 |
+
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
209 |
+
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
210 |
+
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
211 |
+
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
212 |
+
assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
213 |
+
|
214 |
+
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
215 |
+
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
216 |
+
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
217 |
+
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
218 |
+
assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
219 |
+
|
220 |
+
# use 1/ndim of dimensions to encode grid_axis
|
221 |
+
embs = []
|
222 |
+
for i in range(len(rope_dim_list)):
|
223 |
+
emb = get_1d_rotary_pos_embed(
|
224 |
+
rope_dim_list[i],
|
225 |
+
grid[i].reshape(-1),
|
226 |
+
theta,
|
227 |
+
use_real=use_real,
|
228 |
+
theta_rescale_factor=theta_rescale_factor[i],
|
229 |
+
interpolation_factor=interpolation_factor[i],
|
230 |
+
) # 2 x [WHD, rope_dim_list[i]]
|
231 |
+
embs.append(emb)
|
232 |
+
|
233 |
+
if use_real:
|
234 |
+
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
235 |
+
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
236 |
+
return cos, sin
|
237 |
+
else:
|
238 |
+
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
239 |
+
return emb
|
240 |
+
|
241 |
+
|
242 |
+
def get_1d_rotary_pos_embed(
|
243 |
+
dim: int,
|
244 |
+
pos: Union[torch.FloatTensor, int],
|
245 |
+
theta: float = 10000.0,
|
246 |
+
use_real: bool = False,
|
247 |
+
theta_rescale_factor: float = 1.0,
|
248 |
+
interpolation_factor: float = 1.0,
|
249 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
250 |
+
"""
|
251 |
+
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
252 |
+
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
253 |
+
|
254 |
+
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
255 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
256 |
+
The returned tensor contains complex values in complex64 data type.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
dim (int): Dimension of the frequency tensor.
|
260 |
+
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
261 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
262 |
+
use_real (bool, optional): If True, return real part and imaginary part separately.
|
263 |
+
Otherwise, return complex numbers.
|
264 |
+
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
268 |
+
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
269 |
+
"""
|
270 |
+
if isinstance(pos, int):
|
271 |
+
pos = torch.arange(pos).float()
|
272 |
+
|
273 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
274 |
+
# has some connection to NTK literature
|
275 |
+
if theta_rescale_factor != 1.0:
|
276 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
277 |
+
|
278 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
279 |
+
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
280 |
+
if use_real:
|
281 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
282 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
283 |
+
return freqs_cos, freqs_sin
|
284 |
+
else:
|
285 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
286 |
+
return freqs_cis
|
hyimage/models/hunyuan/modules/token_refiner.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from hyimage.models.hunyuan.modules.flash_attn_no_pad import flash_attn_no_pad
|
8 |
+
from .activation_layers import get_activation_layer
|
9 |
+
from .embed_layers import TextProjection, TimestepEmbedder
|
10 |
+
from .mlp_layers import MLP
|
11 |
+
from .modulate_layers import apply_gate
|
12 |
+
from .norm_layers import get_norm_layer
|
13 |
+
|
14 |
+
|
15 |
+
@torch.compiler.disable
|
16 |
+
def attention(
|
17 |
+
q: torch.Tensor,
|
18 |
+
k: torch.Tensor,
|
19 |
+
v: torch.Tensor,
|
20 |
+
drop_rate: float = 0.0,
|
21 |
+
attn_mask: Optional[torch.Tensor] = None,
|
22 |
+
causal: bool = False,
|
23 |
+
) -> torch.Tensor:
|
24 |
+
"""
|
25 |
+
Compute attention using flash_attn_no_pad.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
q: Query tensor of shape [B, L, H, D]
|
29 |
+
k: Key tensor of shape [B, L, H, D]
|
30 |
+
v: Value tensor of shape [B, L, H, D]
|
31 |
+
drop_rate: Dropout rate for attention weights.
|
32 |
+
attn_mask: Optional attention mask of shape [B, L].
|
33 |
+
causal: Whether to apply causal masking.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
Output tensor after attention of shape [B, L, H*D]
|
37 |
+
"""
|
38 |
+
qkv = torch.stack([q, k, v], dim=2)
|
39 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
40 |
+
attn_mask = attn_mask.bool()
|
41 |
+
x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
|
42 |
+
b, s, a, d = x.shape
|
43 |
+
out = x.reshape(b, s, -1)
|
44 |
+
return out
|
45 |
+
|
46 |
+
|
47 |
+
class IndividualTokenRefinerBlock(nn.Module):
|
48 |
+
"""
|
49 |
+
A single block for token refinement with self-attention and MLP.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
hidden_size: Hidden dimension size.
|
53 |
+
heads_num: Number of attention heads.
|
54 |
+
mlp_width_ratio: Expansion ratio for MLP hidden size.
|
55 |
+
mlp_drop_rate: Dropout rate for MLP.
|
56 |
+
act_type: Activation function type.
|
57 |
+
qk_norm: Whether to use QK normalization.
|
58 |
+
qk_norm_type: Type of QK normalization.
|
59 |
+
qkv_bias: Whether to use bias in QKV projections.
|
60 |
+
dtype: Optional torch dtype.
|
61 |
+
device: Optional torch device.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
hidden_size: int,
|
67 |
+
heads_num: int,
|
68 |
+
mlp_width_ratio: float = 4.0,
|
69 |
+
mlp_drop_rate: float = 0.0,
|
70 |
+
act_type: str = "silu",
|
71 |
+
qk_norm: bool = False,
|
72 |
+
qk_norm_type: str = "layer",
|
73 |
+
qkv_bias: bool = True,
|
74 |
+
dtype: Optional[torch.dtype] = None,
|
75 |
+
device: Optional[torch.device] = None,
|
76 |
+
):
|
77 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
78 |
+
super().__init__()
|
79 |
+
self.heads_num = heads_num
|
80 |
+
head_dim = hidden_size // heads_num
|
81 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
82 |
+
|
83 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
84 |
+
self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
85 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
86 |
+
self.self_attn_q_norm = (
|
87 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
88 |
+
)
|
89 |
+
self.self_attn_k_norm = (
|
90 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
91 |
+
)
|
92 |
+
self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
93 |
+
|
94 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
95 |
+
act_layer = get_activation_layer(act_type)
|
96 |
+
self.mlp = MLP(
|
97 |
+
in_channels=hidden_size,
|
98 |
+
hidden_channels=mlp_hidden_dim,
|
99 |
+
act_layer=act_layer,
|
100 |
+
drop=mlp_drop_rate,
|
101 |
+
**factory_kwargs,
|
102 |
+
)
|
103 |
+
|
104 |
+
self.adaLN_modulation = nn.Sequential(
|
105 |
+
act_layer(),
|
106 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
107 |
+
)
|
108 |
+
# Zero-initialize the modulation
|
109 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
110 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
x: torch.Tensor,
|
115 |
+
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
116 |
+
attn_mask: Optional[torch.Tensor] = None,
|
117 |
+
) -> torch.Tensor:
|
118 |
+
"""
|
119 |
+
Forward pass for IndividualTokenRefinerBlock.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
x: Input tensor of shape [B, L, C].
|
123 |
+
c: Conditioning tensor of shape [B, C].
|
124 |
+
attn_mask: Optional attention mask of shape [B, L].
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
Refined tensor of shape [B, L, C].
|
128 |
+
"""
|
129 |
+
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
130 |
+
norm_x = self.norm1(x)
|
131 |
+
qkv = self.self_attn_qkv(norm_x)
|
132 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
133 |
+
q = self.self_attn_q_norm(q).to(v)
|
134 |
+
k = self.self_attn_k_norm(k).to(v)
|
135 |
+
attn = attention(q, k, v, attn_mask=attn_mask)
|
136 |
+
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
137 |
+
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class IndividualTokenRefiner(nn.Module):
|
142 |
+
"""
|
143 |
+
Stacks multiple IndividualTokenRefinerBlock modules.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
hidden_size: Hidden dimension size.
|
147 |
+
heads_num: Number of attention heads.
|
148 |
+
depth: Number of blocks.
|
149 |
+
mlp_width_ratio: Expansion ratio for MLP hidden size.
|
150 |
+
mlp_drop_rate: Dropout rate for MLP.
|
151 |
+
act_type: Activation function type.
|
152 |
+
qk_norm: Whether to use QK normalization.
|
153 |
+
qk_norm_type: Type of QK normalization.
|
154 |
+
qkv_bias: Whether to use bias in QKV projections.
|
155 |
+
dtype: Optional torch dtype.
|
156 |
+
device: Optional torch device.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
hidden_size: int,
|
162 |
+
heads_num: int,
|
163 |
+
depth: int,
|
164 |
+
mlp_width_ratio: float = 4.0,
|
165 |
+
mlp_drop_rate: float = 0.0,
|
166 |
+
act_type: str = "silu",
|
167 |
+
qk_norm: bool = False,
|
168 |
+
qk_norm_type: str = "layer",
|
169 |
+
qkv_bias: bool = True,
|
170 |
+
dtype: Optional[torch.dtype] = None,
|
171 |
+
device: Optional[torch.device] = None,
|
172 |
+
):
|
173 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
174 |
+
super().__init__()
|
175 |
+
self.blocks = nn.ModuleList(
|
176 |
+
[
|
177 |
+
IndividualTokenRefinerBlock(
|
178 |
+
hidden_size=hidden_size,
|
179 |
+
heads_num=heads_num,
|
180 |
+
mlp_width_ratio=mlp_width_ratio,
|
181 |
+
mlp_drop_rate=mlp_drop_rate,
|
182 |
+
act_type=act_type,
|
183 |
+
qk_norm=qk_norm,
|
184 |
+
qk_norm_type=qk_norm_type,
|
185 |
+
qkv_bias=qkv_bias,
|
186 |
+
**factory_kwargs,
|
187 |
+
)
|
188 |
+
for _ in range(depth)
|
189 |
+
]
|
190 |
+
)
|
191 |
+
|
192 |
+
def forward(
|
193 |
+
self,
|
194 |
+
x: torch.Tensor,
|
195 |
+
c: torch.LongTensor,
|
196 |
+
mask: Optional[torch.Tensor] = None,
|
197 |
+
) -> torch.Tensor:
|
198 |
+
"""
|
199 |
+
Forward pass for IndividualTokenRefiner.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
x: Input tensor of shape [B, L, C].
|
203 |
+
c: Conditioning tensor of shape [B, C].
|
204 |
+
mask: Optional mask tensor of shape [B, L].
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
Refined tensor of shape [B, L, C].
|
208 |
+
"""
|
209 |
+
if mask is not None:
|
210 |
+
mask = mask.clone().bool()
|
211 |
+
mask[:, 0] = True # Prevent attention weights from becoming NaN
|
212 |
+
for block in self.blocks:
|
213 |
+
x = block(x, c, mask)
|
214 |
+
return x
|
215 |
+
|
216 |
+
|
217 |
+
class SingleTokenRefiner(nn.Module):
|
218 |
+
"""
|
219 |
+
Single token refiner block for LLM text embedding refinement.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
in_channels: Input feature dimension.
|
223 |
+
hidden_size: Hidden dimension size.
|
224 |
+
heads_num: Number of attention heads.
|
225 |
+
depth: Number of blocks.
|
226 |
+
mlp_width_ratio: Expansion ratio for MLP hidden size.
|
227 |
+
mlp_drop_rate: Dropout rate for MLP.
|
228 |
+
act_type: Activation function type.
|
229 |
+
qk_norm: Whether to use QK normalization.
|
230 |
+
qk_norm_type: Type of QK normalization.
|
231 |
+
qkv_bias: Whether to use bias in QKV projections.
|
232 |
+
dtype: Optional torch dtype.
|
233 |
+
device: Optional torch device.
|
234 |
+
"""
|
235 |
+
|
236 |
+
def __init__(
|
237 |
+
self,
|
238 |
+
in_channels: int,
|
239 |
+
hidden_size: int,
|
240 |
+
heads_num: int,
|
241 |
+
depth: int,
|
242 |
+
mlp_width_ratio: float = 4.0,
|
243 |
+
mlp_drop_rate: float = 0.0,
|
244 |
+
act_type: str = "silu",
|
245 |
+
qk_norm: bool = False,
|
246 |
+
qk_norm_type: str = "layer",
|
247 |
+
qkv_bias: bool = True,
|
248 |
+
dtype: Optional[torch.dtype] = None,
|
249 |
+
device: Optional[torch.device] = None,
|
250 |
+
):
|
251 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
252 |
+
super().__init__()
|
253 |
+
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
|
254 |
+
act_layer = get_activation_layer(act_type)
|
255 |
+
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
256 |
+
self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
|
257 |
+
self.individual_token_refiner = IndividualTokenRefiner(
|
258 |
+
hidden_size=hidden_size,
|
259 |
+
heads_num=heads_num,
|
260 |
+
depth=depth,
|
261 |
+
mlp_width_ratio=mlp_width_ratio,
|
262 |
+
mlp_drop_rate=mlp_drop_rate,
|
263 |
+
act_type=act_type,
|
264 |
+
qk_norm=qk_norm,
|
265 |
+
qk_norm_type=qk_norm_type,
|
266 |
+
qkv_bias=qkv_bias,
|
267 |
+
**factory_kwargs,
|
268 |
+
)
|
269 |
+
|
270 |
+
def forward(
|
271 |
+
self,
|
272 |
+
x: torch.Tensor,
|
273 |
+
t: torch.LongTensor,
|
274 |
+
mask: Optional[torch.LongTensor] = None,
|
275 |
+
) -> torch.Tensor:
|
276 |
+
"""
|
277 |
+
Forward pass for SingleTokenRefiner.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
x: Input tensor of shape [B, L, in_channels].
|
281 |
+
t: Timestep tensor of shape [B].
|
282 |
+
mask: Optional mask tensor of shape [B, L].
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
Refined tensor of shape [B, L, hidden_size].
|
286 |
+
"""
|
287 |
+
timestep_aware_representations = self.t_embedder(t)
|
288 |
+
if mask is None:
|
289 |
+
context_aware_representations = x.mean(dim=1)
|
290 |
+
else:
|
291 |
+
mask_float = mask.unsqueeze(-1) # [B, L, 1]
|
292 |
+
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
293 |
+
context_aware_representations = self.c_embedder(context_aware_representations)
|
294 |
+
c = timestep_aware_representations + context_aware_representations
|
295 |
+
x = self.input_embedder(x)
|
296 |
+
x = self.individual_token_refiner(x, c, mask)
|
297 |
+
return x
|
hyimage/models/hunyuan/utils/__init__.py
ADDED
File without changes
|
hyimage/models/hunyuan/utils/helpers.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
from itertools import repeat
|
3 |
+
|
4 |
+
def _ntuple(n):
|
5 |
+
"""
|
6 |
+
Returns a function that converts input to a tuple of length n.
|
7 |
+
If input is an iterable (except str), it is converted to a tuple.
|
8 |
+
If the tuple has length 1, it is repeated n times.
|
9 |
+
Otherwise, the input is repeated n times to form the tuple.
|
10 |
+
"""
|
11 |
+
def parse(x):
|
12 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
13 |
+
x = tuple(x)
|
14 |
+
if len(x) == 1:
|
15 |
+
x = tuple(repeat(x[0], n))
|
16 |
+
return x
|
17 |
+
return tuple(repeat(x, n))
|
18 |
+
return parse
|
19 |
+
|
20 |
+
to_1tuple = _ntuple(1)
|
21 |
+
to_2tuple = _ntuple(2)
|
22 |
+
to_3tuple = _ntuple(3)
|
23 |
+
to_4tuple = _ntuple(4)
|
hyimage/models/model_zoo.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
|
4 |
+
from hyimage.common.config import LazyCall as L
|
5 |
+
from hyimage.models.hunyuan.configs.hunyuanimage_config import (
|
6 |
+
hunyuanimage_v2_1_cfg,
|
7 |
+
hunyuanimage_v2_1_distilled_cfg,
|
8 |
+
hunyuanimage_refiner_cfg,
|
9 |
+
)
|
10 |
+
from hyimage.models.vae import load_vae
|
11 |
+
from hyimage.common.config.base_config import (
|
12 |
+
DiTConfig,
|
13 |
+
RepromptConfig,
|
14 |
+
TextEncoderConfig,
|
15 |
+
VAEConfig,
|
16 |
+
)
|
17 |
+
from hyimage.models.text_encoder import TextEncoder
|
18 |
+
|
19 |
+
HUNYUANIMAGE_V2_1_MODEL_ROOT = os.environ.get("HUNYUANIMAGE_V2_1_MODEL_ROOT", "./ckpts")
|
20 |
+
|
21 |
+
# =============================================================================
|
22 |
+
# MODEL CONFIGURATIONS
|
23 |
+
# =============================================================================
|
24 |
+
|
25 |
+
# =============================================================================
|
26 |
+
# V2.1 MODELS
|
27 |
+
# =============================================================================
|
28 |
+
|
29 |
+
def HUNYUANIMAGE_V2_1_TEXT_ENCODER(**kwargs):
|
30 |
+
return TextEncoderConfig(
|
31 |
+
model=L(TextEncoder)(
|
32 |
+
text_encoder_type="llm",
|
33 |
+
max_length=1000,
|
34 |
+
text_encoder_precision='fp16',
|
35 |
+
tokenizer_type="llm",
|
36 |
+
text_encoder_path=None,
|
37 |
+
prompt_template=None,
|
38 |
+
prompt_template_video=None,
|
39 |
+
hidden_state_skip_layer=2,
|
40 |
+
apply_final_norm=False,
|
41 |
+
reproduce=False,
|
42 |
+
logger=None,
|
43 |
+
device=None,
|
44 |
+
),
|
45 |
+
prompt_template="dit-llm-encode-v2",
|
46 |
+
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/text_encoder",
|
47 |
+
text_len=1000,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
def HUNYUANIMAGE_V2_1_VAE_32x(**kwargs):
|
52 |
+
return VAEConfig(
|
53 |
+
model=L(load_vae)(
|
54 |
+
vae_path=None,
|
55 |
+
device="cuda",
|
56 |
+
),
|
57 |
+
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/vae/vae_2_1",
|
58 |
+
cpu_offload=False,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def HUNYUANIMAGE_V2_1_DIT(**kwargs):
|
63 |
+
return DiTConfig(
|
64 |
+
model=copy.deepcopy(hunyuanimage_v2_1_cfg),
|
65 |
+
use_lora=False,
|
66 |
+
use_cpu_offload=False,
|
67 |
+
gradient_checkpointing=True,
|
68 |
+
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1.safetensors",
|
69 |
+
use_compile=True,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
def HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL(**kwargs):
|
74 |
+
return DiTConfig(
|
75 |
+
model=copy.deepcopy(hunyuanimage_v2_1_distilled_cfg),
|
76 |
+
use_lora=False,
|
77 |
+
use_cpu_offload=False,
|
78 |
+
gradient_checkpointing=True,
|
79 |
+
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1-distilled.safetensors",
|
80 |
+
use_compile=True,
|
81 |
+
)
|
82 |
+
|
83 |
+
# =============================================================================
|
84 |
+
# REFINER MODELS
|
85 |
+
# =============================================================================
|
86 |
+
|
87 |
+
def HUNYUANIMAGE_REFINER_DIT(**kwargs):
|
88 |
+
return DiTConfig(
|
89 |
+
model=copy.deepcopy(hunyuanimage_refiner_cfg),
|
90 |
+
use_lora=False,
|
91 |
+
use_cpu_offload=False,
|
92 |
+
gradient_checkpointing=True,
|
93 |
+
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage-refiner.safetensors",
|
94 |
+
use_compile=True,
|
95 |
+
)
|
96 |
+
|
97 |
+
def HUNYUANIMAGE_REFINER_VAE_32x(**kwargs):
|
98 |
+
return VAEConfig(
|
99 |
+
model=L(load_vae)(
|
100 |
+
vae_path=None,
|
101 |
+
device="cuda",
|
102 |
+
),
|
103 |
+
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/vae/vae_refiner",
|
104 |
+
cpu_offload=False,
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def HUNYUANIMAGE_REFINER_TEXT_ENCODER(**kwargs):
|
109 |
+
return TextEncoderConfig(
|
110 |
+
model=L(TextEncoder)(
|
111 |
+
text_encoder_type="llm",
|
112 |
+
max_length=1000,
|
113 |
+
text_encoder_precision='fp16',
|
114 |
+
tokenizer_type="llm",
|
115 |
+
text_encoder_path=None,
|
116 |
+
prompt_template=None,
|
117 |
+
prompt_template_video=None,
|
118 |
+
hidden_state_skip_layer=2,
|
119 |
+
apply_final_norm=False,
|
120 |
+
reproduce=False,
|
121 |
+
logger=None,
|
122 |
+
device=None,
|
123 |
+
),
|
124 |
+
prompt_template="dit-llm-encode",
|
125 |
+
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/text_encoder",
|
126 |
+
text_len=256,
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
# =============================================================================
|
131 |
+
# SPECIALIZED MODELS
|
132 |
+
# =============================================================================
|
133 |
+
|
134 |
+
def HUNYUANIMAGE_REPROMPT(**kwargs):
|
135 |
+
from hyimage.models.reprompt import RePrompt
|
136 |
+
|
137 |
+
return RepromptConfig(
|
138 |
+
model=L(RePrompt)(
|
139 |
+
models_root_path=None,
|
140 |
+
device_map="auto",
|
141 |
+
),
|
142 |
+
load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/reprompt",
|
143 |
+
)
|
hyimage/models/reprompt/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .reprompt import RePrompt
|
hyimage/models/reprompt/reprompt.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import loguru
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
from accelerate import cpu_offload_with_hook
|
6 |
+
|
7 |
+
"""
|
8 |
+
English translation of the System prompt:
|
9 |
+
----------------------------------------
|
10 |
+
You are an expert in writing image generation prompts. Please rewrite the user's prompt according to the following requirements:
|
11 |
+
1. The main subject/action/quantity/style/layout/relationship/attribute/text in the rewritten prompt must be consistent with the original intention;
|
12 |
+
2. The rewritten prompt should follow the "overall-detail-conclusion" structure, ensuring the clarity of information hierarchy;
|
13 |
+
3. The rewritten prompt should be objective and neutral, avoiding subjective judgment and emotional evaluation;
|
14 |
+
4. The rewritten prompt should be from the main to the secondary, always describing the most important elements first, and then the secondary and background elements;
|
15 |
+
5. The rewritten prompt should be logically clear, strictly follow the spatial logic or main-secondary logic, allowing the reader to reconstruct the image in the brain;
|
16 |
+
6. The rewritten prompt should end with a summary sentence, summarizing the overall style or type of the image.
|
17 |
+
"""
|
18 |
+
|
19 |
+
SYSTEM_PROMPT = (
|
20 |
+
"你是一位图像生成提示词撰写专家,请根据用户输入的提示词,改写生成新的提示词,改写后的提示词要求:"
|
21 |
+
"1 改写后提示词包含的主体/动作/数量/风格/布局/关系/属性/文字等 必须和改写前的意图一致; "
|
22 |
+
"2 在宏观上遵循“总-分-总”的结构,确保信息的层次清晰;"
|
23 |
+
"3 客观中立,避免主观臆断和情感评价;"
|
24 |
+
"4 由主到次,始终先描述最重要的元素,再描述次要和背景元素;"
|
25 |
+
"5 逻辑清晰,严格遵循空间逻辑或主次逻辑,使读者能在大脑中重建画面;"
|
26 |
+
"6 结尾点题,必须用一句话总结图像的整体风格或类型。"
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def replace_single_quotes(text):
|
31 |
+
"""
|
32 |
+
Replace single quotes within words with double quotes, and convert
|
33 |
+
curly single quotes to curly double quotes for consistency.
|
34 |
+
"""
|
35 |
+
pattern = r"\B'([^']*)'\B"
|
36 |
+
replaced_text = re.sub(pattern, r'"\1"', text)
|
37 |
+
replaced_text = replaced_text.replace("’", "”")
|
38 |
+
replaced_text = replaced_text.replace("‘", "“")
|
39 |
+
return replaced_text
|
40 |
+
|
41 |
+
|
42 |
+
class RePrompt:
|
43 |
+
|
44 |
+
def __init__(self, models_root_path, device_map="auto", enable_offloading=True):
|
45 |
+
"""
|
46 |
+
Initialize the RePrompt class with model and processor.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
models_root_path (str): Path to the pretrained model.
|
50 |
+
device_map (str): Device mapping for model loading.
|
51 |
+
"""
|
52 |
+
if enable_offloading:
|
53 |
+
device_map = None
|
54 |
+
self.model = AutoModelForCausalLM.from_pretrained(models_root_path, device_map=device_map, trust_remote_code=True)
|
55 |
+
self.tokenizer = AutoTokenizer.from_pretrained(models_root_path, trust_remote_code=True)
|
56 |
+
self.enable_offloading = enable_offloading
|
57 |
+
|
58 |
+
if enable_offloading:
|
59 |
+
_, self.offload_hook = cpu_offload_with_hook(self.model, execution_device=torch.device('cuda'))
|
60 |
+
self.device_map = device_map
|
61 |
+
self.original_device_map = getattr(self.model, 'hf_device_map', None)
|
62 |
+
|
63 |
+
@torch.inference_mode()
|
64 |
+
def predict(
|
65 |
+
self,
|
66 |
+
prompt_cot,
|
67 |
+
sys_prompt=SYSTEM_PROMPT,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Generate a rewritten prompt using the model.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
prompt_cot (str): The original prompt to be rewritten.
|
74 |
+
sys_prompt (str): System prompt to guide the rewriting.
|
75 |
+
temperature (float): Sampling temperature.
|
76 |
+
device (str): Device for inference.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
str: The rewritten prompt, or the original if generation fails.
|
80 |
+
"""
|
81 |
+
org_prompt_cot = prompt_cot
|
82 |
+
try:
|
83 |
+
messages = [
|
84 |
+
{"role": "system", "content": sys_prompt},
|
85 |
+
{"role": "user", "content": org_prompt_cot},
|
86 |
+
]
|
87 |
+
tokenized_chat = self.tokenizer.apply_chat_template(
|
88 |
+
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", enable_thinking=False # Toggle thinking mode (default: True)
|
89 |
+
)
|
90 |
+
if self.model.device != torch.device('meta'):
|
91 |
+
tokenized_chat = tokenized_chat.to(self.model.device)
|
92 |
+
outputs = self.model.generate(tokenized_chat, max_new_tokens=2048, temperature=0.0, do_sample=False, top_k=5, top_p=0.9)
|
93 |
+
if self.enable_offloading:
|
94 |
+
self.offload_hook.offload()
|
95 |
+
output_res = self.tokenizer.decode(outputs[0])
|
96 |
+
answer_pattern = r'<answer>(.*?)</answer>'
|
97 |
+
answer_matches = re.findall(answer_pattern, output_res, re.DOTALL)
|
98 |
+
prompt_cot = [match.strip() for match in answer_matches][0]
|
99 |
+
prompt_cot = replace_single_quotes(prompt_cot)
|
100 |
+
except Exception as e:
|
101 |
+
prompt_cot = org_prompt_cot
|
102 |
+
loguru.logger.error(f"✗ Re-prompting failed, fall back to generate prompt. Cause: {e}")
|
103 |
+
|
104 |
+
return prompt_cot
|
105 |
+
|
106 |
+
def to(self, device, *args, **kwargs):
|
107 |
+
self.model = self.model.to(device, *args, **kwargs)
|
108 |
+
return self
|
hyimage/models/text_encoder/__init__.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
from copy import deepcopy
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from transformers import AutoModelForVision2Seq, AutoTokenizer
|
8 |
+
|
9 |
+
from transformers.utils import ModelOutput
|
10 |
+
|
11 |
+
|
12 |
+
def use_default(value, default):
|
13 |
+
"""Utility: return value if not None, else default."""
|
14 |
+
return value if value is not None else default
|
15 |
+
|
16 |
+
# Prompt templates for different models and tasks
|
17 |
+
PROMPT_TEMPLATE_ENCODE = (
|
18 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
19 |
+
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
20 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
21 |
+
)
|
22 |
+
PROMPT_TEMPLATE_ENCODE_V2 = (
|
23 |
+
"<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, "
|
24 |
+
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
25 |
+
"<|im_start|>user\n{}<|im_end|>"
|
26 |
+
)
|
27 |
+
|
28 |
+
NEGATIVE_PROMPT = (
|
29 |
+
"Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, "
|
30 |
+
"bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
31 |
+
)
|
32 |
+
|
33 |
+
PROMPT_TEMPLATE = {
|
34 |
+
"dit-llm-encode": {
|
35 |
+
"template": PROMPT_TEMPLATE_ENCODE,
|
36 |
+
"crop_start": 36,
|
37 |
+
},
|
38 |
+
"dit-llm-encode-v2": {
|
39 |
+
"template": PROMPT_TEMPLATE_ENCODE_V2,
|
40 |
+
"crop_start": 34,
|
41 |
+
},
|
42 |
+
}
|
43 |
+
|
44 |
+
def load_text_encoder(
|
45 |
+
text_encoder_type,
|
46 |
+
text_encoder_precision=None,
|
47 |
+
text_encoder_path=None,
|
48 |
+
infer_mode="encoder",
|
49 |
+
logger=None,
|
50 |
+
device=None
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Load a text encoder model from pretrained weights.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
text_encoder_type (str): Type of text encoder.
|
57 |
+
text_encoder_precision (str, optional): Precision for model weights.
|
58 |
+
text_encoder_path (str, optional): Path to pretrained weights.
|
59 |
+
infer_mode (str): "encoder" or "decoder".
|
60 |
+
logger (logging.Logger, optional): Logger for info.
|
61 |
+
device (torch.device, optional): Device to move model to.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
model (nn.Module): Loaded text encoder.
|
65 |
+
model_path (str): Path to model.
|
66 |
+
"""
|
67 |
+
if logger is not None:
|
68 |
+
logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
|
69 |
+
|
70 |
+
if text_encoder_type == 'llm':
|
71 |
+
text_encoder = AutoModelForVision2Seq.from_pretrained(
|
72 |
+
text_encoder_path,
|
73 |
+
torch_dtype="auto"
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
77 |
+
|
78 |
+
text_encoder.requires_grad_(False)
|
79 |
+
|
80 |
+
if logger is not None:
|
81 |
+
logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
|
82 |
+
|
83 |
+
if device is not None:
|
84 |
+
text_encoder = text_encoder.to(device)
|
85 |
+
|
86 |
+
return text_encoder, text_encoder_path
|
87 |
+
|
88 |
+
def load_tokenizer(
|
89 |
+
tokenizer_type,
|
90 |
+
tokenizer_path=None,
|
91 |
+
padding_side="right",
|
92 |
+
logger=None
|
93 |
+
):
|
94 |
+
"""
|
95 |
+
Load a tokenizer from pretrained weights.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
tokenizer_type (str): Type of tokenizer.
|
99 |
+
tokenizer_path (str, optional): Path to pretrained tokenizer.
|
100 |
+
padding_side (str): Padding side for tokenizer.
|
101 |
+
logger (logging.Logger, optional): Logger for info.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
tokenizer: Loaded tokenizer.
|
105 |
+
tokenizer_path (str): Path to tokenizer.
|
106 |
+
"""
|
107 |
+
if logger is not None:
|
108 |
+
logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
|
109 |
+
|
110 |
+
if tokenizer_type == "llm":
|
111 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
112 |
+
tokenizer_path, use_fast=False, padding_side=padding_side, trust_remote_code=True)
|
113 |
+
else:
|
114 |
+
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
|
115 |
+
|
116 |
+
return tokenizer, tokenizer_path
|
117 |
+
|
118 |
+
@dataclass
|
119 |
+
class TextEncoderModelOutput(ModelOutput):
|
120 |
+
"""
|
121 |
+
Output for text encoder models.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
hidden_state (torch.FloatTensor): Output hidden states of the last layer.
|
125 |
+
attention_mask (torch.LongTensor, optional): Attention mask for valid tokens.
|
126 |
+
hidden_states_list (tuple(torch.FloatTensor), optional): All hidden states if requested.
|
127 |
+
text_outputs (list, optional): Decoded texts if requested.
|
128 |
+
"""
|
129 |
+
hidden_state: torch.FloatTensor = None
|
130 |
+
attention_mask: Optional[torch.LongTensor] = None
|
131 |
+
hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
|
132 |
+
text_outputs: Optional[list] = None
|
133 |
+
|
134 |
+
class TextEncoder(nn.Module):
|
135 |
+
"""
|
136 |
+
TextEncoder wraps a pretrained text encoder and tokenizer for flexible text encoding.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
text_encoder_type (str): Type of text encoder.
|
140 |
+
max_length (int): Maximum sequence length.
|
141 |
+
text_encoder_precision (str, optional): Precision for model weights.
|
142 |
+
text_encoder_path (str, optional): Path to pretrained weights.
|
143 |
+
tokenizer_type (str, optional): Type of tokenizer.
|
144 |
+
tokenizer_path (str, optional): Path to pretrained tokenizer.
|
145 |
+
output_key (str, optional): Output key for model output.
|
146 |
+
use_attention_mask (bool): Whether to use attention mask.
|
147 |
+
infer_mode (str): "encoder" or "decoder".
|
148 |
+
input_max_length (int, optional): Max input length.
|
149 |
+
prompt_template (dict, optional): Prompt template for image.
|
150 |
+
prompt_template_video (dict, optional): Prompt template for video.
|
151 |
+
hidden_state_skip_layer (int, optional): Skip layers from last for hidden state.
|
152 |
+
apply_final_norm (bool): Whether to apply final layer norm.
|
153 |
+
reproduce (bool): Deterministic output if True.
|
154 |
+
logger (logging.Logger, optional): Logger for info.
|
155 |
+
device (torch.device, optional): Device to move model to.
|
156 |
+
"""
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
text_encoder_type: str,
|
160 |
+
max_length: int,
|
161 |
+
text_encoder_precision: Optional[str] = None,
|
162 |
+
text_encoder_path: Optional[str] = None,
|
163 |
+
tokenizer_type: Optional[str] = None,
|
164 |
+
tokenizer_path: Optional[str] = None,
|
165 |
+
output_key: Optional[str] = None,
|
166 |
+
use_attention_mask: bool = True,
|
167 |
+
infer_mode: str = "encoder",
|
168 |
+
input_max_length: Optional[int] = None,
|
169 |
+
prompt_template: Optional[dict] = None,
|
170 |
+
prompt_template_video: Optional[dict] = None,
|
171 |
+
hidden_state_skip_layer: Optional[int] = None,
|
172 |
+
apply_final_norm: bool = False,
|
173 |
+
reproduce: bool = False,
|
174 |
+
logger=None,
|
175 |
+
device=None,
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
self.text_encoder_type = text_encoder_type
|
179 |
+
self.max_length = max_length
|
180 |
+
self.precision = text_encoder_precision
|
181 |
+
self.model_path = text_encoder_path
|
182 |
+
self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
|
183 |
+
self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
|
184 |
+
self.use_attention_mask = use_attention_mask
|
185 |
+
self.input_max_length = input_max_length if input_max_length is not None else max_length
|
186 |
+
self.prompt_template = dict(prompt_template) if prompt_template is not None else None
|
187 |
+
self.prompt_template_video = dict(prompt_template_video) if prompt_template_video is not None else None
|
188 |
+
self.hidden_state_skip_layer = hidden_state_skip_layer
|
189 |
+
self.apply_final_norm = apply_final_norm
|
190 |
+
self.infer_mode = infer_mode
|
191 |
+
self.reproduce = reproduce
|
192 |
+
self.logger = logger
|
193 |
+
|
194 |
+
self.use_template = self.prompt_template is not None
|
195 |
+
if self.use_template:
|
196 |
+
assert isinstance(self.prompt_template, dict) and "template" in self.prompt_template, (
|
197 |
+
f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
|
198 |
+
)
|
199 |
+
if self.prompt_template_video is not None:
|
200 |
+
assert isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video, (
|
201 |
+
f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
|
202 |
+
)
|
203 |
+
assert '{}' in str(self.prompt_template["template"]), (
|
204 |
+
"`prompt_template['template']` must contain a placeholder `{}` for the input text, "
|
205 |
+
f"got {self.prompt_template['template']}"
|
206 |
+
)
|
207 |
+
|
208 |
+
if infer_mode == "decoder":
|
209 |
+
assert text_encoder_type in ["llava-llama-3-8b"], (
|
210 |
+
f"Unsupported text encoder type for infer_mode='decoder': {text_encoder_type}"
|
211 |
+
)
|
212 |
+
assert self.prompt_template is not None and hidden_state_skip_layer is not None, (
|
213 |
+
f"`prompt_template` and `hidden_state_skip_layer` must be provided for infer_mode='decoder', "
|
214 |
+
f"got prompt_template={self.prompt_template}, hidden_state_skip_layer={self.hidden_state_skip_layer}"
|
215 |
+
)
|
216 |
+
|
217 |
+
if "t5" in text_encoder_type:
|
218 |
+
self.output_key = output_key or "last_hidden_state"
|
219 |
+
elif "clip" in text_encoder_type:
|
220 |
+
self.output_key = output_key or "pooler_output"
|
221 |
+
elif any(x in text_encoder_type for x in ["llm"]):
|
222 |
+
self.output_key = output_key or ("last_hidden_state" if infer_mode == "encoder" else None)
|
223 |
+
else:
|
224 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
225 |
+
|
226 |
+
self.model, self.model_path = load_text_encoder(
|
227 |
+
text_encoder_type=self.text_encoder_type,
|
228 |
+
text_encoder_precision=self.precision,
|
229 |
+
text_encoder_path=self.model_path,
|
230 |
+
infer_mode=self.infer_mode,
|
231 |
+
logger=self.logger,
|
232 |
+
device=device
|
233 |
+
)
|
234 |
+
self.dtype = self.model.dtype
|
235 |
+
self.device = self.model.device
|
236 |
+
|
237 |
+
padding_side = "right" if self.infer_mode == "encoder" else "left"
|
238 |
+
self.tokenizer, self.tokenizer_path = load_tokenizer(
|
239 |
+
tokenizer_type=self.tokenizer_type,
|
240 |
+
tokenizer_path=self.tokenizer_path,
|
241 |
+
padding_side=padding_side,
|
242 |
+
logger=self.logger
|
243 |
+
)
|
244 |
+
|
245 |
+
def __repr__(self):
|
246 |
+
return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def apply_text_to_template(text, template, prevent_empty_text=True):
|
250 |
+
"""
|
251 |
+
Apply text to a prompt template.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
text (str): Input text.
|
255 |
+
template (str or list): Template string or list of chat conversation.
|
256 |
+
prevent_empty_text (bool): If True, prevent empty user text by adding a space.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
str or list: Text with template applied.
|
260 |
+
"""
|
261 |
+
if isinstance(template, str):
|
262 |
+
return template.format(text)
|
263 |
+
elif isinstance(template, list):
|
264 |
+
conversation = deepcopy(template)
|
265 |
+
for message in conversation:
|
266 |
+
if '{}' in message.get("content", ""):
|
267 |
+
filled_text = message["content"].format(text)
|
268 |
+
if prevent_empty_text and len(filled_text) == 0:
|
269 |
+
filled_text = ' '
|
270 |
+
message["content"] = filled_text
|
271 |
+
break # Only one placeholder per conversation
|
272 |
+
return conversation
|
273 |
+
else:
|
274 |
+
raise TypeError(f"Unsupported template type: {type(template)}")
|
275 |
+
|
276 |
+
def text2tokens(self, text, data_type='image'):
|
277 |
+
"""
|
278 |
+
Tokenize the input text, optionally applying a prompt template.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
text (str or list): Input text.
|
282 |
+
data_type (str): 'image' or 'video'.
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
dict: Tokenized input.
|
286 |
+
"""
|
287 |
+
tokenize_input_type = 'str'
|
288 |
+
if self.use_template:
|
289 |
+
if data_type == 'image':
|
290 |
+
prompt_template = self.prompt_template["template"]
|
291 |
+
elif data_type == 'video':
|
292 |
+
prompt_template = self.prompt_template_video["template"]
|
293 |
+
else:
|
294 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
295 |
+
if isinstance(text, (list, tuple)):
|
296 |
+
text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
|
297 |
+
if isinstance(text[0], list):
|
298 |
+
tokenize_input_type = 'list'
|
299 |
+
elif isinstance(text, str):
|
300 |
+
text = self.apply_text_to_template(text, prompt_template)
|
301 |
+
if isinstance(text, list):
|
302 |
+
tokenize_input_type = 'list'
|
303 |
+
else:
|
304 |
+
raise TypeError(f"Unsupported text type: {type(text)}")
|
305 |
+
kwargs = dict(truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
|
306 |
+
if tokenize_input_type == 'str':
|
307 |
+
return self.tokenizer(
|
308 |
+
text,
|
309 |
+
return_length=False,
|
310 |
+
return_overflowing_tokens=False,
|
311 |
+
return_attention_mask=True,
|
312 |
+
**kwargs,
|
313 |
+
)
|
314 |
+
elif tokenize_input_type == 'list':
|
315 |
+
return self.tokenizer.apply_chat_template(
|
316 |
+
text,
|
317 |
+
add_generation_prompt=True,
|
318 |
+
tokenize=True,
|
319 |
+
return_dict=True,
|
320 |
+
**kwargs,
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
|
324 |
+
|
325 |
+
def encode(
|
326 |
+
self,
|
327 |
+
batch_encoding,
|
328 |
+
use_attention_mask=None,
|
329 |
+
output_hidden_states=False,
|
330 |
+
do_sample=None,
|
331 |
+
hidden_state_skip_layer=None,
|
332 |
+
return_texts=False,
|
333 |
+
data_type='image',
|
334 |
+
device=None
|
335 |
+
):
|
336 |
+
"""
|
337 |
+
Encode tokenized input using the text encoder.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
batch_encoding (dict): Batch encoding from tokenizer.
|
341 |
+
use_attention_mask (bool, optional): Whether to use attention mask.
|
342 |
+
output_hidden_states (bool): Whether to output all hidden states.
|
343 |
+
do_sample (bool, optional): Whether to sample from the model (for decoder-only LLMs).
|
344 |
+
hidden_state_skip_layer (int, optional): Number of layers to skip from last for hidden state.
|
345 |
+
return_texts (bool): Whether to return decoded texts.
|
346 |
+
data_type (str): 'image' or 'video'.
|
347 |
+
device (torch.device, optional): Device to use.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
TextEncoderModelOutput: Encoded output.
|
351 |
+
"""
|
352 |
+
use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
|
353 |
+
hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
|
354 |
+
do_sample = use_default(do_sample, not self.reproduce)
|
355 |
+
|
356 |
+
if self.infer_mode == "encoder":
|
357 |
+
attention_mask = batch_encoding["attention_mask"].to(self.model.device) if use_attention_mask else None
|
358 |
+
if 'Gemma2' in self.text_encoder_type:
|
359 |
+
input_ids = batch_encoding["input_ids"].to(self.model.device)
|
360 |
+
_, inputs_embeds, labels, attention_mask = self.model.merge_multimodal(
|
361 |
+
text_input_ids=input_ids,
|
362 |
+
text_attention_masks=attention_mask,
|
363 |
+
text_labels=None,
|
364 |
+
pixel_values=[None]
|
365 |
+
)
|
366 |
+
outputs = self.model.llm(inputs_embeds=inputs_embeds, labels=labels, attention_mask=attention_mask)
|
367 |
+
else:
|
368 |
+
outputs = self.model(
|
369 |
+
input_ids=batch_encoding["input_ids"].to(self.model.device),
|
370 |
+
attention_mask=attention_mask,
|
371 |
+
output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
|
372 |
+
)
|
373 |
+
if hidden_state_skip_layer is not None:
|
374 |
+
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
375 |
+
# Apply final norm for intermediate layers if requested
|
376 |
+
if hidden_state_skip_layer > 0 and self.apply_final_norm:
|
377 |
+
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
|
378 |
+
else:
|
379 |
+
last_hidden_state = outputs[self.output_key]
|
380 |
+
|
381 |
+
# Remove hidden states of instruction tokens, only keep prompt tokens.
|
382 |
+
if self.use_template:
|
383 |
+
if data_type == 'image':
|
384 |
+
crop_start = self.prompt_template.get("crop_start", -1)
|
385 |
+
elif data_type == 'video':
|
386 |
+
crop_start = self.prompt_template_video.get("crop_start", -1)
|
387 |
+
else:
|
388 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
389 |
+
if crop_start > 0:
|
390 |
+
last_hidden_state = last_hidden_state[:, crop_start:]
|
391 |
+
attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
|
392 |
+
|
393 |
+
if output_hidden_states:
|
394 |
+
return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
|
395 |
+
return TextEncoderModelOutput(last_hidden_state, attention_mask)
|
396 |
+
|
397 |
+
elif self.infer_mode == "decoder":
|
398 |
+
# Remove leading padding tokens
|
399 |
+
input_max_valid_tokens = batch_encoding["attention_mask"].sum(dim=1).max().item()
|
400 |
+
if input_max_valid_tokens < batch_encoding["attention_mask"].shape[1]:
|
401 |
+
batch_encoding = {
|
402 |
+
"input_ids": batch_encoding["input_ids"][:, -input_max_valid_tokens:],
|
403 |
+
"attention_mask": batch_encoding["attention_mask"][:, -input_max_valid_tokens:],
|
404 |
+
}
|
405 |
+
|
406 |
+
# Generate text from the model.
|
407 |
+
outputs = self.model.generate(
|
408 |
+
input_ids=batch_encoding["input_ids"].to(self.model.device),
|
409 |
+
attention_mask=batch_encoding["attention_mask"].to(self.model.device) if use_attention_mask else None,
|
410 |
+
max_new_tokens=self.max_length,
|
411 |
+
do_sample=do_sample,
|
412 |
+
return_dict_in_generate=True,
|
413 |
+
output_hidden_states=True,
|
414 |
+
stop_strings='<|eot_id|>', tokenizer=self.tokenizer,
|
415 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
416 |
+
)
|
417 |
+
|
418 |
+
# Concatenate hidden states from all generated tokens.
|
419 |
+
hidden_states = torch.cat([
|
420 |
+
per_token_hidden_states[-(hidden_state_skip_layer + 1)]
|
421 |
+
for per_token_hidden_states in outputs.hidden_states[1:]
|
422 |
+
], dim=1)
|
423 |
+
if self.apply_final_norm:
|
424 |
+
hidden_states = self.model.final_layer_norm(hidden_states)
|
425 |
+
|
426 |
+
# Make sequence mask from output sequences
|
427 |
+
output_max_valid_tokens = hidden_states.shape[1]
|
428 |
+
attention_mask = (outputs.sequences[:, -output_max_valid_tokens - 1:-1] != self.tokenizer.eos_token_id).long()
|
429 |
+
|
430 |
+
if return_texts:
|
431 |
+
text_outputs = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
|
432 |
+
return TextEncoderModelOutput(hidden_states, attention_mask, None, text_outputs)
|
433 |
+
else:
|
434 |
+
return TextEncoderModelOutput(hidden_states, attention_mask)
|
435 |
+
else:
|
436 |
+
raise ValueError(f"Unsupported text encoder infer mode: {self.infer_mode}")
|
437 |
+
|
438 |
+
def forward(
|
439 |
+
self,
|
440 |
+
text,
|
441 |
+
use_attention_mask=None,
|
442 |
+
output_hidden_states=False,
|
443 |
+
do_sample=False,
|
444 |
+
hidden_state_skip_layer=None,
|
445 |
+
return_texts=False
|
446 |
+
):
|
447 |
+
"""
|
448 |
+
Forward pass: encode text to hidden states.
|
449 |
+
|
450 |
+
Args:
|
451 |
+
text (str or list): Input text.
|
452 |
+
use_attention_mask (bool, optional): Whether to use attention mask.
|
453 |
+
output_hidden_states (bool): Whether to output all hidden states.
|
454 |
+
do_sample (bool): Whether to sample from the model (for decoder-only LLMs).
|
455 |
+
hidden_state_skip_layer (int, optional): Number of layers to skip from last for hidden state.
|
456 |
+
return_texts (bool): Whether to return decoded texts.
|
457 |
+
|
458 |
+
Returns:
|
459 |
+
TextEncoderModelOutput: Encoded output.
|
460 |
+
"""
|
461 |
+
batch_encoding = self.text2tokens(text)
|
462 |
+
return self.encode(
|
463 |
+
batch_encoding,
|
464 |
+
use_attention_mask=use_attention_mask,
|
465 |
+
output_hidden_states=output_hidden_states,
|
466 |
+
do_sample=do_sample,
|
467 |
+
hidden_state_skip_layer=hidden_state_skip_layer,
|
468 |
+
return_texts=return_texts
|
469 |
+
)
|
hyimage/models/text_encoder/byT5/__init__.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
5 |
+
|
6 |
+
|
7 |
+
def load_glyph_byT5_v2(args, device):
|
8 |
+
"""
|
9 |
+
Loads ByT5 tokenizer and encoder model for glyph encoding.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
args (dict): Configuration dictionary containing paths and settings.
|
13 |
+
device (str or torch.device): Device to load the model onto.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
dict: Dictionary with keys 'byt5_tokenizer', 'byt5_model', 'byt5_max_length'.
|
17 |
+
"""
|
18 |
+
byt5_tokenizer, byt5_model, byt5_max_length = create_byt5(args, device)
|
19 |
+
byt5_model = byt5_model.to(device=device)
|
20 |
+
return {
|
21 |
+
"byt5_tokenizer": byt5_tokenizer,
|
22 |
+
"byt5_model": byt5_model,
|
23 |
+
"byt5_max_length": byt5_max_length,
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def create_byt5(args, device):
|
28 |
+
"""
|
29 |
+
Create ByT5 tokenizer and encoder, load weights if provided.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
args (dict): Configuration dictionary.
|
33 |
+
device (str or torch.device): Device to load the model onto.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
tuple: (byt5_tokenizer, byt5_model, byt5_max_length)
|
37 |
+
"""
|
38 |
+
byt5_max_length = args['byt5_max_length']
|
39 |
+
byt5_config = dict(
|
40 |
+
byt5_name=args['byT5_google_path'],
|
41 |
+
special_token=True,
|
42 |
+
color_special_token=True,
|
43 |
+
font_special_token=True,
|
44 |
+
color_ann_path=args['multilingual_prompt_format_color_path'],
|
45 |
+
font_ann_path=args['multilingual_prompt_format_font_path'],
|
46 |
+
multilingual=True,
|
47 |
+
)
|
48 |
+
huggingface_cache_dir = None
|
49 |
+
byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer(
|
50 |
+
**byt5_config,
|
51 |
+
huggingface_cache_dir=huggingface_cache_dir,
|
52 |
+
device=device,
|
53 |
+
)
|
54 |
+
|
55 |
+
# Load custom checkpoint if provided
|
56 |
+
if args['byT5_ckpt_path'] is not None:
|
57 |
+
if "cuda" not in str(device):
|
58 |
+
byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=f"cuda:{device}")
|
59 |
+
else:
|
60 |
+
byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=device)
|
61 |
+
if 'state_dict' in byt5_state_dict:
|
62 |
+
sd = byt5_state_dict["state_dict"]
|
63 |
+
newsd = {}
|
64 |
+
for k, v in sd.items():
|
65 |
+
if k.startswith('module.text_tower.encoder.'):
|
66 |
+
newsd[k[len('module.text_tower.encoder.'):]] = v
|
67 |
+
byt5_state_dict = newsd
|
68 |
+
byt5_model.load_state_dict(byt5_state_dict)
|
69 |
+
byt5_model.requires_grad_(False)
|
70 |
+
return byt5_tokenizer, byt5_model, byt5_max_length
|
71 |
+
|
72 |
+
|
73 |
+
def add_special_token(
|
74 |
+
tokenizer,
|
75 |
+
text_encoder,
|
76 |
+
add_color,
|
77 |
+
add_font,
|
78 |
+
color_ann_path,
|
79 |
+
font_ann_path,
|
80 |
+
multilingual=False,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Add special tokens for color and font to tokenizer and text encoder.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
tokenizer: Huggingface tokenizer.
|
87 |
+
text_encoder: Huggingface T5 encoder.
|
88 |
+
add_color (bool): Whether to add color tokens.
|
89 |
+
add_font (bool): Whether to add font tokens.
|
90 |
+
color_ann_path (str): Path to color annotation JSON.
|
91 |
+
font_ann_path (str): Path to font annotation JSON.
|
92 |
+
multilingual (bool): Whether to use multilingual font tokens.
|
93 |
+
"""
|
94 |
+
with open(font_ann_path, 'r') as f:
|
95 |
+
idx_font_dict = json.load(f)
|
96 |
+
with open(color_ann_path, 'r') as f:
|
97 |
+
idx_color_dict = json.load(f)
|
98 |
+
|
99 |
+
if multilingual:
|
100 |
+
font_token = [f'<{font_code[:2]}-font-{idx_font_dict[font_code]}>' for font_code in idx_font_dict]
|
101 |
+
else:
|
102 |
+
font_token = [f'<font-{i}>' for i in range(len(idx_font_dict))]
|
103 |
+
color_token = [f'<color-{i}>' for i in range(len(idx_color_dict))]
|
104 |
+
additional_special_tokens = []
|
105 |
+
if add_color:
|
106 |
+
additional_special_tokens += color_token
|
107 |
+
if add_font:
|
108 |
+
additional_special_tokens += font_token
|
109 |
+
|
110 |
+
tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
|
111 |
+
# Set mean_resizing=False to avoid PyTorch LAPACK dependency
|
112 |
+
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
113 |
+
|
114 |
+
|
115 |
+
def load_byt5_and_byt5_tokenizer(
|
116 |
+
byt5_name='google/byt5-small',
|
117 |
+
special_token=False,
|
118 |
+
color_special_token=False,
|
119 |
+
font_special_token=False,
|
120 |
+
color_ann_path='assets/color_idx.json',
|
121 |
+
font_ann_path='assets/font_idx_512.json',
|
122 |
+
huggingface_cache_dir=None,
|
123 |
+
multilingual=False,
|
124 |
+
device=None,
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
Load ByT5 encoder and tokenizer from Huggingface, and add special tokens if needed.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
byt5_name (str): Model name or path.
|
131 |
+
special_token (bool): Whether to add special tokens.
|
132 |
+
color_special_token (bool): Whether to add color tokens.
|
133 |
+
font_special_token (bool): Whether to add font tokens.
|
134 |
+
color_ann_path (str): Path to color annotation JSON.
|
135 |
+
font_ann_path (str): Path to font annotation JSON.
|
136 |
+
huggingface_cache_dir (str): Huggingface cache directory.
|
137 |
+
multilingual (bool): Whether to use multilingual font tokens.
|
138 |
+
device (str or torch.device): Device to load the model onto.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
tuple: (byt5_text_encoder, byt5_tokenizer)
|
142 |
+
"""
|
143 |
+
byt5_tokenizer = AutoTokenizer.from_pretrained(
|
144 |
+
byt5_name,
|
145 |
+
cache_dir=huggingface_cache_dir,
|
146 |
+
)
|
147 |
+
byt5_text_encoder = T5ForConditionalGeneration.from_pretrained(
|
148 |
+
byt5_name,
|
149 |
+
cache_dir=huggingface_cache_dir,
|
150 |
+
).get_encoder()
|
151 |
+
|
152 |
+
if "cuda" not in str(device):
|
153 |
+
device = torch.device(f"cuda:{device}")
|
154 |
+
else:
|
155 |
+
device = torch.device(device)
|
156 |
+
byt5_text_encoder = byt5_text_encoder.to(device)
|
157 |
+
|
158 |
+
if special_token:
|
159 |
+
add_special_token(
|
160 |
+
byt5_tokenizer,
|
161 |
+
byt5_text_encoder,
|
162 |
+
add_color=color_special_token,
|
163 |
+
add_font=font_special_token,
|
164 |
+
color_ann_path=color_ann_path,
|
165 |
+
font_ann_path=font_ann_path,
|
166 |
+
multilingual=multilingual,
|
167 |
+
)
|
168 |
+
return byt5_text_encoder, byt5_tokenizer
|
169 |
+
|
170 |
+
|
171 |
+
class ByT5Mapper(nn.Module):
|
172 |
+
"""
|
173 |
+
ByT5Mapper: Maps ByT5 encoder outputs to a new space, with optional residual connection.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
in_dim (int): Input dimension (must equal out_dim if use_residual).
|
177 |
+
out_dim (int): Output dimension after second linear layer.
|
178 |
+
hidden_dim (int): Hidden dimension for intermediate layer.
|
179 |
+
out_dim1 (int): Final output dimension.
|
180 |
+
use_residual (bool): Whether to use residual connection (default: True).
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True):
|
184 |
+
super().__init__()
|
185 |
+
if use_residual:
|
186 |
+
assert in_dim == out_dim
|
187 |
+
self.layernorm = nn.LayerNorm(in_dim)
|
188 |
+
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
189 |
+
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
190 |
+
self.fc3 = nn.Linear(out_dim, out_dim1)
|
191 |
+
self.use_residual = use_residual
|
192 |
+
self.act_fn = nn.GELU()
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
"""
|
196 |
+
Forward pass for ByT5Mapper.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
x (Tensor): Input tensor of shape (..., in_dim).
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
Tensor: Output tensor of shape (..., out_dim1).
|
203 |
+
"""
|
204 |
+
residual = x
|
205 |
+
x = self.layernorm(x)
|
206 |
+
x = self.fc1(x)
|
207 |
+
x = self.act_fn(x)
|
208 |
+
x = self.fc2(x)
|
209 |
+
x2 = self.act_fn(x)
|
210 |
+
x2 = self.fc3(x2)
|
211 |
+
if self.use_residual:
|
212 |
+
x2 = x2 + residual
|
213 |
+
return x2
|
hyimage/models/vae/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pathlib import Path
|
3 |
+
from hyimage.common.constants import PRECISION_TO_TYPE
|
4 |
+
from .hunyuanimage_vae import HunyuanVAE2D
|
5 |
+
|
6 |
+
def load_vae(device, vae_path: str = None, vae_precision: str = None):
|
7 |
+
config = HunyuanVAE2D.load_config(vae_path)
|
8 |
+
vae = HunyuanVAE2D.from_config(config)
|
9 |
+
|
10 |
+
if Path(vae_path).exists():
|
11 |
+
ckpt = torch.load(Path(vae_path) / "pytorch_model.ckpt", map_location='cpu')
|
12 |
+
if "state_dict" in ckpt:
|
13 |
+
ckpt = ckpt["state_dict"]
|
14 |
+
vae_ckpt = {}
|
15 |
+
for k, v in ckpt.items():
|
16 |
+
if k.startswith("vae."):
|
17 |
+
vae_ckpt[k.replace("vae.", "")] = v
|
18 |
+
vae.load_state_dict(vae_ckpt)
|
19 |
+
|
20 |
+
if vae_precision is not None:
|
21 |
+
vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
|
22 |
+
|
23 |
+
vae.requires_grad_(False)
|
24 |
+
|
25 |
+
if device is not None:
|
26 |
+
vae = vae.to(device)
|
27 |
+
|
28 |
+
vae.eval()
|
29 |
+
return vae
|
hyimage/models/vae/hunyuanimage_vae.py
ADDED
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from diffusers.configuration_utils import ConfigMixin
|
7 |
+
from diffusers.configuration_utils import register_to_config
|
8 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
9 |
+
from diffusers.models.modeling_utils import ModelMixin
|
10 |
+
from diffusers.utils import BaseOutput
|
11 |
+
from diffusers.utils.torch_utils import randn_tensor
|
12 |
+
from einops import rearrange
|
13 |
+
from torch import Tensor, nn
|
14 |
+
from torch.nn import Conv2d
|
15 |
+
|
16 |
+
|
17 |
+
class DiagonalGaussianDistribution:
|
18 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
19 |
+
if parameters.ndim == 3:
|
20 |
+
dim = 2 # (B, L, C)
|
21 |
+
elif parameters.ndim == 5 or parameters.ndim == 4:
|
22 |
+
dim = 1 # (B, C, T, H, W) / (B, C, H, W)
|
23 |
+
else:
|
24 |
+
raise NotImplementedError
|
25 |
+
self.parameters = parameters
|
26 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
27 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
28 |
+
self.deterministic = deterministic
|
29 |
+
self.std = torch.exp(0.5 * self.logvar)
|
30 |
+
self.var = torch.exp(self.logvar)
|
31 |
+
if self.deterministic:
|
32 |
+
zero_tensor = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
|
33 |
+
self.var = zero_tensor
|
34 |
+
self.std = zero_tensor
|
35 |
+
|
36 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
37 |
+
sample = randn_tensor(
|
38 |
+
self.mean.shape,
|
39 |
+
generator=generator,
|
40 |
+
device=self.parameters.device,
|
41 |
+
dtype=self.parameters.dtype,
|
42 |
+
)
|
43 |
+
return self.mean + self.std * sample
|
44 |
+
|
45 |
+
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
|
46 |
+
if self.deterministic:
|
47 |
+
return torch.tensor([0.0], device=self.parameters.device, dtype=self.parameters.dtype)
|
48 |
+
reduce_dim = list(range(1, self.mean.ndim))
|
49 |
+
if other is None:
|
50 |
+
return 0.5 * torch.sum(
|
51 |
+
self.mean.pow(2) + self.var - 1.0 - self.logvar,
|
52 |
+
dim=reduce_dim,
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
return 0.5 * torch.sum(
|
56 |
+
(self.mean - other.mean).pow(2) / other.var
|
57 |
+
+ self.var / other.var
|
58 |
+
- 1.0
|
59 |
+
- self.logvar
|
60 |
+
+ other.logvar,
|
61 |
+
dim=reduce_dim,
|
62 |
+
)
|
63 |
+
|
64 |
+
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = (1, 2, 3)) -> torch.Tensor:
|
65 |
+
if self.deterministic:
|
66 |
+
return torch.tensor([0.0], device=self.parameters.device, dtype=self.parameters.dtype)
|
67 |
+
logtwopi = np.log(2.0 * np.pi)
|
68 |
+
return 0.5 * torch.sum(
|
69 |
+
logtwopi + self.logvar + (sample - self.mean).pow(2) / self.var,
|
70 |
+
dim=dims,
|
71 |
+
)
|
72 |
+
|
73 |
+
def mode(self) -> torch.Tensor:
|
74 |
+
return self.mean
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class DecoderOutput(BaseOutput):
|
79 |
+
"""Output of the decoder with sample and optional posterior distribution."""
|
80 |
+
sample: torch.FloatTensor
|
81 |
+
posterior: Optional[DiagonalGaussianDistribution] = None
|
82 |
+
|
83 |
+
|
84 |
+
def swish(x: Tensor) -> Tensor:
|
85 |
+
"""Swish activation function: x * sigmoid(x)."""
|
86 |
+
return x * torch.sigmoid(x)
|
87 |
+
|
88 |
+
|
89 |
+
def forward_with_checkpointing(module, *inputs, use_checkpointing=False):
|
90 |
+
"""
|
91 |
+
Forward pass with optional gradient checkpointing for memory efficiency.
|
92 |
+
|
93 |
+
Parameters
|
94 |
+
----------
|
95 |
+
module : nn.Module
|
96 |
+
The module to run.
|
97 |
+
*inputs : Tensor
|
98 |
+
Inputs to the module.
|
99 |
+
use_checkpointing : bool
|
100 |
+
Whether to use gradient checkpointing.
|
101 |
+
"""
|
102 |
+
def create_custom_forward(module):
|
103 |
+
def custom_forward(*inputs):
|
104 |
+
return module(*inputs)
|
105 |
+
return custom_forward
|
106 |
+
|
107 |
+
if use_checkpointing:
|
108 |
+
return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False)
|
109 |
+
else:
|
110 |
+
return module(*inputs)
|
111 |
+
|
112 |
+
|
113 |
+
class AttnBlock(nn.Module):
|
114 |
+
"""Self-attention block for 3D tensors."""
|
115 |
+
|
116 |
+
def __init__(self, in_channels: int):
|
117 |
+
super().__init__()
|
118 |
+
self.in_channels = in_channels
|
119 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
120 |
+
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
|
121 |
+
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
|
122 |
+
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
|
123 |
+
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
|
124 |
+
|
125 |
+
def attention(self, x: Tensor) -> Tensor:
|
126 |
+
x = self.norm(x)
|
127 |
+
q = self.q(x)
|
128 |
+
k = self.k(x)
|
129 |
+
v = self.v(x)
|
130 |
+
|
131 |
+
b, c, h, w = q.shape
|
132 |
+
q = rearrange(q, "b c h w -> b (h w) c").contiguous()
|
133 |
+
k = rearrange(k, "b c h w -> b (h w) c").contiguous()
|
134 |
+
v = rearrange(v, "b c h w -> b (h w) c").contiguous()
|
135 |
+
|
136 |
+
x = nn.functional.scaled_dot_product_attention(q, k, v)
|
137 |
+
return rearrange(x, "b (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
138 |
+
|
139 |
+
def forward(self, x: Tensor) -> Tensor:
|
140 |
+
return x + self.proj_out(self.attention(x))
|
141 |
+
|
142 |
+
|
143 |
+
class ResnetBlock(nn.Module):
|
144 |
+
"""
|
145 |
+
Residual block with two convolutions and optional channel change.
|
146 |
+
|
147 |
+
Parameters
|
148 |
+
----------
|
149 |
+
in_channels : int
|
150 |
+
Number of input channels.
|
151 |
+
out_channels : int
|
152 |
+
Number of output channels.
|
153 |
+
"""
|
154 |
+
|
155 |
+
def __init__(self, in_channels: int, out_channels: int):
|
156 |
+
super().__init__()
|
157 |
+
self.in_channels = in_channels
|
158 |
+
out_channels = in_channels if out_channels is None else out_channels
|
159 |
+
self.out_channels = out_channels
|
160 |
+
|
161 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
162 |
+
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
163 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
164 |
+
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
165 |
+
|
166 |
+
if self.in_channels != self.out_channels:
|
167 |
+
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
168 |
+
|
169 |
+
def forward(self, x: Tensor) -> Tensor:
|
170 |
+
h = x
|
171 |
+
h = self.norm1(h)
|
172 |
+
h = swish(h)
|
173 |
+
h = self.conv1(h)
|
174 |
+
h = self.norm2(h)
|
175 |
+
h = swish(h)
|
176 |
+
h = self.conv2(h)
|
177 |
+
|
178 |
+
if self.in_channels != self.out_channels:
|
179 |
+
x = self.nin_shortcut(x)
|
180 |
+
return x + h
|
181 |
+
|
182 |
+
|
183 |
+
class Downsample(nn.Module):
|
184 |
+
"""
|
185 |
+
Downsampling block for spatial reduction.
|
186 |
+
|
187 |
+
Parameters
|
188 |
+
----------
|
189 |
+
in_channels : int
|
190 |
+
Number of input channels.
|
191 |
+
out_channels : int
|
192 |
+
Number of output channels.
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self, in_channels: int, out_channels: int):
|
196 |
+
super().__init__()
|
197 |
+
factor = 4
|
198 |
+
assert out_channels % factor == 0
|
199 |
+
|
200 |
+
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
201 |
+
self.group_size = factor * in_channels // out_channels
|
202 |
+
|
203 |
+
def forward(self, x: Tensor) -> Tensor:
|
204 |
+
h = self.conv(x)
|
205 |
+
h = rearrange(h, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2)
|
206 |
+
shortcut = rearrange(x, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2)
|
207 |
+
|
208 |
+
B, C, H, W = shortcut.shape
|
209 |
+
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
|
210 |
+
return h + shortcut
|
211 |
+
|
212 |
+
|
213 |
+
class Upsample(nn.Module):
|
214 |
+
"""
|
215 |
+
Upsampling block for spatial expansion.
|
216 |
+
|
217 |
+
Parameters
|
218 |
+
----------
|
219 |
+
in_channels : int
|
220 |
+
Number of input channels.
|
221 |
+
out_channels : int
|
222 |
+
Number of output channels.
|
223 |
+
"""
|
224 |
+
|
225 |
+
def __init__(self, in_channels: int, out_channels: int):
|
226 |
+
super().__init__()
|
227 |
+
factor = 4
|
228 |
+
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
229 |
+
self.repeats = factor * out_channels // in_channels
|
230 |
+
|
231 |
+
def forward(self, x: Tensor) -> Tensor:
|
232 |
+
h = self.conv(x)
|
233 |
+
h = rearrange(h, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2)
|
234 |
+
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
235 |
+
shortcut = rearrange(shortcut, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2)
|
236 |
+
return h + shortcut
|
237 |
+
|
238 |
+
|
239 |
+
class Encoder(nn.Module):
|
240 |
+
"""
|
241 |
+
Encoder network that compresses input to latent representation.
|
242 |
+
|
243 |
+
Parameters
|
244 |
+
----------
|
245 |
+
in_channels : int
|
246 |
+
Number of input channels.
|
247 |
+
z_channels : int
|
248 |
+
Number of latent channels.
|
249 |
+
block_out_channels : Tuple[int, ...]
|
250 |
+
Output channels for each block.
|
251 |
+
num_res_blocks : int
|
252 |
+
Number of residual blocks per block.
|
253 |
+
ffactor_spatial : int
|
254 |
+
Spatial downsampling factor.
|
255 |
+
downsample_match_channel : bool
|
256 |
+
Whether to match channels during downsampling.
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
in_channels: int,
|
262 |
+
z_channels: int,
|
263 |
+
block_out_channels: Tuple[int, ...],
|
264 |
+
num_res_blocks: int,
|
265 |
+
ffactor_spatial: int,
|
266 |
+
downsample_match_channel: bool = True,
|
267 |
+
):
|
268 |
+
super().__init__()
|
269 |
+
assert block_out_channels[-1] % (2 * z_channels) == 0
|
270 |
+
|
271 |
+
self.z_channels = z_channels
|
272 |
+
self.block_out_channels = block_out_channels
|
273 |
+
self.num_res_blocks = num_res_blocks
|
274 |
+
|
275 |
+
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
276 |
+
|
277 |
+
self.down = nn.ModuleList()
|
278 |
+
block_in = block_out_channels[0]
|
279 |
+
|
280 |
+
for i_level, ch in enumerate(block_out_channels):
|
281 |
+
block = nn.ModuleList()
|
282 |
+
block_out = ch
|
283 |
+
|
284 |
+
for _ in range(self.num_res_blocks):
|
285 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
286 |
+
block_in = block_out
|
287 |
+
|
288 |
+
down = nn.Module()
|
289 |
+
down.block = block
|
290 |
+
|
291 |
+
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
|
292 |
+
|
293 |
+
if add_spatial_downsample:
|
294 |
+
assert i_level < len(block_out_channels) - 1
|
295 |
+
block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in
|
296 |
+
down.downsample = Downsample(block_in, block_out)
|
297 |
+
block_in = block_out
|
298 |
+
|
299 |
+
self.down.append(down)
|
300 |
+
|
301 |
+
# Middle blocks with attention
|
302 |
+
self.mid = nn.Module()
|
303 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
304 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
305 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
306 |
+
|
307 |
+
# Output layers
|
308 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
309 |
+
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
310 |
+
|
311 |
+
self.gradient_checkpointing = False
|
312 |
+
|
313 |
+
def forward(self, x: Tensor) -> Tensor:
|
314 |
+
use_checkpointing = bool(self.training and self.gradient_checkpointing)
|
315 |
+
|
316 |
+
# Downsampling
|
317 |
+
h = self.conv_in(x)
|
318 |
+
for i_level in range(len(self.block_out_channels)):
|
319 |
+
for i_block in range(self.num_res_blocks):
|
320 |
+
h = forward_with_checkpointing(
|
321 |
+
self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing
|
322 |
+
)
|
323 |
+
if hasattr(self.down[i_level], "downsample"):
|
324 |
+
h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing)
|
325 |
+
|
326 |
+
# Middle processing
|
327 |
+
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
|
328 |
+
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
|
329 |
+
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
|
330 |
+
|
331 |
+
# Output with shortcut connection
|
332 |
+
group_size = self.block_out_channels[-1] // (2 * self.z_channels)
|
333 |
+
shortcut = rearrange(h, "b (c r) h w -> b c r h w", r=group_size).mean(dim=2)
|
334 |
+
h = self.norm_out(h)
|
335 |
+
h = swish(h)
|
336 |
+
h = self.conv_out(h)
|
337 |
+
h += shortcut
|
338 |
+
return h
|
339 |
+
|
340 |
+
|
341 |
+
class Decoder(nn.Module):
|
342 |
+
"""
|
343 |
+
Decoder network that reconstructs output from latent representation.
|
344 |
+
|
345 |
+
Parameters
|
346 |
+
----------
|
347 |
+
z_channels : int
|
348 |
+
Number of latent channels.
|
349 |
+
out_channels : int
|
350 |
+
Number of output channels.
|
351 |
+
block_out_channels : Tuple[int, ...]
|
352 |
+
Output channels for each block.
|
353 |
+
num_res_blocks : int
|
354 |
+
Number of residual blocks per block.
|
355 |
+
ffactor_spatial : int
|
356 |
+
Spatial upsampling factor.
|
357 |
+
upsample_match_channel : bool
|
358 |
+
Whether to match channels during upsampling.
|
359 |
+
"""
|
360 |
+
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
z_channels: int,
|
364 |
+
out_channels: int,
|
365 |
+
block_out_channels: Tuple[int, ...],
|
366 |
+
num_res_blocks: int,
|
367 |
+
ffactor_spatial: int,
|
368 |
+
upsample_match_channel: bool = True,
|
369 |
+
):
|
370 |
+
super().__init__()
|
371 |
+
assert block_out_channels[0] % z_channels == 0
|
372 |
+
|
373 |
+
self.z_channels = z_channels
|
374 |
+
self.block_out_channels = block_out_channels
|
375 |
+
self.num_res_blocks = num_res_blocks
|
376 |
+
|
377 |
+
block_in = block_out_channels[0]
|
378 |
+
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
379 |
+
|
380 |
+
# Middle blocks with attention
|
381 |
+
self.mid = nn.Module()
|
382 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
383 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
384 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
385 |
+
|
386 |
+
# Upsampling blocks
|
387 |
+
self.up = nn.ModuleList()
|
388 |
+
for i_level, ch in enumerate(block_out_channels):
|
389 |
+
block = nn.ModuleList()
|
390 |
+
block_out = ch
|
391 |
+
|
392 |
+
for _ in range(self.num_res_blocks + 1):
|
393 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
394 |
+
block_in = block_out
|
395 |
+
|
396 |
+
up = nn.Module()
|
397 |
+
up.block = block
|
398 |
+
|
399 |
+
# Determine upsampling strategy
|
400 |
+
add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
|
401 |
+
|
402 |
+
if add_spatial_upsample:
|
403 |
+
assert i_level < len(block_out_channels) - 1
|
404 |
+
block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in
|
405 |
+
up.upsample = Upsample(block_in, block_out)
|
406 |
+
block_in = block_out
|
407 |
+
|
408 |
+
self.up.append(up)
|
409 |
+
|
410 |
+
# Output layers
|
411 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
412 |
+
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
413 |
+
|
414 |
+
self.gradient_checkpointing = False
|
415 |
+
|
416 |
+
def forward(self, z: Tensor) -> Tensor:
|
417 |
+
use_checkpointing = bool(self.training and self.gradient_checkpointing)
|
418 |
+
|
419 |
+
repeats = self.block_out_channels[0] // self.z_channels
|
420 |
+
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
421 |
+
|
422 |
+
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
|
423 |
+
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
|
424 |
+
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
|
425 |
+
|
426 |
+
for i_level in range(len(self.block_out_channels)):
|
427 |
+
for i_block in range(self.num_res_blocks + 1):
|
428 |
+
h = forward_with_checkpointing(self.up[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
|
429 |
+
if hasattr(self.up[i_level], "upsample"):
|
430 |
+
h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing)
|
431 |
+
|
432 |
+
h = self.norm_out(h)
|
433 |
+
h = swish(h)
|
434 |
+
h = self.conv_out(h)
|
435 |
+
return h
|
436 |
+
|
437 |
+
|
438 |
+
class HunyuanVAE2D(ModelMixin, ConfigMixin):
|
439 |
+
"""
|
440 |
+
HunyuanVAE2D: A 2D image VAE model with spatial tiling support.
|
441 |
+
|
442 |
+
This model implements a variational autoencoder specifically designed for image data,
|
443 |
+
with support for memory-efficient processing through tiling strategies.
|
444 |
+
"""
|
445 |
+
|
446 |
+
_supports_gradient_checkpointing = True
|
447 |
+
|
448 |
+
@register_to_config
|
449 |
+
def __init__(
|
450 |
+
self,
|
451 |
+
in_channels: int,
|
452 |
+
out_channels: int,
|
453 |
+
latent_channels: int,
|
454 |
+
block_out_channels: Tuple[int, ...],
|
455 |
+
layers_per_block: int,
|
456 |
+
ffactor_spatial: int,
|
457 |
+
sample_size: int,
|
458 |
+
sample_tsize: int,
|
459 |
+
scaling_factor: float = None,
|
460 |
+
shift_factor: Optional[float] = None,
|
461 |
+
downsample_match_channel: bool = True,
|
462 |
+
upsample_match_channel: bool = True,
|
463 |
+
**kwargs,
|
464 |
+
):
|
465 |
+
super().__init__()
|
466 |
+
self.ffactor_spatial = ffactor_spatial
|
467 |
+
self.scaling_factor = scaling_factor
|
468 |
+
self.shift_factor = shift_factor
|
469 |
+
|
470 |
+
self.encoder = Encoder(
|
471 |
+
in_channels=in_channels,
|
472 |
+
z_channels=latent_channels,
|
473 |
+
block_out_channels=block_out_channels,
|
474 |
+
num_res_blocks=layers_per_block,
|
475 |
+
ffactor_spatial=ffactor_spatial,
|
476 |
+
downsample_match_channel=downsample_match_channel,
|
477 |
+
)
|
478 |
+
|
479 |
+
self.decoder = Decoder(
|
480 |
+
z_channels=latent_channels,
|
481 |
+
out_channels=out_channels,
|
482 |
+
block_out_channels=list(reversed(block_out_channels)),
|
483 |
+
num_res_blocks=layers_per_block,
|
484 |
+
ffactor_spatial=ffactor_spatial,
|
485 |
+
upsample_match_channel=upsample_match_channel,
|
486 |
+
)
|
487 |
+
|
488 |
+
# Tiling and slicing configuration
|
489 |
+
self.use_slicing = False
|
490 |
+
self.use_spatial_tiling = False
|
491 |
+
self.use_tiling_during_training = False
|
492 |
+
|
493 |
+
# Tiling parameters
|
494 |
+
self.tile_sample_min_size = sample_size
|
495 |
+
self.tile_latent_min_size = sample_size // ffactor_spatial
|
496 |
+
self.tile_overlap_factor = 0.25
|
497 |
+
|
498 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
499 |
+
"""
|
500 |
+
Enable or disable gradient checkpointing for memory efficiency.
|
501 |
+
|
502 |
+
Parameters
|
503 |
+
----------
|
504 |
+
module : nn.Module
|
505 |
+
The module to set.
|
506 |
+
value : bool
|
507 |
+
Whether to enable gradient checkpointing.
|
508 |
+
"""
|
509 |
+
if isinstance(module, (Encoder, Decoder)):
|
510 |
+
module.gradient_checkpointing = value
|
511 |
+
|
512 |
+
def enable_spatial_tiling(self, use_tiling: bool = True):
|
513 |
+
"""Enable or disable spatial tiling."""
|
514 |
+
self.use_spatial_tiling = use_tiling
|
515 |
+
|
516 |
+
def disable_spatial_tiling(self):
|
517 |
+
"""Disable spatial tiling."""
|
518 |
+
self.use_spatial_tiling = False
|
519 |
+
|
520 |
+
def enable_tiling(self, use_tiling: bool = True):
|
521 |
+
"""Enable or disable spatial tiling (alias for enable_spatial_tiling)."""
|
522 |
+
self.enable_spatial_tiling(use_tiling)
|
523 |
+
|
524 |
+
def disable_tiling(self):
|
525 |
+
"""Disable spatial tiling (alias for disable_spatial_tiling)."""
|
526 |
+
self.disable_spatial_tiling()
|
527 |
+
|
528 |
+
def enable_slicing(self):
|
529 |
+
"""Enable slicing for batch processing."""
|
530 |
+
self.use_slicing = True
|
531 |
+
|
532 |
+
def disable_slicing(self):
|
533 |
+
"""Disable slicing for batch processing."""
|
534 |
+
self.use_slicing = False
|
535 |
+
|
536 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
537 |
+
"""
|
538 |
+
Blend two tensors horizontally with smooth transition.
|
539 |
+
|
540 |
+
Parameters
|
541 |
+
----------
|
542 |
+
a : torch.Tensor
|
543 |
+
Left tensor.
|
544 |
+
b : torch.Tensor
|
545 |
+
Right tensor.
|
546 |
+
blend_extent : int
|
547 |
+
Number of columns to blend.
|
548 |
+
"""
|
549 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
550 |
+
for x in range(blend_extent):
|
551 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
552 |
+
x / blend_extent
|
553 |
+
)
|
554 |
+
return b
|
555 |
+
|
556 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
557 |
+
"""
|
558 |
+
Blend two tensors vertically with smooth transition.
|
559 |
+
|
560 |
+
Parameters
|
561 |
+
----------
|
562 |
+
a : torch.Tensor
|
563 |
+
Top tensor.
|
564 |
+
b : torch.Tensor
|
565 |
+
Bottom tensor.
|
566 |
+
blend_extent : int
|
567 |
+
Number of rows to blend.
|
568 |
+
"""
|
569 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
570 |
+
for y in range(blend_extent):
|
571 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
572 |
+
y / blend_extent
|
573 |
+
)
|
574 |
+
return b
|
575 |
+
|
576 |
+
def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
577 |
+
"""
|
578 |
+
Encode input using spatial tiling strategy.
|
579 |
+
|
580 |
+
Parameters
|
581 |
+
----------
|
582 |
+
x : torch.Tensor
|
583 |
+
Input tensor of shape (B, C, T, H, W).
|
584 |
+
"""
|
585 |
+
B, C, T, H, W = x.shape
|
586 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
587 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
588 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
589 |
+
|
590 |
+
rows = []
|
591 |
+
for i in range(0, H, overlap_size):
|
592 |
+
row = []
|
593 |
+
for j in range(0, W, overlap_size):
|
594 |
+
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
595 |
+
tile = self.encoder(tile)
|
596 |
+
row.append(tile)
|
597 |
+
rows.append(row)
|
598 |
+
|
599 |
+
result_rows = []
|
600 |
+
for i, row in enumerate(rows):
|
601 |
+
result_row = []
|
602 |
+
for j, tile in enumerate(row):
|
603 |
+
if i > 0:
|
604 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
605 |
+
if j > 0:
|
606 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
607 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
608 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
609 |
+
|
610 |
+
moments = torch.cat(result_rows, dim=-2)
|
611 |
+
return moments
|
612 |
+
|
613 |
+
def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
|
614 |
+
"""
|
615 |
+
Decode latent using spatial tiling strategy.
|
616 |
+
|
617 |
+
Parameters
|
618 |
+
----------
|
619 |
+
z : torch.Tensor
|
620 |
+
Latent tensor of shape (B, C, H, W).
|
621 |
+
"""
|
622 |
+
B, C, H, W = z.shape
|
623 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
624 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
625 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
626 |
+
|
627 |
+
rows = []
|
628 |
+
for i in range(0, H, overlap_size):
|
629 |
+
row = []
|
630 |
+
for j in range(0, W, overlap_size):
|
631 |
+
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
632 |
+
decoded = self.decoder(tile)
|
633 |
+
row.append(decoded)
|
634 |
+
rows.append(row)
|
635 |
+
|
636 |
+
result_rows = []
|
637 |
+
for i, row in enumerate(rows):
|
638 |
+
result_row = []
|
639 |
+
for j, tile in enumerate(row):
|
640 |
+
if i > 0:
|
641 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
642 |
+
if j > 0:
|
643 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
644 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
645 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
646 |
+
|
647 |
+
dec = torch.cat(result_rows, dim=-2)
|
648 |
+
return dec
|
649 |
+
|
650 |
+
def encode(self, x: Tensor, return_dict: bool = True):
|
651 |
+
"""
|
652 |
+
Encode input tensor to latent representation.
|
653 |
+
|
654 |
+
Parameters
|
655 |
+
----------
|
656 |
+
x : Tensor
|
657 |
+
Input tensor.
|
658 |
+
return_dict : bool
|
659 |
+
Whether to return a dict.
|
660 |
+
"""
|
661 |
+
original_ndim = x.ndim
|
662 |
+
if original_ndim == 5:
|
663 |
+
x = x.squeeze(2)
|
664 |
+
|
665 |
+
def _encode(x):
|
666 |
+
if self.use_spatial_tiling and (
|
667 |
+
x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
|
668 |
+
):
|
669 |
+
return self.spatial_tiled_encode(x)
|
670 |
+
return self.encoder(x)
|
671 |
+
|
672 |
+
# Process with or without slicing
|
673 |
+
if self.use_slicing and x.shape[0] > 1:
|
674 |
+
encoded_slices = [_encode(x_slice) for x_slice in x.split(1)]
|
675 |
+
h = torch.cat(encoded_slices)
|
676 |
+
else:
|
677 |
+
h = _encode(x)
|
678 |
+
|
679 |
+
if original_ndim == 5:
|
680 |
+
h = h.unsqueeze(2)
|
681 |
+
|
682 |
+
posterior = DiagonalGaussianDistribution(h)
|
683 |
+
|
684 |
+
if not return_dict:
|
685 |
+
return (posterior,)
|
686 |
+
|
687 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
688 |
+
|
689 |
+
def decode(self, z: Tensor, return_dict: bool = True, generator=None):
|
690 |
+
"""
|
691 |
+
Decode latent representation to output tensor.
|
692 |
+
|
693 |
+
Parameters
|
694 |
+
----------
|
695 |
+
z : Tensor
|
696 |
+
Latent tensor.
|
697 |
+
return_dict : bool
|
698 |
+
Whether to return a dict.
|
699 |
+
generator : unused
|
700 |
+
For compatibility.
|
701 |
+
"""
|
702 |
+
original_ndim = z.ndim
|
703 |
+
if original_ndim == 5:
|
704 |
+
z = z.squeeze(2)
|
705 |
+
|
706 |
+
def _decode(z):
|
707 |
+
if self.use_spatial_tiling and (
|
708 |
+
z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size
|
709 |
+
):
|
710 |
+
return self.spatial_tiled_decode(z)
|
711 |
+
return self.decoder(z)
|
712 |
+
|
713 |
+
if self.use_slicing and z.shape[0] > 1:
|
714 |
+
decoded_slices = [_decode(z_slice) for z_slice in z.split(1)]
|
715 |
+
decoded = torch.cat(decoded_slices)
|
716 |
+
else:
|
717 |
+
decoded = _decode(z)
|
718 |
+
|
719 |
+
if original_ndim == 5:
|
720 |
+
decoded = decoded.unsqueeze(2)
|
721 |
+
|
722 |
+
if not return_dict:
|
723 |
+
return (decoded,)
|
724 |
+
|
725 |
+
return DecoderOutput(sample=decoded)
|
726 |
+
|
727 |
+
def forward(
|
728 |
+
self,
|
729 |
+
sample: torch.Tensor,
|
730 |
+
sample_posterior: bool = False,
|
731 |
+
return_posterior: bool = True,
|
732 |
+
return_dict: bool = True,
|
733 |
+
):
|
734 |
+
"""
|
735 |
+
Forward pass through the VAE (Encode and Decode).
|
736 |
+
|
737 |
+
Parameters
|
738 |
+
----------
|
739 |
+
sample : torch.Tensor
|
740 |
+
Input tensor.
|
741 |
+
sample_posterior : bool
|
742 |
+
Whether to sample from the posterior.
|
743 |
+
return_posterior : bool
|
744 |
+
Whether to return the posterior.
|
745 |
+
return_dict : bool
|
746 |
+
Whether to return a dict.
|
747 |
+
"""
|
748 |
+
posterior = self.encode(sample).latent_dist
|
749 |
+
z = posterior.sample() if sample_posterior else posterior.mode()
|
750 |
+
dec = self.decode(z).sample
|
751 |
+
|
752 |
+
if return_dict:
|
753 |
+
return DecoderOutput(sample=dec, posterior=posterior)
|
754 |
+
else:
|
755 |
+
return (dec, posterior)
|
756 |
+
|
757 |
+
def load_state_dict(self, state_dict, strict=True):
|
758 |
+
"""
|
759 |
+
Load state dict, handling possible 5D weight tensors.
|
760 |
+
|
761 |
+
Parameters
|
762 |
+
----------
|
763 |
+
state_dict : dict
|
764 |
+
State dictionary.
|
765 |
+
strict : bool
|
766 |
+
Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict function.
|
767 |
+
"""
|
768 |
+
converted_state_dict = {}
|
769 |
+
|
770 |
+
for key, value in state_dict.items():
|
771 |
+
if 'weight' in key:
|
772 |
+
if len(value.shape) == 5 and value.shape[2] == 1:
|
773 |
+
converted_state_dict[key] = value.squeeze(2)
|
774 |
+
else:
|
775 |
+
converted_state_dict[key] = value
|
776 |
+
else:
|
777 |
+
converted_state_dict[key] = value
|
778 |
+
|
779 |
+
return super().load_state_dict(converted_state_dict, strict=strict)
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm==4.67.1
|
2 |
+
torch>=2.6.0
|
3 |
+
einops==0.8.0
|
4 |
+
loguru==0.7.3
|
5 |
+
numpy==1.26.4
|
6 |
+
pillow==11.3.0
|
7 |
+
omegaconf>=2.3.0
|
8 |
+
torchaudio==2.6.0
|
9 |
+
diffusers>=0.32.0
|
10 |
+
safetensors==0.4.5
|
11 |
+
torchvision==0.21.0
|
12 |
+
huggingface-hub==0.34.0
|
13 |
+
transformers[accelerate,tiktoken]==4.56.0
|
14 |
+
wheel
|
15 |
+
setuptools
|
16 |
+
modelscope
|
17 |
+
huggingface_hub[cli]
|