fffiloni commited on
Commit
e462867
·
verified ·
1 Parent(s): 7a6f992

Migrated from GitHub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +12 -0
  2. INSTALL.md +54 -0
  3. LICENSE.txt +201 -0
  4. ORIGINAL_README.md +663 -0
  5. assets/comp_effic.png +3 -0
  6. assets/data_for_diff_stage.jpg +3 -0
  7. assets/i2v_res.png +3 -0
  8. assets/logo.png +0 -0
  9. assets/t2v_res.jpg +3 -0
  10. assets/vben_vs_sota.png +3 -0
  11. assets/video_dit_arch.jpg +3 -0
  12. assets/video_vae_res.jpg +3 -0
  13. examples/flf2v_input_first_frame.png +3 -0
  14. examples/flf2v_input_last_frame.png +3 -0
  15. examples/girl.png +3 -0
  16. examples/i2v_input.JPG +3 -0
  17. examples/snake.png +3 -0
  18. generate.py +572 -0
  19. gradio/fl2v_14B_singleGPU.py +252 -0
  20. gradio/i2v_14B_singleGPU.py +287 -0
  21. gradio/t2i_14B_singleGPU.py +205 -0
  22. gradio/t2v_1.3B_singleGPU.py +207 -0
  23. gradio/t2v_14B_singleGPU.py +205 -0
  24. gradio/vace.py +295 -0
  25. pyproject.toml +67 -0
  26. requirements.txt +16 -0
  27. tests/README.md +6 -0
  28. tests/test.sh +120 -0
  29. wan/__init__.py +5 -0
  30. wan/configs/__init__.py +53 -0
  31. wan/configs/shared_config.py +19 -0
  32. wan/configs/wan_i2v_14B.py +36 -0
  33. wan/configs/wan_t2v_14B.py +29 -0
  34. wan/configs/wan_t2v_1_3B.py +29 -0
  35. wan/distributed/__init__.py +0 -0
  36. wan/distributed/fsdp.py +41 -0
  37. wan/distributed/xdit_context_parallel.py +230 -0
  38. wan/first_last_frame2video.py +370 -0
  39. wan/image2video.py +347 -0
  40. wan/modules/__init__.py +18 -0
  41. wan/modules/attention.py +179 -0
  42. wan/modules/clip.py +542 -0
  43. wan/modules/model.py +630 -0
  44. wan/modules/t5.py +513 -0
  45. wan/modules/tokenizers.py +82 -0
  46. wan/modules/vace_model.py +233 -0
  47. wan/modules/vae.py +663 -0
  48. wan/modules/xlm_roberta.py +170 -0
  49. wan/text2video.py +267 -0
  50. wan/utils/__init__.py +10 -0
.gitattributes CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/comp_effic.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/data_for_diff_stage.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/i2v_res.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/t2v_res.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/vben_vs_sota.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/video_dit_arch.jpg filter=lfs diff=lfs merge=lfs -text
42
+ assets/video_vae_res.jpg filter=lfs diff=lfs merge=lfs -text
43
+ examples/flf2v_input_first_frame.png filter=lfs diff=lfs merge=lfs -text
44
+ examples/flf2v_input_last_frame.png filter=lfs diff=lfs merge=lfs -text
45
+ examples/girl.png filter=lfs diff=lfs merge=lfs -text
46
+ examples/i2v_input.JPG filter=lfs diff=lfs merge=lfs -text
47
+ examples/snake.png filter=lfs diff=lfs merge=lfs -text
INSTALL.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation Guide
2
+
3
+ ## Install with pip
4
+
5
+ ```bash
6
+ pip install .
7
+ pip install .[dev] # Installe aussi les outils de dev
8
+ ```
9
+
10
+ ## Install with Poetry
11
+
12
+ Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
13
+
14
+ To install all dependencies:
15
+
16
+ ```bash
17
+ poetry install
18
+ ```
19
+
20
+ ### Handling `flash-attn` Installation Issues
21
+
22
+ If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.
23
+
24
+ #### No-Build-Isolation Installation (Recommended)
25
+ ```bash
26
+ poetry run pip install --upgrade pip setuptools wheel
27
+ poetry run pip install flash-attn --no-build-isolation
28
+ poetry install
29
+ ```
30
+
31
+ #### Install from Git (Alternative)
32
+ ```bash
33
+ poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
34
+ ```
35
+
36
+ ---
37
+
38
+ ### Running the Model
39
+
40
+ Once the installation is complete, you can run **Wan2.1** using:
41
+
42
+ ```bash
43
+ poetry run python generate.py --task t2v-14B --size '1280x720' --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
44
+ ```
45
+
46
+ #### Test
47
+ ```bash
48
+ pytest tests/
49
+ ```
50
+ #### Format
51
+ ```bash
52
+ black .
53
+ isort .
54
+ ```
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ORIGINAL_README.md ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wan2.1
2
+
3
+ <p align="center">
4
+ <img src="assets/logo.png" width="400"/>
5
+ <p>
6
+
7
+ <p align="center">
8
+ 💜 <a href="https://wan.video"><b>Wan</b></a> &nbsp&nbsp | &nbsp&nbsp 🖥️ <a href="https://github.com/Wan-Video/Wan2.1">GitHub</a> &nbsp&nbsp | &nbsp&nbsp🤗 <a href="https://huggingface.co/Wan-AI/">Hugging Face</a>&nbsp&nbsp | &nbsp&nbsp🤖 <a href="https://modelscope.cn/organization/Wan-AI">ModelScope</a>&nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://arxiv.org/abs/2503.20314">Technical Report</a> &nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://wan.video/welcome?spm=a2ty_o02.30011076.0.0.6c9ee41eCcluqg">Blog</a> &nbsp&nbsp | &nbsp&nbsp💬 <a href="https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg">WeChat Group</a>&nbsp&nbsp | &nbsp&nbsp 📖 <a href="https://discord.gg/AKNgpMK4Yj">Discord</a>&nbsp&nbsp
9
+ <br>
10
+
11
+ -----
12
+
13
+ [**Wan: Open and Advanced Large-Scale Video Generative Models**](https://arxiv.org/abs/2503.20314) <be>
14
+
15
+ In this repository, we present **Wan2.1**, a comprehensive and open suite of video foundation models that pushes the boundaries of video generation. **Wan2.1** offers these key features:
16
+ - 👍 **SOTA Performance**: **Wan2.1** consistently outperforms existing open-source models and state-of-the-art commercial solutions across multiple benchmarks.
17
+ - 👍 **Supports Consumer-grade GPUs**: The T2V-1.3B model requires only 8.19 GB VRAM, making it compatible with almost all consumer-grade GPUs. It can generate a 5-second 480P video on an RTX 4090 in about 4 minutes (without optimization techniques like quantization). Its performance is even comparable to some closed-source models.
18
+ - 👍 **Multiple Tasks**: **Wan2.1** excels in Text-to-Video, Image-to-Video, Video Editing, Text-to-Image, and Video-to-Audio, advancing the field of video generation.
19
+ - 👍 **Visual Text Generation**: **Wan2.1** is the first video model capable of generating both Chinese and English text, featuring robust text generation that enhances its practical applications.
20
+ - 👍 **Powerful Video VAE**: **Wan-VAE** delivers exceptional efficiency and performance, encoding and decoding 1080P videos of any length while preserving temporal information, making it an ideal foundation for video and image generation.
21
+
22
+ ## Video Demos
23
+
24
+ <div align="center">
25
+ <video src="https://github.com/user-attachments/assets/4aca6063-60bf-4953-bfb7-e265053f49ef" width="70%" poster=""> </video>
26
+ </div>
27
+
28
+ ## 🔥 Latest News!!
29
+
30
+ * May 14, 2025: 👋 We introduce **Wan2.1** [VACE](https://github.com/ali-vilab/VACE), an all-in-one model for video creation and editing, along with its [inference code](#run-vace), [weights](#model-download), and [technical report](https://arxiv.org/abs/2503.07598)!
31
+ * Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
32
+ * Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
33
+ * Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
34
+ * Feb 27, 2025: 👋 **Wan2.1** has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
35
+ * Feb 25, 2025: 👋 We've released the inference code and weights of **Wan2.1**.
36
+
37
+ ## Community Works
38
+ If your work has improved **Wan2.1** and you would like more people to see it, please inform us.
39
+ - [Phantom](https://github.com/Phantom-video/Phantom) has developed a unified video generation framework for single and multi-subject references based on **Wan2.1-T2V-1.3B**. Please refer to [their examples](https://github.com/Phantom-video/Phantom).
40
+ - [UniAnimate-DiT](https://github.com/ali-vilab/UniAnimate-DiT), based on **Wan2.1-14B-I2V**, has trained a Human image animation model and has open-sourced the inference and training code. Feel free to enjoy it!
41
+ - [CFG-Zero](https://github.com/WeichenFan/CFG-Zero-star) enhances **Wan2.1** (covering both T2V and I2V models) from the perspective of CFG.
42
+ - [TeaCache](https://github.com/ali-vilab/TeaCache) now supports **Wan2.1** acceleration, capable of increasing speed by approximately 2x. Feel free to give it a try!
43
+ - [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) provides more support for **Wan2.1**, including video-to-video, FP8 quantization, VRAM optimization, LoRA training, and more. Please refer to [their examples](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo).
44
+
45
+
46
+ ## 📑 Todo List
47
+ - Wan2.1 Text-to-Video
48
+ - [x] Multi-GPU Inference code of the 14B and 1.3B models
49
+ - [x] Checkpoints of the 14B and 1.3B models
50
+ - [x] Gradio demo
51
+ - [x] ComfyUI integration
52
+ - [x] Diffusers integration
53
+ - [ ] Diffusers + Multi-GPU Inference
54
+ - Wan2.1 Image-to-Video
55
+ - [x] Multi-GPU Inference code of the 14B model
56
+ - [x] Checkpoints of the 14B model
57
+ - [x] Gradio demo
58
+ - [x] ComfyUI integration
59
+ - [x] Diffusers integration
60
+ - [ ] Diffusers + Multi-GPU Inference
61
+ - Wan2.1 First-Last-Frame-to-Video
62
+ - [x] Multi-GPU Inference code of the 14B model
63
+ - [x] Checkpoints of the 14B model
64
+ - [x] Gradio demo
65
+ - [ ] ComfyUI integration
66
+ - [ ] Diffusers integration
67
+ - [ ] Diffusers + Multi-GPU Inference
68
+ - Wan2.1 VACE
69
+ - [x] Multi-GPU Inference code of the 14B and 1.3B models
70
+ - [x] Checkpoints of the 14B and 1.3B models
71
+ - [x] Gradio demo
72
+ - [x] ComfyUI integration
73
+ - [ ] Diffusers integration
74
+ - [ ] Diffusers + Multi-GPU Inference
75
+
76
+ ## Quickstart
77
+
78
+ #### Installation
79
+ Clone the repo:
80
+ ```sh
81
+ git clone https://github.com/Wan-Video/Wan2.1.git
82
+ cd Wan2.1
83
+ ```
84
+
85
+ Install dependencies:
86
+ ```sh
87
+ # Ensure torch >= 2.4.0
88
+ pip install -r requirements.txt
89
+ ```
90
+
91
+
92
+ #### Model Download
93
+
94
+ | Models | Download Link | Notes |
95
+ |--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
96
+ | T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
97
+ | I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
98
+ | I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
99
+ | T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
100
+ | FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
101
+ | VACE-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | Supports 480P
102
+ | VACE-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | Supports both 480P and 720P
103
+
104
+ > 💡Note:
105
+ > * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
106
+ > * For the first-last frame to video generation, we train our model primarily on Chinese text-video pairs. Therefore, we recommend using Chinese prompt to achieve better results.
107
+
108
+
109
+ Download models using huggingface-cli:
110
+ ``` sh
111
+ pip install "huggingface_hub[cli]"
112
+ huggingface-cli download Wan-AI/Wan2.1-T2V-14B --local-dir ./Wan2.1-T2V-14B
113
+ ```
114
+
115
+ Download models using modelscope-cli:
116
+ ``` sh
117
+ pip install modelscope
118
+ modelscope download Wan-AI/Wan2.1-T2V-14B --local_dir ./Wan2.1-T2V-14B
119
+ ```
120
+ #### Run Text-to-Video Generation
121
+
122
+ This repository supports two Text-to-Video models (1.3B and 14B) and two resolutions (480P and 720P). The parameters and configurations for these models are as follows:
123
+
124
+ <table>
125
+ <thead>
126
+ <tr>
127
+ <th rowspan="2">Task</th>
128
+ <th colspan="2">Resolution</th>
129
+ <th rowspan="2">Model</th>
130
+ </tr>
131
+ <tr>
132
+ <th>480P</th>
133
+ <th>720P</th>
134
+ </tr>
135
+ </thead>
136
+ <tbody>
137
+ <tr>
138
+ <td>t2v-14B</td>
139
+ <td style="color: green;">✔️</td>
140
+ <td style="color: green;">✔️</td>
141
+ <td>Wan2.1-T2V-14B</td>
142
+ </tr>
143
+ <tr>
144
+ <td>t2v-1.3B</td>
145
+ <td style="color: green;">✔️</td>
146
+ <td style="color: red;">❌</td>
147
+ <td>Wan2.1-T2V-1.3B</td>
148
+ </tr>
149
+ </tbody>
150
+ </table>
151
+
152
+
153
+ ##### (1) Without Prompt Extension
154
+
155
+ To facilitate implementation, we will start with a basic version of the inference process that skips the [prompt extension](#2-using-prompt-extention) step.
156
+
157
+ - Single-GPU inference
158
+
159
+ ``` sh
160
+ python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
161
+ ```
162
+
163
+ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True` and `--t5_cpu` options to reduce GPU memory usage. For example, on an RTX 4090 GPU:
164
+
165
+ ``` sh
166
+ python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
167
+ ```
168
+
169
+ > 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
170
+
171
+
172
+ - Multi-GPU inference using FSDP + xDiT USP
173
+
174
+ We use FSDP and [xDiT](https://github.com/xdit-project/xDiT) USP to accelerate inference.
175
+
176
+ * Ulysess Strategy
177
+
178
+ If you want to use [`Ulysses`](https://arxiv.org/abs/2309.14509) strategy, you should set `--ulysses_size $GPU_NUMS`. Note that the `num_heads` should be divisible by `ulysses_size` if you wish to use `Ulysess` strategy. For the 1.3B model, the `num_heads` is `12` which can't be divided by 8 (as most multi-GPU machines have 8 GPUs). Therefore, it is recommended to use `Ring Strategy` instead.
179
+
180
+ * Ring Strategy
181
+
182
+ If you want to use [`Ring`](https://arxiv.org/pdf/2310.01889) strategy, you should set `--ring_size $GPU_NUMS`. Note that the `sequence length` should be divisible by `ring_size` when using the `Ring` strategy.
183
+
184
+ Of course, you can also combine the use of `Ulysses` and `Ring` strategies.
185
+
186
+
187
+ ``` sh
188
+ pip install "xfuser>=0.4.1"
189
+ torchrun --nproc_per_node=8 generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
190
+ ```
191
+
192
+
193
+ ##### (2) Using Prompt Extension
194
+
195
+ Extending the prompts can effectively enrich the details in the generated videos, further enhancing the video quality. Therefore, we recommend enabling prompt extension. We provide the following two methods for prompt extension:
196
+
197
+ - Use the Dashscope API for extension.
198
+ - Apply for a `dashscope.api_key` in advance ([EN](https://www.alibabacloud.com/help/en/model-studio/getting-started/first-api-call-to-qwen) | [CN](https://help.aliyun.com/zh/model-studio/getting-started/first-api-call-to-qwen)).
199
+ - Configure the environment variable `DASH_API_KEY` to specify the Dashscope API key. For users of Alibaba Cloud's international site, you also need to set the environment variable `DASH_API_URL` to 'https://dashscope-intl.aliyuncs.com/api/v1'. For more detailed instructions, please refer to the [dashscope document](https://www.alibabacloud.com/help/en/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c63.p38356.0.i1).
200
+ - Use the `qwen-plus` model for text-to-video tasks and `qwen-vl-max` for image-to-video tasks.
201
+ - You can modify the model used for extension with the parameter `--prompt_extend_model`. For example:
202
+ ```sh
203
+ DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'dashscope' --prompt_extend_target_lang 'zh'
204
+ ```
205
+
206
+ - Using a local model for extension.
207
+
208
+ - By default, the Qwen model on HuggingFace is used for this extension. Users can choose Qwen models or other models based on the available GPU memory size.
209
+ - For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`.
210
+ - For image-to-video or first-last-frame-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`.
211
+ - Larger models generally provide better extension results but require more GPU memory.
212
+ - You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example:
213
+
214
+ ``` sh
215
+ python generate.py --task t2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" --use_prompt_extend --prompt_extend_method 'local_qwen' --prompt_extend_target_lang 'zh'
216
+ ```
217
+
218
+
219
+ ##### (3) Running with Diffusers
220
+
221
+ You can easily inference **Wan2.1**-T2V using Diffusers with the following command:
222
+ ``` python
223
+ import torch
224
+ from diffusers.utils import export_to_video
225
+ from diffusers import AutoencoderKLWan, WanPipeline
226
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
227
+
228
+ # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
229
+ model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
230
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
231
+ flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
232
+ scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift)
233
+ pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
234
+ pipe.scheduler = scheduler
235
+ pipe.to("cuda")
236
+
237
+ prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
238
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
239
+
240
+ output = pipe(
241
+ prompt=prompt,
242
+ negative_prompt=negative_prompt,
243
+ height=720,
244
+ width=1280,
245
+ num_frames=81,
246
+ guidance_scale=5.0,
247
+ ).frames[0]
248
+ export_to_video(output, "output.mp4", fps=16)
249
+ ```
250
+ > 💡Note: Please note that this example does not integrate Prompt Extension and distributed inference. We will soon update with the integrated prompt extension and multi-GPU version of Diffusers.
251
+
252
+
253
+ ##### (4) Running local gradio
254
+
255
+ ``` sh
256
+ cd gradio
257
+ # if one uses dashscope’s API for prompt extension
258
+ DASH_API_KEY=your_key python t2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir ./Wan2.1-T2V-14B
259
+
260
+ # if one uses a local model for prompt extension
261
+ python t2v_14B_singleGPU.py --prompt_extend_method 'local_qwen' --ckpt_dir ./Wan2.1-T2V-14B
262
+ ```
263
+
264
+
265
+
266
+ #### Run Image-to-Video Generation
267
+
268
+ Similar to Text-to-Video, Image-to-Video is also divided into processes with and without the prompt extension step. The specific parameters and their corresponding settings are as follows:
269
+ <table>
270
+ <thead>
271
+ <tr>
272
+ <th rowspan="2">Task</th>
273
+ <th colspan="2">Resolution</th>
274
+ <th rowspan="2">Model</th>
275
+ </tr>
276
+ <tr>
277
+ <th>480P</th>
278
+ <th>720P</th>
279
+ </tr>
280
+ </thead>
281
+ <tbody>
282
+ <tr>
283
+ <td>i2v-14B</td>
284
+ <td style="color: green;">❌</td>
285
+ <td style="color: green;">✔️</td>
286
+ <td>Wan2.1-I2V-14B-720P</td>
287
+ </tr>
288
+ <tr>
289
+ <td>i2v-14B</td>
290
+ <td style="color: green;">✔️</td>
291
+ <td style="color: red;">❌</td>
292
+ <td>Wan2.1-T2V-14B-480P</td>
293
+ </tr>
294
+ </tbody>
295
+ </table>
296
+
297
+
298
+ ##### (1) Without Prompt Extension
299
+
300
+ - Single-GPU inference
301
+ ```sh
302
+ python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
303
+ ```
304
+
305
+ > 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
306
+
307
+
308
+ - Multi-GPU inference using FSDP + xDiT USP
309
+
310
+ ```sh
311
+ pip install "xfuser>=0.4.1"
312
+ torchrun --nproc_per_node=8 generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
313
+ ```
314
+
315
+ ##### (2) Using Prompt Extension
316
+
317
+
318
+ The process of prompt extension can be referenced [here](#2-using-prompt-extention).
319
+
320
+ Run with local prompt extension using `Qwen/Qwen2.5-VL-7B-Instruct`:
321
+ ```
322
+ python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
323
+ ```
324
+
325
+ Run with remote prompt extension using `dashscope`:
326
+ ```
327
+ DASH_API_KEY=your_key python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --use_prompt_extend --prompt_extend_method 'dashscope' --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
328
+ ```
329
+
330
+
331
+ ##### (3) Running with Diffusers
332
+
333
+ You can easily inference **Wan2.1**-I2V using Diffusers with the following command:
334
+ ``` python
335
+ import torch
336
+ import numpy as np
337
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
338
+ from diffusers.utils import export_to_video, load_image
339
+ from transformers import CLIPVisionModel
340
+
341
+ # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
342
+ model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
343
+ image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
344
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
345
+ pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16)
346
+ pipe.to("cuda")
347
+
348
+ image = load_image(
349
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
350
+ )
351
+ max_area = 720 * 1280
352
+ aspect_ratio = image.height / image.width
353
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
354
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
355
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
356
+ image = image.resize((width, height))
357
+ prompt = (
358
+ "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
359
+ "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
360
+ )
361
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
362
+
363
+ output = pipe(
364
+ image=image,
365
+ prompt=prompt,
366
+ negative_prompt=negative_prompt,
367
+ height=height, width=width,
368
+ num_frames=81,
369
+ guidance_scale=5.0
370
+ ).frames[0]
371
+ export_to_video(output, "output.mp4", fps=16)
372
+
373
+ ```
374
+ > 💡Note: Please note that this example does not integrate Prompt Extension and distributed inference. We will soon update with the integrated prompt extension and multi-GPU version of Diffusers.
375
+
376
+
377
+ ##### (4) Running local gradio
378
+
379
+ ```sh
380
+ cd gradio
381
+ # if one only uses 480P model in gradio
382
+ DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P
383
+
384
+ # if one only uses 720P model in gradio
385
+ DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-I2V-14B-720P
386
+
387
+ # if one uses both 480P and 720P models in gradio
388
+ DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_480p ./Wan2.1-I2V-14B-480P --ckpt_dir_720p ./Wan2.1-I2V-14B-720P
389
+ ```
390
+
391
+
392
+ #### Run First-Last-Frame-to-Video Generation
393
+
394
+ First-Last-Frame-to-Video is also divided into processes with and without the prompt extension step. Currently, only 720P is supported. The specific parameters and corresponding settings are as follows:
395
+ <table>
396
+ <thead>
397
+ <tr>
398
+ <th rowspan="2">Task</th>
399
+ <th colspan="2">Resolution</th>
400
+ <th rowspan="2">Model</th>
401
+ </tr>
402
+ <tr>
403
+ <th>480P</th>
404
+ <th>720P</th>
405
+ </tr>
406
+ </thead>
407
+ <tbody>
408
+ <tr>
409
+ <td>flf2v-14B</td>
410
+ <td style="color: green;">❌</td>
411
+ <td style="color: green;">✔️</td>
412
+ <td>Wan2.1-FLF2V-14B-720P</td>
413
+ </tr>
414
+ </tbody>
415
+ </table>
416
+
417
+
418
+ ##### (1) Without Prompt Extension
419
+
420
+ - Single-GPU inference
421
+ ```sh
422
+ python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
423
+ ```
424
+
425
+ > 💡Similar to Image-to-Video, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
426
+
427
+
428
+ - Multi-GPU inference using FSDP + xDiT USP
429
+
430
+ ```sh
431
+ pip install "xfuser>=0.4.1"
432
+ torchrun --nproc_per_node=8 generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
433
+ ```
434
+
435
+ ##### (2) Using Prompt Extension
436
+
437
+
438
+ The process of prompt extension can be referenced [here](#2-using-prompt-extention).
439
+
440
+ Run with local prompt extension using `Qwen/Qwen2.5-VL-7B-Instruct`:
441
+ ```
442
+ python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
443
+ ```
444
+
445
+ Run with remote prompt extension using `dashscope`:
446
+ ```
447
+ DASH_API_KEY=your_key python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_method 'dashscope' --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
448
+ ```
449
+
450
+
451
+ ##### (3) Running local gradio
452
+
453
+ ```sh
454
+ cd gradio
455
+ # use 720P model in gradio
456
+ DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-FLF2V-14B-720P
457
+ ```
458
+
459
+
460
+ #### Run VACE
461
+
462
+ [VACE](https://github.com/ali-vilab/VACE) now supports two models (1.3B and 14B) and two main resolutions (480P and 720P).
463
+ The input supports any resolution, but to achieve optimal results, the video size should fall within a specific range.
464
+ The parameters and configurations for these models are as follows:
465
+
466
+ <table>
467
+ <thead>
468
+ <tr>
469
+ <th rowspan="2">Task</th>
470
+ <th colspan="2">Resolution</th>
471
+ <th rowspan="2">Model</th>
472
+ </tr>
473
+ <tr>
474
+ <th>480P(~81x480x832)</th>
475
+ <th>720P(~81x720x1280)</th>
476
+ </tr>
477
+ </thead>
478
+ <tbody>
479
+ <tr>
480
+ <td>VACE</td>
481
+ <td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
482
+ <td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
483
+ <td>Wan2.1-VACE-14B</td>
484
+ </tr>
485
+ <tr>
486
+ <td>VACE</td>
487
+ <td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
488
+ <td style="color: red; text-align: center; vertical-align: middle;">❌</td>
489
+ <td>Wan2.1-VACE-1.3B</td>
490
+ </tr>
491
+ </tbody>
492
+ </table>
493
+
494
+ In VACE, users can input text prompt and optional video, mask, and image for video generation or editing. Detailed instructions for using VACE can be found in the [User Guide](https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).
495
+ The execution process is as follows:
496
+
497
+ ##### (1) Preprocessing
498
+
499
+ User-collected materials needs to be preprocessed into VACE-recognizable inputs, including `src_video`, `src_mask`, `src_ref_images`, and `prompt`.
500
+ For R2V (Reference-to-Video Generation), you may skip this preprocessing, but for V2V (Video-to-Video Editing) and MV2V (Masked Video-to-Video Editing) tasks, additional preprocessing is required to obtain video with conditions such as depth, pose or masked regions.
501
+ For more details, please refer to [vace_preproccess](https://github.com/ali-vilab/VACE/blob/main/vace/vace_preproccess.py).
502
+
503
+ ##### (2) cli inference
504
+
505
+ - Single-GPU inference
506
+ ```sh
507
+ python generate.py --task vace-1.3B --size 832*480 --ckpt_dir ./Wan2.1-VACE-1.3B --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇��头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
508
+ ```
509
+
510
+ - Multi-GPU inference using FSDP + xDiT USP
511
+
512
+ ```sh
513
+ torchrun --nproc_per_node=8 generate.py --task vace-14B --size 1280*720 --ckpt_dir ./Wan2.1-VACE-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
514
+ ```
515
+
516
+ ##### (3) Running local gradio
517
+ - Single-GPU inference
518
+ ```sh
519
+ python gradio/vace.py --ckpt_dir ./Wan2.1-VACE-1.3B
520
+ ```
521
+
522
+ - Multi-GPU inference using FSDP + xDiT USP
523
+ ```sh
524
+ python gradio/vace.py --mp --ulysses_size 8 --ckpt_dir ./Wan2.1-VACE-14B/
525
+ ```
526
+
527
+ #### Run Text-to-Image Generation
528
+
529
+ Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:
530
+
531
+ ##### (1) Without Prompt Extension
532
+
533
+ - Single-GPU inference
534
+ ```sh
535
+ python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人'
536
+ ```
537
+
538
+ - Multi-GPU inference using FSDP + xDiT USP
539
+
540
+ ```sh
541
+ torchrun --nproc_per_node=8 generate.py --dit_fsdp --t5_fsdp --ulysses_size 8 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --prompt '一个朴素端庄的美人' --ckpt_dir ./Wan2.1-T2V-14B
542
+ ```
543
+
544
+ ##### (2) With Prompt Extention
545
+
546
+ - Single-GPU inference
547
+ ```sh
548
+ python generate.py --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend
549
+ ```
550
+
551
+ - Multi-GPU inference using FSDP + xDiT USP
552
+ ```sh
553
+ torchrun --nproc_per_node=8 generate.py --dit_fsdp --t5_fsdp --ulysses_size 8 --base_seed 0 --frame_num 1 --task t2i-14B --size 1024*1024 --ckpt_dir ./Wan2.1-T2V-14B --prompt '一个朴素端庄的美人' --use_prompt_extend
554
+ ```
555
+
556
+
557
+ ## Manual Evaluation
558
+
559
+ ##### (1) Text-to-Video Evaluation
560
+
561
+ Through manual evaluation, the results generated after prompt extension are superior to those from both closed-source and open-source models.
562
+
563
+ <div align="center">
564
+ <img src="assets/t2v_res.jpg" alt="" style="width: 80%;" />
565
+ </div>
566
+
567
+
568
+ ##### (2) Image-to-Video Evaluation
569
+
570
+ We also conducted extensive manual evaluations to evaluate the performance of the Image-to-Video model, and the results are presented in the table below. The results clearly indicate that **Wan2.1** outperforms both closed-source and open-source models.
571
+
572
+ <div align="center">
573
+ <img src="assets/i2v_res.png" alt="" style="width: 80%;" />
574
+ </div>
575
+
576
+
577
+ ## Computational Efficiency on Different GPUs
578
+
579
+ We test the computational efficiency of different **Wan2.1** models on different GPUs in the following table. The results are presented in the format: **Total time (s) / peak GPU memory (GB)**.
580
+
581
+
582
+ <div align="center">
583
+ <img src="assets/comp_effic.png" alt="" style="width: 80%;" />
584
+ </div>
585
+
586
+ > The parameter settings for the tests presented in this table are as follows:
587
+ > (1) For the 1.3B model on 8 GPUs, set `--ring_size 8` and `--ulysses_size 1`;
588
+ > (2) For the 14B model on 1 GPU, use `--offload_model True`;
589
+ > (3) For the 1.3B model on a single 4090 GPU, set `--offload_model True --t5_cpu`;
590
+ > (4) For all testings, no prompt extension was applied, meaning `--use_prompt_extend` was not enabled.
591
+
592
+ > 💡Note: T2V-14B is slower than I2V-14B because the former samples 50 steps while the latter uses 40 steps.
593
+
594
+
595
+ -------
596
+
597
+ ## Introduction of Wan2.1
598
+
599
+ **Wan2.1** is designed on the mainstream diffusion transformer paradigm, achieving significant advancements in generative capabilities through a series of innovations. These include our novel spatio-temporal variational autoencoder (VAE), scalable training strategies, large-scale data construction, and automated evaluation metrics. Collectively, these contributions enhance the model’s performance and versatility.
600
+
601
+
602
+ ##### (1) 3D Variational Autoencoders
603
+ We propose a novel 3D causal VAE architecture, termed **Wan-VAE** specifically designed for video generation. By combining multiple strategies, we improve spatio-temporal compression, reduce memory usage, and ensure temporal causality. **Wan-VAE** demonstrates significant advantages in performance efficiency compared to other open-source VAEs. Furthermore, our **Wan-VAE** can encode and decode unlimited-length 1080P videos without losing historical temporal information, making it particularly well-suited for video generation tasks.
604
+
605
+
606
+ <div align="center">
607
+ <img src="assets/video_vae_res.jpg" alt="" style="width: 80%;" />
608
+ </div>
609
+
610
+
611
+ ##### (2) Video Diffusion DiT
612
+
613
+ **Wan2.1** is designed using the Flow Matching framework within the paradigm of mainstream Diffusion Transformers. Our model's architecture uses the T5 Encoder to encode multilingual text input, with cross-attention in each transformer block embedding the text into the model structure. Additionally, we employ an MLP with a Linear layer and a SiLU layer to process the input time embeddings and predict six modulation parameters individually. This MLP is shared across all transformer blocks, with each block learning a distinct set of biases. Our experimental findings reveal a significant performance improvement with this approach at the same parameter scale.
614
+
615
+ <div align="center">
616
+ <img src="assets/video_dit_arch.jpg" alt="" style="width: 80%;" />
617
+ </div>
618
+
619
+
620
+ | Model | Dimension | Input Dimension | Output Dimension | Feedforward Dimension | Frequency Dimension | Number of Heads | Number of Layers |
621
+ |--------|-----------|-----------------|------------------|-----------------------|---------------------|-----------------|------------------|
622
+ | 1.3B | 1536 | 16 | 16 | 8960 | 256 | 12 | 30 |
623
+ | 14B | 5120 | 16 | 16 | 13824 | 256 | 40 | 40 |
624
+
625
+
626
+
627
+ ##### Data
628
+
629
+ We curated and deduplicated a candidate dataset comprising a vast amount of image and video data. During the data curation process, we designed a four-step data cleaning process, focusing on fundamental dimensions, visual quality and motion quality. Through the robust data processing pipeline, we can easily obtain high-quality, diverse, and large-scale training sets of images and videos.
630
+
631
+ ![figure1](assets/data_for_diff_stage.jpg "figure1")
632
+
633
+
634
+ ##### Comparisons to SOTA
635
+ We compared **Wan2.1** with leading open-source and closed-source models to evaluate the performance. Using our carefully designed set of 1,035 internal prompts, we tested across 14 major dimensions and 26 sub-dimensions. We then compute the total score by performing a weighted calculation on the scores of each dimension, utilizing weights derived from human preferences in the matching process. The detailed results are shown in the table below. These results demonstrate our model's superior performance compared to both open-source and closed-source models.
636
+
637
+ ![figure1](assets/vben_vs_sota.png "figure1")
638
+
639
+
640
+ ## Citation
641
+ If you find our work helpful, please cite us.
642
+
643
+ ```
644
+ @article{wan2025,
645
+ title={Wan: Open and Advanced Large-Scale Video Generative Models},
646
+ author={Ang Wang and Baole Ai and Bin Wen and Chaojie Mao and Chen-Wei Xie and Di Chen and Feiwu Yu and Haiming Zhao and Jianxiao Yang and Jianyuan Zeng and Jiayu Wang and Jingfeng Zhang and Jingren Zhou and Jinkai Wang and Jixuan Chen and Kai Zhu and Kang Zhao and Keyu Yan and Lianghua Huang and Mengyang Feng and Ningyi Zhang and Pandeng Li and Pingyu Wu and Ruihang Chu and Ruili Feng and Shiwei Zhang and Siyang Sun and Tao Fang and Tianxing Wang and Tianyi Gui and Tingyu Weng and Tong Shen and Wei Lin and Wei Wang and Wei Wang and Wenmeng Zhou and Wente Wang and Wenting Shen and Wenyuan Yu and Xianzhong Shi and Xiaoming Huang and Xin Xu and Yan Kou and Yangyu Lv and Yifei Li and Yijing Liu and Yiming Wang and Yingya Zhang and Yitong Huang and Yong Li and You Wu and Yu Liu and Yulin Pan and Yun Zheng and Yuntao Hong and Yupeng Shi and Yutong Feng and Zeyinzi Jiang and Zhen Han and Zhi-Fan Wu and Ziyu Liu},
647
+ journal = {arXiv preprint arXiv:2503.20314},
648
+ year={2025}
649
+ }
650
+ ```
651
+
652
+ ## License Agreement
653
+ The models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generated contents, granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. For a complete list of restrictions and details regarding your rights, please refer to the full text of the [license](LICENSE.txt).
654
+
655
+
656
+ ## Acknowledgements
657
+
658
+ We would like to thank the contributors to the [SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [Qwen](https://huggingface.co/Qwen), [umt5-xxl](https://huggingface.co/google/umt5-xxl), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research.
659
+
660
+
661
+
662
+ ## Contact Us
663
+ If you would like to leave a message to our research or product teams, feel free to join our [Discord](https://discord.gg/AKNgpMK4Yj) or [WeChat groups](https://gw.alicdn.com/imgextra/i2/O1CN01tqjWFi1ByuyehkTSB_!!6000000000015-0-tps-611-1279.jpg)!
assets/comp_effic.png ADDED

Git LFS Details

  • SHA256: b0e225caffb4b31295ad150f95ee852e4c3dde4a00ac8f79a2ff500f2ce26b8d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.79 MB
assets/data_for_diff_stage.jpg ADDED

Git LFS Details

  • SHA256: 59aec08409f2d46b0e640e4e120dc7cca52c08c3de56d026602dbcff1ebf241a
  • Pointer size: 131 Bytes
  • Size of remote file: 528 kB
assets/i2v_res.png ADDED

Git LFS Details

  • SHA256: 6823b3206d8d0cb18d3b5b949dec1217f1178109ba11f14e977b67e1f7b8a248
  • Pointer size: 131 Bytes
  • Size of remote file: 892 kB
assets/logo.png ADDED
assets/t2v_res.jpg ADDED

Git LFS Details

  • SHA256: 91db579092446be2a834bc67721a8e4346936f38c4edb912f459ca3e10f8f439
  • Pointer size: 131 Bytes
  • Size of remote file: 301 kB
assets/vben_vs_sota.png ADDED

Git LFS Details

  • SHA256: 9a0e86ca85046d2675f97984b88b6e74df07bba8a62a31ab8a1aef50d4eda44e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.55 MB
assets/video_dit_arch.jpg ADDED

Git LFS Details

  • SHA256: 195dceec6570289d8b01cc51d2e28a7786216f19de55b23978a52610d1646a66
  • Pointer size: 131 Bytes
  • Size of remote file: 643 kB
assets/video_vae_res.jpg ADDED

Git LFS Details

  • SHA256: d8f9e7f7353848056a615c8ef35ab86ec22976bb46cb27405008b4089701945c
  • Pointer size: 131 Bytes
  • Size of remote file: 213 kB
examples/flf2v_input_first_frame.png ADDED

Git LFS Details

  • SHA256: fc27f36ae244c3366d6d342c5fc78a959bfeddc7c6ddfb62fc255c18cdac977f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.78 MB
examples/flf2v_input_last_frame.png ADDED

Git LFS Details

  • SHA256: 981d904fda073c87225089971b4bbff6f15d1f2511a111903989f391c526820f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
examples/girl.png ADDED

Git LFS Details

  • SHA256: f461a83c0772dbe93a05ae6b8ce9fa77f0e7f5facb4402685b5410c0dc18397f
  • Pointer size: 131 Bytes
  • Size of remote file: 836 kB
examples/i2v_input.JPG ADDED

Git LFS Details

  • SHA256: 077e3d965090c9028c69c00931675f42e1acc815c6eb450ab291b3b72d211a8e
  • Pointer size: 131 Bytes
  • Size of remote file: 251 kB
examples/snake.png ADDED

Git LFS Details

  • SHA256: 60ae5e275f64de6ca99c5e63eaea6812fe09a6d7e7a233e483e700122ad08124
  • Pointer size: 131 Bytes
  • Size of remote file: 446 kB
generate.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ from datetime import datetime
4
+ import logging
5
+ import os
6
+ import sys
7
+ import warnings
8
+
9
+ warnings.filterwarnings('ignore')
10
+
11
+ import torch, random
12
+ import torch.distributed as dist
13
+ from PIL import Image
14
+
15
+ import wan
16
+ from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
17
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
18
+ from wan.utils.utils import cache_video, cache_image, str2bool
19
+
20
+ EXAMPLE_PROMPT = {
21
+ "t2v-1.3B": {
22
+ "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
23
+ },
24
+ "t2v-14B": {
25
+ "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
26
+ },
27
+ "t2i-14B": {
28
+ "prompt": "一个朴素端庄的美人",
29
+ },
30
+ "i2v-14B": {
31
+ "prompt":
32
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
33
+ "image":
34
+ "examples/i2v_input.JPG",
35
+ },
36
+ "flf2v-14B": {
37
+ "prompt":
38
+ "CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。",
39
+ "first_frame":
40
+ "examples/flf2v_input_first_frame.png",
41
+ "last_frame":
42
+ "examples/flf2v_input_last_frame.png",
43
+ },
44
+ "vace-1.3B": {
45
+ "src_ref_images": 'examples/girl.png,examples/snake.png',
46
+ "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
47
+ },
48
+ "vace-14B": {
49
+ "src_ref_images": 'examples/girl.png,examples/snake.png',
50
+ "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
51
+ }
52
+ }
53
+
54
+
55
+ def _validate_args(args):
56
+ # Basic check
57
+ assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
58
+ assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
59
+ assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
60
+
61
+ # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
62
+ if args.sample_steps is None:
63
+ args.sample_steps = 50
64
+ if "i2v" in args.task:
65
+ args.sample_steps = 40
66
+
67
+
68
+ if args.sample_shift is None:
69
+ args.sample_shift = 5.0
70
+ if "i2v" in args.task and args.size in ["832*480", "480*832"]:
71
+ args.sample_shift = 3.0
72
+ elif "flf2v" in args.task or "vace" in args.task:
73
+ args.sample_shift = 16
74
+
75
+
76
+ # The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
77
+ if args.frame_num is None:
78
+ args.frame_num = 1 if "t2i" in args.task else 81
79
+
80
+ # T2I frame_num check
81
+ if "t2i" in args.task:
82
+ assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
83
+
84
+ args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
85
+ 0, sys.maxsize)
86
+ # Size check
87
+ assert args.size in SUPPORTED_SIZES[
88
+ args.
89
+ task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
90
+
91
+
92
+ def _parse_args():
93
+ parser = argparse.ArgumentParser(
94
+ description="Generate a image or video from a text prompt or image using Wan"
95
+ )
96
+ parser.add_argument(
97
+ "--task",
98
+ type=str,
99
+ default="t2v-14B",
100
+ choices=list(WAN_CONFIGS.keys()),
101
+ help="The task to run.")
102
+ parser.add_argument(
103
+ "--size",
104
+ type=str,
105
+ default="1280*720",
106
+ choices=list(SIZE_CONFIGS.keys()),
107
+ help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
108
+ )
109
+ parser.add_argument(
110
+ "--frame_num",
111
+ type=int,
112
+ default=None,
113
+ help="How many frames to sample from a image or video. The number should be 4n+1"
114
+ )
115
+ parser.add_argument(
116
+ "--ckpt_dir",
117
+ type=str,
118
+ default=None,
119
+ help="The path to the checkpoint directory.")
120
+ parser.add_argument(
121
+ "--offload_model",
122
+ type=str2bool,
123
+ default=None,
124
+ help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
125
+ )
126
+ parser.add_argument(
127
+ "--ulysses_size",
128
+ type=int,
129
+ default=1,
130
+ help="The size of the ulysses parallelism in DiT.")
131
+ parser.add_argument(
132
+ "--ring_size",
133
+ type=int,
134
+ default=1,
135
+ help="The size of the ring attention parallelism in DiT.")
136
+ parser.add_argument(
137
+ "--t5_fsdp",
138
+ action="store_true",
139
+ default=False,
140
+ help="Whether to use FSDP for T5.")
141
+ parser.add_argument(
142
+ "--t5_cpu",
143
+ action="store_true",
144
+ default=False,
145
+ help="Whether to place T5 model on CPU.")
146
+ parser.add_argument(
147
+ "--dit_fsdp",
148
+ action="store_true",
149
+ default=False,
150
+ help="Whether to use FSDP for DiT.")
151
+ parser.add_argument(
152
+ "--save_file",
153
+ type=str,
154
+ default=None,
155
+ help="The file to save the generated image or video to.")
156
+ parser.add_argument(
157
+ "--src_video",
158
+ type=str,
159
+ default=None,
160
+ help="The file of the source video. Default None.")
161
+ parser.add_argument(
162
+ "--src_mask",
163
+ type=str,
164
+ default=None,
165
+ help="The file of the source mask. Default None.")
166
+ parser.add_argument(
167
+ "--src_ref_images",
168
+ type=str,
169
+ default=None,
170
+ help="The file list of the source reference images. Separated by ','. Default None.")
171
+ parser.add_argument(
172
+ "--prompt",
173
+ type=str,
174
+ default=None,
175
+ help="The prompt to generate the image or video from.")
176
+ parser.add_argument(
177
+ "--use_prompt_extend",
178
+ action="store_true",
179
+ default=False,
180
+ help="Whether to use prompt extend.")
181
+ parser.add_argument(
182
+ "--prompt_extend_method",
183
+ type=str,
184
+ default="local_qwen",
185
+ choices=["dashscope", "local_qwen"],
186
+ help="The prompt extend method to use.")
187
+ parser.add_argument(
188
+ "--prompt_extend_model",
189
+ type=str,
190
+ default=None,
191
+ help="The prompt extend model to use.")
192
+ parser.add_argument(
193
+ "--prompt_extend_target_lang",
194
+ type=str,
195
+ default="zh",
196
+ choices=["zh", "en"],
197
+ help="The target language of prompt extend.")
198
+ parser.add_argument(
199
+ "--base_seed",
200
+ type=int,
201
+ default=-1,
202
+ help="The seed to use for generating the image or video.")
203
+ parser.add_argument(
204
+ "--image",
205
+ type=str,
206
+ default=None,
207
+ help="[image to video] The image to generate the video from.")
208
+ parser.add_argument(
209
+ "--first_frame",
210
+ type=str,
211
+ default=None,
212
+ help="[first-last frame to video] The image (first frame) to generate the video from.")
213
+ parser.add_argument(
214
+ "--last_frame",
215
+ type=str,
216
+ default=None,
217
+ help="[first-last frame to video] The image (last frame) to generate the video from.")
218
+ parser.add_argument(
219
+ "--sample_solver",
220
+ type=str,
221
+ default='unipc',
222
+ choices=['unipc', 'dpm++'],
223
+ help="The solver used to sample.")
224
+ parser.add_argument(
225
+ "--sample_steps", type=int, default=None, help="The sampling steps.")
226
+ parser.add_argument(
227
+ "--sample_shift",
228
+ type=float,
229
+ default=None,
230
+ help="Sampling shift factor for flow matching schedulers.")
231
+ parser.add_argument(
232
+ "--sample_guide_scale",
233
+ type=float,
234
+ default=5.0,
235
+ help="Classifier free guidance scale.")
236
+
237
+ args = parser.parse_args()
238
+
239
+ _validate_args(args)
240
+
241
+ return args
242
+
243
+
244
+ def _init_logging(rank):
245
+ # logging
246
+ if rank == 0:
247
+ # set format
248
+ logging.basicConfig(
249
+ level=logging.INFO,
250
+ format="[%(asctime)s] %(levelname)s: %(message)s",
251
+ handlers=[logging.StreamHandler(stream=sys.stdout)])
252
+ else:
253
+ logging.basicConfig(level=logging.ERROR)
254
+
255
+
256
+ def generate(args):
257
+ rank = int(os.getenv("RANK", 0))
258
+ world_size = int(os.getenv("WORLD_SIZE", 1))
259
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
260
+ device = local_rank
261
+ _init_logging(rank)
262
+
263
+ if args.offload_model is None:
264
+ args.offload_model = False if world_size > 1 else True
265
+ logging.info(
266
+ f"offload_model is not specified, set to {args.offload_model}.")
267
+ if world_size > 1:
268
+ torch.cuda.set_device(local_rank)
269
+ dist.init_process_group(
270
+ backend="nccl",
271
+ init_method="env://",
272
+ rank=rank,
273
+ world_size=world_size)
274
+ else:
275
+ assert not (
276
+ args.t5_fsdp or args.dit_fsdp
277
+ ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
278
+ assert not (
279
+ args.ulysses_size > 1 or args.ring_size > 1
280
+ ), f"context parallel are not supported in non-distributed environments."
281
+
282
+ if args.ulysses_size > 1 or args.ring_size > 1:
283
+ assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
284
+ from xfuser.core.distributed import (initialize_model_parallel,
285
+ init_distributed_environment)
286
+ init_distributed_environment(
287
+ rank=dist.get_rank(), world_size=dist.get_world_size())
288
+
289
+ initialize_model_parallel(
290
+ sequence_parallel_degree=dist.get_world_size(),
291
+ ring_degree=args.ring_size,
292
+ ulysses_degree=args.ulysses_size,
293
+ )
294
+
295
+ if args.use_prompt_extend:
296
+ if args.prompt_extend_method == "dashscope":
297
+ prompt_expander = DashScopePromptExpander(
298
+ model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task)
299
+ elif args.prompt_extend_method == "local_qwen":
300
+ prompt_expander = QwenPromptExpander(
301
+ model_name=args.prompt_extend_model,
302
+ is_vl="i2v" in args.task,
303
+ device=rank)
304
+ else:
305
+ raise NotImplementedError(
306
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
307
+
308
+ cfg = WAN_CONFIGS[args.task]
309
+ if args.ulysses_size > 1:
310
+ assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
311
+
312
+ logging.info(f"Generation job args: {args}")
313
+ logging.info(f"Generation model config: {cfg}")
314
+
315
+ if dist.is_initialized():
316
+ base_seed = [args.base_seed] if rank == 0 else [None]
317
+ dist.broadcast_object_list(base_seed, src=0)
318
+ args.base_seed = base_seed[0]
319
+
320
+ if "t2v" in args.task or "t2i" in args.task:
321
+ if args.prompt is None:
322
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
323
+ logging.info(f"Input prompt: {args.prompt}")
324
+ if args.use_prompt_extend:
325
+ logging.info("Extending prompt ...")
326
+ if rank == 0:
327
+ prompt_output = prompt_expander(
328
+ args.prompt,
329
+ tar_lang=args.prompt_extend_target_lang,
330
+ seed=args.base_seed)
331
+ if prompt_output.status == False:
332
+ logging.info(
333
+ f"Extending prompt failed: {prompt_output.message}")
334
+ logging.info("Falling back to original prompt.")
335
+ input_prompt = args.prompt
336
+ else:
337
+ input_prompt = prompt_output.prompt
338
+ input_prompt = [input_prompt]
339
+ else:
340
+ input_prompt = [None]
341
+ if dist.is_initialized():
342
+ dist.broadcast_object_list(input_prompt, src=0)
343
+ args.prompt = input_prompt[0]
344
+ logging.info(f"Extended prompt: {args.prompt}")
345
+
346
+ logging.info("Creating WanT2V pipeline.")
347
+ wan_t2v = wan.WanT2V(
348
+ config=cfg,
349
+ checkpoint_dir=args.ckpt_dir,
350
+ device_id=device,
351
+ rank=rank,
352
+ t5_fsdp=args.t5_fsdp,
353
+ dit_fsdp=args.dit_fsdp,
354
+ use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
355
+ t5_cpu=args.t5_cpu,
356
+ )
357
+
358
+ logging.info(
359
+ f"Generating {'image' if 't2i' in args.task else 'video'} ...")
360
+ video = wan_t2v.generate(
361
+ args.prompt,
362
+ size=SIZE_CONFIGS[args.size],
363
+ frame_num=args.frame_num,
364
+ shift=args.sample_shift,
365
+ sample_solver=args.sample_solver,
366
+ sampling_steps=args.sample_steps,
367
+ guide_scale=args.sample_guide_scale,
368
+ seed=args.base_seed,
369
+ offload_model=args.offload_model)
370
+
371
+ elif "i2v" in args.task:
372
+ if args.prompt is None:
373
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
374
+ if args.image is None:
375
+ args.image = EXAMPLE_PROMPT[args.task]["image"]
376
+ logging.info(f"Input prompt: {args.prompt}")
377
+ logging.info(f"Input image: {args.image}")
378
+
379
+ img = Image.open(args.image).convert("RGB")
380
+ if args.use_prompt_extend:
381
+ logging.info("Extending prompt ...")
382
+ if rank == 0:
383
+ prompt_output = prompt_expander(
384
+ args.prompt,
385
+ tar_lang=args.prompt_extend_target_lang,
386
+ image=img,
387
+ seed=args.base_seed)
388
+ if prompt_output.status == False:
389
+ logging.info(
390
+ f"Extending prompt failed: {prompt_output.message}")
391
+ logging.info("Falling back to original prompt.")
392
+ input_prompt = args.prompt
393
+ else:
394
+ input_prompt = prompt_output.prompt
395
+ input_prompt = [input_prompt]
396
+ else:
397
+ input_prompt = [None]
398
+ if dist.is_initialized():
399
+ dist.broadcast_object_list(input_prompt, src=0)
400
+ args.prompt = input_prompt[0]
401
+ logging.info(f"Extended prompt: {args.prompt}")
402
+
403
+ logging.info("Creating WanI2V pipeline.")
404
+ wan_i2v = wan.WanI2V(
405
+ config=cfg,
406
+ checkpoint_dir=args.ckpt_dir,
407
+ device_id=device,
408
+ rank=rank,
409
+ t5_fsdp=args.t5_fsdp,
410
+ dit_fsdp=args.dit_fsdp,
411
+ use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
412
+ t5_cpu=args.t5_cpu,
413
+ )
414
+
415
+ logging.info("Generating video ...")
416
+ video = wan_i2v.generate(
417
+ args.prompt,
418
+ img,
419
+ max_area=MAX_AREA_CONFIGS[args.size],
420
+ frame_num=args.frame_num,
421
+ shift=args.sample_shift,
422
+ sample_solver=args.sample_solver,
423
+ sampling_steps=args.sample_steps,
424
+ guide_scale=args.sample_guide_scale,
425
+ seed=args.base_seed,
426
+ offload_model=args.offload_model)
427
+ elif "flf2v" in args.task:
428
+ if args.prompt is None:
429
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
430
+ if args.first_frame is None or args.last_frame is None:
431
+ args.first_frame = EXAMPLE_PROMPT[args.task]["first_frame"]
432
+ args.last_frame = EXAMPLE_PROMPT[args.task]["last_frame"]
433
+ logging.info(f"Input prompt: {args.prompt}")
434
+ logging.info(f"Input first frame: {args.first_frame}")
435
+ logging.info(f"Input last frame: {args.last_frame}")
436
+ first_frame = Image.open(args.first_frame).convert("RGB")
437
+ last_frame = Image.open(args.last_frame).convert("RGB")
438
+ if args.use_prompt_extend:
439
+ logging.info("Extending prompt ...")
440
+ if rank == 0:
441
+ prompt_output = prompt_expander(
442
+ args.prompt,
443
+ tar_lang=args.prompt_extend_target_lang,
444
+ image=[first_frame, last_frame],
445
+ seed=args.base_seed)
446
+ if prompt_output.status == False:
447
+ logging.info(
448
+ f"Extending prompt failed: {prompt_output.message}")
449
+ logging.info("Falling back to original prompt.")
450
+ input_prompt = args.prompt
451
+ else:
452
+ input_prompt = prompt_output.prompt
453
+ input_prompt = [input_prompt]
454
+ else:
455
+ input_prompt = [None]
456
+ if dist.is_initialized():
457
+ dist.broadcast_object_list(input_prompt, src=0)
458
+ args.prompt = input_prompt[0]
459
+ logging.info(f"Extended prompt: {args.prompt}")
460
+
461
+ logging.info("Creating WanFLF2V pipeline.")
462
+ wan_flf2v = wan.WanFLF2V(
463
+ config=cfg,
464
+ checkpoint_dir=args.ckpt_dir,
465
+ device_id=device,
466
+ rank=rank,
467
+ t5_fsdp=args.t5_fsdp,
468
+ dit_fsdp=args.dit_fsdp,
469
+ use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
470
+ t5_cpu=args.t5_cpu,
471
+ )
472
+
473
+ logging.info("Generating video ...")
474
+ video = wan_flf2v.generate(
475
+ args.prompt,
476
+ first_frame,
477
+ last_frame,
478
+ max_area=MAX_AREA_CONFIGS[args.size],
479
+ frame_num=args.frame_num,
480
+ shift=args.sample_shift,
481
+ sample_solver=args.sample_solver,
482
+ sampling_steps=args.sample_steps,
483
+ guide_scale=args.sample_guide_scale,
484
+ seed=args.base_seed,
485
+ offload_model=args.offload_model
486
+ )
487
+ elif "vace" in args.task:
488
+ if args.prompt is None:
489
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
490
+ args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
491
+ args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
492
+ args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None)
493
+
494
+ logging.info(f"Input prompt: {args.prompt}")
495
+ if args.use_prompt_extend and args.use_prompt_extend != 'plain':
496
+ logging.info("Extending prompt ...")
497
+ if rank == 0:
498
+ prompt = prompt_expander.forward(args.prompt)
499
+ logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'")
500
+ input_prompt = [prompt]
501
+ else:
502
+ input_prompt = [None]
503
+ if dist.is_initialized():
504
+ dist.broadcast_object_list(input_prompt, src=0)
505
+ args.prompt = input_prompt[0]
506
+ logging.info(f"Extended prompt: {args.prompt}")
507
+
508
+ logging.info("Creating VACE pipeline.")
509
+ wan_vace = wan.WanVace(
510
+ config=cfg,
511
+ checkpoint_dir=args.ckpt_dir,
512
+ device_id=device,
513
+ rank=rank,
514
+ t5_fsdp=args.t5_fsdp,
515
+ dit_fsdp=args.dit_fsdp,
516
+ use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
517
+ t5_cpu=args.t5_cpu,
518
+ )
519
+
520
+ src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video],
521
+ [args.src_mask],
522
+ [None if args.src_ref_images is None else args.src_ref_images.split(',')],
523
+ args.frame_num, SIZE_CONFIGS[args.size], device)
524
+
525
+ logging.info(f"Generating video...")
526
+ video = wan_vace.generate(
527
+ args.prompt,
528
+ src_video,
529
+ src_mask,
530
+ src_ref_images,
531
+ size=SIZE_CONFIGS[args.size],
532
+ frame_num=args.frame_num,
533
+ shift=args.sample_shift,
534
+ sample_solver=args.sample_solver,
535
+ sampling_steps=args.sample_steps,
536
+ guide_scale=args.sample_guide_scale,
537
+ seed=args.base_seed,
538
+ offload_model=args.offload_model)
539
+ else:
540
+ raise ValueError(f"Unkown task type: {args.task}")
541
+
542
+ if rank == 0:
543
+ if args.save_file is None:
544
+ formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
545
+ formatted_prompt = args.prompt.replace(" ", "_").replace("/",
546
+ "_")[:50]
547
+ suffix = '.png' if "t2i" in args.task else '.mp4'
548
+ args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
549
+
550
+ if "t2i" in args.task:
551
+ logging.info(f"Saving generated image to {args.save_file}")
552
+ cache_image(
553
+ tensor=video.squeeze(1)[None],
554
+ save_file=args.save_file,
555
+ nrow=1,
556
+ normalize=True,
557
+ value_range=(-1, 1))
558
+ else:
559
+ logging.info(f"Saving generated video to {args.save_file}")
560
+ cache_video(
561
+ tensor=video[None],
562
+ save_file=args.save_file,
563
+ fps=cfg.sample_fps,
564
+ nrow=1,
565
+ normalize=True,
566
+ value_range=(-1, 1))
567
+ logging.info("Finished.")
568
+
569
+
570
+ if __name__ == "__main__":
571
+ args = _parse_args()
572
+ generate(args)
gradio/fl2v_14B_singleGPU.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import gc
4
+ import os.path as osp
5
+ import os
6
+ import sys
7
+ import warnings
8
+
9
+ import gradio as gr
10
+
11
+ warnings.filterwarnings('ignore')
12
+
13
+ # Model
14
+ sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
15
+ import wan
16
+ from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
17
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
18
+ from wan.utils.utils import cache_video
19
+
20
+ # Global Var
21
+ prompt_expander = None
22
+ wan_flf2v_720P = None
23
+
24
+
25
+ # Button Func
26
+ def load_model(value):
27
+ global wan_flf2v_720P
28
+
29
+ if value == '------':
30
+ print("No model loaded")
31
+ return '------'
32
+
33
+ if value == '720P':
34
+ if args.ckpt_dir_720p is None:
35
+ print("Please specify the checkpoint directory for 720P model")
36
+ return '------'
37
+ if wan_flf2v_720P is not None:
38
+ pass
39
+ else:
40
+ gc.collect()
41
+
42
+ print("load 14B-720P flf2v model...", end='', flush=True)
43
+ cfg = WAN_CONFIGS['flf2v-14B']
44
+ wan_flf2v_720P = wan.WanFLF2V(
45
+ config=cfg,
46
+ checkpoint_dir=args.ckpt_dir_720p,
47
+ device_id=0,
48
+ rank=0,
49
+ t5_fsdp=False,
50
+ dit_fsdp=False,
51
+ use_usp=False,
52
+ )
53
+ print("done", flush=True)
54
+ return '720P'
55
+ return value
56
+
57
+
58
+ def prompt_enc(prompt, img_first, img_last, tar_lang):
59
+ print('prompt extend...')
60
+ if img_first is None or img_last is None:
61
+ print('Please upload the first and last frames')
62
+ return prompt
63
+ global prompt_expander
64
+ prompt_output = prompt_expander(
65
+ prompt, image=[img_first, img_last], tar_lang=tar_lang.lower())
66
+ if prompt_output.status == False:
67
+ return prompt
68
+ else:
69
+ return prompt_output.prompt
70
+
71
+
72
+ def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
73
+ guide_scale, shift_scale, seed, n_prompt):
74
+
75
+ if resolution == '------':
76
+ print(
77
+ 'Please specify the resolution ckpt dir or specify the resolution'
78
+ )
79
+ return None
80
+
81
+ else:
82
+ if resolution == '720P':
83
+ global wan_flf2v_720P
84
+ video = wan_flf2v_720P.generate(
85
+ flf2vid_prompt,
86
+ flf2vid_image_first,
87
+ flf2vid_image_last,
88
+ max_area=MAX_AREA_CONFIGS['720*1280'],
89
+ shift=shift_scale,
90
+ sampling_steps=sd_steps,
91
+ guide_scale=guide_scale,
92
+ n_prompt=n_prompt,
93
+ seed=seed,
94
+ offload_model=True)
95
+ pass
96
+ else:
97
+ print(
98
+ 'Sorry, currently only 720P is supported.'
99
+ )
100
+ return None
101
+
102
+ cache_video(
103
+ tensor=video[None],
104
+ save_file="example.mp4",
105
+ fps=16,
106
+ nrow=1,
107
+ normalize=True,
108
+ value_range=(-1, 1))
109
+
110
+ return "example.mp4"
111
+
112
+
113
+ # Interface
114
+ def gradio_interface():
115
+ with gr.Blocks() as demo:
116
+ gr.Markdown("""
117
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
118
+ Wan2.1 (FLF2V-14B)
119
+ </div>
120
+ <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
121
+ Wan: Open and Advanced Large-Scale Video Generative Models.
122
+ </div>
123
+ """)
124
+
125
+ with gr.Row():
126
+ with gr.Column():
127
+ resolution = gr.Dropdown(
128
+ label='Resolution',
129
+ choices=['------', '720P'],
130
+ value='------')
131
+ flf2vid_image_first = gr.Image(
132
+ type="pil",
133
+ label="Upload First Frame",
134
+ elem_id="image_upload",
135
+ )
136
+ flf2vid_image_last = gr.Image(
137
+ type="pil",
138
+ label="Upload Last Frame",
139
+ elem_id="image_upload",
140
+ )
141
+ flf2vid_prompt = gr.Textbox(
142
+ label="Prompt",
143
+ placeholder="Describe the video you want to generate",
144
+ )
145
+ tar_lang = gr.Radio(
146
+ choices=["ZH", "EN"],
147
+ label="Target language of prompt enhance",
148
+ value="ZH")
149
+ run_p_button = gr.Button(value="Prompt Enhance")
150
+
151
+ with gr.Accordion("Advanced Options", open=True):
152
+ with gr.Row():
153
+ sd_steps = gr.Slider(
154
+ label="Diffusion steps",
155
+ minimum=1,
156
+ maximum=1000,
157
+ value=50,
158
+ step=1)
159
+ guide_scale = gr.Slider(
160
+ label="Guide scale",
161
+ minimum=0,
162
+ maximum=20,
163
+ value=5.0,
164
+ step=1)
165
+ with gr.Row():
166
+ shift_scale = gr.Slider(
167
+ label="Shift scale",
168
+ minimum=0,
169
+ maximum=20,
170
+ value=5.0,
171
+ step=1)
172
+ seed = gr.Slider(
173
+ label="Seed",
174
+ minimum=-1,
175
+ maximum=2147483647,
176
+ step=1,
177
+ value=-1)
178
+ n_prompt = gr.Textbox(
179
+ label="Negative Prompt",
180
+ placeholder="Describe the negative prompt you want to add"
181
+ )
182
+
183
+ run_flf2v_button = gr.Button("Generate Video")
184
+
185
+ with gr.Column():
186
+ result_gallery = gr.Video(
187
+ label='Generated Video', interactive=False, height=600)
188
+
189
+ resolution.input(
190
+ fn=load_model, inputs=[resolution], outputs=[resolution])
191
+
192
+ run_p_button.click(
193
+ fn=prompt_enc,
194
+ inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang],
195
+ outputs=[flf2vid_prompt])
196
+
197
+ run_flf2v_button.click(
198
+ fn=flf2v_generation,
199
+ inputs=[
200
+ flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
201
+ guide_scale, shift_scale, seed, n_prompt
202
+ ],
203
+ outputs=[result_gallery],
204
+ )
205
+
206
+ return demo
207
+
208
+
209
+ # Main
210
+ def _parse_args():
211
+ parser = argparse.ArgumentParser(
212
+ description="Generate a video from a text prompt or image using Gradio")
213
+ parser.add_argument(
214
+ "--ckpt_dir_720p",
215
+ type=str,
216
+ default=None,
217
+ help="The path to the checkpoint directory.")
218
+ parser.add_argument(
219
+ "--prompt_extend_method",
220
+ type=str,
221
+ default="local_qwen",
222
+ choices=["dashscope", "local_qwen"],
223
+ help="The prompt extend method to use.")
224
+ parser.add_argument(
225
+ "--prompt_extend_model",
226
+ type=str,
227
+ default=None,
228
+ help="The prompt extend model to use.")
229
+
230
+ args = parser.parse_args()
231
+ assert args.ckpt_dir_720p is not None, "Please specify the checkpoint directory."
232
+
233
+ return args
234
+
235
+
236
+ if __name__ == '__main__':
237
+ args = _parse_args()
238
+
239
+ print("Step1: Init prompt_expander...", end='', flush=True)
240
+ if args.prompt_extend_method == "dashscope":
241
+ prompt_expander = DashScopePromptExpander(
242
+ model_name=args.prompt_extend_model, is_vl=True)
243
+ elif args.prompt_extend_method == "local_qwen":
244
+ prompt_expander = QwenPromptExpander(
245
+ model_name=args.prompt_extend_model, is_vl=True, device=0)
246
+ else:
247
+ raise NotImplementedError(
248
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
249
+ print("done", flush=True)
250
+
251
+ demo = gradio_interface()
252
+ demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
gradio/i2v_14B_singleGPU.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import gc
4
+ import os.path as osp
5
+ import os
6
+ import sys
7
+ import warnings
8
+
9
+ import gradio as gr
10
+
11
+ warnings.filterwarnings('ignore')
12
+
13
+ # Model
14
+ sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
15
+ import wan
16
+ from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
17
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
18
+ from wan.utils.utils import cache_video
19
+
20
+ # Global Var
21
+ prompt_expander = None
22
+ wan_i2v_480P = None
23
+ wan_i2v_720P = None
24
+
25
+
26
+ # Button Func
27
+ def load_model(value):
28
+ global wan_i2v_480P, wan_i2v_720P
29
+
30
+ if value == '------':
31
+ print("No model loaded")
32
+ return '------'
33
+
34
+ if value == '720P':
35
+ if args.ckpt_dir_720p is None:
36
+ print("Please specify the checkpoint directory for 720P model")
37
+ return '------'
38
+ if wan_i2v_720P is not None:
39
+ pass
40
+ else:
41
+ del wan_i2v_480P
42
+ gc.collect()
43
+ wan_i2v_480P = None
44
+
45
+ print("load 14B-720P i2v model...", end='', flush=True)
46
+ cfg = WAN_CONFIGS['i2v-14B']
47
+ wan_i2v_720P = wan.WanI2V(
48
+ config=cfg,
49
+ checkpoint_dir=args.ckpt_dir_720p,
50
+ device_id=0,
51
+ rank=0,
52
+ t5_fsdp=False,
53
+ dit_fsdp=False,
54
+ use_usp=False,
55
+ )
56
+ print("done", flush=True)
57
+ return '720P'
58
+
59
+ if value == '480P':
60
+ if args.ckpt_dir_480p is None:
61
+ print("Please specify the checkpoint directory for 480P model")
62
+ return '------'
63
+ if wan_i2v_480P is not None:
64
+ pass
65
+ else:
66
+ del wan_i2v_720P
67
+ gc.collect()
68
+ wan_i2v_720P = None
69
+
70
+ print("load 14B-480P i2v model...", end='', flush=True)
71
+ cfg = WAN_CONFIGS['i2v-14B']
72
+ wan_i2v_480P = wan.WanI2V(
73
+ config=cfg,
74
+ checkpoint_dir=args.ckpt_dir_480p,
75
+ device_id=0,
76
+ rank=0,
77
+ t5_fsdp=False,
78
+ dit_fsdp=False,
79
+ use_usp=False,
80
+ )
81
+ print("done", flush=True)
82
+ return '480P'
83
+ return value
84
+
85
+
86
+ def prompt_enc(prompt, img, tar_lang):
87
+ print('prompt extend...')
88
+ if img is None:
89
+ print('Please upload an image')
90
+ return prompt
91
+ global prompt_expander
92
+ prompt_output = prompt_expander(
93
+ prompt, image=img, tar_lang=tar_lang.lower())
94
+ if prompt_output.status == False:
95
+ return prompt
96
+ else:
97
+ return prompt_output.prompt
98
+
99
+
100
+ def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
101
+ guide_scale, shift_scale, seed, n_prompt):
102
+ # print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
103
+
104
+ if resolution == '------':
105
+ print(
106
+ 'Please specify at least one resolution ckpt dir or specify the resolution'
107
+ )
108
+ return None
109
+
110
+ else:
111
+ if resolution == '720P':
112
+ global wan_i2v_720P
113
+ video = wan_i2v_720P.generate(
114
+ img2vid_prompt,
115
+ img2vid_image,
116
+ max_area=MAX_AREA_CONFIGS['720*1280'],
117
+ shift=shift_scale,
118
+ sampling_steps=sd_steps,
119
+ guide_scale=guide_scale,
120
+ n_prompt=n_prompt,
121
+ seed=seed,
122
+ offload_model=True)
123
+ else:
124
+ global wan_i2v_480P
125
+ video = wan_i2v_480P.generate(
126
+ img2vid_prompt,
127
+ img2vid_image,
128
+ max_area=MAX_AREA_CONFIGS['480*832'],
129
+ shift=shift_scale,
130
+ sampling_steps=sd_steps,
131
+ guide_scale=guide_scale,
132
+ n_prompt=n_prompt,
133
+ seed=seed,
134
+ offload_model=True)
135
+
136
+ cache_video(
137
+ tensor=video[None],
138
+ save_file="example.mp4",
139
+ fps=16,
140
+ nrow=1,
141
+ normalize=True,
142
+ value_range=(-1, 1))
143
+
144
+ return "example.mp4"
145
+
146
+
147
+ # Interface
148
+ def gradio_interface():
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown("""
151
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
152
+ Wan2.1 (I2V-14B)
153
+ </div>
154
+ <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
155
+ Wan: Open and Advanced Large-Scale Video Generative Models.
156
+ </div>
157
+ """)
158
+
159
+ with gr.Row():
160
+ with gr.Column():
161
+ resolution = gr.Dropdown(
162
+ label='Resolution',
163
+ choices=['------', '720P', '480P'],
164
+ value='------')
165
+
166
+ img2vid_image = gr.Image(
167
+ type="pil",
168
+ label="Upload Input Image",
169
+ elem_id="image_upload",
170
+ )
171
+ img2vid_prompt = gr.Textbox(
172
+ label="Prompt",
173
+ placeholder="Describe the video you want to generate",
174
+ )
175
+ tar_lang = gr.Radio(
176
+ choices=["ZH", "EN"],
177
+ label="Target language of prompt enhance",
178
+ value="ZH")
179
+ run_p_button = gr.Button(value="Prompt Enhance")
180
+
181
+ with gr.Accordion("Advanced Options", open=True):
182
+ with gr.Row():
183
+ sd_steps = gr.Slider(
184
+ label="Diffusion steps",
185
+ minimum=1,
186
+ maximum=1000,
187
+ value=50,
188
+ step=1)
189
+ guide_scale = gr.Slider(
190
+ label="Guide scale",
191
+ minimum=0,
192
+ maximum=20,
193
+ value=5.0,
194
+ step=1)
195
+ with gr.Row():
196
+ shift_scale = gr.Slider(
197
+ label="Shift scale",
198
+ minimum=0,
199
+ maximum=10,
200
+ value=5.0,
201
+ step=1)
202
+ seed = gr.Slider(
203
+ label="Seed",
204
+ minimum=-1,
205
+ maximum=2147483647,
206
+ step=1,
207
+ value=-1)
208
+ n_prompt = gr.Textbox(
209
+ label="Negative Prompt",
210
+ placeholder="Describe the negative prompt you want to add"
211
+ )
212
+
213
+ run_i2v_button = gr.Button("Generate Video")
214
+
215
+ with gr.Column():
216
+ result_gallery = gr.Video(
217
+ label='Generated Video', interactive=False, height=600)
218
+
219
+ resolution.input(
220
+ fn=load_model, inputs=[resolution], outputs=[resolution])
221
+
222
+ run_p_button.click(
223
+ fn=prompt_enc,
224
+ inputs=[img2vid_prompt, img2vid_image, tar_lang],
225
+ outputs=[img2vid_prompt])
226
+
227
+ run_i2v_button.click(
228
+ fn=i2v_generation,
229
+ inputs=[
230
+ img2vid_prompt, img2vid_image, resolution, sd_steps,
231
+ guide_scale, shift_scale, seed, n_prompt
232
+ ],
233
+ outputs=[result_gallery],
234
+ )
235
+
236
+ return demo
237
+
238
+
239
+ # Main
240
+ def _parse_args():
241
+ parser = argparse.ArgumentParser(
242
+ description="Generate a video from a text prompt or image using Gradio")
243
+ parser.add_argument(
244
+ "--ckpt_dir_720p",
245
+ type=str,
246
+ default=None,
247
+ help="The path to the checkpoint directory.")
248
+ parser.add_argument(
249
+ "--ckpt_dir_480p",
250
+ type=str,
251
+ default=None,
252
+ help="The path to the checkpoint directory.")
253
+ parser.add_argument(
254
+ "--prompt_extend_method",
255
+ type=str,
256
+ default="local_qwen",
257
+ choices=["dashscope", "local_qwen"],
258
+ help="The prompt extend method to use.")
259
+ parser.add_argument(
260
+ "--prompt_extend_model",
261
+ type=str,
262
+ default=None,
263
+ help="The prompt extend model to use.")
264
+
265
+ args = parser.parse_args()
266
+ assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
267
+
268
+ return args
269
+
270
+
271
+ if __name__ == '__main__':
272
+ args = _parse_args()
273
+
274
+ print("Step1: Init prompt_expander...", end='', flush=True)
275
+ if args.prompt_extend_method == "dashscope":
276
+ prompt_expander = DashScopePromptExpander(
277
+ model_name=args.prompt_extend_model, is_vl=True)
278
+ elif args.prompt_extend_method == "local_qwen":
279
+ prompt_expander = QwenPromptExpander(
280
+ model_name=args.prompt_extend_model, is_vl=True, device=0)
281
+ else:
282
+ raise NotImplementedError(
283
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
284
+ print("done", flush=True)
285
+
286
+ demo = gradio_interface()
287
+ demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
gradio/t2i_14B_singleGPU.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ import os
5
+ import sys
6
+ import warnings
7
+
8
+ import gradio as gr
9
+
10
+ warnings.filterwarnings('ignore')
11
+
12
+ # Model
13
+ sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
14
+ import wan
15
+ from wan.configs import WAN_CONFIGS
16
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
17
+ from wan.utils.utils import cache_image
18
+
19
+ # Global Var
20
+ prompt_expander = None
21
+ wan_t2i = None
22
+
23
+
24
+ # Button Func
25
+ def prompt_enc(prompt, tar_lang):
26
+ global prompt_expander
27
+ prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
28
+ if prompt_output.status == False:
29
+ return prompt
30
+ else:
31
+ return prompt_output.prompt
32
+
33
+
34
+ def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
35
+ shift_scale, seed, n_prompt):
36
+ global wan_t2i
37
+ # print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
38
+
39
+ W = int(resolution.split("*")[0])
40
+ H = int(resolution.split("*")[1])
41
+ video = wan_t2i.generate(
42
+ txt2img_prompt,
43
+ size=(W, H),
44
+ frame_num=1,
45
+ shift=shift_scale,
46
+ sampling_steps=sd_steps,
47
+ guide_scale=guide_scale,
48
+ n_prompt=n_prompt,
49
+ seed=seed,
50
+ offload_model=True)
51
+
52
+ cache_image(
53
+ tensor=video.squeeze(1)[None],
54
+ save_file="example.png",
55
+ nrow=1,
56
+ normalize=True,
57
+ value_range=(-1, 1))
58
+
59
+ return "example.png"
60
+
61
+
62
+ # Interface
63
+ def gradio_interface():
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("""
66
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
67
+ Wan2.1 (T2I-14B)
68
+ </div>
69
+ <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
70
+ Wan: Open and Advanced Large-Scale Video Generative Models.
71
+ </div>
72
+ """)
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ txt2img_prompt = gr.Textbox(
77
+ label="Prompt",
78
+ placeholder="Describe the image you want to generate",
79
+ )
80
+ tar_lang = gr.Radio(
81
+ choices=["ZH", "EN"],
82
+ label="Target language of prompt enhance",
83
+ value="ZH")
84
+ run_p_button = gr.Button(value="Prompt Enhance")
85
+
86
+ with gr.Accordion("Advanced Options", open=True):
87
+ resolution = gr.Dropdown(
88
+ label='Resolution(Width*Height)',
89
+ choices=[
90
+ '720*1280', '1280*720', '960*960', '1088*832',
91
+ '832*1088', '480*832', '832*480', '624*624',
92
+ '704*544', '544*704'
93
+ ],
94
+ value='720*1280')
95
+
96
+ with gr.Row():
97
+ sd_steps = gr.Slider(
98
+ label="Diffusion steps",
99
+ minimum=1,
100
+ maximum=1000,
101
+ value=50,
102
+ step=1)
103
+ guide_scale = gr.Slider(
104
+ label="Guide scale",
105
+ minimum=0,
106
+ maximum=20,
107
+ value=5.0,
108
+ step=1)
109
+ with gr.Row():
110
+ shift_scale = gr.Slider(
111
+ label="Shift scale",
112
+ minimum=0,
113
+ maximum=10,
114
+ value=5.0,
115
+ step=1)
116
+ seed = gr.Slider(
117
+ label="Seed",
118
+ minimum=-1,
119
+ maximum=2147483647,
120
+ step=1,
121
+ value=-1)
122
+ n_prompt = gr.Textbox(
123
+ label="Negative Prompt",
124
+ placeholder="Describe the negative prompt you want to add"
125
+ )
126
+
127
+ run_t2i_button = gr.Button("Generate Image")
128
+
129
+ with gr.Column():
130
+ result_gallery = gr.Image(
131
+ label='Generated Image', interactive=False, height=600)
132
+
133
+ run_p_button.click(
134
+ fn=prompt_enc,
135
+ inputs=[txt2img_prompt, tar_lang],
136
+ outputs=[txt2img_prompt])
137
+
138
+ run_t2i_button.click(
139
+ fn=t2i_generation,
140
+ inputs=[
141
+ txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale,
142
+ seed, n_prompt
143
+ ],
144
+ outputs=[result_gallery],
145
+ )
146
+
147
+ return demo
148
+
149
+
150
+ # Main
151
+ def _parse_args():
152
+ parser = argparse.ArgumentParser(
153
+ description="Generate a image from a text prompt or image using Gradio")
154
+ parser.add_argument(
155
+ "--ckpt_dir",
156
+ type=str,
157
+ default="cache",
158
+ help="The path to the checkpoint directory.")
159
+ parser.add_argument(
160
+ "--prompt_extend_method",
161
+ type=str,
162
+ default="local_qwen",
163
+ choices=["dashscope", "local_qwen"],
164
+ help="The prompt extend method to use.")
165
+ parser.add_argument(
166
+ "--prompt_extend_model",
167
+ type=str,
168
+ default=None,
169
+ help="The prompt extend model to use.")
170
+
171
+ args = parser.parse_args()
172
+
173
+ return args
174
+
175
+
176
+ if __name__ == '__main__':
177
+ args = _parse_args()
178
+
179
+ print("Step1: Init prompt_expander...", end='', flush=True)
180
+ if args.prompt_extend_method == "dashscope":
181
+ prompt_expander = DashScopePromptExpander(
182
+ model_name=args.prompt_extend_model, is_vl=False)
183
+ elif args.prompt_extend_method == "local_qwen":
184
+ prompt_expander = QwenPromptExpander(
185
+ model_name=args.prompt_extend_model, is_vl=False, device=0)
186
+ else:
187
+ raise NotImplementedError(
188
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
189
+ print("done", flush=True)
190
+
191
+ print("Step2: Init 14B t2i model...", end='', flush=True)
192
+ cfg = WAN_CONFIGS['t2i-14B']
193
+ wan_t2i = wan.WanT2V(
194
+ config=cfg,
195
+ checkpoint_dir=args.ckpt_dir,
196
+ device_id=0,
197
+ rank=0,
198
+ t5_fsdp=False,
199
+ dit_fsdp=False,
200
+ use_usp=False,
201
+ )
202
+ print("done", flush=True)
203
+
204
+ demo = gradio_interface()
205
+ demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
gradio/t2v_1.3B_singleGPU.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ import os
5
+ import sys
6
+ import warnings
7
+
8
+ import gradio as gr
9
+
10
+ warnings.filterwarnings('ignore')
11
+
12
+ # Model
13
+ sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
14
+ import wan
15
+ from wan.configs import WAN_CONFIGS
16
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
17
+ from wan.utils.utils import cache_video
18
+
19
+ # Global Var
20
+ prompt_expander = None
21
+ wan_t2v = None
22
+
23
+
24
+ # Button Func
25
+ def prompt_enc(prompt, tar_lang):
26
+ global prompt_expander
27
+ prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
28
+ if prompt_output.status == False:
29
+ return prompt
30
+ else:
31
+ return prompt_output.prompt
32
+
33
+
34
+ def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
35
+ shift_scale, seed, n_prompt):
36
+ global wan_t2v
37
+ # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
38
+
39
+ W = int(resolution.split("*")[0])
40
+ H = int(resolution.split("*")[1])
41
+ video = wan_t2v.generate(
42
+ txt2vid_prompt,
43
+ size=(W, H),
44
+ shift=shift_scale,
45
+ sampling_steps=sd_steps,
46
+ guide_scale=guide_scale,
47
+ n_prompt=n_prompt,
48
+ seed=seed,
49
+ offload_model=True)
50
+
51
+ cache_video(
52
+ tensor=video[None],
53
+ save_file="example.mp4",
54
+ fps=16,
55
+ nrow=1,
56
+ normalize=True,
57
+ value_range=(-1, 1))
58
+
59
+ return "example.mp4"
60
+
61
+
62
+ # Interface
63
+ def gradio_interface():
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("""
66
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
67
+ Wan2.1 (T2V-1.3B)
68
+ </div>
69
+ <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
70
+ Wan: Open and Advanced Large-Scale Video Generative Models.
71
+ </div>
72
+ """)
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ txt2vid_prompt = gr.Textbox(
77
+ label="Prompt",
78
+ placeholder="Describe the video you want to generate",
79
+ )
80
+ tar_lang = gr.Radio(
81
+ choices=["ZH", "EN"],
82
+ label="Target language of prompt enhance",
83
+ value="ZH")
84
+ run_p_button = gr.Button(value="Prompt Enhance")
85
+
86
+ with gr.Accordion("Advanced Options", open=True):
87
+ resolution = gr.Dropdown(
88
+ label='Resolution(Width*Height)',
89
+ choices=[
90
+ '480*832',
91
+ '832*480',
92
+ '624*624',
93
+ '704*544',
94
+ '544*704',
95
+ ],
96
+ value='480*832')
97
+
98
+ with gr.Row():
99
+ sd_steps = gr.Slider(
100
+ label="Diffusion steps",
101
+ minimum=1,
102
+ maximum=1000,
103
+ value=50,
104
+ step=1)
105
+ guide_scale = gr.Slider(
106
+ label="Guide scale",
107
+ minimum=0,
108
+ maximum=20,
109
+ value=6.0,
110
+ step=1)
111
+ with gr.Row():
112
+ shift_scale = gr.Slider(
113
+ label="Shift scale",
114
+ minimum=0,
115
+ maximum=20,
116
+ value=8.0,
117
+ step=1)
118
+ seed = gr.Slider(
119
+ label="Seed",
120
+ minimum=-1,
121
+ maximum=2147483647,
122
+ step=1,
123
+ value=-1)
124
+ n_prompt = gr.Textbox(
125
+ label="Negative Prompt",
126
+ placeholder="Describe the negative prompt you want to add"
127
+ )
128
+
129
+ run_t2v_button = gr.Button("Generate Video")
130
+
131
+ with gr.Column():
132
+ result_gallery = gr.Video(
133
+ label='Generated Video', interactive=False, height=600)
134
+
135
+ run_p_button.click(
136
+ fn=prompt_enc,
137
+ inputs=[txt2vid_prompt, tar_lang],
138
+ outputs=[txt2vid_prompt])
139
+
140
+ run_t2v_button.click(
141
+ fn=t2v_generation,
142
+ inputs=[
143
+ txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
144
+ seed, n_prompt
145
+ ],
146
+ outputs=[result_gallery],
147
+ )
148
+
149
+ return demo
150
+
151
+
152
+ # Main
153
+ def _parse_args():
154
+ parser = argparse.ArgumentParser(
155
+ description="Generate a video from a text prompt or image using Gradio")
156
+ parser.add_argument(
157
+ "--ckpt_dir",
158
+ type=str,
159
+ default="cache",
160
+ help="The path to the checkpoint directory.")
161
+ parser.add_argument(
162
+ "--prompt_extend_method",
163
+ type=str,
164
+ default="local_qwen",
165
+ choices=["dashscope", "local_qwen"],
166
+ help="The prompt extend method to use.")
167
+ parser.add_argument(
168
+ "--prompt_extend_model",
169
+ type=str,
170
+ default=None,
171
+ help="The prompt extend model to use.")
172
+
173
+ args = parser.parse_args()
174
+
175
+ return args
176
+
177
+
178
+ if __name__ == '__main__':
179
+ args = _parse_args()
180
+
181
+ print("Step1: Init prompt_expander...", end='', flush=True)
182
+ if args.prompt_extend_method == "dashscope":
183
+ prompt_expander = DashScopePromptExpander(
184
+ model_name=args.prompt_extend_model, is_vl=False)
185
+ elif args.prompt_extend_method == "local_qwen":
186
+ prompt_expander = QwenPromptExpander(
187
+ model_name=args.prompt_extend_model, is_vl=False, device=0)
188
+ else:
189
+ raise NotImplementedError(
190
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
191
+ print("done", flush=True)
192
+
193
+ print("Step2: Init 1.3B t2v model...", end='', flush=True)
194
+ cfg = WAN_CONFIGS['t2v-1.3B']
195
+ wan_t2v = wan.WanT2V(
196
+ config=cfg,
197
+ checkpoint_dir=args.ckpt_dir,
198
+ device_id=0,
199
+ rank=0,
200
+ t5_fsdp=False,
201
+ dit_fsdp=False,
202
+ use_usp=False,
203
+ )
204
+ print("done", flush=True)
205
+
206
+ demo = gradio_interface()
207
+ demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
gradio/t2v_14B_singleGPU.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import os.path as osp
4
+ import os
5
+ import sys
6
+ import warnings
7
+
8
+ import gradio as gr
9
+
10
+ warnings.filterwarnings('ignore')
11
+
12
+ # Model
13
+ sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
14
+ import wan
15
+ from wan.configs import WAN_CONFIGS
16
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
17
+ from wan.utils.utils import cache_video
18
+
19
+ # Global Var
20
+ prompt_expander = None
21
+ wan_t2v = None
22
+
23
+
24
+ # Button Func
25
+ def prompt_enc(prompt, tar_lang):
26
+ global prompt_expander
27
+ prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
28
+ if prompt_output.status == False:
29
+ return prompt
30
+ else:
31
+ return prompt_output.prompt
32
+
33
+
34
+ def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
35
+ shift_scale, seed, n_prompt):
36
+ global wan_t2v
37
+ # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
38
+
39
+ W = int(resolution.split("*")[0])
40
+ H = int(resolution.split("*")[1])
41
+ video = wan_t2v.generate(
42
+ txt2vid_prompt,
43
+ size=(W, H),
44
+ shift=shift_scale,
45
+ sampling_steps=sd_steps,
46
+ guide_scale=guide_scale,
47
+ n_prompt=n_prompt,
48
+ seed=seed,
49
+ offload_model=True)
50
+
51
+ cache_video(
52
+ tensor=video[None],
53
+ save_file="example.mp4",
54
+ fps=16,
55
+ nrow=1,
56
+ normalize=True,
57
+ value_range=(-1, 1))
58
+
59
+ return "example.mp4"
60
+
61
+
62
+ # Interface
63
+ def gradio_interface():
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("""
66
+ <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
67
+ Wan2.1 (T2V-14B)
68
+ </div>
69
+ <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
70
+ Wan: Open and Advanced Large-Scale Video Generative Models.
71
+ </div>
72
+ """)
73
+
74
+ with gr.Row():
75
+ with gr.Column():
76
+ txt2vid_prompt = gr.Textbox(
77
+ label="Prompt",
78
+ placeholder="Describe the video you want to generate",
79
+ )
80
+ tar_lang = gr.Radio(
81
+ choices=["ZH", "EN"],
82
+ label="Target language of prompt enhance",
83
+ value="ZH")
84
+ run_p_button = gr.Button(value="Prompt Enhance")
85
+
86
+ with gr.Accordion("Advanced Options", open=True):
87
+ resolution = gr.Dropdown(
88
+ label='Resolution(Width*Height)',
89
+ choices=[
90
+ '720*1280', '1280*720', '960*960', '1088*832',
91
+ '832*1088', '480*832', '832*480', '624*624',
92
+ '704*544', '544*704'
93
+ ],
94
+ value='720*1280')
95
+
96
+ with gr.Row():
97
+ sd_steps = gr.Slider(
98
+ label="Diffusion steps",
99
+ minimum=1,
100
+ maximum=1000,
101
+ value=50,
102
+ step=1)
103
+ guide_scale = gr.Slider(
104
+ label="Guide scale",
105
+ minimum=0,
106
+ maximum=20,
107
+ value=5.0,
108
+ step=1)
109
+ with gr.Row():
110
+ shift_scale = gr.Slider(
111
+ label="Shift scale",
112
+ minimum=0,
113
+ maximum=10,
114
+ value=5.0,
115
+ step=1)
116
+ seed = gr.Slider(
117
+ label="Seed",
118
+ minimum=-1,
119
+ maximum=2147483647,
120
+ step=1,
121
+ value=-1)
122
+ n_prompt = gr.Textbox(
123
+ label="Negative Prompt",
124
+ placeholder="Describe the negative prompt you want to add"
125
+ )
126
+
127
+ run_t2v_button = gr.Button("Generate Video")
128
+
129
+ with gr.Column():
130
+ result_gallery = gr.Video(
131
+ label='Generated Video', interactive=False, height=600)
132
+
133
+ run_p_button.click(
134
+ fn=prompt_enc,
135
+ inputs=[txt2vid_prompt, tar_lang],
136
+ outputs=[txt2vid_prompt])
137
+
138
+ run_t2v_button.click(
139
+ fn=t2v_generation,
140
+ inputs=[
141
+ txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
142
+ seed, n_prompt
143
+ ],
144
+ outputs=[result_gallery],
145
+ )
146
+
147
+ return demo
148
+
149
+
150
+ # Main
151
+ def _parse_args():
152
+ parser = argparse.ArgumentParser(
153
+ description="Generate a video from a text prompt or image using Gradio")
154
+ parser.add_argument(
155
+ "--ckpt_dir",
156
+ type=str,
157
+ default="cache",
158
+ help="The path to the checkpoint directory.")
159
+ parser.add_argument(
160
+ "--prompt_extend_method",
161
+ type=str,
162
+ default="local_qwen",
163
+ choices=["dashscope", "local_qwen"],
164
+ help="The prompt extend method to use.")
165
+ parser.add_argument(
166
+ "--prompt_extend_model",
167
+ type=str,
168
+ default=None,
169
+ help="The prompt extend model to use.")
170
+
171
+ args = parser.parse_args()
172
+
173
+ return args
174
+
175
+
176
+ if __name__ == '__main__':
177
+ args = _parse_args()
178
+
179
+ print("Step1: Init prompt_expander...", end='', flush=True)
180
+ if args.prompt_extend_method == "dashscope":
181
+ prompt_expander = DashScopePromptExpander(
182
+ model_name=args.prompt_extend_model, is_vl=False)
183
+ elif args.prompt_extend_method == "local_qwen":
184
+ prompt_expander = QwenPromptExpander(
185
+ model_name=args.prompt_extend_model, is_vl=False, device=0)
186
+ else:
187
+ raise NotImplementedError(
188
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
189
+ print("done", flush=True)
190
+
191
+ print("Step2: Init 14B t2v model...", end='', flush=True)
192
+ cfg = WAN_CONFIGS['t2v-14B']
193
+ wan_t2v = wan.WanT2V(
194
+ config=cfg,
195
+ checkpoint_dir=args.ckpt_dir,
196
+ device_id=0,
197
+ rank=0,
198
+ t5_fsdp=False,
199
+ dit_fsdp=False,
200
+ use_usp=False,
201
+ )
202
+ print("done", flush=True)
203
+
204
+ demo = gradio_interface()
205
+ demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
gradio/vace.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import argparse
5
+ import os
6
+ import sys
7
+ import datetime
8
+ import imageio
9
+ import numpy as np
10
+ import torch
11
+ import gradio as gr
12
+
13
+ sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
14
+ import wan
15
+ from wan import WanVace, WanVaceMP
16
+ from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
17
+
18
+
19
+ class FixedSizeQueue:
20
+ def __init__(self, max_size):
21
+ self.max_size = max_size
22
+ self.queue = []
23
+ def add(self, item):
24
+ self.queue.insert(0, item)
25
+ if len(self.queue) > self.max_size:
26
+ self.queue.pop()
27
+ def get(self):
28
+ return self.queue
29
+ def __repr__(self):
30
+ return str(self.queue)
31
+
32
+
33
+ class VACEInference:
34
+ def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
35
+ self.cfg = cfg
36
+ self.save_dir = cfg.save_dir
37
+ self.gallery_share = gallery_share
38
+ self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
39
+ if not skip_load:
40
+ if not args.mp:
41
+ self.pipe = WanVace(
42
+ config=WAN_CONFIGS[cfg.model_name],
43
+ checkpoint_dir=cfg.ckpt_dir,
44
+ device_id=0,
45
+ rank=0,
46
+ t5_fsdp=False,
47
+ dit_fsdp=False,
48
+ use_usp=False,
49
+ )
50
+ else:
51
+ self.pipe = WanVaceMP(
52
+ config=WAN_CONFIGS[cfg.model_name],
53
+ checkpoint_dir=cfg.ckpt_dir,
54
+ use_usp=True,
55
+ ulysses_size=cfg.ulysses_size,
56
+ ring_size=cfg.ring_size
57
+ )
58
+
59
+
60
+ def create_ui(self, *args, **kwargs):
61
+ gr.Markdown("""
62
+ <div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
63
+ <a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
64
+ </div>
65
+ """)
66
+ with gr.Row(variant='panel', equal_height=True):
67
+ with gr.Column(scale=1, min_width=0):
68
+ self.src_video = gr.Video(
69
+ label="src_video",
70
+ sources=['upload'],
71
+ value=None,
72
+ interactive=True)
73
+ with gr.Column(scale=1, min_width=0):
74
+ self.src_mask = gr.Video(
75
+ label="src_mask",
76
+ sources=['upload'],
77
+ value=None,
78
+ interactive=True)
79
+ #
80
+ with gr.Row(variant='panel', equal_height=True):
81
+ with gr.Column(scale=1, min_width=0):
82
+ with gr.Row(equal_height=True):
83
+ self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
84
+ height=200,
85
+ interactive=True,
86
+ type='filepath',
87
+ image_mode='RGB',
88
+ sources=['upload'],
89
+ elem_id="src_ref_image_1",
90
+ format='png')
91
+ self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
92
+ height=200,
93
+ interactive=True,
94
+ type='filepath',
95
+ image_mode='RGB',
96
+ sources=['upload'],
97
+ elem_id="src_ref_image_2",
98
+ format='png')
99
+ self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
100
+ height=200,
101
+ interactive=True,
102
+ type='filepath',
103
+ image_mode='RGB',
104
+ sources=['upload'],
105
+ elem_id="src_ref_image_3",
106
+ format='png')
107
+ with gr.Row(variant='panel', equal_height=True):
108
+ with gr.Column(scale=1):
109
+ self.prompt = gr.Textbox(
110
+ show_label=False,
111
+ placeholder="positive_prompt_input",
112
+ elem_id='positive_prompt',
113
+ container=True,
114
+ autofocus=True,
115
+ elem_classes='type_row',
116
+ visible=True,
117
+ lines=2)
118
+ self.negative_prompt = gr.Textbox(
119
+ show_label=False,
120
+ value=self.pipe.config.sample_neg_prompt,
121
+ placeholder="negative_prompt_input",
122
+ elem_id='negative_prompt',
123
+ container=True,
124
+ autofocus=False,
125
+ elem_classes='type_row',
126
+ visible=True,
127
+ interactive=True,
128
+ lines=1)
129
+ #
130
+ with gr.Row(variant='panel', equal_height=True):
131
+ with gr.Column(scale=1, min_width=0):
132
+ with gr.Row(equal_height=True):
133
+ self.shift_scale = gr.Slider(
134
+ label='shift_scale',
135
+ minimum=0.0,
136
+ maximum=100.0,
137
+ step=1.0,
138
+ value=16.0,
139
+ interactive=True)
140
+ self.sample_steps = gr.Slider(
141
+ label='sample_steps',
142
+ minimum=1,
143
+ maximum=100,
144
+ step=1,
145
+ value=25,
146
+ interactive=True)
147
+ self.context_scale = gr.Slider(
148
+ label='context_scale',
149
+ minimum=0.0,
150
+ maximum=2.0,
151
+ step=0.1,
152
+ value=1.0,
153
+ interactive=True)
154
+ self.guide_scale = gr.Slider(
155
+ label='guide_scale',
156
+ minimum=1,
157
+ maximum=10,
158
+ step=0.5,
159
+ value=5.0,
160
+ interactive=True)
161
+ self.infer_seed = gr.Slider(minimum=-1,
162
+ maximum=10000000,
163
+ value=2025,
164
+ label="Seed")
165
+ #
166
+ with gr.Accordion(label="Usable without source video", open=False):
167
+ with gr.Row(equal_height=True):
168
+ self.output_height = gr.Textbox(
169
+ label='resolutions_height',
170
+ # value=480,
171
+ value=720,
172
+ interactive=True)
173
+ self.output_width = gr.Textbox(
174
+ label='resolutions_width',
175
+ # value=832,
176
+ value=1280,
177
+ interactive=True)
178
+ self.frame_rate = gr.Textbox(
179
+ label='frame_rate',
180
+ value=16,
181
+ interactive=True)
182
+ self.num_frames = gr.Textbox(
183
+ label='num_frames',
184
+ value=81,
185
+ interactive=True)
186
+ #
187
+ with gr.Row(equal_height=True):
188
+ with gr.Column(scale=5):
189
+ self.generate_button = gr.Button(
190
+ value='Run',
191
+ elem_classes='type_row',
192
+ elem_id='generate_button',
193
+ visible=True)
194
+ with gr.Column(scale=1):
195
+ self.refresh_button = gr.Button(value='\U0001f504') # 🔄
196
+ #
197
+ self.output_gallery = gr.Gallery(
198
+ label="output_gallery",
199
+ value=[],
200
+ interactive=False,
201
+ allow_preview=True,
202
+ preview=True)
203
+
204
+
205
+ def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
206
+ output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
207
+ src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
208
+ x is not None]
209
+ src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
210
+ [src_mask],
211
+ [src_ref_images],
212
+ num_frames=num_frames,
213
+ image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
214
+ device=self.pipe.device)
215
+ video = self.pipe.generate(
216
+ prompt,
217
+ src_video,
218
+ src_mask,
219
+ src_ref_images,
220
+ size=(output_width, output_height),
221
+ context_scale=context_scale,
222
+ shift=shift_scale,
223
+ sampling_steps=sample_steps,
224
+ guide_scale=guide_scale,
225
+ n_prompt=negative_prompt,
226
+ seed=infer_seed,
227
+ offload_model=True)
228
+
229
+ name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
230
+ video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
231
+ video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
232
+
233
+ try:
234
+ writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
235
+ for frame in video_frames:
236
+ writer.append_data(frame)
237
+ writer.close()
238
+ print(video_path)
239
+ except Exception as e:
240
+ raise gr.Error(f"Video save error: {e}")
241
+
242
+ if self.gallery_share:
243
+ self.gallery_share_data.add(video_path)
244
+ return self.gallery_share_data.get()
245
+ else:
246
+ return [video_path]
247
+
248
+ def set_callbacks(self, **kwargs):
249
+ self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
250
+ self.gen_outputs = [self.output_gallery]
251
+ self.generate_button.click(self.generate,
252
+ inputs=self.gen_inputs,
253
+ outputs=self.gen_outputs,
254
+ queue=True)
255
+ self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
256
+
257
+
258
+ if __name__ == '__main__':
259
+ parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n')
260
+ parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
261
+ parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
262
+ parser.add_argument('--root_path', dest='root_path', help='', default=None)
263
+ parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
264
+ parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",)
265
+ parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.")
266
+ parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
267
+ parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.")
268
+ parser.add_argument(
269
+ "--ckpt_dir",
270
+ type=str,
271
+ # default='models/VACE-Wan2.1-1.3B-Preview',
272
+ default='models/Wan2.1-VACE-14B/',
273
+ help="The path to the checkpoint directory.",
274
+ )
275
+ parser.add_argument(
276
+ "--offload_to_cpu",
277
+ action="store_true",
278
+ help="Offloading unnecessary computations to CPU.",
279
+ )
280
+
281
+ args = parser.parse_args()
282
+
283
+ if not os.path.exists(args.save_dir):
284
+ os.makedirs(args.save_dir, exist_ok=True)
285
+
286
+ with gr.Blocks() as demo:
287
+ infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
288
+ infer_gr.create_ui()
289
+ infer_gr.set_callbacks()
290
+ allowed_paths = [args.save_dir]
291
+ demo.queue(status_update_rate=1).launch(server_name=args.server_name,
292
+ server_port=args.server_port,
293
+ root_path=args.root_path,
294
+ allowed_paths=allowed_paths,
295
+ show_error=True, debug=True)
pyproject.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "wan"
7
+ version = "2.1.0"
8
+ description = "Wan: Open and Advanced Large-Scale Video Generative Models"
9
+ authors = [
10
+ { name = "Wan Team", email = "[email protected]" }
11
+ ]
12
+ license = { file = "LICENSE.txt" }
13
+ readme = "README.md"
14
+ requires-python = ">=3.10,<4.0"
15
+ dependencies = [
16
+ "torch>=2.4.0",
17
+ "torchvision>=0.19.0",
18
+ "opencv-python>=4.9.0.80",
19
+ "diffusers>=0.31.0",
20
+ "transformers>=4.49.0",
21
+ "tokenizers>=0.20.3",
22
+ "accelerate>=1.1.1",
23
+ "tqdm",
24
+ "imageio",
25
+ "easydict",
26
+ "ftfy",
27
+ "dashscope",
28
+ "imageio-ffmpeg",
29
+ "flash_attn",
30
+ "gradio>=5.0.0",
31
+ "numpy>=1.23.5,<2"
32
+ ]
33
+
34
+ [project.optional-dependencies]
35
+ dev = [
36
+ "pytest",
37
+ "black",
38
+ "flake8",
39
+ "isort",
40
+ "mypy",
41
+ "huggingface-hub[cli]"
42
+ ]
43
+
44
+ [project.urls]
45
+ homepage = "https://wanxai.com"
46
+ documentation = "https://github.com/Wan-Video/Wan2.1"
47
+ repository = "https://github.com/Wan-Video/Wan2.1"
48
+ huggingface = "https://huggingface.co/Wan-AI/"
49
+ modelscope = "https://modelscope.cn/organization/Wan-AI"
50
+ discord = "https://discord.gg/p5XbdQV7"
51
+
52
+ [tool.setuptools]
53
+ packages = ["wan"]
54
+
55
+ [tool.setuptools.package-data]
56
+ "wan" = ["**/*.py"]
57
+
58
+ [tool.black]
59
+ line-length = 88
60
+
61
+ [tool.isort]
62
+ profile = "black"
63
+
64
+ [tool.mypy]
65
+ strict = true
66
+
67
+
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.4.0
2
+ torchvision>=0.19.0
3
+ opencv-python>=4.9.0.80
4
+ diffusers>=0.31.0
5
+ transformers>=4.49.0
6
+ tokenizers>=0.20.3
7
+ accelerate>=1.1.1
8
+ tqdm
9
+ imageio
10
+ easydict
11
+ ftfy
12
+ dashscope
13
+ imageio-ffmpeg
14
+ flash_attn
15
+ gradio>=5.0.0
16
+ numpy>=1.23.5,<2
tests/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use.
3
+
4
+ ```bash
5
+ bash ./test.sh <local model dir> <gpu number>
6
+ ```
tests/test.sh ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+
4
+ if [ "$#" -eq 2 ]; then
5
+ MODEL_DIR=$(realpath "$1")
6
+ GPUS=$2
7
+ else
8
+ echo "Usage: $0 <local model dir> <gpu number>"
9
+ exit 1
10
+ fi
11
+
12
+ SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
13
+ REPO_ROOT="$(dirname "$SCRIPT_DIR")"
14
+ cd "$REPO_ROOT" || exit 1
15
+
16
+ PY_FILE=./generate.py
17
+
18
+
19
+ function t2v_1_3B() {
20
+ T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B"
21
+
22
+ # 1-GPU Test
23
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: "
24
+ python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR
25
+
26
+ # Multiple GPU Test
27
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: "
28
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
29
+
30
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: "
31
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
32
+
33
+ if [ -n "${DASH_API_KEY+x}" ]; then
34
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: "
35
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
36
+ else
37
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
38
+ fi
39
+ }
40
+
41
+ function t2v_14B() {
42
+ T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
43
+
44
+ # 1-GPU Test
45
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: "
46
+ python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
47
+
48
+ # Multiple GPU Test
49
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: "
50
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
51
+
52
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: "
53
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
54
+ }
55
+
56
+
57
+
58
+ function t2i_14B() {
59
+ T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
60
+
61
+ # 1-GPU Test
62
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: "
63
+ python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
64
+
65
+ # Multiple GPU Test
66
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: "
67
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
68
+
69
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: "
70
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
71
+ }
72
+
73
+
74
+ function i2v_14B_480p() {
75
+ I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P"
76
+
77
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
78
+ python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR
79
+
80
+ # Multiple GPU Test
81
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
82
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
83
+
84
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
85
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
86
+
87
+ if [ -n "${DASH_API_KEY+x}" ]; then
88
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
89
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
90
+ else
91
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
92
+ fi
93
+ }
94
+
95
+
96
+ function i2v_14B_720p() {
97
+ I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"
98
+
99
+ # 1-GPU Test
100
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
101
+ python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR
102
+
103
+ # Multiple GPU Test
104
+ echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
105
+ torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
106
+ }
107
+
108
+ function vace_1_3B() {
109
+ VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
110
+ torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR
111
+
112
+ }
113
+
114
+
115
+ t2i_14B
116
+ t2v_1_3B
117
+ t2v_14B
118
+ i2v_14B_480p
119
+ i2v_14B_720p
120
+ vace_1_3B
wan/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import configs, distributed, modules
2
+ from .image2video import WanI2V
3
+ from .text2video import WanT2V
4
+ from .first_last_frame2video import WanFLF2V
5
+ from .vace import WanVace, WanVaceMP
wan/configs/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import copy
3
+ import os
4
+
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+ from .wan_i2v_14B import i2v_14B
8
+ from .wan_t2v_1_3B import t2v_1_3B
9
+ from .wan_t2v_14B import t2v_14B
10
+
11
+ # the config of t2i_14B is the same as t2v_14B
12
+ t2i_14B = copy.deepcopy(t2v_14B)
13
+ t2i_14B.__name__ = 'Config: Wan T2I 14B'
14
+
15
+ # the config of flf2v_14B is the same as i2v_14B
16
+ flf2v_14B = copy.deepcopy(i2v_14B)
17
+ flf2v_14B.__name__ = 'Config: Wan FLF2V 14B'
18
+ flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt
19
+
20
+ WAN_CONFIGS = {
21
+ 't2v-14B': t2v_14B,
22
+ 't2v-1.3B': t2v_1_3B,
23
+ 'i2v-14B': i2v_14B,
24
+ 't2i-14B': t2i_14B,
25
+ 'flf2v-14B': flf2v_14B,
26
+ 'vace-1.3B': t2v_1_3B,
27
+ 'vace-14B': t2v_14B,
28
+ }
29
+
30
+ SIZE_CONFIGS = {
31
+ '720*1280': (720, 1280),
32
+ '1280*720': (1280, 720),
33
+ '480*832': (480, 832),
34
+ '832*480': (832, 480),
35
+ '1024*1024': (1024, 1024),
36
+ }
37
+
38
+ MAX_AREA_CONFIGS = {
39
+ '720*1280': 720 * 1280,
40
+ '1280*720': 1280 * 720,
41
+ '480*832': 480 * 832,
42
+ '832*480': 832 * 480,
43
+ }
44
+
45
+ SUPPORTED_SIZES = {
46
+ 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
47
+ 't2v-1.3B': ('480*832', '832*480'),
48
+ 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
49
+ 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
50
+ 't2i-14B': tuple(SIZE_CONFIGS.keys()),
51
+ 'vace-1.3B': ('480*832', '832*480'),
52
+ 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480')
53
+ }
wan/configs/shared_config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ #------------------------ Wan shared config ------------------------#
6
+ wan_shared_cfg = EasyDict()
7
+
8
+ # t5
9
+ wan_shared_cfg.t5_model = 'umt5_xxl'
10
+ wan_shared_cfg.t5_dtype = torch.bfloat16
11
+ wan_shared_cfg.text_len = 512
12
+
13
+ # transformer
14
+ wan_shared_cfg.param_dtype = torch.bfloat16
15
+
16
+ # inference
17
+ wan_shared_cfg.num_train_timesteps = 1000
18
+ wan_shared_cfg.sample_fps = 16
19
+ wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
wan/configs/wan_i2v_14B.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ from .shared_config import wan_shared_cfg
6
+
7
+ #------------------------ Wan I2V 14B ------------------------#
8
+
9
+ i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
10
+ i2v_14B.update(wan_shared_cfg)
11
+ i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
12
+
13
+ i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
14
+ i2v_14B.t5_tokenizer = 'google/umt5-xxl'
15
+
16
+ # clip
17
+ i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
18
+ i2v_14B.clip_dtype = torch.float16
19
+ i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
20
+ i2v_14B.clip_tokenizer = 'xlm-roberta-large'
21
+
22
+ # vae
23
+ i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
24
+ i2v_14B.vae_stride = (4, 8, 8)
25
+
26
+ # transformer
27
+ i2v_14B.patch_size = (1, 2, 2)
28
+ i2v_14B.dim = 5120
29
+ i2v_14B.ffn_dim = 13824
30
+ i2v_14B.freq_dim = 256
31
+ i2v_14B.num_heads = 40
32
+ i2v_14B.num_layers = 40
33
+ i2v_14B.window_size = (-1, -1)
34
+ i2v_14B.qk_norm = True
35
+ i2v_14B.cross_attn_norm = True
36
+ i2v_14B.eps = 1e-6
wan/configs/wan_t2v_14B.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan T2V 14B ------------------------#
7
+
8
+ t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
9
+ t2v_14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_14B.patch_size = (1, 2, 2)
21
+ t2v_14B.dim = 5120
22
+ t2v_14B.ffn_dim = 13824
23
+ t2v_14B.freq_dim = 256
24
+ t2v_14B.num_heads = 40
25
+ t2v_14B.num_layers = 40
26
+ t2v_14B.window_size = (-1, -1)
27
+ t2v_14B.qk_norm = True
28
+ t2v_14B.cross_attn_norm = True
29
+ t2v_14B.eps = 1e-6
wan/configs/wan_t2v_1_3B.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan T2V 1.3B ------------------------#
7
+
8
+ t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
9
+ t2v_1_3B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_1_3B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_1_3B.patch_size = (1, 2, 2)
21
+ t2v_1_3B.dim = 1536
22
+ t2v_1_3B.ffn_dim = 8960
23
+ t2v_1_3B.freq_dim = 256
24
+ t2v_1_3B.num_heads = 12
25
+ t2v_1_3B.num_layers = 30
26
+ t2v_1_3B.window_size = (-1, -1)
27
+ t2v_1_3B.qk_norm = True
28
+ t2v_1_3B.cross_attn_norm = True
29
+ t2v_1_3B.eps = 1e-6
wan/distributed/__init__.py ADDED
File without changes
wan/distributed/fsdp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
+ from torch.distributed.utils import _free_storage
10
+
11
+ def shard_model(
12
+ model,
13
+ device_id,
14
+ param_dtype=torch.bfloat16,
15
+ reduce_dtype=torch.float32,
16
+ buffer_dtype=torch.float32,
17
+ process_group=None,
18
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
19
+ sync_module_states=True,
20
+ ):
21
+ model = FSDP(
22
+ module=model,
23
+ process_group=process_group,
24
+ sharding_strategy=sharding_strategy,
25
+ auto_wrap_policy=partial(
26
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
27
+ mixed_precision=MixedPrecision(
28
+ param_dtype=param_dtype,
29
+ reduce_dtype=reduce_dtype,
30
+ buffer_dtype=buffer_dtype),
31
+ device_id=device_id,
32
+ sync_module_states=sync_module_states)
33
+ return model
34
+
35
+ def free_model(model):
36
+ for m in model.modules():
37
+ if isinstance(m, FSDP):
38
+ _free_storage(m._handle.flat_param.data)
39
+ del model
40
+ gc.collect()
41
+ torch.cuda.empty_cache()
wan/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size,
6
+ get_sp_group)
7
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8
+
9
+ from ..modules.model import sinusoidal_embedding_1d
10
+
11
+
12
+ def pad_freqs(original_tensor, target_len):
13
+ seq_len, s1, s2 = original_tensor.shape
14
+ pad_size = target_len - seq_len
15
+ padding_tensor = torch.ones(
16
+ pad_size,
17
+ s1,
18
+ s2,
19
+ dtype=original_tensor.dtype,
20
+ device=original_tensor.device)
21
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
22
+ return padded_tensor
23
+
24
+
25
+ @amp.autocast(enabled=False)
26
+ def rope_apply(x, grid_sizes, freqs):
27
+ """
28
+ x: [B, L, N, C].
29
+ grid_sizes: [B, 3].
30
+ freqs: [M, C // 2].
31
+ """
32
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
33
+ # split freqs
34
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
35
+
36
+ # loop over samples
37
+ output = []
38
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
39
+ seq_len = f * h * w
40
+
41
+ # precompute multipliers
42
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
43
+ s, n, -1, 2))
44
+ freqs_i = torch.cat([
45
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
46
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
47
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
48
+ ],
49
+ dim=-1).reshape(seq_len, 1, -1)
50
+
51
+ # apply rotary embedding
52
+ sp_size = get_sequence_parallel_world_size()
53
+ sp_rank = get_sequence_parallel_rank()
54
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
55
+ s_per_rank = s
56
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
57
+ s_per_rank), :, :]
58
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
59
+ x_i = torch.cat([x_i, x[i, s:]])
60
+
61
+ # append to collection
62
+ output.append(x_i)
63
+ return torch.stack(output).float()
64
+
65
+
66
+ def usp_dit_forward_vace(
67
+ self,
68
+ x,
69
+ vace_context,
70
+ seq_len,
71
+ kwargs
72
+ ):
73
+ # embeddings
74
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
75
+ c = [u.flatten(2).transpose(1, 2) for u in c]
76
+ c = torch.cat([
77
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
78
+ dim=1) for u in c
79
+ ])
80
+
81
+ # arguments
82
+ new_kwargs = dict(x=x)
83
+ new_kwargs.update(kwargs)
84
+
85
+ # Context Parallel
86
+ c = torch.chunk(
87
+ c, get_sequence_parallel_world_size(),
88
+ dim=1)[get_sequence_parallel_rank()]
89
+
90
+ hints = []
91
+ for block in self.vace_blocks:
92
+ c, c_skip = block(c, **new_kwargs)
93
+ hints.append(c_skip)
94
+ return hints
95
+
96
+
97
+ def usp_dit_forward(
98
+ self,
99
+ x,
100
+ t,
101
+ context,
102
+ seq_len,
103
+ vace_context=None,
104
+ vace_context_scale=1.0,
105
+ clip_fea=None,
106
+ y=None,
107
+ ):
108
+ """
109
+ x: A list of videos each with shape [C, T, H, W].
110
+ t: [B].
111
+ context: A list of text embeddings each with shape [L, C].
112
+ """
113
+ if self.model_type == 'i2v':
114
+ assert clip_fea is not None and y is not None
115
+ # params
116
+ device = self.patch_embedding.weight.device
117
+ if self.freqs.device != device:
118
+ self.freqs = self.freqs.to(device)
119
+
120
+ if self.model_type != 'vace' and y is not None:
121
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
122
+
123
+ # embeddings
124
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
125
+ grid_sizes = torch.stack(
126
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
127
+ x = [u.flatten(2).transpose(1, 2) for u in x]
128
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
129
+ assert seq_lens.max() <= seq_len
130
+ x = torch.cat([
131
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
132
+ for u in x
133
+ ])
134
+
135
+ # time embeddings
136
+ with amp.autocast(dtype=torch.float32):
137
+ e = self.time_embedding(
138
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
139
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
140
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
141
+
142
+ # context
143
+ context_lens = None
144
+ context = self.text_embedding(
145
+ torch.stack([
146
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
147
+ for u in context
148
+ ]))
149
+
150
+ if self.model_type != 'vace' and clip_fea is not None:
151
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
152
+ context = torch.concat([context_clip, context], dim=1)
153
+
154
+ # arguments
155
+ kwargs = dict(
156
+ e=e0,
157
+ seq_lens=seq_lens,
158
+ grid_sizes=grid_sizes,
159
+ freqs=self.freqs,
160
+ context=context,
161
+ context_lens=context_lens)
162
+
163
+ # Context Parallel
164
+ x = torch.chunk(
165
+ x, get_sequence_parallel_world_size(),
166
+ dim=1)[get_sequence_parallel_rank()]
167
+
168
+ if self.model_type == 'vace':
169
+ hints = self.forward_vace(x, vace_context, seq_len, kwargs)
170
+ kwargs['hints'] = hints
171
+ kwargs['context_scale'] = vace_context_scale
172
+
173
+ for block in self.blocks:
174
+ x = block(x, **kwargs)
175
+
176
+ # head
177
+ x = self.head(x, e)
178
+
179
+ # Context Parallel
180
+ x = get_sp_group().all_gather(x, dim=1)
181
+
182
+ # unpatchify
183
+ x = self.unpatchify(x, grid_sizes)
184
+ return [u.float() for u in x]
185
+
186
+
187
+ def usp_attn_forward(self,
188
+ x,
189
+ seq_lens,
190
+ grid_sizes,
191
+ freqs,
192
+ dtype=torch.bfloat16):
193
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
194
+ half_dtypes = (torch.float16, torch.bfloat16)
195
+
196
+ def half(x):
197
+ return x if x.dtype in half_dtypes else x.to(dtype)
198
+
199
+ # query, key, value function
200
+ def qkv_fn(x):
201
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
202
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
203
+ v = self.v(x).view(b, s, n, d)
204
+ return q, k, v
205
+
206
+ q, k, v = qkv_fn(x)
207
+ q = rope_apply(q, grid_sizes, freqs)
208
+ k = rope_apply(k, grid_sizes, freqs)
209
+
210
+ # TODO: We should use unpaded q,k,v for attention.
211
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
212
+ # if k_lens is not None:
213
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
214
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
215
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
216
+
217
+ x = xFuserLongContextAttention()(
218
+ None,
219
+ query=half(q),
220
+ key=half(k),
221
+ value=half(v),
222
+ window_size=self.window_size)
223
+
224
+ # TODO: padding after attention.
225
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
226
+
227
+ # output
228
+ x = x.flatten(2)
229
+ x = self.o(x)
230
+ return x
wan/first_last_frame2video.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ import torchvision.transforms.functional as TF
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .modules.clip import CLIPModel
21
+ from .modules.model import WanModel
22
+ from .modules.t5 import T5EncoderModel
23
+ from .modules.vae import WanVAE
24
+ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
+ get_sampling_sigmas, retrieve_timesteps)
26
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
+
28
+
29
+ class WanFLF2V:
30
+
31
+ def __init__(
32
+ self,
33
+ config,
34
+ checkpoint_dir,
35
+ device_id=0,
36
+ rank=0,
37
+ t5_fsdp=False,
38
+ dit_fsdp=False,
39
+ use_usp=False,
40
+ t5_cpu=False,
41
+ init_on_cpu=True,
42
+ ):
43
+ r"""
44
+ Initializes the image-to-video generation model components.
45
+
46
+ Args:
47
+ config (EasyDict):
48
+ Object containing model parameters initialized from config.py
49
+ checkpoint_dir (`str`):
50
+ Path to directory containing model checkpoints
51
+ device_id (`int`, *optional*, defaults to 0):
52
+ Id of target GPU device
53
+ rank (`int`, *optional*, defaults to 0):
54
+ Process rank for distributed training
55
+ t5_fsdp (`bool`, *optional*, defaults to False):
56
+ Enable FSDP sharding for T5 model
57
+ dit_fsdp (`bool`, *optional*, defaults to False):
58
+ Enable FSDP sharding for DiT model
59
+ use_usp (`bool`, *optional*, defaults to False):
60
+ Enable distribution strategy of USP.
61
+ t5_cpu (`bool`, *optional*, defaults to False):
62
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
63
+ init_on_cpu (`bool`, *optional*, defaults to True):
64
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
65
+ """
66
+ self.device = torch.device(f"cuda:{device_id}")
67
+ self.config = config
68
+ self.rank = rank
69
+ self.use_usp = use_usp
70
+ self.t5_cpu = t5_cpu
71
+
72
+ self.num_train_timesteps = config.num_train_timesteps
73
+ self.param_dtype = config.param_dtype
74
+
75
+ shard_fn = partial(shard_model, device_id=device_id)
76
+ self.text_encoder = T5EncoderModel(
77
+ text_len=config.text_len,
78
+ dtype=config.t5_dtype,
79
+ device=torch.device('cpu'),
80
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
81
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
82
+ shard_fn=shard_fn if t5_fsdp else None,
83
+ )
84
+
85
+ self.vae_stride = config.vae_stride
86
+ self.patch_size = config.patch_size
87
+ self.vae = WanVAE(
88
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
89
+ device=self.device)
90
+
91
+ self.clip = CLIPModel(
92
+ dtype=config.clip_dtype,
93
+ device=self.device,
94
+ checkpoint_path=os.path.join(checkpoint_dir,
95
+ config.clip_checkpoint),
96
+ tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
97
+
98
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
99
+ self.model = WanModel.from_pretrained(checkpoint_dir)
100
+ self.model.eval().requires_grad_(False)
101
+
102
+ if t5_fsdp or dit_fsdp or use_usp:
103
+ init_on_cpu = False
104
+
105
+ if use_usp:
106
+ from xfuser.core.distributed import \
107
+ get_sequence_parallel_world_size
108
+
109
+ from .distributed.xdit_context_parallel import (usp_attn_forward,
110
+ usp_dit_forward)
111
+ for block in self.model.blocks:
112
+ block.self_attn.forward = types.MethodType(
113
+ usp_attn_forward, block.self_attn)
114
+ self.model.forward = types.MethodType(usp_dit_forward, self.model)
115
+ self.sp_size = get_sequence_parallel_world_size()
116
+ else:
117
+ self.sp_size = 1
118
+
119
+ if dist.is_initialized():
120
+ dist.barrier()
121
+ if dit_fsdp:
122
+ self.model = shard_fn(self.model)
123
+ else:
124
+ if not init_on_cpu:
125
+ self.model.to(self.device)
126
+
127
+ self.sample_neg_prompt = config.sample_neg_prompt
128
+
129
+ def generate(self,
130
+ input_prompt,
131
+ first_frame,
132
+ last_frame,
133
+ max_area=720 * 1280,
134
+ frame_num=81,
135
+ shift=16,
136
+ sample_solver='unipc',
137
+ sampling_steps=50,
138
+ guide_scale=5.5,
139
+ n_prompt="",
140
+ seed=-1,
141
+ offload_model=True):
142
+ r"""
143
+ Generates video frames from input first-last frame and text prompt using diffusion process.
144
+
145
+ Args:
146
+ input_prompt (`str`):
147
+ Text prompt for content generation.
148
+ first_frame (PIL.Image.Image):
149
+ Input image tensor. Shape: [3, H, W]
150
+ last_frame (PIL.Image.Image):
151
+ Input image tensor. Shape: [3, H, W]
152
+ [NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized
153
+ to match first_frame.
154
+ max_area (`int`, *optional*, defaults to 720*1280):
155
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
156
+ frame_num (`int`, *optional*, defaults to 81):
157
+ How many frames to sample from a video. The number should be 4n+1
158
+ shift (`float`, *optional*, defaults to 5.0):
159
+ Noise schedule shift parameter. Affects temporal dynamics
160
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
161
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
162
+ Solver used to sample the video.
163
+ sampling_steps (`int`, *optional*, defaults to 40):
164
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
165
+ guide_scale (`float`, *optional*, defaults 5.0):
166
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity
167
+ n_prompt (`str`, *optional*, defaults to ""):
168
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
169
+ seed (`int`, *optional*, defaults to -1):
170
+ Random seed for noise generation. If -1, use random seed
171
+ offload_model (`bool`, *optional*, defaults to True):
172
+ If True, offloads models to CPU during generation to save VRAM
173
+
174
+ Returns:
175
+ torch.Tensor:
176
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
177
+ - C: Color channels (3 for RGB)
178
+ - N: Number of frames (81)
179
+ - H: Frame height (from max_area)
180
+ - W: Frame width from max_area)
181
+ """
182
+ first_frame_size = first_frame.size
183
+ last_frame_size = last_frame.size
184
+ first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device)
185
+ last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device)
186
+
187
+ F = frame_num
188
+ first_frame_h, first_frame_w = first_frame.shape[1:]
189
+ aspect_ratio = first_frame_h / first_frame_w
190
+ lat_h = round(
191
+ np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
192
+ self.patch_size[1] * self.patch_size[1])
193
+ lat_w = round(
194
+ np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
195
+ self.patch_size[2] * self.patch_size[2])
196
+ first_frame_h = lat_h * self.vae_stride[1]
197
+ first_frame_w = lat_w * self.vae_stride[2]
198
+ if first_frame_size != last_frame_size:
199
+ # 1. resize
200
+ last_frame_resize_ratio = max(
201
+ first_frame_size[0] / last_frame_size[0],
202
+ first_frame_size[1] / last_frame_size[1]
203
+ )
204
+ last_frame_size = [
205
+ round(last_frame_size[0] * last_frame_resize_ratio),
206
+ round(last_frame_size[1] * last_frame_resize_ratio),
207
+ ]
208
+ # 2. center crop
209
+ last_frame = TF.center_crop(last_frame, last_frame_size)
210
+
211
+ max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
212
+ self.patch_size[1] * self.patch_size[2])
213
+ max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
214
+
215
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
216
+ seed_g = torch.Generator(device=self.device)
217
+ seed_g.manual_seed(seed)
218
+ noise = torch.randn(
219
+ 16,
220
+ (F - 1) // 4 + 1,
221
+ lat_h,
222
+ lat_w,
223
+ dtype=torch.float32,
224
+ generator=seed_g,
225
+ device=self.device)
226
+
227
+ msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
228
+ msk[:, 1: -1] = 0
229
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
230
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
231
+ msk = msk.transpose(1, 2)[0]
232
+
233
+ if n_prompt == "":
234
+ n_prompt = self.sample_neg_prompt
235
+
236
+ # preprocess
237
+ if not self.t5_cpu:
238
+ self.text_encoder.model.to(self.device)
239
+ context = self.text_encoder([input_prompt], self.device)
240
+ context_null = self.text_encoder([n_prompt], self.device)
241
+ if offload_model:
242
+ self.text_encoder.model.cpu()
243
+ else:
244
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
245
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
246
+ context = [t.to(self.device) for t in context]
247
+ context_null = [t.to(self.device) for t in context_null]
248
+
249
+ self.clip.model.to(self.device)
250
+ clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]])
251
+ if offload_model:
252
+ self.clip.model.cpu()
253
+
254
+ y = self.vae.encode([
255
+ torch.concat([
256
+ torch.nn.functional.interpolate(
257
+ first_frame[None].cpu(),
258
+ size=(first_frame_h, first_frame_w),
259
+ mode='bicubic'
260
+ ).transpose(0, 1),
261
+ torch.zeros(3, F - 2, first_frame_h, first_frame_w),
262
+ torch.nn.functional.interpolate(
263
+ last_frame[None].cpu(),
264
+ size=(first_frame_h, first_frame_w),
265
+ mode='bicubic'
266
+ ).transpose(0, 1),
267
+ ], dim=1).to(self.device)
268
+ ])[0]
269
+ y = torch.concat([msk, y])
270
+
271
+ @contextmanager
272
+ def noop_no_sync():
273
+ yield
274
+
275
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
276
+
277
+ # evaluation mode
278
+ with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
279
+
280
+ if sample_solver == 'unipc':
281
+ sample_scheduler = FlowUniPCMultistepScheduler(
282
+ num_train_timesteps=self.num_train_timesteps,
283
+ shift=1,
284
+ use_dynamic_shifting=False)
285
+ sample_scheduler.set_timesteps(
286
+ sampling_steps, device=self.device, shift=shift)
287
+ timesteps = sample_scheduler.timesteps
288
+ elif sample_solver == 'dpm++':
289
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
290
+ num_train_timesteps=self.num_train_timesteps,
291
+ shift=1,
292
+ use_dynamic_shifting=False)
293
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
294
+ timesteps, _ = retrieve_timesteps(
295
+ sample_scheduler,
296
+ device=self.device,
297
+ sigmas=sampling_sigmas)
298
+ else:
299
+ raise NotImplementedError("Unsupported solver.")
300
+
301
+ # sample videos
302
+ latent = noise
303
+
304
+ arg_c = {
305
+ 'context': [context[0]],
306
+ 'clip_fea': clip_context,
307
+ 'seq_len': max_seq_len,
308
+ 'y': [y],
309
+ }
310
+
311
+ arg_null = {
312
+ 'context': context_null,
313
+ 'clip_fea': clip_context,
314
+ 'seq_len': max_seq_len,
315
+ 'y': [y],
316
+ }
317
+
318
+ if offload_model:
319
+ torch.cuda.empty_cache()
320
+
321
+ self.model.to(self.device)
322
+ for _, t in enumerate(tqdm(timesteps)):
323
+ latent_model_input = [latent.to(self.device)]
324
+ timestep = [t]
325
+
326
+ timestep = torch.stack(timestep).to(self.device)
327
+
328
+ noise_pred_cond = self.model(
329
+ latent_model_input, t=timestep, **arg_c)[0].to(
330
+ torch.device('cpu') if offload_model else self.device)
331
+ if offload_model:
332
+ torch.cuda.empty_cache()
333
+ noise_pred_uncond = self.model(
334
+ latent_model_input, t=timestep, **arg_null)[0].to(
335
+ torch.device('cpu') if offload_model else self.device)
336
+ if offload_model:
337
+ torch.cuda.empty_cache()
338
+ noise_pred = noise_pred_uncond + guide_scale * (
339
+ noise_pred_cond - noise_pred_uncond)
340
+
341
+ latent = latent.to(
342
+ torch.device('cpu') if offload_model else self.device)
343
+
344
+ temp_x0 = sample_scheduler.step(
345
+ noise_pred.unsqueeze(0),
346
+ t,
347
+ latent.unsqueeze(0),
348
+ return_dict=False,
349
+ generator=seed_g)[0]
350
+ latent = temp_x0.squeeze(0)
351
+
352
+ x0 = [latent.to(self.device)]
353
+ del latent_model_input, timestep
354
+
355
+ if offload_model:
356
+ self.model.cpu()
357
+ torch.cuda.empty_cache()
358
+
359
+ if self.rank == 0:
360
+ videos = self.vae.decode(x0)
361
+
362
+ del noise, latent
363
+ del sample_scheduler
364
+ if offload_model:
365
+ gc.collect()
366
+ torch.cuda.synchronize()
367
+ if dist.is_initialized():
368
+ dist.barrier()
369
+
370
+ return videos[0] if self.rank == 0 else None
wan/image2video.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ import torchvision.transforms.functional as TF
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .modules.clip import CLIPModel
21
+ from .modules.model import WanModel
22
+ from .modules.t5 import T5EncoderModel
23
+ from .modules.vae import WanVAE
24
+ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
+ get_sampling_sigmas, retrieve_timesteps)
26
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
+
28
+
29
+ class WanI2V:
30
+
31
+ def __init__(
32
+ self,
33
+ config,
34
+ checkpoint_dir,
35
+ device_id=0,
36
+ rank=0,
37
+ t5_fsdp=False,
38
+ dit_fsdp=False,
39
+ use_usp=False,
40
+ t5_cpu=False,
41
+ init_on_cpu=True,
42
+ ):
43
+ r"""
44
+ Initializes the image-to-video generation model components.
45
+
46
+ Args:
47
+ config (EasyDict):
48
+ Object containing model parameters initialized from config.py
49
+ checkpoint_dir (`str`):
50
+ Path to directory containing model checkpoints
51
+ device_id (`int`, *optional*, defaults to 0):
52
+ Id of target GPU device
53
+ rank (`int`, *optional*, defaults to 0):
54
+ Process rank for distributed training
55
+ t5_fsdp (`bool`, *optional*, defaults to False):
56
+ Enable FSDP sharding for T5 model
57
+ dit_fsdp (`bool`, *optional*, defaults to False):
58
+ Enable FSDP sharding for DiT model
59
+ use_usp (`bool`, *optional*, defaults to False):
60
+ Enable distribution strategy of USP.
61
+ t5_cpu (`bool`, *optional*, defaults to False):
62
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
63
+ init_on_cpu (`bool`, *optional*, defaults to True):
64
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
65
+ """
66
+ self.device = torch.device(f"cuda:{device_id}")
67
+ self.config = config
68
+ self.rank = rank
69
+ self.use_usp = use_usp
70
+ self.t5_cpu = t5_cpu
71
+
72
+ self.num_train_timesteps = config.num_train_timesteps
73
+ self.param_dtype = config.param_dtype
74
+
75
+ shard_fn = partial(shard_model, device_id=device_id)
76
+ self.text_encoder = T5EncoderModel(
77
+ text_len=config.text_len,
78
+ dtype=config.t5_dtype,
79
+ device=torch.device('cpu'),
80
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
81
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
82
+ shard_fn=shard_fn if t5_fsdp else None,
83
+ )
84
+
85
+ self.vae_stride = config.vae_stride
86
+ self.patch_size = config.patch_size
87
+ self.vae = WanVAE(
88
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
89
+ device=self.device)
90
+
91
+ self.clip = CLIPModel(
92
+ dtype=config.clip_dtype,
93
+ device=self.device,
94
+ checkpoint_path=os.path.join(checkpoint_dir,
95
+ config.clip_checkpoint),
96
+ tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
97
+
98
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
99
+ self.model = WanModel.from_pretrained(checkpoint_dir)
100
+ self.model.eval().requires_grad_(False)
101
+
102
+ if t5_fsdp or dit_fsdp or use_usp:
103
+ init_on_cpu = False
104
+
105
+ if use_usp:
106
+ from xfuser.core.distributed import \
107
+ get_sequence_parallel_world_size
108
+
109
+ from .distributed.xdit_context_parallel import (usp_attn_forward,
110
+ usp_dit_forward)
111
+ for block in self.model.blocks:
112
+ block.self_attn.forward = types.MethodType(
113
+ usp_attn_forward, block.self_attn)
114
+ self.model.forward = types.MethodType(usp_dit_forward, self.model)
115
+ self.sp_size = get_sequence_parallel_world_size()
116
+ else:
117
+ self.sp_size = 1
118
+
119
+ if dist.is_initialized():
120
+ dist.barrier()
121
+ if dit_fsdp:
122
+ self.model = shard_fn(self.model)
123
+ else:
124
+ if not init_on_cpu:
125
+ self.model.to(self.device)
126
+
127
+ self.sample_neg_prompt = config.sample_neg_prompt
128
+
129
+ def generate(self,
130
+ input_prompt,
131
+ img,
132
+ max_area=720 * 1280,
133
+ frame_num=81,
134
+ shift=5.0,
135
+ sample_solver='unipc',
136
+ sampling_steps=40,
137
+ guide_scale=5.0,
138
+ n_prompt="",
139
+ seed=-1,
140
+ offload_model=True):
141
+ r"""
142
+ Generates video frames from input image and text prompt using diffusion process.
143
+
144
+ Args:
145
+ input_prompt (`str`):
146
+ Text prompt for content generation.
147
+ img (PIL.Image.Image):
148
+ Input image tensor. Shape: [3, H, W]
149
+ max_area (`int`, *optional*, defaults to 720*1280):
150
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
151
+ frame_num (`int`, *optional*, defaults to 81):
152
+ How many frames to sample from a video. The number should be 4n+1
153
+ shift (`float`, *optional*, defaults to 5.0):
154
+ Noise schedule shift parameter. Affects temporal dynamics
155
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
156
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
157
+ Solver used to sample the video.
158
+ sampling_steps (`int`, *optional*, defaults to 40):
159
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
160
+ guide_scale (`float`, *optional*, defaults 5.0):
161
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity
162
+ n_prompt (`str`, *optional*, defaults to ""):
163
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
164
+ seed (`int`, *optional*, defaults to -1):
165
+ Random seed for noise generation. If -1, use random seed
166
+ offload_model (`bool`, *optional*, defaults to True):
167
+ If True, offloads models to CPU during generation to save VRAM
168
+
169
+ Returns:
170
+ torch.Tensor:
171
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
172
+ - C: Color channels (3 for RGB)
173
+ - N: Number of frames (81)
174
+ - H: Frame height (from max_area)
175
+ - W: Frame width from max_area)
176
+ """
177
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
178
+
179
+ F = frame_num
180
+ h, w = img.shape[1:]
181
+ aspect_ratio = h / w
182
+ lat_h = round(
183
+ np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
184
+ self.patch_size[1] * self.patch_size[1])
185
+ lat_w = round(
186
+ np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
187
+ self.patch_size[2] * self.patch_size[2])
188
+ h = lat_h * self.vae_stride[1]
189
+ w = lat_w * self.vae_stride[2]
190
+
191
+ max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
192
+ self.patch_size[1] * self.patch_size[2])
193
+ max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
194
+
195
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
196
+ seed_g = torch.Generator(device=self.device)
197
+ seed_g.manual_seed(seed)
198
+ noise = torch.randn(
199
+ 16,
200
+ (F - 1) // 4 + 1,
201
+ lat_h,
202
+ lat_w,
203
+ dtype=torch.float32,
204
+ generator=seed_g,
205
+ device=self.device)
206
+
207
+ msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
208
+ msk[:, 1:] = 0
209
+ msk = torch.concat([
210
+ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
211
+ ],
212
+ dim=1)
213
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
214
+ msk = msk.transpose(1, 2)[0]
215
+
216
+ if n_prompt == "":
217
+ n_prompt = self.sample_neg_prompt
218
+
219
+ # preprocess
220
+ if not self.t5_cpu:
221
+ self.text_encoder.model.to(self.device)
222
+ context = self.text_encoder([input_prompt], self.device)
223
+ context_null = self.text_encoder([n_prompt], self.device)
224
+ if offload_model:
225
+ self.text_encoder.model.cpu()
226
+ else:
227
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
228
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
229
+ context = [t.to(self.device) for t in context]
230
+ context_null = [t.to(self.device) for t in context_null]
231
+
232
+ self.clip.model.to(self.device)
233
+ clip_context = self.clip.visual([img[:, None, :, :]])
234
+ if offload_model:
235
+ self.clip.model.cpu()
236
+
237
+ y = self.vae.encode([
238
+ torch.concat([
239
+ torch.nn.functional.interpolate(
240
+ img[None].cpu(), size=(h, w), mode='bicubic').transpose(
241
+ 0, 1),
242
+ torch.zeros(3, F - 1, h, w)
243
+ ],
244
+ dim=1).to(self.device)
245
+ ])[0]
246
+ y = torch.concat([msk, y])
247
+
248
+ @contextmanager
249
+ def noop_no_sync():
250
+ yield
251
+
252
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
253
+
254
+ # evaluation mode
255
+ with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
256
+
257
+ if sample_solver == 'unipc':
258
+ sample_scheduler = FlowUniPCMultistepScheduler(
259
+ num_train_timesteps=self.num_train_timesteps,
260
+ shift=1,
261
+ use_dynamic_shifting=False)
262
+ sample_scheduler.set_timesteps(
263
+ sampling_steps, device=self.device, shift=shift)
264
+ timesteps = sample_scheduler.timesteps
265
+ elif sample_solver == 'dpm++':
266
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
267
+ num_train_timesteps=self.num_train_timesteps,
268
+ shift=1,
269
+ use_dynamic_shifting=False)
270
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
271
+ timesteps, _ = retrieve_timesteps(
272
+ sample_scheduler,
273
+ device=self.device,
274
+ sigmas=sampling_sigmas)
275
+ else:
276
+ raise NotImplementedError("Unsupported solver.")
277
+
278
+ # sample videos
279
+ latent = noise
280
+
281
+ arg_c = {
282
+ 'context': [context[0]],
283
+ 'clip_fea': clip_context,
284
+ 'seq_len': max_seq_len,
285
+ 'y': [y],
286
+ }
287
+
288
+ arg_null = {
289
+ 'context': context_null,
290
+ 'clip_fea': clip_context,
291
+ 'seq_len': max_seq_len,
292
+ 'y': [y],
293
+ }
294
+
295
+ if offload_model:
296
+ torch.cuda.empty_cache()
297
+
298
+ self.model.to(self.device)
299
+ for _, t in enumerate(tqdm(timesteps)):
300
+ latent_model_input = [latent.to(self.device)]
301
+ timestep = [t]
302
+
303
+ timestep = torch.stack(timestep).to(self.device)
304
+
305
+ noise_pred_cond = self.model(
306
+ latent_model_input, t=timestep, **arg_c)[0].to(
307
+ torch.device('cpu') if offload_model else self.device)
308
+ if offload_model:
309
+ torch.cuda.empty_cache()
310
+ noise_pred_uncond = self.model(
311
+ latent_model_input, t=timestep, **arg_null)[0].to(
312
+ torch.device('cpu') if offload_model else self.device)
313
+ if offload_model:
314
+ torch.cuda.empty_cache()
315
+ noise_pred = noise_pred_uncond + guide_scale * (
316
+ noise_pred_cond - noise_pred_uncond)
317
+
318
+ latent = latent.to(
319
+ torch.device('cpu') if offload_model else self.device)
320
+
321
+ temp_x0 = sample_scheduler.step(
322
+ noise_pred.unsqueeze(0),
323
+ t,
324
+ latent.unsqueeze(0),
325
+ return_dict=False,
326
+ generator=seed_g)[0]
327
+ latent = temp_x0.squeeze(0)
328
+
329
+ x0 = [latent.to(self.device)]
330
+ del latent_model_input, timestep
331
+
332
+ if offload_model:
333
+ self.model.cpu()
334
+ torch.cuda.empty_cache()
335
+
336
+ if self.rank == 0:
337
+ videos = self.vae.decode(x0)
338
+
339
+ del noise, latent
340
+ del sample_scheduler
341
+ if offload_model:
342
+ gc.collect()
343
+ torch.cuda.synchronize()
344
+ if dist.is_initialized():
345
+ dist.barrier()
346
+
347
+ return videos[0] if self.rank == 0 else None
wan/modules/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import flash_attention
2
+ from .model import WanModel
3
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
4
+ from .tokenizers import HuggingfaceTokenizer
5
+ from .vace_model import VaceWanModel
6
+ from .vae import WanVAE
7
+
8
+ __all__ = [
9
+ 'WanVAE',
10
+ 'WanModel',
11
+ 'VaceWanModel',
12
+ 'T5Model',
13
+ 'T5Encoder',
14
+ 'T5Decoder',
15
+ 'T5EncoderModel',
16
+ 'HuggingfaceTokenizer',
17
+ 'flash_attention',
18
+ ]
wan/modules/attention.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+ FLASH_ATTN_3_AVAILABLE = True
7
+ except ModuleNotFoundError:
8
+ FLASH_ATTN_3_AVAILABLE = False
9
+
10
+ try:
11
+ import flash_attn
12
+ FLASH_ATTN_2_AVAILABLE = True
13
+ except ModuleNotFoundError:
14
+ FLASH_ATTN_2_AVAILABLE = False
15
+
16
+ import warnings
17
+
18
+ __all__ = [
19
+ 'flash_attention',
20
+ 'attention',
21
+ ]
22
+
23
+
24
+ def flash_attention(
25
+ q,
26
+ k,
27
+ v,
28
+ q_lens=None,
29
+ k_lens=None,
30
+ dropout_p=0.,
31
+ softmax_scale=None,
32
+ q_scale=None,
33
+ causal=False,
34
+ window_size=(-1, -1),
35
+ deterministic=False,
36
+ dtype=torch.bfloat16,
37
+ version=None,
38
+ ):
39
+ """
40
+ q: [B, Lq, Nq, C1].
41
+ k: [B, Lk, Nk, C1].
42
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
43
+ q_lens: [B].
44
+ k_lens: [B].
45
+ dropout_p: float. Dropout probability.
46
+ softmax_scale: float. The scaling of QK^T before applying softmax.
47
+ causal: bool. Whether to apply causal attention mask.
48
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
49
+ deterministic: bool. If True, slightly slower and uses more memory.
50
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
51
+ """
52
+ half_dtypes = (torch.float16, torch.bfloat16)
53
+ assert dtype in half_dtypes
54
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
55
+
56
+ # params
57
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
58
+
59
+ def half(x):
60
+ return x if x.dtype in half_dtypes else x.to(dtype)
61
+
62
+ # preprocess query
63
+ if q_lens is None:
64
+ q = half(q.flatten(0, 1))
65
+ q_lens = torch.tensor(
66
+ [lq] * b, dtype=torch.int32).to(
67
+ device=q.device, non_blocking=True)
68
+ else:
69
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
70
+
71
+ # preprocess key, value
72
+ if k_lens is None:
73
+ k = half(k.flatten(0, 1))
74
+ v = half(v.flatten(0, 1))
75
+ k_lens = torch.tensor(
76
+ [lk] * b, dtype=torch.int32).to(
77
+ device=k.device, non_blocking=True)
78
+ else:
79
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
80
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
81
+
82
+ q = q.to(v.dtype)
83
+ k = k.to(v.dtype)
84
+
85
+ if q_scale is not None:
86
+ q = q * q_scale
87
+
88
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
89
+ warnings.warn(
90
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
91
+ )
92
+
93
+ # apply attention
94
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
95
+ # Note: dropout_p, window_size are not supported in FA3 now.
96
+ x = flash_attn_interface.flash_attn_varlen_func(
97
+ q=q,
98
+ k=k,
99
+ v=v,
100
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
101
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
102
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
103
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
104
+ seqused_q=None,
105
+ seqused_k=None,
106
+ max_seqlen_q=lq,
107
+ max_seqlen_k=lk,
108
+ softmax_scale=softmax_scale,
109
+ causal=causal,
110
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
111
+ else:
112
+ assert FLASH_ATTN_2_AVAILABLE
113
+ x = flash_attn.flash_attn_varlen_func(
114
+ q=q,
115
+ k=k,
116
+ v=v,
117
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
118
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
119
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
120
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
121
+ max_seqlen_q=lq,
122
+ max_seqlen_k=lk,
123
+ dropout_p=dropout_p,
124
+ softmax_scale=softmax_scale,
125
+ causal=causal,
126
+ window_size=window_size,
127
+ deterministic=deterministic).unflatten(0, (b, lq))
128
+
129
+ # output
130
+ return x.type(out_dtype)
131
+
132
+
133
+ def attention(
134
+ q,
135
+ k,
136
+ v,
137
+ q_lens=None,
138
+ k_lens=None,
139
+ dropout_p=0.,
140
+ softmax_scale=None,
141
+ q_scale=None,
142
+ causal=False,
143
+ window_size=(-1, -1),
144
+ deterministic=False,
145
+ dtype=torch.bfloat16,
146
+ fa_version=None,
147
+ ):
148
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
149
+ return flash_attention(
150
+ q=q,
151
+ k=k,
152
+ v=v,
153
+ q_lens=q_lens,
154
+ k_lens=k_lens,
155
+ dropout_p=dropout_p,
156
+ softmax_scale=softmax_scale,
157
+ q_scale=q_scale,
158
+ causal=causal,
159
+ window_size=window_size,
160
+ deterministic=deterministic,
161
+ dtype=dtype,
162
+ version=fa_version,
163
+ )
164
+ else:
165
+ if q_lens is not None or k_lens is not None:
166
+ warnings.warn(
167
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
168
+ )
169
+ attn_mask = None
170
+
171
+ q = q.transpose(1, 2).to(dtype)
172
+ k = k.transpose(1, 2).to(dtype)
173
+ v = v.transpose(1, 2).to(dtype)
174
+
175
+ out = torch.nn.functional.scaled_dot_product_attention(
176
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
177
+
178
+ out = out.transpose(1, 2).contiguous()
179
+ return out
wan/modules/clip.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+
11
+ from .attention import flash_attention
12
+ from .tokenizers import HuggingfaceTokenizer
13
+ from .xlm_roberta import XLMRoberta
14
+
15
+ __all__ = [
16
+ 'XLMRobertaCLIP',
17
+ 'clip_xlm_roberta_vit_h_14',
18
+ 'CLIPModel',
19
+ ]
20
+
21
+
22
+ def pos_interpolate(pos, seq_len):
23
+ if pos.size(1) == seq_len:
24
+ return pos
25
+ else:
26
+ src_grid = int(math.sqrt(pos.size(1)))
27
+ tar_grid = int(math.sqrt(seq_len))
28
+ n = pos.size(1) - src_grid * src_grid
29
+ return torch.cat([
30
+ pos[:, :n],
31
+ F.interpolate(
32
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
33
+ 0, 3, 1, 2),
34
+ size=(tar_grid, tar_grid),
35
+ mode='bicubic',
36
+ align_corners=False).flatten(2).transpose(1, 2)
37
+ ],
38
+ dim=1)
39
+
40
+
41
+ class QuickGELU(nn.Module):
42
+
43
+ def forward(self, x):
44
+ return x * torch.sigmoid(1.702 * x)
45
+
46
+
47
+ class LayerNorm(nn.LayerNorm):
48
+
49
+ def forward(self, x):
50
+ return super().forward(x.float()).type_as(x)
51
+
52
+
53
+ class SelfAttention(nn.Module):
54
+
55
+ def __init__(self,
56
+ dim,
57
+ num_heads,
58
+ causal=False,
59
+ attn_dropout=0.0,
60
+ proj_dropout=0.0):
61
+ assert dim % num_heads == 0
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = dim // num_heads
66
+ self.causal = causal
67
+ self.attn_dropout = attn_dropout
68
+ self.proj_dropout = proj_dropout
69
+
70
+ # layers
71
+ self.to_qkv = nn.Linear(dim, dim * 3)
72
+ self.proj = nn.Linear(dim, dim)
73
+
74
+ def forward(self, x):
75
+ """
76
+ x: [B, L, C].
77
+ """
78
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
79
+
80
+ # compute query, key, value
81
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
82
+
83
+ # compute attention
84
+ p = self.attn_dropout if self.training else 0.0
85
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
86
+ x = x.reshape(b, s, c)
87
+
88
+ # output
89
+ x = self.proj(x)
90
+ x = F.dropout(x, self.proj_dropout, self.training)
91
+ return x
92
+
93
+
94
+ class SwiGLU(nn.Module):
95
+
96
+ def __init__(self, dim, mid_dim):
97
+ super().__init__()
98
+ self.dim = dim
99
+ self.mid_dim = mid_dim
100
+
101
+ # layers
102
+ self.fc1 = nn.Linear(dim, mid_dim)
103
+ self.fc2 = nn.Linear(dim, mid_dim)
104
+ self.fc3 = nn.Linear(mid_dim, dim)
105
+
106
+ def forward(self, x):
107
+ x = F.silu(self.fc1(x)) * self.fc2(x)
108
+ x = self.fc3(x)
109
+ return x
110
+
111
+
112
+ class AttentionBlock(nn.Module):
113
+
114
+ def __init__(self,
115
+ dim,
116
+ mlp_ratio,
117
+ num_heads,
118
+ post_norm=False,
119
+ causal=False,
120
+ activation='quick_gelu',
121
+ attn_dropout=0.0,
122
+ proj_dropout=0.0,
123
+ norm_eps=1e-5):
124
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
125
+ super().__init__()
126
+ self.dim = dim
127
+ self.mlp_ratio = mlp_ratio
128
+ self.num_heads = num_heads
129
+ self.post_norm = post_norm
130
+ self.causal = causal
131
+ self.norm_eps = norm_eps
132
+
133
+ # layers
134
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
135
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
136
+ proj_dropout)
137
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
138
+ if activation == 'swi_glu':
139
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
140
+ else:
141
+ self.mlp = nn.Sequential(
142
+ nn.Linear(dim, int(dim * mlp_ratio)),
143
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
144
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
145
+
146
+ def forward(self, x):
147
+ if self.post_norm:
148
+ x = x + self.norm1(self.attn(x))
149
+ x = x + self.norm2(self.mlp(x))
150
+ else:
151
+ x = x + self.attn(self.norm1(x))
152
+ x = x + self.mlp(self.norm2(x))
153
+ return x
154
+
155
+
156
+ class AttentionPool(nn.Module):
157
+
158
+ def __init__(self,
159
+ dim,
160
+ mlp_ratio,
161
+ num_heads,
162
+ activation='gelu',
163
+ proj_dropout=0.0,
164
+ norm_eps=1e-5):
165
+ assert dim % num_heads == 0
166
+ super().__init__()
167
+ self.dim = dim
168
+ self.mlp_ratio = mlp_ratio
169
+ self.num_heads = num_heads
170
+ self.head_dim = dim // num_heads
171
+ self.proj_dropout = proj_dropout
172
+ self.norm_eps = norm_eps
173
+
174
+ # layers
175
+ gain = 1.0 / math.sqrt(dim)
176
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
177
+ self.to_q = nn.Linear(dim, dim)
178
+ self.to_kv = nn.Linear(dim, dim * 2)
179
+ self.proj = nn.Linear(dim, dim)
180
+ self.norm = LayerNorm(dim, eps=norm_eps)
181
+ self.mlp = nn.Sequential(
182
+ nn.Linear(dim, int(dim * mlp_ratio)),
183
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
184
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
185
+
186
+ def forward(self, x):
187
+ """
188
+ x: [B, L, C].
189
+ """
190
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
191
+
192
+ # compute query, key, value
193
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
194
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
195
+
196
+ # compute attention
197
+ x = flash_attention(q, k, v, version=2)
198
+ x = x.reshape(b, 1, c)
199
+
200
+ # output
201
+ x = self.proj(x)
202
+ x = F.dropout(x, self.proj_dropout, self.training)
203
+
204
+ # mlp
205
+ x = x + self.mlp(self.norm(x))
206
+ return x[:, 0]
207
+
208
+
209
+ class VisionTransformer(nn.Module):
210
+
211
+ def __init__(self,
212
+ image_size=224,
213
+ patch_size=16,
214
+ dim=768,
215
+ mlp_ratio=4,
216
+ out_dim=512,
217
+ num_heads=12,
218
+ num_layers=12,
219
+ pool_type='token',
220
+ pre_norm=True,
221
+ post_norm=False,
222
+ activation='quick_gelu',
223
+ attn_dropout=0.0,
224
+ proj_dropout=0.0,
225
+ embedding_dropout=0.0,
226
+ norm_eps=1e-5):
227
+ if image_size % patch_size != 0:
228
+ print(
229
+ '[WARNING] image_size is not divisible by patch_size',
230
+ flush=True)
231
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
232
+ out_dim = out_dim or dim
233
+ super().__init__()
234
+ self.image_size = image_size
235
+ self.patch_size = patch_size
236
+ self.num_patches = (image_size // patch_size)**2
237
+ self.dim = dim
238
+ self.mlp_ratio = mlp_ratio
239
+ self.out_dim = out_dim
240
+ self.num_heads = num_heads
241
+ self.num_layers = num_layers
242
+ self.pool_type = pool_type
243
+ self.post_norm = post_norm
244
+ self.norm_eps = norm_eps
245
+
246
+ # embeddings
247
+ gain = 1.0 / math.sqrt(dim)
248
+ self.patch_embedding = nn.Conv2d(
249
+ 3,
250
+ dim,
251
+ kernel_size=patch_size,
252
+ stride=patch_size,
253
+ bias=not pre_norm)
254
+ if pool_type in ('token', 'token_fc'):
255
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
256
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
257
+ 1, self.num_patches +
258
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
259
+ self.dropout = nn.Dropout(embedding_dropout)
260
+
261
+ # transformer
262
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
263
+ self.transformer = nn.Sequential(*[
264
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
265
+ activation, attn_dropout, proj_dropout, norm_eps)
266
+ for _ in range(num_layers)
267
+ ])
268
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
269
+
270
+ # head
271
+ if pool_type == 'token':
272
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
273
+ elif pool_type == 'token_fc':
274
+ self.head = nn.Linear(dim, out_dim)
275
+ elif pool_type == 'attn_pool':
276
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
277
+ proj_dropout, norm_eps)
278
+
279
+ def forward(self, x, interpolation=False, use_31_block=False):
280
+ b = x.size(0)
281
+
282
+ # embeddings
283
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
284
+ if self.pool_type in ('token', 'token_fc'):
285
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
286
+ if interpolation:
287
+ e = pos_interpolate(self.pos_embedding, x.size(1))
288
+ else:
289
+ e = self.pos_embedding
290
+ x = self.dropout(x + e)
291
+ if self.pre_norm is not None:
292
+ x = self.pre_norm(x)
293
+
294
+ # transformer
295
+ if use_31_block:
296
+ x = self.transformer[:-1](x)
297
+ return x
298
+ else:
299
+ x = self.transformer(x)
300
+ return x
301
+
302
+
303
+ class XLMRobertaWithHead(XLMRoberta):
304
+
305
+ def __init__(self, **kwargs):
306
+ self.out_dim = kwargs.pop('out_dim')
307
+ super().__init__(**kwargs)
308
+
309
+ # head
310
+ mid_dim = (self.dim + self.out_dim) // 2
311
+ self.head = nn.Sequential(
312
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
313
+ nn.Linear(mid_dim, self.out_dim, bias=False))
314
+
315
+ def forward(self, ids):
316
+ # xlm-roberta
317
+ x = super().forward(ids)
318
+
319
+ # average pooling
320
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
321
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
322
+
323
+ # head
324
+ x = self.head(x)
325
+ return x
326
+
327
+
328
+ class XLMRobertaCLIP(nn.Module):
329
+
330
+ def __init__(self,
331
+ embed_dim=1024,
332
+ image_size=224,
333
+ patch_size=14,
334
+ vision_dim=1280,
335
+ vision_mlp_ratio=4,
336
+ vision_heads=16,
337
+ vision_layers=32,
338
+ vision_pool='token',
339
+ vision_pre_norm=True,
340
+ vision_post_norm=False,
341
+ activation='gelu',
342
+ vocab_size=250002,
343
+ max_text_len=514,
344
+ type_size=1,
345
+ pad_id=1,
346
+ text_dim=1024,
347
+ text_heads=16,
348
+ text_layers=24,
349
+ text_post_norm=True,
350
+ text_dropout=0.1,
351
+ attn_dropout=0.0,
352
+ proj_dropout=0.0,
353
+ embedding_dropout=0.0,
354
+ norm_eps=1e-5):
355
+ super().__init__()
356
+ self.embed_dim = embed_dim
357
+ self.image_size = image_size
358
+ self.patch_size = patch_size
359
+ self.vision_dim = vision_dim
360
+ self.vision_mlp_ratio = vision_mlp_ratio
361
+ self.vision_heads = vision_heads
362
+ self.vision_layers = vision_layers
363
+ self.vision_pre_norm = vision_pre_norm
364
+ self.vision_post_norm = vision_post_norm
365
+ self.activation = activation
366
+ self.vocab_size = vocab_size
367
+ self.max_text_len = max_text_len
368
+ self.type_size = type_size
369
+ self.pad_id = pad_id
370
+ self.text_dim = text_dim
371
+ self.text_heads = text_heads
372
+ self.text_layers = text_layers
373
+ self.text_post_norm = text_post_norm
374
+ self.norm_eps = norm_eps
375
+
376
+ # models
377
+ self.visual = VisionTransformer(
378
+ image_size=image_size,
379
+ patch_size=patch_size,
380
+ dim=vision_dim,
381
+ mlp_ratio=vision_mlp_ratio,
382
+ out_dim=embed_dim,
383
+ num_heads=vision_heads,
384
+ num_layers=vision_layers,
385
+ pool_type=vision_pool,
386
+ pre_norm=vision_pre_norm,
387
+ post_norm=vision_post_norm,
388
+ activation=activation,
389
+ attn_dropout=attn_dropout,
390
+ proj_dropout=proj_dropout,
391
+ embedding_dropout=embedding_dropout,
392
+ norm_eps=norm_eps)
393
+ self.textual = XLMRobertaWithHead(
394
+ vocab_size=vocab_size,
395
+ max_seq_len=max_text_len,
396
+ type_size=type_size,
397
+ pad_id=pad_id,
398
+ dim=text_dim,
399
+ out_dim=embed_dim,
400
+ num_heads=text_heads,
401
+ num_layers=text_layers,
402
+ post_norm=text_post_norm,
403
+ dropout=text_dropout)
404
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
405
+
406
+ def forward(self, imgs, txt_ids):
407
+ """
408
+ imgs: [B, 3, H, W] of torch.float32.
409
+ - mean: [0.48145466, 0.4578275, 0.40821073]
410
+ - std: [0.26862954, 0.26130258, 0.27577711]
411
+ txt_ids: [B, L] of torch.long.
412
+ Encoded by data.CLIPTokenizer.
413
+ """
414
+ xi = self.visual(imgs)
415
+ xt = self.textual(txt_ids)
416
+ return xi, xt
417
+
418
+ def param_groups(self):
419
+ groups = [{
420
+ 'params': [
421
+ p for n, p in self.named_parameters()
422
+ if 'norm' in n or n.endswith('bias')
423
+ ],
424
+ 'weight_decay': 0.0
425
+ }, {
426
+ 'params': [
427
+ p for n, p in self.named_parameters()
428
+ if not ('norm' in n or n.endswith('bias'))
429
+ ]
430
+ }]
431
+ return groups
432
+
433
+
434
+ def _clip(pretrained=False,
435
+ pretrained_name=None,
436
+ model_cls=XLMRobertaCLIP,
437
+ return_transforms=False,
438
+ return_tokenizer=False,
439
+ tokenizer_padding='eos',
440
+ dtype=torch.float32,
441
+ device='cpu',
442
+ **kwargs):
443
+ # init a model on device
444
+ with torch.device(device):
445
+ model = model_cls(**kwargs)
446
+
447
+ # set device
448
+ model = model.to(dtype=dtype, device=device)
449
+ output = (model,)
450
+
451
+ # init transforms
452
+ if return_transforms:
453
+ # mean and std
454
+ if 'siglip' in pretrained_name.lower():
455
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
456
+ else:
457
+ mean = [0.48145466, 0.4578275, 0.40821073]
458
+ std = [0.26862954, 0.26130258, 0.27577711]
459
+
460
+ # transforms
461
+ transforms = T.Compose([
462
+ T.Resize((model.image_size, model.image_size),
463
+ interpolation=T.InterpolationMode.BICUBIC),
464
+ T.ToTensor(),
465
+ T.Normalize(mean=mean, std=std)
466
+ ])
467
+ output += (transforms,)
468
+ return output[0] if len(output) == 1 else output
469
+
470
+
471
+ def clip_xlm_roberta_vit_h_14(
472
+ pretrained=False,
473
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
474
+ **kwargs):
475
+ cfg = dict(
476
+ embed_dim=1024,
477
+ image_size=224,
478
+ patch_size=14,
479
+ vision_dim=1280,
480
+ vision_mlp_ratio=4,
481
+ vision_heads=16,
482
+ vision_layers=32,
483
+ vision_pool='token',
484
+ activation='gelu',
485
+ vocab_size=250002,
486
+ max_text_len=514,
487
+ type_size=1,
488
+ pad_id=1,
489
+ text_dim=1024,
490
+ text_heads=16,
491
+ text_layers=24,
492
+ text_post_norm=True,
493
+ text_dropout=0.1,
494
+ attn_dropout=0.0,
495
+ proj_dropout=0.0,
496
+ embedding_dropout=0.0)
497
+ cfg.update(**kwargs)
498
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
499
+
500
+
501
+ class CLIPModel:
502
+
503
+ def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
504
+ self.dtype = dtype
505
+ self.device = device
506
+ self.checkpoint_path = checkpoint_path
507
+ self.tokenizer_path = tokenizer_path
508
+
509
+ # init model
510
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
511
+ pretrained=False,
512
+ return_transforms=True,
513
+ return_tokenizer=False,
514
+ dtype=dtype,
515
+ device=device)
516
+ self.model = self.model.eval().requires_grad_(False)
517
+ logging.info(f'loading {checkpoint_path}')
518
+ self.model.load_state_dict(
519
+ torch.load(checkpoint_path, map_location='cpu'))
520
+
521
+ # init tokenizer
522
+ self.tokenizer = HuggingfaceTokenizer(
523
+ name=tokenizer_path,
524
+ seq_len=self.model.max_text_len - 2,
525
+ clean='whitespace')
526
+
527
+ def visual(self, videos):
528
+ # preprocess
529
+ size = (self.model.image_size,) * 2
530
+ videos = torch.cat([
531
+ F.interpolate(
532
+ u.transpose(0, 1),
533
+ size=size,
534
+ mode='bicubic',
535
+ align_corners=False) for u in videos
536
+ ])
537
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
538
+
539
+ # forward
540
+ with torch.cuda.amp.autocast(dtype=self.dtype):
541
+ out = self.model.visual(videos, use_31_block=True)
542
+ return out
wan/modules/model.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ import torch.cuda.amp as amp
6
+ import torch.nn as nn
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+
10
+ from .attention import flash_attention
11
+
12
+ __all__ = ['WanModel']
13
+
14
+ T5_CONTEXT_TOKEN_NUMBER = 512
15
+ FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2
16
+
17
+
18
+ def sinusoidal_embedding_1d(dim, position):
19
+ # preprocess
20
+ assert dim % 2 == 0
21
+ half = dim // 2
22
+ position = position.type(torch.float64)
23
+
24
+ # calculation
25
+ sinusoid = torch.outer(
26
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
27
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
28
+ return x
29
+
30
+
31
+ @amp.autocast(enabled=False)
32
+ def rope_params(max_seq_len, dim, theta=10000):
33
+ assert dim % 2 == 0
34
+ freqs = torch.outer(
35
+ torch.arange(max_seq_len),
36
+ 1.0 / torch.pow(theta,
37
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
38
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
39
+ return freqs
40
+
41
+
42
+ @amp.autocast(enabled=False)
43
+ def rope_apply(x, grid_sizes, freqs):
44
+ n, c = x.size(2), x.size(3) // 2
45
+
46
+ # split freqs
47
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
48
+
49
+ # loop over samples
50
+ output = []
51
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
52
+ seq_len = f * h * w
53
+
54
+ # precompute multipliers
55
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
56
+ seq_len, n, -1, 2))
57
+ freqs_i = torch.cat([
58
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
59
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
60
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
61
+ ],
62
+ dim=-1).reshape(seq_len, 1, -1)
63
+
64
+ # apply rotary embedding
65
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
66
+ x_i = torch.cat([x_i, x[i, seq_len:]])
67
+
68
+ # append to collection
69
+ output.append(x_i)
70
+ return torch.stack(output).float()
71
+
72
+
73
+ class WanRMSNorm(nn.Module):
74
+
75
+ def __init__(self, dim, eps=1e-5):
76
+ super().__init__()
77
+ self.dim = dim
78
+ self.eps = eps
79
+ self.weight = nn.Parameter(torch.ones(dim))
80
+
81
+ def forward(self, x):
82
+ r"""
83
+ Args:
84
+ x(Tensor): Shape [B, L, C]
85
+ """
86
+ return self._norm(x.float()).type_as(x) * self.weight
87
+
88
+ def _norm(self, x):
89
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
90
+
91
+
92
+ class WanLayerNorm(nn.LayerNorm):
93
+
94
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
95
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
96
+
97
+ def forward(self, x):
98
+ r"""
99
+ Args:
100
+ x(Tensor): Shape [B, L, C]
101
+ """
102
+ return super().forward(x.float()).type_as(x)
103
+
104
+
105
+ class WanSelfAttention(nn.Module):
106
+
107
+ def __init__(self,
108
+ dim,
109
+ num_heads,
110
+ window_size=(-1, -1),
111
+ qk_norm=True,
112
+ eps=1e-6):
113
+ assert dim % num_heads == 0
114
+ super().__init__()
115
+ self.dim = dim
116
+ self.num_heads = num_heads
117
+ self.head_dim = dim // num_heads
118
+ self.window_size = window_size
119
+ self.qk_norm = qk_norm
120
+ self.eps = eps
121
+
122
+ # layers
123
+ self.q = nn.Linear(dim, dim)
124
+ self.k = nn.Linear(dim, dim)
125
+ self.v = nn.Linear(dim, dim)
126
+ self.o = nn.Linear(dim, dim)
127
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
128
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
129
+
130
+ def forward(self, x, seq_lens, grid_sizes, freqs):
131
+ r"""
132
+ Args:
133
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
134
+ seq_lens(Tensor): Shape [B]
135
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
136
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
137
+ """
138
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
139
+
140
+ # query, key, value function
141
+ def qkv_fn(x):
142
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
143
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
144
+ v = self.v(x).view(b, s, n, d)
145
+ return q, k, v
146
+
147
+ q, k, v = qkv_fn(x)
148
+
149
+ x = flash_attention(
150
+ q=rope_apply(q, grid_sizes, freqs),
151
+ k=rope_apply(k, grid_sizes, freqs),
152
+ v=v,
153
+ k_lens=seq_lens,
154
+ window_size=self.window_size)
155
+
156
+ # output
157
+ x = x.flatten(2)
158
+ x = self.o(x)
159
+ return x
160
+
161
+
162
+ class WanT2VCrossAttention(WanSelfAttention):
163
+
164
+ def forward(self, x, context, context_lens):
165
+ r"""
166
+ Args:
167
+ x(Tensor): Shape [B, L1, C]
168
+ context(Tensor): Shape [B, L2, C]
169
+ context_lens(Tensor): Shape [B]
170
+ """
171
+ b, n, d = x.size(0), self.num_heads, self.head_dim
172
+
173
+ # compute query, key, value
174
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
175
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
176
+ v = self.v(context).view(b, -1, n, d)
177
+
178
+ # compute attention
179
+ x = flash_attention(q, k, v, k_lens=context_lens)
180
+
181
+ # output
182
+ x = x.flatten(2)
183
+ x = self.o(x)
184
+ return x
185
+
186
+
187
+ class WanI2VCrossAttention(WanSelfAttention):
188
+
189
+ def __init__(self,
190
+ dim,
191
+ num_heads,
192
+ window_size=(-1, -1),
193
+ qk_norm=True,
194
+ eps=1e-6):
195
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
196
+
197
+ self.k_img = nn.Linear(dim, dim)
198
+ self.v_img = nn.Linear(dim, dim)
199
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
200
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
201
+
202
+ def forward(self, x, context, context_lens):
203
+ r"""
204
+ Args:
205
+ x(Tensor): Shape [B, L1, C]
206
+ context(Tensor): Shape [B, L2, C]
207
+ context_lens(Tensor): Shape [B]
208
+ """
209
+ image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER
210
+ context_img = context[:, :image_context_length]
211
+ context = context[:, image_context_length:]
212
+ b, n, d = x.size(0), self.num_heads, self.head_dim
213
+
214
+ # compute query, key, value
215
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
216
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
217
+ v = self.v(context).view(b, -1, n, d)
218
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
219
+ v_img = self.v_img(context_img).view(b, -1, n, d)
220
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
221
+ # compute attention
222
+ x = flash_attention(q, k, v, k_lens=context_lens)
223
+
224
+ # output
225
+ x = x.flatten(2)
226
+ img_x = img_x.flatten(2)
227
+ x = x + img_x
228
+ x = self.o(x)
229
+ return x
230
+
231
+
232
+ WAN_CROSSATTENTION_CLASSES = {
233
+ 't2v_cross_attn': WanT2VCrossAttention,
234
+ 'i2v_cross_attn': WanI2VCrossAttention,
235
+ }
236
+
237
+
238
+ class WanAttentionBlock(nn.Module):
239
+
240
+ def __init__(self,
241
+ cross_attn_type,
242
+ dim,
243
+ ffn_dim,
244
+ num_heads,
245
+ window_size=(-1, -1),
246
+ qk_norm=True,
247
+ cross_attn_norm=False,
248
+ eps=1e-6):
249
+ super().__init__()
250
+ self.dim = dim
251
+ self.ffn_dim = ffn_dim
252
+ self.num_heads = num_heads
253
+ self.window_size = window_size
254
+ self.qk_norm = qk_norm
255
+ self.cross_attn_norm = cross_attn_norm
256
+ self.eps = eps
257
+
258
+ # layers
259
+ self.norm1 = WanLayerNorm(dim, eps)
260
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
261
+ eps)
262
+ self.norm3 = WanLayerNorm(
263
+ dim, eps,
264
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
265
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
266
+ num_heads,
267
+ (-1, -1),
268
+ qk_norm,
269
+ eps)
270
+ self.norm2 = WanLayerNorm(dim, eps)
271
+ self.ffn = nn.Sequential(
272
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
273
+ nn.Linear(ffn_dim, dim))
274
+
275
+ # modulation
276
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
277
+
278
+ def forward(
279
+ self,
280
+ x,
281
+ e,
282
+ seq_lens,
283
+ grid_sizes,
284
+ freqs,
285
+ context,
286
+ context_lens,
287
+ ):
288
+ r"""
289
+ Args:
290
+ x(Tensor): Shape [B, L, C]
291
+ e(Tensor): Shape [B, 6, C]
292
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
293
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
294
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
295
+ """
296
+ assert e.dtype == torch.float32
297
+ with amp.autocast(dtype=torch.float32):
298
+ e = (self.modulation + e).chunk(6, dim=1)
299
+ assert e[0].dtype == torch.float32
300
+
301
+ # self-attention
302
+ y = self.self_attn(
303
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
304
+ freqs)
305
+ with amp.autocast(dtype=torch.float32):
306
+ x = x + y * e[2]
307
+
308
+ # cross-attention & ffn function
309
+ def cross_attn_ffn(x, context, context_lens, e):
310
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
311
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
312
+ with amp.autocast(dtype=torch.float32):
313
+ x = x + y * e[5]
314
+ return x
315
+
316
+ x = cross_attn_ffn(x, context, context_lens, e)
317
+ return x
318
+
319
+
320
+ class Head(nn.Module):
321
+
322
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
323
+ super().__init__()
324
+ self.dim = dim
325
+ self.out_dim = out_dim
326
+ self.patch_size = patch_size
327
+ self.eps = eps
328
+
329
+ # layers
330
+ out_dim = math.prod(patch_size) * out_dim
331
+ self.norm = WanLayerNorm(dim, eps)
332
+ self.head = nn.Linear(dim, out_dim)
333
+
334
+ # modulation
335
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5)
336
+
337
+ def forward(self, x, e):
338
+ r"""
339
+ Args:
340
+ x(Tensor): Shape [B, L1, C]
341
+ e(Tensor): Shape [B, C]
342
+ """
343
+ assert e.dtype == torch.float32
344
+ with amp.autocast(dtype=torch.float32):
345
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
346
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
347
+ return x
348
+
349
+
350
+ class MLPProj(torch.nn.Module):
351
+
352
+ def __init__(self, in_dim, out_dim, flf_pos_emb=False):
353
+ super().__init__()
354
+
355
+ self.proj = torch.nn.Sequential(
356
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
357
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
358
+ torch.nn.LayerNorm(out_dim))
359
+ if flf_pos_emb: # NOTE: we only use this for `flf2v`
360
+ self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
361
+
362
+ def forward(self, image_embeds):
363
+ if hasattr(self, 'emb_pos'):
364
+ bs, n, d = image_embeds.shape
365
+ image_embeds = image_embeds.view(-1, 2 * n, d)
366
+ image_embeds = image_embeds + self.emb_pos
367
+ clip_extra_context_tokens = self.proj(image_embeds)
368
+ return clip_extra_context_tokens
369
+
370
+
371
+ class WanModel(ModelMixin, ConfigMixin):
372
+ r"""
373
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
374
+ """
375
+
376
+ ignore_for_config = [
377
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
378
+ ]
379
+ _no_split_modules = ['WanAttentionBlock']
380
+
381
+ @register_to_config
382
+ def __init__(self,
383
+ model_type='t2v',
384
+ patch_size=(1, 2, 2),
385
+ text_len=512,
386
+ in_dim=16,
387
+ dim=2048,
388
+ ffn_dim=8192,
389
+ freq_dim=256,
390
+ text_dim=4096,
391
+ out_dim=16,
392
+ num_heads=16,
393
+ num_layers=32,
394
+ window_size=(-1, -1),
395
+ qk_norm=True,
396
+ cross_attn_norm=True,
397
+ eps=1e-6):
398
+ r"""
399
+ Initialize the diffusion model backbone.
400
+
401
+ Args:
402
+ model_type (`str`, *optional*, defaults to 't2v'):
403
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
404
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
405
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
406
+ text_len (`int`, *optional*, defaults to 512):
407
+ Fixed length for text embeddings
408
+ in_dim (`int`, *optional*, defaults to 16):
409
+ Input video channels (C_in)
410
+ dim (`int`, *optional*, defaults to 2048):
411
+ Hidden dimension of the transformer
412
+ ffn_dim (`int`, *optional*, defaults to 8192):
413
+ Intermediate dimension in feed-forward network
414
+ freq_dim (`int`, *optional*, defaults to 256):
415
+ Dimension for sinusoidal time embeddings
416
+ text_dim (`int`, *optional*, defaults to 4096):
417
+ Input dimension for text embeddings
418
+ out_dim (`int`, *optional*, defaults to 16):
419
+ Output video channels (C_out)
420
+ num_heads (`int`, *optional*, defaults to 16):
421
+ Number of attention heads
422
+ num_layers (`int`, *optional*, defaults to 32):
423
+ Number of transformer blocks
424
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
425
+ Window size for local attention (-1 indicates global attention)
426
+ qk_norm (`bool`, *optional*, defaults to True):
427
+ Enable query/key normalization
428
+ cross_attn_norm (`bool`, *optional*, defaults to False):
429
+ Enable cross-attention normalization
430
+ eps (`float`, *optional*, defaults to 1e-6):
431
+ Epsilon value for normalization layers
432
+ """
433
+
434
+ super().__init__()
435
+
436
+ assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
437
+ self.model_type = model_type
438
+
439
+ self.patch_size = patch_size
440
+ self.text_len = text_len
441
+ self.in_dim = in_dim
442
+ self.dim = dim
443
+ self.ffn_dim = ffn_dim
444
+ self.freq_dim = freq_dim
445
+ self.text_dim = text_dim
446
+ self.out_dim = out_dim
447
+ self.num_heads = num_heads
448
+ self.num_layers = num_layers
449
+ self.window_size = window_size
450
+ self.qk_norm = qk_norm
451
+ self.cross_attn_norm = cross_attn_norm
452
+ self.eps = eps
453
+
454
+ # embeddings
455
+ self.patch_embedding = nn.Conv3d(
456
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
457
+ self.text_embedding = nn.Sequential(
458
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
459
+ nn.Linear(dim, dim))
460
+
461
+ self.time_embedding = nn.Sequential(
462
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
463
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
464
+
465
+ # blocks
466
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
467
+ self.blocks = nn.ModuleList([
468
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
469
+ window_size, qk_norm, cross_attn_norm, eps)
470
+ for _ in range(num_layers)
471
+ ])
472
+
473
+ # head
474
+ self.head = Head(dim, out_dim, patch_size, eps)
475
+
476
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
477
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
478
+ d = dim // num_heads
479
+ self.freqs = torch.cat([
480
+ rope_params(1024, d - 4 * (d // 6)),
481
+ rope_params(1024, 2 * (d // 6)),
482
+ rope_params(1024, 2 * (d // 6))
483
+ ],
484
+ dim=1)
485
+
486
+ if model_type == 'i2v' or model_type == 'flf2v':
487
+ self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
488
+
489
+ # initialize weights
490
+ self.init_weights()
491
+
492
+ def forward(
493
+ self,
494
+ x,
495
+ t,
496
+ context,
497
+ seq_len,
498
+ clip_fea=None,
499
+ y=None,
500
+ ):
501
+ r"""
502
+ Forward pass through the diffusion model
503
+
504
+ Args:
505
+ x (List[Tensor]):
506
+ List of input video tensors, each with shape [C_in, F, H, W]
507
+ t (Tensor):
508
+ Diffusion timesteps tensor of shape [B]
509
+ context (List[Tensor]):
510
+ List of text embeddings each with shape [L, C]
511
+ seq_len (`int`):
512
+ Maximum sequence length for positional encoding
513
+ clip_fea (Tensor, *optional*):
514
+ CLIP image features for image-to-video mode or first-last-frame-to-video mode
515
+ y (List[Tensor], *optional*):
516
+ Conditional video inputs for image-to-video mode, same shape as x
517
+
518
+ Returns:
519
+ List[Tensor]:
520
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
521
+ """
522
+ if self.model_type == 'i2v' or self.model_type == 'flf2v':
523
+ assert clip_fea is not None and y is not None
524
+ # params
525
+ device = self.patch_embedding.weight.device
526
+ if self.freqs.device != device:
527
+ self.freqs = self.freqs.to(device)
528
+
529
+ if y is not None:
530
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
531
+
532
+ # embeddings
533
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
534
+ grid_sizes = torch.stack(
535
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
536
+ x = [u.flatten(2).transpose(1, 2) for u in x]
537
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
538
+ assert seq_lens.max() <= seq_len
539
+ x = torch.cat([
540
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
541
+ dim=1) for u in x
542
+ ])
543
+
544
+ # time embeddings
545
+ with amp.autocast(dtype=torch.float32):
546
+ e = self.time_embedding(
547
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
548
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
549
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
550
+
551
+ # context
552
+ context_lens = None
553
+ context = self.text_embedding(
554
+ torch.stack([
555
+ torch.cat(
556
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
557
+ for u in context
558
+ ]))
559
+
560
+ if clip_fea is not None:
561
+ context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dim
562
+ context = torch.concat([context_clip, context], dim=1)
563
+
564
+ # arguments
565
+ kwargs = dict(
566
+ e=e0,
567
+ seq_lens=seq_lens,
568
+ grid_sizes=grid_sizes,
569
+ freqs=self.freqs,
570
+ context=context,
571
+ context_lens=context_lens)
572
+
573
+ for block in self.blocks:
574
+ x = block(x, **kwargs)
575
+
576
+ # head
577
+ x = self.head(x, e)
578
+
579
+ # unpatchify
580
+ x = self.unpatchify(x, grid_sizes)
581
+ return [u.float() for u in x]
582
+
583
+ def unpatchify(self, x, grid_sizes):
584
+ r"""
585
+ Reconstruct video tensors from patch embeddings.
586
+
587
+ Args:
588
+ x (List[Tensor]):
589
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
590
+ grid_sizes (Tensor):
591
+ Original spatial-temporal grid dimensions before patching,
592
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
593
+
594
+ Returns:
595
+ List[Tensor]:
596
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
597
+ """
598
+
599
+ c = self.out_dim
600
+ out = []
601
+ for u, v in zip(x, grid_sizes.tolist()):
602
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
603
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
604
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
605
+ out.append(u)
606
+ return out
607
+
608
+ def init_weights(self):
609
+ r"""
610
+ Initialize model parameters using Xavier initialization.
611
+ """
612
+
613
+ # basic init
614
+ for m in self.modules():
615
+ if isinstance(m, nn.Linear):
616
+ nn.init.xavier_uniform_(m.weight)
617
+ if m.bias is not None:
618
+ nn.init.zeros_(m.bias)
619
+
620
+ # init embeddings
621
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
622
+ for m in self.text_embedding.modules():
623
+ if isinstance(m, nn.Linear):
624
+ nn.init.normal_(m.weight, std=.02)
625
+ for m in self.time_embedding.modules():
626
+ if isinstance(m, nn.Linear):
627
+ nn.init.normal_(m.weight, std=.02)
628
+
629
+ # init output layer
630
+ nn.init.zeros_(self.head.head.weight)
wan/modules/t5.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .tokenizers import HuggingfaceTokenizer
11
+
12
+ __all__ = [
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ ]
18
+
19
+
20
+ def fp16_clamp(x):
21
+ if x.dtype == torch.float16 and torch.isinf(x).any():
22
+ clamp = torch.finfo(x.dtype).max - 1000
23
+ x = torch.clamp(x, min=-clamp, max=clamp)
24
+ return x
25
+
26
+
27
+ def init_weights(m):
28
+ if isinstance(m, T5LayerNorm):
29
+ nn.init.ones_(m.weight)
30
+ elif isinstance(m, T5Model):
31
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
32
+ elif isinstance(m, T5FeedForward):
33
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36
+ elif isinstance(m, T5Attention):
37
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
38
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
41
+ elif isinstance(m, T5RelativeEmbedding):
42
+ nn.init.normal_(
43
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
44
+
45
+
46
+ class GELU(nn.Module):
47
+
48
+ def forward(self, x):
49
+ return 0.5 * x * (1.0 + torch.tanh(
50
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
51
+
52
+
53
+ class T5LayerNorm(nn.Module):
54
+
55
+ def __init__(self, dim, eps=1e-6):
56
+ super(T5LayerNorm, self).__init__()
57
+ self.dim = dim
58
+ self.eps = eps
59
+ self.weight = nn.Parameter(torch.ones(dim))
60
+
61
+ def forward(self, x):
62
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
63
+ self.eps)
64
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
65
+ x = x.type_as(self.weight)
66
+ return self.weight * x
67
+
68
+
69
+ class T5Attention(nn.Module):
70
+
71
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
72
+ assert dim_attn % num_heads == 0
73
+ super(T5Attention, self).__init__()
74
+ self.dim = dim
75
+ self.dim_attn = dim_attn
76
+ self.num_heads = num_heads
77
+ self.head_dim = dim_attn // num_heads
78
+
79
+ # layers
80
+ self.q = nn.Linear(dim, dim_attn, bias=False)
81
+ self.k = nn.Linear(dim, dim_attn, bias=False)
82
+ self.v = nn.Linear(dim, dim_attn, bias=False)
83
+ self.o = nn.Linear(dim_attn, dim, bias=False)
84
+ self.dropout = nn.Dropout(dropout)
85
+
86
+ def forward(self, x, context=None, mask=None, pos_bias=None):
87
+ """
88
+ x: [B, L1, C].
89
+ context: [B, L2, C] or None.
90
+ mask: [B, L2] or [B, L1, L2] or None.
91
+ """
92
+ # check inputs
93
+ context = x if context is None else context
94
+ b, n, c = x.size(0), self.num_heads, self.head_dim
95
+
96
+ # compute query, key, value
97
+ q = self.q(x).view(b, -1, n, c)
98
+ k = self.k(context).view(b, -1, n, c)
99
+ v = self.v(context).view(b, -1, n, c)
100
+
101
+ # attention bias
102
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
103
+ if pos_bias is not None:
104
+ attn_bias += pos_bias
105
+ if mask is not None:
106
+ assert mask.ndim in [2, 3]
107
+ mask = mask.view(b, 1, 1,
108
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
109
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
110
+
111
+ # compute attention (T5 does not use scaling)
112
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
113
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
114
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
115
+
116
+ # output
117
+ x = x.reshape(b, -1, n * c)
118
+ x = self.o(x)
119
+ x = self.dropout(x)
120
+ return x
121
+
122
+
123
+ class T5FeedForward(nn.Module):
124
+
125
+ def __init__(self, dim, dim_ffn, dropout=0.1):
126
+ super(T5FeedForward, self).__init__()
127
+ self.dim = dim
128
+ self.dim_ffn = dim_ffn
129
+
130
+ # layers
131
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
132
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
133
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
134
+ self.dropout = nn.Dropout(dropout)
135
+
136
+ def forward(self, x):
137
+ x = self.fc1(x) * self.gate(x)
138
+ x = self.dropout(x)
139
+ x = self.fc2(x)
140
+ x = self.dropout(x)
141
+ return x
142
+
143
+
144
+ class T5SelfAttention(nn.Module):
145
+
146
+ def __init__(self,
147
+ dim,
148
+ dim_attn,
149
+ dim_ffn,
150
+ num_heads,
151
+ num_buckets,
152
+ shared_pos=True,
153
+ dropout=0.1):
154
+ super(T5SelfAttention, self).__init__()
155
+ self.dim = dim
156
+ self.dim_attn = dim_attn
157
+ self.dim_ffn = dim_ffn
158
+ self.num_heads = num_heads
159
+ self.num_buckets = num_buckets
160
+ self.shared_pos = shared_pos
161
+
162
+ # layers
163
+ self.norm1 = T5LayerNorm(dim)
164
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
165
+ self.norm2 = T5LayerNorm(dim)
166
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
167
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
168
+ num_buckets, num_heads, bidirectional=True)
169
+
170
+ def forward(self, x, mask=None, pos_bias=None):
171
+ e = pos_bias if self.shared_pos else self.pos_embedding(
172
+ x.size(1), x.size(1))
173
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
174
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
175
+ return x
176
+
177
+
178
+ class T5CrossAttention(nn.Module):
179
+
180
+ def __init__(self,
181
+ dim,
182
+ dim_attn,
183
+ dim_ffn,
184
+ num_heads,
185
+ num_buckets,
186
+ shared_pos=True,
187
+ dropout=0.1):
188
+ super(T5CrossAttention, self).__init__()
189
+ self.dim = dim
190
+ self.dim_attn = dim_attn
191
+ self.dim_ffn = dim_ffn
192
+ self.num_heads = num_heads
193
+ self.num_buckets = num_buckets
194
+ self.shared_pos = shared_pos
195
+
196
+ # layers
197
+ self.norm1 = T5LayerNorm(dim)
198
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
199
+ self.norm2 = T5LayerNorm(dim)
200
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
201
+ self.norm3 = T5LayerNorm(dim)
202
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
203
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
204
+ num_buckets, num_heads, bidirectional=False)
205
+
206
+ def forward(self,
207
+ x,
208
+ mask=None,
209
+ encoder_states=None,
210
+ encoder_mask=None,
211
+ pos_bias=None):
212
+ e = pos_bias if self.shared_pos else self.pos_embedding(
213
+ x.size(1), x.size(1))
214
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
215
+ x = fp16_clamp(x + self.cross_attn(
216
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
217
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
218
+ return x
219
+
220
+
221
+ class T5RelativeEmbedding(nn.Module):
222
+
223
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
224
+ super(T5RelativeEmbedding, self).__init__()
225
+ self.num_buckets = num_buckets
226
+ self.num_heads = num_heads
227
+ self.bidirectional = bidirectional
228
+ self.max_dist = max_dist
229
+
230
+ # layers
231
+ self.embedding = nn.Embedding(num_buckets, num_heads)
232
+
233
+ def forward(self, lq, lk):
234
+ device = self.embedding.weight.device
235
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
236
+ # torch.arange(lq).unsqueeze(1).to(device)
237
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
238
+ torch.arange(lq, device=device).unsqueeze(1)
239
+ rel_pos = self._relative_position_bucket(rel_pos)
240
+ rel_pos_embeds = self.embedding(rel_pos)
241
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
242
+ 0) # [1, N, Lq, Lk]
243
+ return rel_pos_embeds.contiguous()
244
+
245
+ def _relative_position_bucket(self, rel_pos):
246
+ # preprocess
247
+ if self.bidirectional:
248
+ num_buckets = self.num_buckets // 2
249
+ rel_buckets = (rel_pos > 0).long() * num_buckets
250
+ rel_pos = torch.abs(rel_pos)
251
+ else:
252
+ num_buckets = self.num_buckets
253
+ rel_buckets = 0
254
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
255
+
256
+ # embeddings for small and large positions
257
+ max_exact = num_buckets // 2
258
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
259
+ math.log(self.max_dist / max_exact) *
260
+ (num_buckets - max_exact)).long()
261
+ rel_pos_large = torch.min(
262
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
263
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
264
+ return rel_buckets
265
+
266
+
267
+ class T5Encoder(nn.Module):
268
+
269
+ def __init__(self,
270
+ vocab,
271
+ dim,
272
+ dim_attn,
273
+ dim_ffn,
274
+ num_heads,
275
+ num_layers,
276
+ num_buckets,
277
+ shared_pos=True,
278
+ dropout=0.1):
279
+ super(T5Encoder, self).__init__()
280
+ self.dim = dim
281
+ self.dim_attn = dim_attn
282
+ self.dim_ffn = dim_ffn
283
+ self.num_heads = num_heads
284
+ self.num_layers = num_layers
285
+ self.num_buckets = num_buckets
286
+ self.shared_pos = shared_pos
287
+
288
+ # layers
289
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
290
+ else nn.Embedding(vocab, dim)
291
+ self.pos_embedding = T5RelativeEmbedding(
292
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
293
+ self.dropout = nn.Dropout(dropout)
294
+ self.blocks = nn.ModuleList([
295
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
296
+ shared_pos, dropout) for _ in range(num_layers)
297
+ ])
298
+ self.norm = T5LayerNorm(dim)
299
+
300
+ # initialize weights
301
+ self.apply(init_weights)
302
+
303
+ def forward(self, ids, mask=None):
304
+ x = self.token_embedding(ids)
305
+ x = self.dropout(x)
306
+ e = self.pos_embedding(x.size(1),
307
+ x.size(1)) if self.shared_pos else None
308
+ for block in self.blocks:
309
+ x = block(x, mask, pos_bias=e)
310
+ x = self.norm(x)
311
+ x = self.dropout(x)
312
+ return x
313
+
314
+
315
+ class T5Decoder(nn.Module):
316
+
317
+ def __init__(self,
318
+ vocab,
319
+ dim,
320
+ dim_attn,
321
+ dim_ffn,
322
+ num_heads,
323
+ num_layers,
324
+ num_buckets,
325
+ shared_pos=True,
326
+ dropout=0.1):
327
+ super(T5Decoder, self).__init__()
328
+ self.dim = dim
329
+ self.dim_attn = dim_attn
330
+ self.dim_ffn = dim_ffn
331
+ self.num_heads = num_heads
332
+ self.num_layers = num_layers
333
+ self.num_buckets = num_buckets
334
+ self.shared_pos = shared_pos
335
+
336
+ # layers
337
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
338
+ else nn.Embedding(vocab, dim)
339
+ self.pos_embedding = T5RelativeEmbedding(
340
+ num_buckets, num_heads, bidirectional=False) if shared_pos else None
341
+ self.dropout = nn.Dropout(dropout)
342
+ self.blocks = nn.ModuleList([
343
+ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
344
+ shared_pos, dropout) for _ in range(num_layers)
345
+ ])
346
+ self.norm = T5LayerNorm(dim)
347
+
348
+ # initialize weights
349
+ self.apply(init_weights)
350
+
351
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
352
+ b, s = ids.size()
353
+
354
+ # causal mask
355
+ if mask is None:
356
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
357
+ elif mask.ndim == 2:
358
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
359
+
360
+ # layers
361
+ x = self.token_embedding(ids)
362
+ x = self.dropout(x)
363
+ e = self.pos_embedding(x.size(1),
364
+ x.size(1)) if self.shared_pos else None
365
+ for block in self.blocks:
366
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
367
+ x = self.norm(x)
368
+ x = self.dropout(x)
369
+ return x
370
+
371
+
372
+ class T5Model(nn.Module):
373
+
374
+ def __init__(self,
375
+ vocab_size,
376
+ dim,
377
+ dim_attn,
378
+ dim_ffn,
379
+ num_heads,
380
+ encoder_layers,
381
+ decoder_layers,
382
+ num_buckets,
383
+ shared_pos=True,
384
+ dropout=0.1):
385
+ super(T5Model, self).__init__()
386
+ self.vocab_size = vocab_size
387
+ self.dim = dim
388
+ self.dim_attn = dim_attn
389
+ self.dim_ffn = dim_ffn
390
+ self.num_heads = num_heads
391
+ self.encoder_layers = encoder_layers
392
+ self.decoder_layers = decoder_layers
393
+ self.num_buckets = num_buckets
394
+
395
+ # layers
396
+ self.token_embedding = nn.Embedding(vocab_size, dim)
397
+ self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
398
+ num_heads, encoder_layers, num_buckets,
399
+ shared_pos, dropout)
400
+ self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
401
+ num_heads, decoder_layers, num_buckets,
402
+ shared_pos, dropout)
403
+ self.head = nn.Linear(dim, vocab_size, bias=False)
404
+
405
+ # initialize weights
406
+ self.apply(init_weights)
407
+
408
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
409
+ x = self.encoder(encoder_ids, encoder_mask)
410
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
411
+ x = self.head(x)
412
+ return x
413
+
414
+
415
+ def _t5(name,
416
+ encoder_only=False,
417
+ decoder_only=False,
418
+ return_tokenizer=False,
419
+ tokenizer_kwargs={},
420
+ dtype=torch.float32,
421
+ device='cpu',
422
+ **kwargs):
423
+ # sanity check
424
+ assert not (encoder_only and decoder_only)
425
+
426
+ # params
427
+ if encoder_only:
428
+ model_cls = T5Encoder
429
+ kwargs['vocab'] = kwargs.pop('vocab_size')
430
+ kwargs['num_layers'] = kwargs.pop('encoder_layers')
431
+ _ = kwargs.pop('decoder_layers')
432
+ elif decoder_only:
433
+ model_cls = T5Decoder
434
+ kwargs['vocab'] = kwargs.pop('vocab_size')
435
+ kwargs['num_layers'] = kwargs.pop('decoder_layers')
436
+ _ = kwargs.pop('encoder_layers')
437
+ else:
438
+ model_cls = T5Model
439
+
440
+ # init model
441
+ with torch.device(device):
442
+ model = model_cls(**kwargs)
443
+
444
+ # set device
445
+ model = model.to(dtype=dtype, device=device)
446
+
447
+ # init tokenizer
448
+ if return_tokenizer:
449
+ from .tokenizers import HuggingfaceTokenizer
450
+ tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
451
+ return model, tokenizer
452
+ else:
453
+ return model
454
+
455
+
456
+ def umt5_xxl(**kwargs):
457
+ cfg = dict(
458
+ vocab_size=256384,
459
+ dim=4096,
460
+ dim_attn=4096,
461
+ dim_ffn=10240,
462
+ num_heads=64,
463
+ encoder_layers=24,
464
+ decoder_layers=24,
465
+ num_buckets=32,
466
+ shared_pos=False,
467
+ dropout=0.1)
468
+ cfg.update(**kwargs)
469
+ return _t5('umt5-xxl', **cfg)
470
+
471
+
472
+ class T5EncoderModel:
473
+
474
+ def __init__(
475
+ self,
476
+ text_len,
477
+ dtype=torch.bfloat16,
478
+ device=torch.cuda.current_device(),
479
+ checkpoint_path=None,
480
+ tokenizer_path=None,
481
+ shard_fn=None,
482
+ ):
483
+ self.text_len = text_len
484
+ self.dtype = dtype
485
+ self.device = device
486
+ self.checkpoint_path = checkpoint_path
487
+ self.tokenizer_path = tokenizer_path
488
+
489
+ # init model
490
+ model = umt5_xxl(
491
+ encoder_only=True,
492
+ return_tokenizer=False,
493
+ dtype=dtype,
494
+ device=device).eval().requires_grad_(False)
495
+ logging.info(f'loading {checkpoint_path}')
496
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
497
+ self.model = model
498
+ if shard_fn is not None:
499
+ self.model = shard_fn(self.model, sync_module_states=False)
500
+ else:
501
+ self.model.to(self.device)
502
+ # init tokenizer
503
+ self.tokenizer = HuggingfaceTokenizer(
504
+ name=tokenizer_path, seq_len=text_len, clean='whitespace')
505
+
506
+ def __call__(self, texts, device):
507
+ ids, mask = self.tokenizer(
508
+ texts, return_mask=True, add_special_tokens=True)
509
+ ids = ids.to(device)
510
+ mask = mask.to(device)
511
+ seq_lens = mask.gt(0).sum(dim=1).long()
512
+ context = self.model(ids, mask)
513
+ return [u[:v] for u, v in zip(context, seq_lens)]
wan/modules/tokenizers.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import html
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ from transformers import AutoTokenizer
8
+
9
+ __all__ = ['HuggingfaceTokenizer']
10
+
11
+
12
+ def basic_clean(text):
13
+ text = ftfy.fix_text(text)
14
+ text = html.unescape(html.unescape(text))
15
+ return text.strip()
16
+
17
+
18
+ def whitespace_clean(text):
19
+ text = re.sub(r'\s+', ' ', text)
20
+ text = text.strip()
21
+ return text
22
+
23
+
24
+ def canonicalize(text, keep_punctuation_exact_string=None):
25
+ text = text.replace('_', ' ')
26
+ if keep_punctuation_exact_string:
27
+ text = keep_punctuation_exact_string.join(
28
+ part.translate(str.maketrans('', '', string.punctuation))
29
+ for part in text.split(keep_punctuation_exact_string))
30
+ else:
31
+ text = text.translate(str.maketrans('', '', string.punctuation))
32
+ text = text.lower()
33
+ text = re.sub(r'\s+', ' ', text)
34
+ return text.strip()
35
+
36
+
37
+ class HuggingfaceTokenizer:
38
+
39
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
41
+ self.name = name
42
+ self.seq_len = seq_len
43
+ self.clean = clean
44
+
45
+ # init tokenizer
46
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
+ self.vocab_size = self.tokenizer.vocab_size
48
+
49
+ def __call__(self, sequence, **kwargs):
50
+ return_mask = kwargs.pop('return_mask', False)
51
+
52
+ # arguments
53
+ _kwargs = {'return_tensors': 'pt'}
54
+ if self.seq_len is not None:
55
+ _kwargs.update({
56
+ 'padding': 'max_length',
57
+ 'truncation': True,
58
+ 'max_length': self.seq_len
59
+ })
60
+ _kwargs.update(**kwargs)
61
+
62
+ # tokenization
63
+ if isinstance(sequence, str):
64
+ sequence = [sequence]
65
+ if self.clean:
66
+ sequence = [self._clean(u) for u in sequence]
67
+ ids = self.tokenizer(sequence, **_kwargs)
68
+
69
+ # output
70
+ if return_mask:
71
+ return ids.input_ids, ids.attention_mask
72
+ else:
73
+ return ids.input_ids
74
+
75
+ def _clean(self, text):
76
+ if self.clean == 'whitespace':
77
+ text = whitespace_clean(basic_clean(text))
78
+ elif self.clean == 'lower':
79
+ text = whitespace_clean(basic_clean(text)).lower()
80
+ elif self.clean == 'canonicalize':
81
+ text = canonicalize(basic_clean(text))
82
+ return text
wan/modules/vace_model.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+ import torch.nn as nn
5
+ from diffusers.configuration_utils import register_to_config
6
+ from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
7
+
8
+
9
+ class VaceWanAttentionBlock(WanAttentionBlock):
10
+ def __init__(
11
+ self,
12
+ cross_attn_type,
13
+ dim,
14
+ ffn_dim,
15
+ num_heads,
16
+ window_size=(-1, -1),
17
+ qk_norm=True,
18
+ cross_attn_norm=False,
19
+ eps=1e-6,
20
+ block_id=0
21
+ ):
22
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
23
+ self.block_id = block_id
24
+ if block_id == 0:
25
+ self.before_proj = nn.Linear(self.dim, self.dim)
26
+ nn.init.zeros_(self.before_proj.weight)
27
+ nn.init.zeros_(self.before_proj.bias)
28
+ self.after_proj = nn.Linear(self.dim, self.dim)
29
+ nn.init.zeros_(self.after_proj.weight)
30
+ nn.init.zeros_(self.after_proj.bias)
31
+
32
+ def forward(self, c, x, **kwargs):
33
+ if self.block_id == 0:
34
+ c = self.before_proj(c) + x
35
+
36
+ c = super().forward(c, **kwargs)
37
+ c_skip = self.after_proj(c)
38
+ return c, c_skip
39
+
40
+
41
+ class BaseWanAttentionBlock(WanAttentionBlock):
42
+ def __init__(
43
+ self,
44
+ cross_attn_type,
45
+ dim,
46
+ ffn_dim,
47
+ num_heads,
48
+ window_size=(-1, -1),
49
+ qk_norm=True,
50
+ cross_attn_norm=False,
51
+ eps=1e-6,
52
+ block_id=None
53
+ ):
54
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
55
+ self.block_id = block_id
56
+
57
+ def forward(self, x, hints, context_scale=1.0, **kwargs):
58
+ x = super().forward(x, **kwargs)
59
+ if self.block_id is not None:
60
+ x = x + hints[self.block_id] * context_scale
61
+ return x
62
+
63
+
64
+ class VaceWanModel(WanModel):
65
+ @register_to_config
66
+ def __init__(self,
67
+ vace_layers=None,
68
+ vace_in_dim=None,
69
+ model_type='vace',
70
+ patch_size=(1, 2, 2),
71
+ text_len=512,
72
+ in_dim=16,
73
+ dim=2048,
74
+ ffn_dim=8192,
75
+ freq_dim=256,
76
+ text_dim=4096,
77
+ out_dim=16,
78
+ num_heads=16,
79
+ num_layers=32,
80
+ window_size=(-1, -1),
81
+ qk_norm=True,
82
+ cross_attn_norm=True,
83
+ eps=1e-6):
84
+ super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
85
+ num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
86
+
87
+ self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
88
+ self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
89
+
90
+ assert 0 in self.vace_layers
91
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
92
+
93
+ # blocks
94
+ self.blocks = nn.ModuleList([
95
+ BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
96
+ self.cross_attn_norm, self.eps,
97
+ block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
98
+ for i in range(self.num_layers)
99
+ ])
100
+
101
+ # vace blocks
102
+ self.vace_blocks = nn.ModuleList([
103
+ VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
104
+ self.cross_attn_norm, self.eps, block_id=i)
105
+ for i in self.vace_layers
106
+ ])
107
+
108
+ # vace patch embeddings
109
+ self.vace_patch_embedding = nn.Conv3d(
110
+ self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
111
+ )
112
+
113
+ def forward_vace(
114
+ self,
115
+ x,
116
+ vace_context,
117
+ seq_len,
118
+ kwargs
119
+ ):
120
+ # embeddings
121
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
122
+ c = [u.flatten(2).transpose(1, 2) for u in c]
123
+ c = torch.cat([
124
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
125
+ dim=1) for u in c
126
+ ])
127
+
128
+ # arguments
129
+ new_kwargs = dict(x=x)
130
+ new_kwargs.update(kwargs)
131
+
132
+ hints = []
133
+ for block in self.vace_blocks:
134
+ c, c_skip = block(c, **new_kwargs)
135
+ hints.append(c_skip)
136
+ return hints
137
+
138
+ def forward(
139
+ self,
140
+ x,
141
+ t,
142
+ vace_context,
143
+ context,
144
+ seq_len,
145
+ vace_context_scale=1.0,
146
+ clip_fea=None,
147
+ y=None,
148
+ ):
149
+ r"""
150
+ Forward pass through the diffusion model
151
+
152
+ Args:
153
+ x (List[Tensor]):
154
+ List of input video tensors, each with shape [C_in, F, H, W]
155
+ t (Tensor):
156
+ Diffusion timesteps tensor of shape [B]
157
+ context (List[Tensor]):
158
+ List of text embeddings each with shape [L, C]
159
+ seq_len (`int`):
160
+ Maximum sequence length for positional encoding
161
+ clip_fea (Tensor, *optional*):
162
+ CLIP image features for image-to-video mode
163
+ y (List[Tensor], *optional*):
164
+ Conditional video inputs for image-to-video mode, same shape as x
165
+
166
+ Returns:
167
+ List[Tensor]:
168
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
169
+ """
170
+ # if self.model_type == 'i2v':
171
+ # assert clip_fea is not None and y is not None
172
+ # params
173
+ device = self.patch_embedding.weight.device
174
+ if self.freqs.device != device:
175
+ self.freqs = self.freqs.to(device)
176
+
177
+ # if y is not None:
178
+ # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
179
+
180
+ # embeddings
181
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
182
+ grid_sizes = torch.stack(
183
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
184
+ x = [u.flatten(2).transpose(1, 2) for u in x]
185
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
186
+ assert seq_lens.max() <= seq_len
187
+ x = torch.cat([
188
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
189
+ dim=1) for u in x
190
+ ])
191
+
192
+ # time embeddings
193
+ with amp.autocast(dtype=torch.float32):
194
+ e = self.time_embedding(
195
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
196
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
197
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
198
+
199
+ # context
200
+ context_lens = None
201
+ context = self.text_embedding(
202
+ torch.stack([
203
+ torch.cat(
204
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
205
+ for u in context
206
+ ]))
207
+
208
+ # if clip_fea is not None:
209
+ # context_clip = self.img_emb(clip_fea) # bs x 257 x dim
210
+ # context = torch.concat([context_clip, context], dim=1)
211
+
212
+ # arguments
213
+ kwargs = dict(
214
+ e=e0,
215
+ seq_lens=seq_lens,
216
+ grid_sizes=grid_sizes,
217
+ freqs=self.freqs,
218
+ context=context,
219
+ context_lens=context_lens)
220
+
221
+ hints = self.forward_vace(x, vace_context, seq_len, kwargs)
222
+ kwargs['hints'] = hints
223
+ kwargs['context_scale'] = vace_context_scale
224
+
225
+ for block in self.blocks:
226
+ x = block(x, **kwargs)
227
+
228
+ # head
229
+ x = self.head(x, e)
230
+
231
+ # unpatchify
232
+ x = self.unpatchify(x, grid_sizes)
233
+ return [u.float() for u in x]
wan/modules/vae.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+
4
+ import torch
5
+ import torch.cuda.amp as amp
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ __all__ = [
11
+ 'WanVAE',
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
25
+ self.padding[1], 2 * self.padding[0], 0)
26
+ self.padding = (0, 0, 0)
27
+
28
+ def forward(self, x, cache_x=None):
29
+ padding = list(self._padding)
30
+ if cache_x is not None and self._padding[4] > 0:
31
+ cache_x = cache_x.to(x.device)
32
+ x = torch.cat([cache_x, x], dim=2)
33
+ padding[4] -= cache_x.shape[2]
34
+ x = F.pad(x, padding)
35
+
36
+ return super().forward(x)
37
+
38
+
39
+ class RMS_norm(nn.Module):
40
+
41
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
42
+ super().__init__()
43
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
44
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
45
+
46
+ self.channel_first = channel_first
47
+ self.scale = dim**0.5
48
+ self.gamma = nn.Parameter(torch.ones(shape))
49
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
50
+
51
+ def forward(self, x):
52
+ return F.normalize(
53
+ x, dim=(1 if self.channel_first else
54
+ -1)) * self.scale * self.gamma + self.bias
55
+
56
+
57
+ class Upsample(nn.Upsample):
58
+
59
+ def forward(self, x):
60
+ """
61
+ Fix bfloat16 support for nearest neighbor interpolation.
62
+ """
63
+ return super().forward(x.float()).type_as(x)
64
+
65
+
66
+ class Resample(nn.Module):
67
+
68
+ def __init__(self, dim, mode):
69
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
70
+ 'downsample3d')
71
+ super().__init__()
72
+ self.dim = dim
73
+ self.mode = mode
74
+
75
+ # layers
76
+ if mode == 'upsample2d':
77
+ self.resample = nn.Sequential(
78
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
79
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
80
+ elif mode == 'upsample3d':
81
+ self.resample = nn.Sequential(
82
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
84
+ self.time_conv = CausalConv3d(
85
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
86
+
87
+ elif mode == 'downsample2d':
88
+ self.resample = nn.Sequential(
89
+ nn.ZeroPad2d((0, 1, 0, 1)),
90
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
91
+ elif mode == 'downsample3d':
92
+ self.resample = nn.Sequential(
93
+ nn.ZeroPad2d((0, 1, 0, 1)),
94
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95
+ self.time_conv = CausalConv3d(
96
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
97
+
98
+ else:
99
+ self.resample = nn.Identity()
100
+
101
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
102
+ b, c, t, h, w = x.size()
103
+ if self.mode == 'upsample3d':
104
+ if feat_cache is not None:
105
+ idx = feat_idx[0]
106
+ if feat_cache[idx] is None:
107
+ feat_cache[idx] = 'Rep'
108
+ feat_idx[0] += 1
109
+ else:
110
+
111
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
112
+ if cache_x.shape[2] < 2 and feat_cache[
113
+ idx] is not None and feat_cache[idx] != 'Rep':
114
+ # cache last frame of last two chunk
115
+ cache_x = torch.cat([
116
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
117
+ cache_x.device), cache_x
118
+ ],
119
+ dim=2)
120
+ if cache_x.shape[2] < 2 and feat_cache[
121
+ idx] is not None and feat_cache[idx] == 'Rep':
122
+ cache_x = torch.cat([
123
+ torch.zeros_like(cache_x).to(cache_x.device),
124
+ cache_x
125
+ ],
126
+ dim=2)
127
+ if feat_cache[idx] == 'Rep':
128
+ x = self.time_conv(x)
129
+ else:
130
+ x = self.time_conv(x, feat_cache[idx])
131
+ feat_cache[idx] = cache_x
132
+ feat_idx[0] += 1
133
+
134
+ x = x.reshape(b, 2, c, t, h, w)
135
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
136
+ 3)
137
+ x = x.reshape(b, c, t * 2, h, w)
138
+ t = x.shape[2]
139
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
140
+ x = self.resample(x)
141
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
142
+
143
+ if self.mode == 'downsample3d':
144
+ if feat_cache is not None:
145
+ idx = feat_idx[0]
146
+ if feat_cache[idx] is None:
147
+ feat_cache[idx] = x.clone()
148
+ feat_idx[0] += 1
149
+ else:
150
+
151
+ cache_x = x[:, :, -1:, :, :].clone()
152
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
153
+ # # cache last frame of last two chunk
154
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
155
+
156
+ x = self.time_conv(
157
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
158
+ feat_cache[idx] = cache_x
159
+ feat_idx[0] += 1
160
+ return x
161
+
162
+ def init_weight(self, conv):
163
+ conv_weight = conv.weight
164
+ nn.init.zeros_(conv_weight)
165
+ c1, c2, t, h, w = conv_weight.size()
166
+ one_matrix = torch.eye(c1, c2)
167
+ init_matrix = one_matrix
168
+ nn.init.zeros_(conv_weight)
169
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
170
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
171
+ conv.weight.data.copy_(conv_weight)
172
+ nn.init.zeros_(conv.bias.data)
173
+
174
+ def init_weight2(self, conv):
175
+ conv_weight = conv.weight.data
176
+ nn.init.zeros_(conv_weight)
177
+ c1, c2, t, h, w = conv_weight.size()
178
+ init_matrix = torch.eye(c1 // 2, c2)
179
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
180
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
181
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
182
+ conv.weight.data.copy_(conv_weight)
183
+ nn.init.zeros_(conv.bias.data)
184
+
185
+
186
+ class ResidualBlock(nn.Module):
187
+
188
+ def __init__(self, in_dim, out_dim, dropout=0.0):
189
+ super().__init__()
190
+ self.in_dim = in_dim
191
+ self.out_dim = out_dim
192
+
193
+ # layers
194
+ self.residual = nn.Sequential(
195
+ RMS_norm(in_dim, images=False), nn.SiLU(),
196
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
197
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
198
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
199
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
200
+ if in_dim != out_dim else nn.Identity()
201
+
202
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
203
+ h = self.shortcut(x)
204
+ for layer in self.residual:
205
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
206
+ idx = feat_idx[0]
207
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
208
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
209
+ # cache last frame of last two chunk
210
+ cache_x = torch.cat([
211
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
212
+ cache_x.device), cache_x
213
+ ],
214
+ dim=2)
215
+ x = layer(x, feat_cache[idx])
216
+ feat_cache[idx] = cache_x
217
+ feat_idx[0] += 1
218
+ else:
219
+ x = layer(x)
220
+ return x + h
221
+
222
+
223
+ class AttentionBlock(nn.Module):
224
+ """
225
+ Causal self-attention with a single head.
226
+ """
227
+
228
+ def __init__(self, dim):
229
+ super().__init__()
230
+ self.dim = dim
231
+
232
+ # layers
233
+ self.norm = RMS_norm(dim)
234
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
235
+ self.proj = nn.Conv2d(dim, dim, 1)
236
+
237
+ # zero out the last layer params
238
+ nn.init.zeros_(self.proj.weight)
239
+
240
+ def forward(self, x):
241
+ identity = x
242
+ b, c, t, h, w = x.size()
243
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
244
+ x = self.norm(x)
245
+ # compute query, key, value
246
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
247
+ -1).permute(0, 1, 3,
248
+ 2).contiguous().chunk(
249
+ 3, dim=-1)
250
+
251
+ # apply attention
252
+ x = F.scaled_dot_product_attention(
253
+ q,
254
+ k,
255
+ v,
256
+ )
257
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
258
+
259
+ # output
260
+ x = self.proj(x)
261
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
262
+ return x + identity
263
+
264
+
265
+ class Encoder3d(nn.Module):
266
+
267
+ def __init__(self,
268
+ dim=128,
269
+ z_dim=4,
270
+ dim_mult=[1, 2, 4, 4],
271
+ num_res_blocks=2,
272
+ attn_scales=[],
273
+ temperal_downsample=[True, True, False],
274
+ dropout=0.0):
275
+ super().__init__()
276
+ self.dim = dim
277
+ self.z_dim = z_dim
278
+ self.dim_mult = dim_mult
279
+ self.num_res_blocks = num_res_blocks
280
+ self.attn_scales = attn_scales
281
+ self.temperal_downsample = temperal_downsample
282
+
283
+ # dimensions
284
+ dims = [dim * u for u in [1] + dim_mult]
285
+ scale = 1.0
286
+
287
+ # init block
288
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
289
+
290
+ # downsample blocks
291
+ downsamples = []
292
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
293
+ # residual (+attention) blocks
294
+ for _ in range(num_res_blocks):
295
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
296
+ if scale in attn_scales:
297
+ downsamples.append(AttentionBlock(out_dim))
298
+ in_dim = out_dim
299
+
300
+ # downsample block
301
+ if i != len(dim_mult) - 1:
302
+ mode = 'downsample3d' if temperal_downsample[
303
+ i] else 'downsample2d'
304
+ downsamples.append(Resample(out_dim, mode=mode))
305
+ scale /= 2.0
306
+ self.downsamples = nn.Sequential(*downsamples)
307
+
308
+ # middle blocks
309
+ self.middle = nn.Sequential(
310
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
311
+ ResidualBlock(out_dim, out_dim, dropout))
312
+
313
+ # output blocks
314
+ self.head = nn.Sequential(
315
+ RMS_norm(out_dim, images=False), nn.SiLU(),
316
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
317
+
318
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
319
+ if feat_cache is not None:
320
+ idx = feat_idx[0]
321
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
322
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
323
+ # cache last frame of last two chunk
324
+ cache_x = torch.cat([
325
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
326
+ cache_x.device), cache_x
327
+ ],
328
+ dim=2)
329
+ x = self.conv1(x, feat_cache[idx])
330
+ feat_cache[idx] = cache_x
331
+ feat_idx[0] += 1
332
+ else:
333
+ x = self.conv1(x)
334
+
335
+ ## downsamples
336
+ for layer in self.downsamples:
337
+ if feat_cache is not None:
338
+ x = layer(x, feat_cache, feat_idx)
339
+ else:
340
+ x = layer(x)
341
+
342
+ ## middle
343
+ for layer in self.middle:
344
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
345
+ x = layer(x, feat_cache, feat_idx)
346
+ else:
347
+ x = layer(x)
348
+
349
+ ## head
350
+ for layer in self.head:
351
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
352
+ idx = feat_idx[0]
353
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
354
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
355
+ # cache last frame of last two chunk
356
+ cache_x = torch.cat([
357
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
358
+ cache_x.device), cache_x
359
+ ],
360
+ dim=2)
361
+ x = layer(x, feat_cache[idx])
362
+ feat_cache[idx] = cache_x
363
+ feat_idx[0] += 1
364
+ else:
365
+ x = layer(x)
366
+ return x
367
+
368
+
369
+ class Decoder3d(nn.Module):
370
+
371
+ def __init__(self,
372
+ dim=128,
373
+ z_dim=4,
374
+ dim_mult=[1, 2, 4, 4],
375
+ num_res_blocks=2,
376
+ attn_scales=[],
377
+ temperal_upsample=[False, True, True],
378
+ dropout=0.0):
379
+ super().__init__()
380
+ self.dim = dim
381
+ self.z_dim = z_dim
382
+ self.dim_mult = dim_mult
383
+ self.num_res_blocks = num_res_blocks
384
+ self.attn_scales = attn_scales
385
+ self.temperal_upsample = temperal_upsample
386
+
387
+ # dimensions
388
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
389
+ scale = 1.0 / 2**(len(dim_mult) - 2)
390
+
391
+ # init block
392
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
393
+
394
+ # middle blocks
395
+ self.middle = nn.Sequential(
396
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
397
+ ResidualBlock(dims[0], dims[0], dropout))
398
+
399
+ # upsample blocks
400
+ upsamples = []
401
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
402
+ # residual (+attention) blocks
403
+ if i == 1 or i == 2 or i == 3:
404
+ in_dim = in_dim // 2
405
+ for _ in range(num_res_blocks + 1):
406
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
407
+ if scale in attn_scales:
408
+ upsamples.append(AttentionBlock(out_dim))
409
+ in_dim = out_dim
410
+
411
+ # upsample block
412
+ if i != len(dim_mult) - 1:
413
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
414
+ upsamples.append(Resample(out_dim, mode=mode))
415
+ scale *= 2.0
416
+ self.upsamples = nn.Sequential(*upsamples)
417
+
418
+ # output blocks
419
+ self.head = nn.Sequential(
420
+ RMS_norm(out_dim, images=False), nn.SiLU(),
421
+ CausalConv3d(out_dim, 3, 3, padding=1))
422
+
423
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
424
+ ## conv1
425
+ if feat_cache is not None:
426
+ idx = feat_idx[0]
427
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
428
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
429
+ # cache last frame of last two chunk
430
+ cache_x = torch.cat([
431
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
432
+ cache_x.device), cache_x
433
+ ],
434
+ dim=2)
435
+ x = self.conv1(x, feat_cache[idx])
436
+ feat_cache[idx] = cache_x
437
+ feat_idx[0] += 1
438
+ else:
439
+ x = self.conv1(x)
440
+
441
+ ## middle
442
+ for layer in self.middle:
443
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
444
+ x = layer(x, feat_cache, feat_idx)
445
+ else:
446
+ x = layer(x)
447
+
448
+ ## upsamples
449
+ for layer in self.upsamples:
450
+ if feat_cache is not None:
451
+ x = layer(x, feat_cache, feat_idx)
452
+ else:
453
+ x = layer(x)
454
+
455
+ ## head
456
+ for layer in self.head:
457
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
458
+ idx = feat_idx[0]
459
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
460
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
461
+ # cache last frame of last two chunk
462
+ cache_x = torch.cat([
463
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
464
+ cache_x.device), cache_x
465
+ ],
466
+ dim=2)
467
+ x = layer(x, feat_cache[idx])
468
+ feat_cache[idx] = cache_x
469
+ feat_idx[0] += 1
470
+ else:
471
+ x = layer(x)
472
+ return x
473
+
474
+
475
+ def count_conv3d(model):
476
+ count = 0
477
+ for m in model.modules():
478
+ if isinstance(m, CausalConv3d):
479
+ count += 1
480
+ return count
481
+
482
+
483
+ class WanVAE_(nn.Module):
484
+
485
+ def __init__(self,
486
+ dim=128,
487
+ z_dim=4,
488
+ dim_mult=[1, 2, 4, 4],
489
+ num_res_blocks=2,
490
+ attn_scales=[],
491
+ temperal_downsample=[True, True, False],
492
+ dropout=0.0):
493
+ super().__init__()
494
+ self.dim = dim
495
+ self.z_dim = z_dim
496
+ self.dim_mult = dim_mult
497
+ self.num_res_blocks = num_res_blocks
498
+ self.attn_scales = attn_scales
499
+ self.temperal_downsample = temperal_downsample
500
+ self.temperal_upsample = temperal_downsample[::-1]
501
+
502
+ # modules
503
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
504
+ attn_scales, self.temperal_downsample, dropout)
505
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
506
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
507
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
508
+ attn_scales, self.temperal_upsample, dropout)
509
+
510
+ def forward(self, x):
511
+ mu, log_var = self.encode(x)
512
+ z = self.reparameterize(mu, log_var)
513
+ x_recon = self.decode(z)
514
+ return x_recon, mu, log_var
515
+
516
+ def encode(self, x, scale):
517
+ self.clear_cache()
518
+ ## cache
519
+ t = x.shape[2]
520
+ iter_ = 1 + (t - 1) // 4
521
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
522
+ for i in range(iter_):
523
+ self._enc_conv_idx = [0]
524
+ if i == 0:
525
+ out = self.encoder(
526
+ x[:, :, :1, :, :],
527
+ feat_cache=self._enc_feat_map,
528
+ feat_idx=self._enc_conv_idx)
529
+ else:
530
+ out_ = self.encoder(
531
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
532
+ feat_cache=self._enc_feat_map,
533
+ feat_idx=self._enc_conv_idx)
534
+ out = torch.cat([out, out_], 2)
535
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
536
+ if isinstance(scale[0], torch.Tensor):
537
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
538
+ 1, self.z_dim, 1, 1, 1)
539
+ else:
540
+ mu = (mu - scale[0]) * scale[1]
541
+ self.clear_cache()
542
+ return mu
543
+
544
+ def decode(self, z, scale):
545
+ self.clear_cache()
546
+ # z: [b,c,t,h,w]
547
+ if isinstance(scale[0], torch.Tensor):
548
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
549
+ 1, self.z_dim, 1, 1, 1)
550
+ else:
551
+ z = z / scale[1] + scale[0]
552
+ iter_ = z.shape[2]
553
+ x = self.conv2(z)
554
+ for i in range(iter_):
555
+ self._conv_idx = [0]
556
+ if i == 0:
557
+ out = self.decoder(
558
+ x[:, :, i:i + 1, :, :],
559
+ feat_cache=self._feat_map,
560
+ feat_idx=self._conv_idx)
561
+ else:
562
+ out_ = self.decoder(
563
+ x[:, :, i:i + 1, :, :],
564
+ feat_cache=self._feat_map,
565
+ feat_idx=self._conv_idx)
566
+ out = torch.cat([out, out_], 2)
567
+ self.clear_cache()
568
+ return out
569
+
570
+ def reparameterize(self, mu, log_var):
571
+ std = torch.exp(0.5 * log_var)
572
+ eps = torch.randn_like(std)
573
+ return eps * std + mu
574
+
575
+ def sample(self, imgs, deterministic=False):
576
+ mu, log_var = self.encode(imgs)
577
+ if deterministic:
578
+ return mu
579
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
580
+ return mu + std * torch.randn_like(std)
581
+
582
+ def clear_cache(self):
583
+ self._conv_num = count_conv3d(self.decoder)
584
+ self._conv_idx = [0]
585
+ self._feat_map = [None] * self._conv_num
586
+ #cache encode
587
+ self._enc_conv_num = count_conv3d(self.encoder)
588
+ self._enc_conv_idx = [0]
589
+ self._enc_feat_map = [None] * self._enc_conv_num
590
+
591
+
592
+ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
593
+ """
594
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
595
+ """
596
+ # params
597
+ cfg = dict(
598
+ dim=96,
599
+ z_dim=z_dim,
600
+ dim_mult=[1, 2, 4, 4],
601
+ num_res_blocks=2,
602
+ attn_scales=[],
603
+ temperal_downsample=[False, True, True],
604
+ dropout=0.0)
605
+ cfg.update(**kwargs)
606
+
607
+ # init model
608
+ with torch.device('meta'):
609
+ model = WanVAE_(**cfg)
610
+
611
+ # load checkpoint
612
+ logging.info(f'loading {pretrained_path}')
613
+ model.load_state_dict(
614
+ torch.load(pretrained_path, map_location=device), assign=True)
615
+
616
+ return model
617
+
618
+
619
+ class WanVAE:
620
+
621
+ def __init__(self,
622
+ z_dim=16,
623
+ vae_pth='cache/vae_step_411000.pth',
624
+ dtype=torch.float,
625
+ device="cuda"):
626
+ self.dtype = dtype
627
+ self.device = device
628
+
629
+ mean = [
630
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
631
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
632
+ ]
633
+ std = [
634
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
635
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
636
+ ]
637
+ self.mean = torch.tensor(mean, dtype=dtype, device=device)
638
+ self.std = torch.tensor(std, dtype=dtype, device=device)
639
+ self.scale = [self.mean, 1.0 / self.std]
640
+
641
+ # init model
642
+ self.model = _video_vae(
643
+ pretrained_path=vae_pth,
644
+ z_dim=z_dim,
645
+ ).eval().requires_grad_(False).to(device)
646
+
647
+ def encode(self, videos):
648
+ """
649
+ videos: A list of videos each with shape [C, T, H, W].
650
+ """
651
+ with amp.autocast(dtype=self.dtype):
652
+ return [
653
+ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
654
+ for u in videos
655
+ ]
656
+
657
+ def decode(self, zs):
658
+ with amp.autocast(dtype=self.dtype):
659
+ return [
660
+ self.model.decode(u.unsqueeze(0),
661
+ self.scale).float().clamp_(-1, 1).squeeze(0)
662
+ for u in zs
663
+ ]
wan/modules/xlm_roberta.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ['XLMRoberta', 'xlm_roberta_large']
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+
12
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13
+ assert dim % num_heads == 0
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.num_heads = num_heads
17
+ self.head_dim = dim // num_heads
18
+ self.eps = eps
19
+
20
+ # layers
21
+ self.q = nn.Linear(dim, dim)
22
+ self.k = nn.Linear(dim, dim)
23
+ self.v = nn.Linear(dim, dim)
24
+ self.o = nn.Linear(dim, dim)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x, mask):
28
+ """
29
+ x: [B, L, C].
30
+ """
31
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32
+
33
+ # compute query, key, value
34
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37
+
38
+ # compute attention
39
+ p = self.dropout.p if self.training else 0.0
40
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
41
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42
+
43
+ # output
44
+ x = self.o(x)
45
+ x = self.dropout(x)
46
+ return x
47
+
48
+
49
+ class AttentionBlock(nn.Module):
50
+
51
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.num_heads = num_heads
55
+ self.post_norm = post_norm
56
+ self.eps = eps
57
+
58
+ # layers
59
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
60
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
61
+ self.ffn = nn.Sequential(
62
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63
+ nn.Dropout(dropout))
64
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
65
+
66
+ def forward(self, x, mask):
67
+ if self.post_norm:
68
+ x = self.norm1(x + self.attn(x, mask))
69
+ x = self.norm2(x + self.ffn(x))
70
+ else:
71
+ x = x + self.attn(self.norm1(x), mask)
72
+ x = x + self.ffn(self.norm2(x))
73
+ return x
74
+
75
+
76
+ class XLMRoberta(nn.Module):
77
+ """
78
+ XLMRobertaModel with no pooler and no LM head.
79
+ """
80
+
81
+ def __init__(self,
82
+ vocab_size=250002,
83
+ max_seq_len=514,
84
+ type_size=1,
85
+ pad_id=1,
86
+ dim=1024,
87
+ num_heads=16,
88
+ num_layers=24,
89
+ post_norm=True,
90
+ dropout=0.1,
91
+ eps=1e-5):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.max_seq_len = max_seq_len
95
+ self.type_size = type_size
96
+ self.pad_id = pad_id
97
+ self.dim = dim
98
+ self.num_heads = num_heads
99
+ self.num_layers = num_layers
100
+ self.post_norm = post_norm
101
+ self.eps = eps
102
+
103
+ # embeddings
104
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105
+ self.type_embedding = nn.Embedding(type_size, dim)
106
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ # blocks
110
+ self.blocks = nn.ModuleList([
111
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112
+ for _ in range(num_layers)
113
+ ])
114
+
115
+ # norm layer
116
+ self.norm = nn.LayerNorm(dim, eps=eps)
117
+
118
+ def forward(self, ids):
119
+ """
120
+ ids: [B, L] of torch.LongTensor.
121
+ """
122
+ b, s = ids.shape
123
+ mask = ids.ne(self.pad_id).long()
124
+
125
+ # embeddings
126
+ x = self.token_embedding(ids) + \
127
+ self.type_embedding(torch.zeros_like(ids)) + \
128
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129
+ if self.post_norm:
130
+ x = self.norm(x)
131
+ x = self.dropout(x)
132
+
133
+ # blocks
134
+ mask = torch.where(
135
+ mask.view(b, 1, 1, s).gt(0), 0.0,
136
+ torch.finfo(x.dtype).min)
137
+ for block in self.blocks:
138
+ x = block(x, mask)
139
+
140
+ # output
141
+ if not self.post_norm:
142
+ x = self.norm(x)
143
+ return x
144
+
145
+
146
+ def xlm_roberta_large(pretrained=False,
147
+ return_tokenizer=False,
148
+ device='cpu',
149
+ **kwargs):
150
+ """
151
+ XLMRobertaLarge adapted from Huggingface.
152
+ """
153
+ # params
154
+ cfg = dict(
155
+ vocab_size=250002,
156
+ max_seq_len=514,
157
+ type_size=1,
158
+ pad_id=1,
159
+ dim=1024,
160
+ num_heads=16,
161
+ num_layers=24,
162
+ post_norm=True,
163
+ dropout=0.1,
164
+ eps=1e-5)
165
+ cfg.update(**kwargs)
166
+
167
+ # init a model on device
168
+ with torch.device(device):
169
+ model = XLMRoberta(**cfg)
170
+ return model
wan/text2video.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import torch
13
+ import torch.cuda.amp as amp
14
+ import torch.distributed as dist
15
+ from tqdm import tqdm
16
+
17
+ from .distributed.fsdp import shard_model
18
+ from .modules.model import WanModel
19
+ from .modules.t5 import T5EncoderModel
20
+ from .modules.vae import WanVAE
21
+ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
22
+ get_sampling_sigmas, retrieve_timesteps)
23
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
24
+
25
+
26
+ class WanT2V:
27
+
28
+ def __init__(
29
+ self,
30
+ config,
31
+ checkpoint_dir,
32
+ device_id=0,
33
+ rank=0,
34
+ t5_fsdp=False,
35
+ dit_fsdp=False,
36
+ use_usp=False,
37
+ t5_cpu=False,
38
+ ):
39
+ r"""
40
+ Initializes the Wan text-to-video generation model components.
41
+
42
+ Args:
43
+ config (EasyDict):
44
+ Object containing model parameters initialized from config.py
45
+ checkpoint_dir (`str`):
46
+ Path to directory containing model checkpoints
47
+ device_id (`int`, *optional*, defaults to 0):
48
+ Id of target GPU device
49
+ rank (`int`, *optional*, defaults to 0):
50
+ Process rank for distributed training
51
+ t5_fsdp (`bool`, *optional*, defaults to False):
52
+ Enable FSDP sharding for T5 model
53
+ dit_fsdp (`bool`, *optional*, defaults to False):
54
+ Enable FSDP sharding for DiT model
55
+ use_usp (`bool`, *optional*, defaults to False):
56
+ Enable distribution strategy of USP.
57
+ t5_cpu (`bool`, *optional*, defaults to False):
58
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
59
+ """
60
+ self.device = torch.device(f"cuda:{device_id}")
61
+ self.config = config
62
+ self.rank = rank
63
+ self.t5_cpu = t5_cpu
64
+
65
+ self.num_train_timesteps = config.num_train_timesteps
66
+ self.param_dtype = config.param_dtype
67
+
68
+ shard_fn = partial(shard_model, device_id=device_id)
69
+ self.text_encoder = T5EncoderModel(
70
+ text_len=config.text_len,
71
+ dtype=config.t5_dtype,
72
+ device=torch.device('cpu'),
73
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
74
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
75
+ shard_fn=shard_fn if t5_fsdp else None)
76
+
77
+ self.vae_stride = config.vae_stride
78
+ self.patch_size = config.patch_size
79
+ self.vae = WanVAE(
80
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
81
+ device=self.device)
82
+
83
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
84
+ self.model = WanModel.from_pretrained(checkpoint_dir)
85
+ self.model.eval().requires_grad_(False)
86
+
87
+ if use_usp:
88
+ from xfuser.core.distributed import \
89
+ get_sequence_parallel_world_size
90
+
91
+ from .distributed.xdit_context_parallel import (usp_attn_forward,
92
+ usp_dit_forward)
93
+ for block in self.model.blocks:
94
+ block.self_attn.forward = types.MethodType(
95
+ usp_attn_forward, block.self_attn)
96
+ self.model.forward = types.MethodType(usp_dit_forward, self.model)
97
+ self.sp_size = get_sequence_parallel_world_size()
98
+ else:
99
+ self.sp_size = 1
100
+
101
+ if dist.is_initialized():
102
+ dist.barrier()
103
+ if dit_fsdp:
104
+ self.model = shard_fn(self.model)
105
+ else:
106
+ self.model.to(self.device)
107
+
108
+ self.sample_neg_prompt = config.sample_neg_prompt
109
+
110
+ def generate(self,
111
+ input_prompt,
112
+ size=(1280, 720),
113
+ frame_num=81,
114
+ shift=5.0,
115
+ sample_solver='unipc',
116
+ sampling_steps=50,
117
+ guide_scale=5.0,
118
+ n_prompt="",
119
+ seed=-1,
120
+ offload_model=True):
121
+ r"""
122
+ Generates video frames from text prompt using diffusion process.
123
+
124
+ Args:
125
+ input_prompt (`str`):
126
+ Text prompt for content generation
127
+ size (tupele[`int`], *optional*, defaults to (1280,720)):
128
+ Controls video resolution, (width,height).
129
+ frame_num (`int`, *optional*, defaults to 81):
130
+ How many frames to sample from a video. The number should be 4n+1
131
+ shift (`float`, *optional*, defaults to 5.0):
132
+ Noise schedule shift parameter. Affects temporal dynamics
133
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
134
+ Solver used to sample the video.
135
+ sampling_steps (`int`, *optional*, defaults to 40):
136
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
137
+ guide_scale (`float`, *optional*, defaults 5.0):
138
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity
139
+ n_prompt (`str`, *optional*, defaults to ""):
140
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
141
+ seed (`int`, *optional*, defaults to -1):
142
+ Random seed for noise generation. If -1, use random seed.
143
+ offload_model (`bool`, *optional*, defaults to True):
144
+ If True, offloads models to CPU during generation to save VRAM
145
+
146
+ Returns:
147
+ torch.Tensor:
148
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
149
+ - C: Color channels (3 for RGB)
150
+ - N: Number of frames (81)
151
+ - H: Frame height (from size)
152
+ - W: Frame width from size)
153
+ """
154
+ # preprocess
155
+ F = frame_num
156
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
157
+ size[1] // self.vae_stride[1],
158
+ size[0] // self.vae_stride[2])
159
+
160
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
161
+ (self.patch_size[1] * self.patch_size[2]) *
162
+ target_shape[1] / self.sp_size) * self.sp_size
163
+
164
+ if n_prompt == "":
165
+ n_prompt = self.sample_neg_prompt
166
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
167
+ seed_g = torch.Generator(device=self.device)
168
+ seed_g.manual_seed(seed)
169
+
170
+ if not self.t5_cpu:
171
+ self.text_encoder.model.to(self.device)
172
+ context = self.text_encoder([input_prompt], self.device)
173
+ context_null = self.text_encoder([n_prompt], self.device)
174
+ if offload_model:
175
+ self.text_encoder.model.cpu()
176
+ else:
177
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
178
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
179
+ context = [t.to(self.device) for t in context]
180
+ context_null = [t.to(self.device) for t in context_null]
181
+
182
+ noise = [
183
+ torch.randn(
184
+ target_shape[0],
185
+ target_shape[1],
186
+ target_shape[2],
187
+ target_shape[3],
188
+ dtype=torch.float32,
189
+ device=self.device,
190
+ generator=seed_g)
191
+ ]
192
+
193
+ @contextmanager
194
+ def noop_no_sync():
195
+ yield
196
+
197
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
198
+
199
+ # evaluation mode
200
+ with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
201
+
202
+ if sample_solver == 'unipc':
203
+ sample_scheduler = FlowUniPCMultistepScheduler(
204
+ num_train_timesteps=self.num_train_timesteps,
205
+ shift=1,
206
+ use_dynamic_shifting=False)
207
+ sample_scheduler.set_timesteps(
208
+ sampling_steps, device=self.device, shift=shift)
209
+ timesteps = sample_scheduler.timesteps
210
+ elif sample_solver == 'dpm++':
211
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
212
+ num_train_timesteps=self.num_train_timesteps,
213
+ shift=1,
214
+ use_dynamic_shifting=False)
215
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
216
+ timesteps, _ = retrieve_timesteps(
217
+ sample_scheduler,
218
+ device=self.device,
219
+ sigmas=sampling_sigmas)
220
+ else:
221
+ raise NotImplementedError("Unsupported solver.")
222
+
223
+ # sample videos
224
+ latents = noise
225
+
226
+ arg_c = {'context': context, 'seq_len': seq_len}
227
+ arg_null = {'context': context_null, 'seq_len': seq_len}
228
+
229
+ for _, t in enumerate(tqdm(timesteps)):
230
+ latent_model_input = latents
231
+ timestep = [t]
232
+
233
+ timestep = torch.stack(timestep)
234
+
235
+ self.model.to(self.device)
236
+ noise_pred_cond = self.model(
237
+ latent_model_input, t=timestep, **arg_c)[0]
238
+ noise_pred_uncond = self.model(
239
+ latent_model_input, t=timestep, **arg_null)[0]
240
+
241
+ noise_pred = noise_pred_uncond + guide_scale * (
242
+ noise_pred_cond - noise_pred_uncond)
243
+
244
+ temp_x0 = sample_scheduler.step(
245
+ noise_pred.unsqueeze(0),
246
+ t,
247
+ latents[0].unsqueeze(0),
248
+ return_dict=False,
249
+ generator=seed_g)[0]
250
+ latents = [temp_x0.squeeze(0)]
251
+
252
+ x0 = latents
253
+ if offload_model:
254
+ self.model.cpu()
255
+ torch.cuda.empty_cache()
256
+ if self.rank == 0:
257
+ videos = self.vae.decode(x0)
258
+
259
+ del noise, latents
260
+ del sample_scheduler
261
+ if offload_model:
262
+ gc.collect()
263
+ torch.cuda.synchronize()
264
+ if dist.is_initialized():
265
+ dist.barrier()
266
+
267
+ return videos[0] if self.rank == 0 else None
wan/utils/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
2
+ retrieve_timesteps)
3
+ from .fm_solvers_unipc import FlowUniPCMultistepScheduler
4
+ from .vace_processor import VaceVideoProcessor
5
+
6
+ __all__ = [
7
+ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
8
+ 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
9
+ 'VaceVideoProcessor'
10
+ ]