svjack commited on
Commit
c43bea2
·
verified ·
1 Parent(s): 99df6e6

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +5 -0
  2. .ipynb_checkpoints/README-checkpoint.md +147 -0
  3. .python-version +1 -0
  4. Gundam_outputs/Gundam_w1_3_lora-000001.safetensors +3 -0
  5. Gundam_outputs/Gundam_w1_3_lora-000002.safetensors +3 -0
  6. Gundam_outputs/Gundam_w1_3_lora-000003.safetensors +3 -0
  7. Gundam_outputs/Gundam_w1_3_lora-000004.safetensors +3 -0
  8. Gundam_outputs/Gundam_w1_3_lora-000005.safetensors +3 -0
  9. Gundam_outputs/Gundam_w1_3_lora-000006.safetensors +3 -0
  10. Gundam_outputs/Gundam_w1_3_lora-000007.safetensors +3 -0
  11. Gundam_outputs/Gundam_w1_3_lora-000008.safetensors +3 -0
  12. Gundam_outputs/Gundam_w1_3_lora-000009.safetensors +3 -0
  13. Gundam_outputs/Gundam_w1_3_lora-000010.safetensors +3 -0
  14. Gundam_outputs/Gundam_w1_3_lora-000011.safetensors +3 -0
  15. Gundam_outputs/Gundam_w1_3_lora-000012.safetensors +3 -0
  16. Gundam_outputs/Gundam_w1_3_lora-000013.safetensors +3 -0
  17. Gundam_outputs/Gundam_w1_3_lora-000014.safetensors +3 -0
  18. Gundam_outputs/Gundam_w1_3_lora-000015.safetensors +3 -0
  19. Gundam_outputs/Gundam_w1_3_lora-000016.safetensors +3 -0
  20. Gundam_outputs/Gundam_w1_3_lora-000017.safetensors +3 -0
  21. Gundam_outputs/Gundam_w1_3_lora-000018.safetensors +3 -0
  22. Gundam_outputs/Gundam_w1_3_lora-000019.safetensors +3 -0
  23. Gundam_outputs/Gundam_w1_3_lora-000020.safetensors +3 -0
  24. Gundam_outputs/Gundam_w1_3_lora-000021.safetensors +3 -0
  25. Gundam_outputs/Gundam_w1_3_lora-000022.safetensors +3 -0
  26. Gundam_outputs/Gundam_w1_3_lora-000023.safetensors +3 -0
  27. Gundam_outputs/Gundam_w1_3_lora-000024.safetensors +3 -0
  28. Gundam_outputs/Gundam_w1_3_lora-000025.safetensors +3 -0
  29. Gundam_outputs/Gundam_w1_3_lora-000026.safetensors +3 -0
  30. Gundam_outputs/Gundam_w1_3_lora-000027.safetensors +3 -0
  31. README.md +147 -0
  32. cache_latents.py +281 -0
  33. cache_text_encoder_outputs.py +214 -0
  34. convert_lora.py +135 -0
  35. dataset/__init__.py +0 -0
  36. dataset/config_utils.py +372 -0
  37. dataset/dataset_config.md +387 -0
  38. dataset/image_video_dataset.py +1400 -0
  39. docs/advanced_config.md +151 -0
  40. docs/sampling_during_training.md +108 -0
  41. docs/wan.md +241 -0
  42. hunyuan_model/__init__.py +0 -0
  43. hunyuan_model/activation_layers.py +23 -0
  44. hunyuan_model/attention.py +295 -0
  45. hunyuan_model/autoencoder_kl_causal_3d.py +609 -0
  46. hunyuan_model/embed_layers.py +132 -0
  47. hunyuan_model/helpers.py +40 -0
  48. hunyuan_model/mlp_layers.py +118 -0
  49. hunyuan_model/models.py +1044 -0
  50. hunyuan_model/modulate_layers.py +76 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ .venv
3
+ venv/
4
+ logs/
5
+ uv.lock
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gundam Text-to-Video Generation
2
+
3
+ This repository contains the necessary steps and scripts to generate videos using the Gundam text-to-video model. The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create high-quality anime-style videos based on textual prompts.
4
+
5
+ ## Prerequisites
6
+
7
+ Before proceeding, ensure that you have the following installed on your system:
8
+
9
+ • **Ubuntu** (or a compatible Linux distribution)
10
+ • **Python 3.x**
11
+ • **pip** (Python package manager)
12
+ • **Git**
13
+ • **Git LFS** (Git Large File Storage)
14
+ • **FFmpeg**
15
+
16
+ ## Installation
17
+
18
+ 1. **Update and Install Dependencies**
19
+
20
+ ```bash
21
+ sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
22
+ ```
23
+
24
+ 2. **Clone the Repository**
25
+
26
+ ```bash
27
+ git clone https://huggingface.co/svjack/Gundam_wan_2_1_1_3_B_text2video_lora
28
+ cd Gundam_wan_2_1_1_3_B_text2video_lora
29
+ ```
30
+
31
+ 3. **Install Python Dependencies**
32
+
33
+ ```bash
34
+ pip install torch torchvision
35
+ pip install -r requirements.txt
36
+ pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
37
+ pip install moviepy==1.0.3
38
+ pip install sageattention==1.0.6
39
+ ```
40
+
41
+ 4. **Download Model Weights**
42
+
43
+ ```bash
44
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
45
+ wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
46
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
47
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
48
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
49
+ ```
50
+
51
+ ## Usage
52
+
53
+ To generate a video, use the `wan_generate_video.py` script with the appropriate parameters. Below are examples of how to generate videos using the Gundam model.
54
+
55
+ #### Gundam Moon Background
56
+
57
+ ```bash
58
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
59
+ --save_path save --output_type both \
60
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
61
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
62
+ --attn_mode torch \
63
+ --lora_weight Gundam_outputs/Gundam_w1_3_lora-000027.safetensors \
64
+ --lora_multiplier 1.0 \
65
+ --prompt "In the style of Gundam , The video features a large, black robot with a predominantly humanoid form. walk one step at a time forward, with the moon as the background"
66
+
67
+ ```
68
+
69
+
70
+
71
+ #### Pink Butterfly Gundam
72
+
73
+ ```bash
74
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
75
+ --save_path save --output_type both \
76
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
77
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
78
+ --attn_mode torch \
79
+ --lora_weight Gundam_outputs/Gundam_w1_3_lora-000027.safetensors \
80
+ --lora_multiplier 1.0 \
81
+ --prompt "In the style of Gundam , The video features a large, pink robot with a butterfly form. suggesting smooth and cute. The background is plain and light-colored, ensuring the focus remains on the robot."
82
+
83
+ ```
84
+
85
+
86
+ #### Yellow Gundam
87
+
88
+ ```bash
89
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
90
+ --save_path save --output_type both \
91
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
92
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
93
+ --attn_mode torch \
94
+ --lora_weight Gundam_outputs/Gundam_w1_3_lora-000025.safetensors \
95
+ --lora_multiplier 1.0 \
96
+ --prompt "In the style of Gundam , The video features a large, yellow robot with two slender arms suggesting strength and speed. The background is plain and light-colored, ensuring the focus remains on the robot."
97
+
98
+
99
+ ```
100
+
101
+
102
+
103
+ ## Parameters
104
+
105
+ * `--fp8`: Enable FP8 precision (optional).
106
+ * `--task`: Specify the task (e.g., `t2v-1.3B`).
107
+ * `--video_size`: Set the resolution of the generated video (e.g., `1024 1024`).
108
+ * `--video_length`: Define the length of the video in frames.
109
+ * `--infer_steps`: Number of inference steps.
110
+ * `--save_path`: Directory to save the generated video.
111
+ * `--output_type`: Output type (e.g., `both` for video and frames).
112
+ * `--dit`: Path to the diffusion model weights.
113
+ * `--vae`: Path to the VAE model weights.
114
+ * `--t5`: Path to the T5 model weights.
115
+ * `--attn_mode`: Attention mode (e.g., `torch`).
116
+ * `--lora_weight`: Path to the LoRA weights.
117
+ * `--lora_multiplier`: Multiplier for LoRA weights.
118
+ * `--prompt`: Textual prompt for video generation.
119
+
120
+
121
+
122
+ ## Output
123
+
124
+ The generated video and frames will be saved in the specified `save_path` directory.
125
+
126
+ ## Troubleshooting
127
+
128
+ • Ensure all dependencies are correctly installed.
129
+ • Verify that the model weights are downloaded and placed in the correct locations.
130
+ • Check for any missing Python packages and install them using `pip`.
131
+
132
+ ## License
133
+
134
+ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
135
+
136
+ ## Acknowledgments
137
+
138
+ • **Hugging Face** for hosting the model weights.
139
+ • **Wan-AI** for providing the pre-trained models.
140
+ • **DeepBeepMeep** for contributing to the model weights.
141
+
142
+ ## Contact
143
+
144
+ For any questions or issues, please open an issue on the repository or contact the maintainer.
145
+
146
+ ---
147
+
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
Gundam_outputs/Gundam_w1_3_lora-000001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bc379f7399e80d4ff4c1ea763b1abe8d06a8b368ebbf058151a88fabd50d29d
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd616ff62de984e314cb00e2813fed91680f330dc3f29cd72c7dc0895b48d084
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10b9fd3544961a69e630ff3a9e3e42c6ea55c9d68c9b0c3d0bbc0aed64218b61
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83dd9043e893384c048eaaa4ce614ed20fd8c74908e767aa94df9b2f6148e32b
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b3c9e2a14d9587b7b7df26e839afcbdbc51df03c5744db181c36ed4441f7300
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e98b6d3d424de07a26a0cf4e3ddf8c9bea53f77dc2f0a687d5ea8dee528eb5bd
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90b40c48bef2fed6cb15225da1c2852fde299b5dfea1beed9868a0c1dc07569d
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc4a71c6cb63d54c9b03a9b015b3b23fef4887fb10425988cc536fbd6967230e
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000009.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ae9e2d10ff2822e93daa67fe7db4cdcba7747744c3c97a5feee3fa55dc8b9c1
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000010.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dba6602ad083af939ba6133e9cb3c6bdb5a9332a79f6fa27d4a291600730bc9a
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000011.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ce0c3acb45bb190331dae7b90b1fbc90b0aa5dbea2a44a489ae8e352937d23a
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000012.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e71353404f12ce7487ba4b392198afbabe69f5eaba2fee5d1b5a10ec3bee22eb
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000013.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5178a67e3f6865d4d8a70755da4c40bf82efff8c668967f15f80290f0a343db0
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000014.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e673899a7bd91058d92add206f39ffb497b961a7f2414fc121bed93ee68b230
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000015.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44f25284fa970b0a247412ccc93c627f9cb6027bfb70fafc94510c15d2041f78
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000016.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1b03ab55fc6fe90e7251784f293344c59f3a06dc2891bf964d4d9ab7d6a194a
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000017.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89902b3ec65576f80c8c3eefd6c69e4e485f2a78d1a040ecd7b15d60e91eafa8
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000018.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbd0c4f04b15dad756e85a4624849b13402e8f6bb1889f144928c9bf23cb0b8e
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000019.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f47dd80d1ecb4b98c1b43f05fb34b2327cd015127446d53b65770a66c66109b
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000020.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50efde2fd8b78d0d83ae24c059834464a38a2b5d695f70e1cec41093c1e1959f
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a580ac6b1c8e8c5ab8a028ce26e262f2c6af55e46639cd6c59088a6f9e984619
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000022.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a38e2e7d6ab1272da8dc8911df28e1a49dbbd602490828b7dfeacda135e0d96c
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000023.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c35d3c93ba0cf30b62dcd3956702dd462be9af09727fbe3316ae4b4b5ded2642
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000024.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a203c9a6106795e6895d16dc11cca70f62e094b82b09f9bda3e0d2129da0efe
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000025.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee416f97d83fa66a95d64aa0995b13ae0bfd640a6591baf17b6f2dbb18624da9
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000026.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d546a7a0714094d1620fb11d5a6e3bbc6d42456f57c0778e710ac77c12e51b9b
3
+ size 87594680
Gundam_outputs/Gundam_w1_3_lora-000027.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ceaa7bd1f4ac61bcbf3a36163524d2ff6b9854e2d095940c98f79d73b4bcf778
3
+ size 87594680
README.md ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gundam Text-to-Video Generation
2
+
3
+ This repository contains the necessary steps and scripts to generate videos using the Gundam text-to-video model. The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create high-quality anime-style videos based on textual prompts.
4
+
5
+ ## Prerequisites
6
+
7
+ Before proceeding, ensure that you have the following installed on your system:
8
+
9
+ • **Ubuntu** (or a compatible Linux distribution)
10
+ • **Python 3.x**
11
+ • **pip** (Python package manager)
12
+ • **Git**
13
+ • **Git LFS** (Git Large File Storage)
14
+ • **FFmpeg**
15
+
16
+ ## Installation
17
+
18
+ 1. **Update and Install Dependencies**
19
+
20
+ ```bash
21
+ sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
22
+ ```
23
+
24
+ 2. **Clone the Repository**
25
+
26
+ ```bash
27
+ git clone https://huggingface.co/svjack/Gundam_wan_2_1_1_3_B_text2video_lora
28
+ cd Gundam_wan_2_1_1_3_B_text2video_lora
29
+ ```
30
+
31
+ 3. **Install Python Dependencies**
32
+
33
+ ```bash
34
+ pip install torch torchvision
35
+ pip install -r requirements.txt
36
+ pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
37
+ pip install moviepy==1.0.3
38
+ pip install sageattention==1.0.6
39
+ ```
40
+
41
+ 4. **Download Model Weights**
42
+
43
+ ```bash
44
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
45
+ wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
46
+ wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
47
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
48
+ wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
49
+ ```
50
+
51
+ ## Usage
52
+
53
+ To generate a video, use the `wan_generate_video.py` script with the appropriate parameters. Below are examples of how to generate videos using the Gundam model.
54
+
55
+ #### Gundam Moon Background
56
+
57
+ ```bash
58
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
59
+ --save_path save --output_type both \
60
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
61
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
62
+ --attn_mode torch \
63
+ --lora_weight Gundam_outputs/Gundam_w1_3_lora-000027.safetensors \
64
+ --lora_multiplier 1.0 \
65
+ --prompt "In the style of Gundam , The video features a large, black robot with a predominantly humanoid form. walk one step at a time forward, with the moon as the background"
66
+
67
+ ```
68
+
69
+
70
+
71
+ #### Pink Butterfly Gundam
72
+
73
+ ```bash
74
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
75
+ --save_path save --output_type both \
76
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
77
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
78
+ --attn_mode torch \
79
+ --lora_weight Gundam_outputs/Gundam_w1_3_lora-000027.safetensors \
80
+ --lora_multiplier 1.0 \
81
+ --prompt "In the style of Gundam , The video features a large, pink robot with a butterfly form. suggesting smooth and cute. The background is plain and light-colored, ensuring the focus remains on the robot."
82
+
83
+ ```
84
+
85
+
86
+ #### Yellow Gundam
87
+
88
+ ```bash
89
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
90
+ --save_path save --output_type both \
91
+ --dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
92
+ --t5 models_t5_umt5-xxl-enc-bf16.pth \
93
+ --attn_mode torch \
94
+ --lora_weight Gundam_outputs/Gundam_w1_3_lora-000025.safetensors \
95
+ --lora_multiplier 1.0 \
96
+ --prompt "In the style of Gundam , The video features a large, yellow robot with two slender arms suggesting strength and speed. The background is plain and light-colored, ensuring the focus remains on the robot."
97
+
98
+
99
+ ```
100
+
101
+
102
+
103
+ ## Parameters
104
+
105
+ * `--fp8`: Enable FP8 precision (optional).
106
+ * `--task`: Specify the task (e.g., `t2v-1.3B`).
107
+ * `--video_size`: Set the resolution of the generated video (e.g., `1024 1024`).
108
+ * `--video_length`: Define the length of the video in frames.
109
+ * `--infer_steps`: Number of inference steps.
110
+ * `--save_path`: Directory to save the generated video.
111
+ * `--output_type`: Output type (e.g., `both` for video and frames).
112
+ * `--dit`: Path to the diffusion model weights.
113
+ * `--vae`: Path to the VAE model weights.
114
+ * `--t5`: Path to the T5 model weights.
115
+ * `--attn_mode`: Attention mode (e.g., `torch`).
116
+ * `--lora_weight`: Path to the LoRA weights.
117
+ * `--lora_multiplier`: Multiplier for LoRA weights.
118
+ * `--prompt`: Textual prompt for video generation.
119
+
120
+
121
+
122
+ ## Output
123
+
124
+ The generated video and frames will be saved in the specified `save_path` directory.
125
+
126
+ ## Troubleshooting
127
+
128
+ • Ensure all dependencies are correctly installed.
129
+ • Verify that the model weights are downloaded and placed in the correct locations.
130
+ • Check for any missing Python packages and install them using `pip`.
131
+
132
+ ## License
133
+
134
+ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
135
+
136
+ ## Acknowledgments
137
+
138
+ • **Hugging Face** for hosting the model weights.
139
+ • **Wan-AI** for providing the pre-trained models.
140
+ • **DeepBeepMeep** for contributing to the model weights.
141
+
142
+ ## Contact
143
+
144
+ For any questions or issues, please open an issue on the repository or contact the maintainer.
145
+
146
+ ---
147
+
cache_latents.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import glob
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from dataset import config_utils
11
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
12
+ from PIL import Image
13
+
14
+ import logging
15
+
16
+ from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache, ARCHITECTURE_HUNYUAN_VIDEO
17
+ from hunyuan_model.vae import load_vae
18
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
19
+ from utils.model_utils import str_to_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+
25
+ def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
26
+ import cv2
27
+
28
+ imgs = (
29
+ [image]
30
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
31
+ else [image[0], image[-1]]
32
+ )
33
+ if len(imgs) > 1:
34
+ print(f"Number of images: {len(image)}")
35
+ for i, img in enumerate(imgs):
36
+ if len(imgs) > 1:
37
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
38
+ else:
39
+ print(f"Image: {img.shape}")
40
+ cv2_img = np.array(img) if isinstance(img, Image.Image) else img
41
+ cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
42
+ cv2.imshow("image", cv2_img)
43
+ k = cv2.waitKey(0)
44
+ cv2.destroyAllWindows()
45
+ if k == ord("q") or k == ord("d"):
46
+ return k
47
+ return k
48
+
49
+
50
+ def show_console(
51
+ image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
52
+ width: int,
53
+ back: str,
54
+ interactive: bool = False,
55
+ ) -> int:
56
+ from ascii_magic import from_pillow_image, Back
57
+
58
+ back = None
59
+ if back is not None:
60
+ back = getattr(Back, back.upper())
61
+
62
+ k = None
63
+ imgs = (
64
+ [image]
65
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
66
+ else [image[0], image[-1]]
67
+ )
68
+ if len(imgs) > 1:
69
+ print(f"Number of images: {len(image)}")
70
+ for i, img in enumerate(imgs):
71
+ if len(imgs) > 1:
72
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
73
+ else:
74
+ print(f"Image: {img.shape}")
75
+ pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
76
+ ascii_img = from_pillow_image(pil_img)
77
+ ascii_img.to_terminal(columns=width, back=back)
78
+
79
+ if interactive:
80
+ k = input("Press q to quit, d to next dataset, other key to next: ")
81
+ if k == "q" or k == "d":
82
+ return ord(k)
83
+
84
+ if not interactive:
85
+ return ord(" ")
86
+ return ord(k) if k else ord(" ")
87
+
88
+
89
+ def show_datasets(
90
+ datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
91
+ ):
92
+ print(f"d: next dataset, q: quit")
93
+
94
+ num_workers = max(1, os.cpu_count() - 1)
95
+ for i, dataset in enumerate(datasets):
96
+ print(f"Dataset [{i}]")
97
+ batch_index = 0
98
+ num_images_to_show = console_num_images
99
+ k = None
100
+ for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
101
+ print(f"bucket resolution: {key}, count: {len(batch)}")
102
+ for j, item_info in enumerate(batch):
103
+ item_info: ItemInfo
104
+ print(f"{batch_index}-{j}: {item_info}")
105
+ if debug_mode == "image":
106
+ k = show_image(item_info.content)
107
+ elif debug_mode == "console":
108
+ k = show_console(item_info.content, console_width, console_back, console_num_images is None)
109
+ if num_images_to_show is not None:
110
+ num_images_to_show -= 1
111
+ if num_images_to_show == 0:
112
+ k = ord("d") # next dataset
113
+
114
+ if k == ord("q"):
115
+ return
116
+ elif k == ord("d"):
117
+ break
118
+ if k == ord("d"):
119
+ break
120
+ batch_index += 1
121
+
122
+
123
+ def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
124
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
125
+ if len(contents.shape) == 4:
126
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
127
+
128
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
129
+ contents = contents.to(vae.device, dtype=vae.dtype)
130
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
131
+
132
+ h, w = contents.shape[3], contents.shape[4]
133
+ if h < 8 or w < 8:
134
+ item = batch[0] # other items should have the same size
135
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
136
+
137
+ # print(f"encode batch: {contents.shape}")
138
+ with torch.no_grad():
139
+ latent = vae.encode(contents).latent_dist.sample()
140
+ # latent = latent * vae.config.scaling_factor
141
+
142
+ # # debug: decode and save
143
+ # with torch.no_grad():
144
+ # latent_to_decode = latent / vae.config.scaling_factor
145
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
146
+ # images = (images / 2 + 0.5).clamp(0, 1)
147
+ # images = images.cpu().float().numpy()
148
+ # images = (images * 255).astype(np.uint8)
149
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
150
+ # for b in range(images.shape[0]):
151
+ # for f in range(images.shape[1]):
152
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
153
+ # img = Image.fromarray(images[b, f])
154
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
155
+
156
+ for item, l in zip(batch, latent):
157
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
158
+ save_latent_cache(item, l)
159
+
160
+
161
+ def encode_datasets(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
162
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
163
+ for i, dataset in enumerate(datasets):
164
+ logger.info(f"Encoding dataset [{i}]")
165
+ all_latent_cache_paths = []
166
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
167
+ all_latent_cache_paths.extend([item.latent_cache_path for item in batch])
168
+
169
+ if args.skip_existing:
170
+ filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
171
+ if len(filtered_batch) == 0:
172
+ continue
173
+ batch = filtered_batch
174
+
175
+ bs = args.batch_size if args.batch_size is not None else len(batch)
176
+ for i in range(0, len(batch), bs):
177
+ encode(batch[i : i + bs])
178
+
179
+ # normalize paths
180
+ all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
181
+ all_latent_cache_paths = set(all_latent_cache_paths)
182
+
183
+ # remove old cache files not in the dataset
184
+ all_cache_files = dataset.get_all_latent_cache_files()
185
+ for cache_file in all_cache_files:
186
+ if os.path.normpath(cache_file) not in all_latent_cache_paths:
187
+ if args.keep_cache:
188
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
189
+ else:
190
+ os.remove(cache_file)
191
+ logger.info(f"Removed old cache file: {cache_file}")
192
+
193
+
194
+ def main(args):
195
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
196
+ device = torch.device(device)
197
+
198
+ # Load dataset config
199
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
200
+ logger.info(f"Load dataset config from {args.dataset_config}")
201
+ user_config = config_utils.load_user_config(args.dataset_config)
202
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
203
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
204
+
205
+ datasets = train_dataset_group.datasets
206
+
207
+ if args.debug_mode is not None:
208
+ show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
209
+ return
210
+
211
+ assert args.vae is not None, "vae checkpoint is required"
212
+
213
+ # Load VAE model: HunyuanVideo VAE model is float16
214
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
215
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
216
+ vae.eval()
217
+ logger.info(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
218
+
219
+ if args.vae_chunk_size is not None:
220
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
221
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
222
+ if args.vae_spatial_tile_sample_min_size is not None:
223
+ vae.enable_spatial_tiling(True)
224
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
225
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
226
+ elif args.vae_tiling:
227
+ vae.enable_spatial_tiling(True)
228
+
229
+ # Encode images
230
+ def encode(one_batch: list[ItemInfo]):
231
+ encode_and_save_batch(vae, one_batch)
232
+
233
+ encode_datasets(datasets, encode, args)
234
+
235
+
236
+ def setup_parser_common() -> argparse.ArgumentParser:
237
+ parser = argparse.ArgumentParser()
238
+
239
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
240
+ parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
241
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
242
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
243
+ parser.add_argument(
244
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
245
+ )
246
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
247
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
248
+ parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
249
+ parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode")
250
+ parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
251
+ parser.add_argument(
252
+ "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
253
+ )
254
+ parser.add_argument(
255
+ "--console_num_images",
256
+ type=int,
257
+ default=None,
258
+ help="debug mode: not interactive, number of images to show for each dataset",
259
+ )
260
+ return parser
261
+
262
+
263
+ def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
264
+ parser.add_argument(
265
+ "--vae_tiling",
266
+ action="store_true",
267
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
268
+ )
269
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
270
+ parser.add_argument(
271
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
272
+ )
273
+ return parser
274
+
275
+
276
+ if __name__ == "__main__":
277
+ parser = setup_parser_common()
278
+ parser = hv_setup_parser(parser)
279
+
280
+ args = parser.parse_args()
281
+ main(args)
cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ import accelerate
12
+
13
+ from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, BaseDataset, ItemInfo, save_text_encoder_output_cache
14
+ from hunyuan_model import text_encoder as text_encoder_module
15
+ from hunyuan_model.text_encoder import TextEncoder
16
+
17
+ import logging
18
+
19
+ from utils.model_utils import str_to_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+
25
+ def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
26
+ data_type = "video" # video only, image is not supported
27
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
28
+
29
+ with torch.no_grad():
30
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
31
+
32
+ return prompt_outputs.hidden_state, prompt_outputs.attention_mask
33
+
34
+
35
+ def encode_and_save_batch(
36
+ text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
37
+ ):
38
+ prompts = [item.caption for item in batch]
39
+ # print(prompts)
40
+
41
+ # encode prompt
42
+ if accelerator is not None:
43
+ with accelerator.autocast():
44
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
45
+ else:
46
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
47
+
48
+ # # convert to fp16 if needed
49
+ # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
50
+ # prompt_embeds = prompt_embeds.to(text_encoder.dtype)
51
+
52
+ # save prompt cache
53
+ for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
54
+ save_text_encoder_output_cache(item, embed, mask, is_llm)
55
+
56
+
57
+ def prepare_cache_files_and_paths(datasets: list[BaseDataset]):
58
+ all_cache_files_for_dataset = [] # exisiting cache files
59
+ all_cache_paths_for_dataset = [] # all cache paths in the dataset
60
+ for dataset in datasets:
61
+ all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()]
62
+ all_cache_files = set(all_cache_files)
63
+ all_cache_files_for_dataset.append(all_cache_files)
64
+
65
+ all_cache_paths_for_dataset.append(set())
66
+ return all_cache_files_for_dataset, all_cache_paths_for_dataset
67
+
68
+
69
+ def process_text_encoder_batches(
70
+ num_workers: Optional[int],
71
+ skip_existing: bool,
72
+ batch_size: int,
73
+ datasets: list[BaseDataset],
74
+ all_cache_files_for_dataset: list[set],
75
+ all_cache_paths_for_dataset: list[set],
76
+ encode: callable,
77
+ ):
78
+ num_workers = num_workers if num_workers is not None else max(1, os.cpu_count() - 1)
79
+ for i, dataset in enumerate(datasets):
80
+ logger.info(f"Encoding dataset [{i}]")
81
+ all_cache_files = all_cache_files_for_dataset[i]
82
+ all_cache_paths = all_cache_paths_for_dataset[i]
83
+ for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
84
+ # update cache files (it's ok if we update it multiple times)
85
+ all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch])
86
+
87
+ # skip existing cache files
88
+ if skip_existing:
89
+ filtered_batch = [
90
+ item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files
91
+ ]
92
+ # print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files")
93
+ if len(filtered_batch) == 0:
94
+ continue
95
+ batch = filtered_batch
96
+
97
+ bs = batch_size if batch_size is not None else len(batch)
98
+ for i in range(0, len(batch), bs):
99
+ encode(batch[i : i + bs])
100
+
101
+
102
+ def post_process_cache_files(
103
+ datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set]
104
+ ):
105
+ for i, dataset in enumerate(datasets):
106
+ all_cache_files = all_cache_files_for_dataset[i]
107
+ all_cache_paths = all_cache_paths_for_dataset[i]
108
+ for cache_file in all_cache_files:
109
+ if cache_file not in all_cache_paths:
110
+ if args.keep_cache:
111
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
112
+ else:
113
+ os.remove(cache_file)
114
+ logger.info(f"Removed old cache file: {cache_file}")
115
+
116
+
117
+ def main(args):
118
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
119
+ device = torch.device(device)
120
+
121
+ # Load dataset config
122
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
123
+ logger.info(f"Load dataset config from {args.dataset_config}")
124
+ user_config = config_utils.load_user_config(args.dataset_config)
125
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
126
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
127
+
128
+ datasets = train_dataset_group.datasets
129
+
130
+ # define accelerator for fp8 inference
131
+ accelerator = None
132
+ if args.fp8_llm:
133
+ accelerator = accelerate.Accelerator(mixed_precision="fp16")
134
+
135
+ # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
136
+ all_cache_files_for_dataset, all_cache_paths_for_dataset = prepare_cache_files_and_paths(datasets)
137
+
138
+ # Load Text Encoder 1
139
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
140
+ logger.info(f"loading text encoder 1: {args.text_encoder1}")
141
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
142
+ text_encoder_1.to(device=device)
143
+
144
+ # Encode with Text Encoder 1 (LLM)
145
+ logger.info("Encoding with Text Encoder 1")
146
+
147
+ def encode_for_text_encoder_1(batch: list[ItemInfo]):
148
+ encode_and_save_batch(text_encoder_1, batch, is_llm=True, accelerator=accelerator)
149
+
150
+ process_text_encoder_batches(
151
+ args.num_workers,
152
+ args.skip_existing,
153
+ args.batch_size,
154
+ datasets,
155
+ all_cache_files_for_dataset,
156
+ all_cache_paths_for_dataset,
157
+ encode_for_text_encoder_1,
158
+ )
159
+ del text_encoder_1
160
+
161
+ # Load Text Encoder 2
162
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
163
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
164
+ text_encoder_2.to(device=device)
165
+
166
+ # Encode with Text Encoder 2
167
+ logger.info("Encoding with Text Encoder 2")
168
+
169
+ def encode_for_text_encoder_2(batch: list[ItemInfo]):
170
+ encode_and_save_batch(text_encoder_2, batch, is_llm=False, accelerator=None)
171
+
172
+ process_text_encoder_batches(
173
+ args.num_workers,
174
+ args.skip_existing,
175
+ args.batch_size,
176
+ datasets,
177
+ all_cache_files_for_dataset,
178
+ all_cache_paths_for_dataset,
179
+ encode_for_text_encoder_2,
180
+ )
181
+ del text_encoder_2
182
+
183
+ # remove cache files not in dataset
184
+ post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset)
185
+
186
+
187
+ def setup_parser_common():
188
+ parser = argparse.ArgumentParser()
189
+
190
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
191
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
192
+ parser.add_argument(
193
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
194
+ )
195
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
196
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
197
+ parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
198
+ return parser
199
+
200
+
201
+ def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
202
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
203
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
204
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
205
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
206
+ return parser
207
+
208
+
209
+ if __name__ == "__main__":
210
+ parser = setup_parser_common()
211
+ parser = hv_setup_parser(parser)
212
+
213
+ args = parser.parse_args()
214
+ main(args)
convert_lora.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from safetensors.torch import load_file, save_file
5
+ from safetensors import safe_open
6
+ from utils import model_utils
7
+
8
+ import logging
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+
15
+ def convert_from_diffusers(prefix, weights_sd):
16
+ # convert from diffusers(?) to default LoRA
17
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
18
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
19
+
20
+ # note: Diffusers has no alpha, so alpha is set to rank
21
+ new_weights_sd = {}
22
+ lora_dims = {}
23
+ for key, weight in weights_sd.items():
24
+ diffusers_prefix, key_body = key.split(".", 1)
25
+ if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
26
+ logger.warning(f"unexpected key: {key} in diffusers format")
27
+ continue
28
+
29
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
30
+ new_weights_sd[new_key] = weight
31
+
32
+ lora_name = new_key.split(".")[0] # before first dot
33
+ if lora_name not in lora_dims and "lora_down" in new_key:
34
+ lora_dims[lora_name] = weight.shape[0]
35
+
36
+ # add alpha with rank
37
+ for lora_name, dim in lora_dims.items():
38
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
39
+
40
+ return new_weights_sd
41
+
42
+
43
+ def convert_to_diffusers(prefix, weights_sd):
44
+ # convert from default LoRA to diffusers
45
+
46
+ # get alphas
47
+ lora_alphas = {}
48
+ for key, weight in weights_sd.items():
49
+ if key.startswith(prefix):
50
+ lora_name = key.split(".", 1)[0] # before first dot
51
+ if lora_name not in lora_alphas and "alpha" in key:
52
+ lora_alphas[lora_name] = weight
53
+
54
+ new_weights_sd = {}
55
+ for key, weight in weights_sd.items():
56
+ if key.startswith(prefix):
57
+ if "alpha" in key:
58
+ continue
59
+
60
+ lora_name = key.split(".", 1)[0] # before first dot
61
+
62
+ module_name = lora_name[len(prefix) :] # remove "lora_unet_"
63
+ module_name = module_name.replace("_", ".") # replace "_" with "."
64
+ if ".cross.attn." in module_name or ".self.attn." in module_name:
65
+ # Wan2.1 lora name to module name: ugly but works
66
+ module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn
67
+ module_name = module_name.replace("self.attn", "self_attn") # fix self attn
68
+ else:
69
+ # HunyuanVideo lora name to module name: ugly but works
70
+ module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
71
+ module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
72
+ module_name = module_name.replace("img.", "img_") # fix img
73
+ module_name = module_name.replace("txt.", "txt_") # fix txt
74
+ module_name = module_name.replace("attn.", "attn_") # fix attn
75
+
76
+ diffusers_prefix = "diffusion_model"
77
+ if "lora_down" in key:
78
+ new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
79
+ dim = weight.shape[0]
80
+ elif "lora_up" in key:
81
+ new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
82
+ dim = weight.shape[1]
83
+ else:
84
+ logger.warning(f"unexpected key: {key} in default LoRA format")
85
+ continue
86
+
87
+ # scale weight by alpha
88
+ if lora_name in lora_alphas:
89
+ # we scale both down and up, so scale is sqrt
90
+ scale = lora_alphas[lora_name] / dim
91
+ scale = scale.sqrt()
92
+ weight = weight * scale
93
+ else:
94
+ logger.warning(f"missing alpha for {lora_name}")
95
+
96
+ new_weights_sd[new_key] = weight
97
+
98
+ return new_weights_sd
99
+
100
+
101
+ def convert(input_file, output_file, target_format):
102
+ logger.info(f"loading {input_file}")
103
+ weights_sd = load_file(input_file)
104
+ with safe_open(input_file, framework="pt") as f:
105
+ metadata = f.metadata()
106
+
107
+ logger.info(f"converting to {target_format}")
108
+ prefix = "lora_unet_"
109
+ if target_format == "default":
110
+ new_weights_sd = convert_from_diffusers(prefix, weights_sd)
111
+ metadata = metadata or {}
112
+ model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
113
+ elif target_format == "other":
114
+ new_weights_sd = convert_to_diffusers(prefix, weights_sd)
115
+ else:
116
+ raise ValueError(f"unknown target format: {target_format}")
117
+
118
+ logger.info(f"saving to {output_file}")
119
+ save_file(new_weights_sd, output_file, metadata=metadata)
120
+
121
+ logger.info("done")
122
+
123
+
124
+ def parse_args():
125
+ parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
126
+ parser.add_argument("--input", type=str, required=True, help="input model file")
127
+ parser.add_argument("--output", type=str, required=True, help="output model file")
128
+ parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
129
+ args = parser.parse_args()
130
+ return args
131
+
132
+
133
+ if __name__ == "__main__":
134
+ args = parse_args()
135
+ convert(args.input, args.output, args.target)
dataset/__init__.py ADDED
File without changes
dataset/config_utils.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
14
+
15
+ import toml
16
+ import voluptuous
17
+ from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
18
+
19
+ from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ @dataclass
28
+ class BaseDatasetParams:
29
+ resolution: Tuple[int, int] = (960, 544)
30
+ enable_bucket: bool = False
31
+ bucket_no_upscale: bool = False
32
+ caption_extension: Optional[str] = None
33
+ batch_size: int = 1
34
+ num_repeats: int = 1
35
+ cache_directory: Optional[str] = None
36
+ debug_dataset: bool = False
37
+ architecture: str = "no_default" # short style like "hv" or "wan"
38
+
39
+
40
+ @dataclass
41
+ class ImageDatasetParams(BaseDatasetParams):
42
+ image_directory: Optional[str] = None
43
+ image_jsonl_file: Optional[str] = None
44
+
45
+
46
+ @dataclass
47
+ class VideoDatasetParams(BaseDatasetParams):
48
+ video_directory: Optional[str] = None
49
+ video_jsonl_file: Optional[str] = None
50
+ target_frames: Sequence[int] = (1,)
51
+ frame_extraction: Optional[str] = "head"
52
+ frame_stride: Optional[int] = 1
53
+ frame_sample: Optional[int] = 1
54
+
55
+
56
+ @dataclass
57
+ class DatasetBlueprint:
58
+ is_image_dataset: bool
59
+ params: Union[ImageDatasetParams, VideoDatasetParams]
60
+
61
+
62
+ @dataclass
63
+ class DatasetGroupBlueprint:
64
+ datasets: Sequence[DatasetBlueprint]
65
+
66
+
67
+ @dataclass
68
+ class Blueprint:
69
+ dataset_group: DatasetGroupBlueprint
70
+
71
+
72
+ class ConfigSanitizer:
73
+ # @curry
74
+ @staticmethod
75
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
76
+ Schema(ExactSequence([klass, klass]))(value)
77
+ return tuple(value)
78
+
79
+ # @curry
80
+ @staticmethod
81
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
82
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
83
+ try:
84
+ Schema(klass)(value)
85
+ return (value, value)
86
+ except:
87
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
88
+
89
+ # datasets schema
90
+ DATASET_ASCENDABLE_SCHEMA = {
91
+ "caption_extension": str,
92
+ "batch_size": int,
93
+ "num_repeats": int,
94
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
95
+ "enable_bucket": bool,
96
+ "bucket_no_upscale": bool,
97
+ }
98
+ IMAGE_DATASET_DISTINCT_SCHEMA = {
99
+ "image_directory": str,
100
+ "image_jsonl_file": str,
101
+ "cache_directory": str,
102
+ }
103
+ VIDEO_DATASET_DISTINCT_SCHEMA = {
104
+ "video_directory": str,
105
+ "video_jsonl_file": str,
106
+ "target_frames": [int],
107
+ "frame_extraction": str,
108
+ "frame_stride": int,
109
+ "frame_sample": int,
110
+ "cache_directory": str,
111
+ }
112
+
113
+ # options handled by argparse but not handled by user config
114
+ ARGPARSE_SPECIFIC_SCHEMA = {
115
+ "debug_dataset": bool,
116
+ }
117
+
118
+ def __init__(self) -> None:
119
+ self.image_dataset_schema = self.__merge_dict(
120
+ self.DATASET_ASCENDABLE_SCHEMA,
121
+ self.IMAGE_DATASET_DISTINCT_SCHEMA,
122
+ )
123
+ self.video_dataset_schema = self.__merge_dict(
124
+ self.DATASET_ASCENDABLE_SCHEMA,
125
+ self.VIDEO_DATASET_DISTINCT_SCHEMA,
126
+ )
127
+
128
+ def validate_flex_dataset(dataset_config: dict):
129
+ if "target_frames" in dataset_config:
130
+ return Schema(self.video_dataset_schema)(dataset_config)
131
+ else:
132
+ return Schema(self.image_dataset_schema)(dataset_config)
133
+
134
+ self.dataset_schema = validate_flex_dataset
135
+
136
+ self.general_schema = self.__merge_dict(
137
+ self.DATASET_ASCENDABLE_SCHEMA,
138
+ )
139
+ self.user_config_validator = Schema(
140
+ {
141
+ "general": self.general_schema,
142
+ "datasets": [self.dataset_schema],
143
+ }
144
+ )
145
+ self.argparse_schema = self.__merge_dict(
146
+ self.ARGPARSE_SPECIFIC_SCHEMA,
147
+ )
148
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
149
+
150
+ def sanitize_user_config(self, user_config: dict) -> dict:
151
+ try:
152
+ return self.user_config_validator(user_config)
153
+ except MultipleInvalid:
154
+ # TODO: clarify the error message
155
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
156
+ raise
157
+
158
+ # NOTE: In nature, argument parser result is not needed to be sanitize
159
+ # However this will help us to detect program bug
160
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
161
+ try:
162
+ return self.argparse_config_validator(argparse_namespace)
163
+ except MultipleInvalid:
164
+ # XXX: this should be a bug
165
+ logger.error(
166
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
167
+ )
168
+ raise
169
+
170
+ # NOTE: value would be overwritten by latter dict if there is already the same key
171
+ @staticmethod
172
+ def __merge_dict(*dict_list: dict) -> dict:
173
+ merged = {}
174
+ for schema in dict_list:
175
+ # merged |= schema
176
+ for k, v in schema.items():
177
+ merged[k] = v
178
+ return merged
179
+
180
+
181
+ class BlueprintGenerator:
182
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
183
+
184
+ def __init__(self, sanitizer: ConfigSanitizer):
185
+ self.sanitizer = sanitizer
186
+
187
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
188
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
189
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
190
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
191
+
192
+ argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
193
+ general_config = sanitized_user_config.get("general", {})
194
+
195
+ dataset_blueprints = []
196
+ for dataset_config in sanitized_user_config.get("datasets", []):
197
+ is_image_dataset = "target_frames" not in dataset_config
198
+ if is_image_dataset:
199
+ dataset_params_klass = ImageDatasetParams
200
+ else:
201
+ dataset_params_klass = VideoDatasetParams
202
+
203
+ params = self.generate_params_by_fallbacks(
204
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
205
+ )
206
+ dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
207
+
208
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
209
+
210
+ return Blueprint(dataset_group_blueprint)
211
+
212
+ @staticmethod
213
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
214
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
215
+ search_value = BlueprintGenerator.search_value
216
+ default_params = asdict(param_klass())
217
+ param_names = default_params.keys()
218
+
219
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
220
+
221
+ return param_klass(**params)
222
+
223
+ @staticmethod
224
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
225
+ for cand in fallbacks:
226
+ value = cand.get(key)
227
+ if value is not None:
228
+ return value
229
+
230
+ return default_value
231
+
232
+
233
+ # if training is True, it will return a dataset group for training, otherwise for caching
234
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
235
+ datasets: List[Union[ImageDataset, VideoDataset]] = []
236
+
237
+ for dataset_blueprint in dataset_group_blueprint.datasets:
238
+ if dataset_blueprint.is_image_dataset:
239
+ dataset_klass = ImageDataset
240
+ else:
241
+ dataset_klass = VideoDataset
242
+
243
+ dataset = dataset_klass(**asdict(dataset_blueprint.params))
244
+ datasets.append(dataset)
245
+
246
+ # assertion
247
+ cache_directories = [dataset.cache_directory for dataset in datasets]
248
+ num_of_unique_cache_directories = len(set(cache_directories))
249
+ if num_of_unique_cache_directories != len(cache_directories):
250
+ raise ValueError(
251
+ "cache directory should be unique for each dataset (note that cache directory is image/video directory if not specified)"
252
+ + " / cache directory は各データセットごとに異なる必要があります(指定されていない場合はimage/video directoryが使われるので注意)"
253
+ )
254
+
255
+ # print info
256
+ info = ""
257
+ for i, dataset in enumerate(datasets):
258
+ is_image_dataset = isinstance(dataset, ImageDataset)
259
+ info += dedent(
260
+ f"""\
261
+ [Dataset {i}]
262
+ is_image_dataset: {is_image_dataset}
263
+ resolution: {dataset.resolution}
264
+ batch_size: {dataset.batch_size}
265
+ num_repeats: {dataset.num_repeats}
266
+ caption_extension: "{dataset.caption_extension}"
267
+ enable_bucket: {dataset.enable_bucket}
268
+ bucket_no_upscale: {dataset.bucket_no_upscale}
269
+ cache_directory: "{dataset.cache_directory}"
270
+ debug_dataset: {dataset.debug_dataset}
271
+ """
272
+ )
273
+
274
+ if is_image_dataset:
275
+ info += indent(
276
+ dedent(
277
+ f"""\
278
+ image_directory: "{dataset.image_directory}"
279
+ image_jsonl_file: "{dataset.image_jsonl_file}"
280
+ \n"""
281
+ ),
282
+ " ",
283
+ )
284
+ else:
285
+ info += indent(
286
+ dedent(
287
+ f"""\
288
+ video_directory: "{dataset.video_directory}"
289
+ video_jsonl_file: "{dataset.video_jsonl_file}"
290
+ target_frames: {dataset.target_frames}
291
+ frame_extraction: {dataset.frame_extraction}
292
+ frame_stride: {dataset.frame_stride}
293
+ frame_sample: {dataset.frame_sample}
294
+ \n"""
295
+ ),
296
+ " ",
297
+ )
298
+ logger.info(f"{info}")
299
+
300
+ # make buckets first because it determines the length of dataset
301
+ # and set the same seed for all datasets
302
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
303
+ for i, dataset in enumerate(datasets):
304
+ # logger.info(f"[Dataset {i}]")
305
+ dataset.set_seed(seed)
306
+ if training:
307
+ dataset.prepare_for_training()
308
+
309
+ return DatasetGroup(datasets)
310
+
311
+
312
+ def load_user_config(file: str) -> dict:
313
+ file: Path = Path(file)
314
+ if not file.is_file():
315
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
316
+
317
+ if file.name.lower().endswith(".json"):
318
+ try:
319
+ with open(file, "r", encoding="utf-8") as f:
320
+ config = json.load(f)
321
+ except Exception:
322
+ logger.error(
323
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
324
+ )
325
+ raise
326
+ elif file.name.lower().endswith(".toml"):
327
+ try:
328
+ config = toml.load(file)
329
+ except Exception:
330
+ logger.error(
331
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
332
+ )
333
+ raise
334
+ else:
335
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
336
+
337
+ return config
338
+
339
+
340
+ # for config test
341
+ if __name__ == "__main__":
342
+ parser = argparse.ArgumentParser()
343
+ parser.add_argument("dataset_config")
344
+ config_args, remain = parser.parse_known_args()
345
+
346
+ parser = argparse.ArgumentParser()
347
+ parser.add_argument("--debug_dataset", action="store_true")
348
+ argparse_namespace = parser.parse_args(remain)
349
+
350
+ logger.info("[argparse_namespace]")
351
+ logger.info(f"{vars(argparse_namespace)}")
352
+
353
+ user_config = load_user_config(config_args.dataset_config)
354
+
355
+ logger.info("")
356
+ logger.info("[user_config]")
357
+ logger.info(f"{user_config}")
358
+
359
+ sanitizer = ConfigSanitizer()
360
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
361
+
362
+ logger.info("")
363
+ logger.info("[sanitized_user_config]")
364
+ logger.info(f"{sanitized_user_config}")
365
+
366
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
367
+
368
+ logger.info("")
369
+ logger.info("[blueprint]")
370
+ logger.info(f"{blueprint}")
371
+
372
+ dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
dataset/dataset_config.md ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > 📝 Click on the language section to expand / 言語をクリックして展開
2
+
3
+ ## Dataset Configuration
4
+
5
+ <details>
6
+ <summary>English</summary>
7
+
8
+ Please create a TOML file for dataset configuration.
9
+
10
+ Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
11
+
12
+ The cache directory must be different for each dataset.
13
+ </details>
14
+
15
+ <details>
16
+ <summary>日本語</summary>
17
+
18
+ データセットの設定を行うためのTOMLファイルを作成してください。
19
+
20
+ 画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。
21
+
22
+ キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。
23
+ </details>
24
+
25
+ ### Sample for Image Dataset with Caption Text Files
26
+
27
+ ```toml
28
+ # resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
29
+ # otherwise, the default values will be used for each item
30
+
31
+ # general configurations
32
+ [general]
33
+ resolution = [960, 544]
34
+ caption_extension = ".txt"
35
+ batch_size = 1
36
+ enable_bucket = true
37
+ bucket_no_upscale = false
38
+
39
+ [[datasets]]
40
+ image_directory = "/path/to/image_dir"
41
+ cache_directory = "/path/to/cache_directory"
42
+ num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
43
+
44
+ # other datasets can be added here. each dataset can have different configurations
45
+ ```
46
+
47
+ <details>
48
+ <summary>English</summary>
49
+
50
+ `cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets.
51
+
52
+ `num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes.
53
+
54
+ </details>
55
+
56
+ <details>
57
+ <summary>日本語</summary>
58
+
59
+ `cache_directory` はオプションです。デフォルトは画像ディレクトリと同じディレクトリに設定されます。ただし、異なるデータセット間でキャッシュファイルが共有されるのを防ぐために、明示的に別のキャッシュディレクトリを設定することをお勧めします。
60
+
61
+ `num_repeats` はオプションで、デフォルトは 1 です(繰り返しなし)。画像(や動画)を、その回数だけ単純に繰り返してデータセットを拡張します。たとえば`num_repeats = 2`としたとき、画像20枚のデータセットなら、各画像が2枚ずつ(同一のキャプションで)計40枚存在した場合と同じになります。異なるデータ数のデータセット間でバランスを取るために使用可能です。
62
+
63
+ resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
64
+
65
+ `[[datasets]]`以下を追加することで、他のデータセットを追加できます。各データセットには異なる設定を持てます。
66
+ </details>
67
+
68
+ ### Sample for Image Dataset with Metadata JSONL File
69
+
70
+ ```toml
71
+ # resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
72
+ # caption_extension is not required for metadata jsonl file
73
+ # cache_directory is required for each dataset with metadata jsonl file
74
+
75
+ # general configurations
76
+ [general]
77
+ resolution = [960, 544]
78
+ batch_size = 1
79
+ enable_bucket = true
80
+ bucket_no_upscale = false
81
+
82
+ [[datasets]]
83
+ image_jsonl_file = "/path/to/metadata.jsonl"
84
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
85
+ num_repeats = 1 # optional, default is 1. Same as above.
86
+
87
+ # other datasets can be added here. each dataset can have different configurations
88
+ ```
89
+
90
+ JSONL file format for metadata:
91
+
92
+ ```json
93
+ {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
94
+ {"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
95
+ ```
96
+
97
+ <details>
98
+ <summary>日本語</summary>
99
+
100
+ resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
101
+
102
+ metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須��す。
103
+
104
+ キャプションによるデータセットと同様に、複数のデータセットを追加できます。各データセットには異なる設定を持てます。
105
+ </details>
106
+
107
+
108
+ ### Sample for Video Dataset with Caption Text Files
109
+
110
+ ```toml
111
+ # resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample,
112
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
113
+ # num_repeats is also available for video dataset, example is not shown here
114
+
115
+ # general configurations
116
+ [general]
117
+ resolution = [960, 544]
118
+ caption_extension = ".txt"
119
+ batch_size = 1
120
+ enable_bucket = true
121
+ bucket_no_upscale = false
122
+
123
+ [[datasets]]
124
+ video_directory = "/path/to/video_dir"
125
+ cache_directory = "/path/to/cache_directory" # recommended to set cache directory
126
+ target_frames = [1, 25, 45]
127
+ frame_extraction = "head"
128
+
129
+ # other datasets can be added here. each dataset can have different configurations
130
+ ```
131
+
132
+ <details>
133
+ <summary>日本語</summary>
134
+
135
+ resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。
136
+
137
+ 他の注意事項は画像データセットと同様です。
138
+ </details>
139
+
140
+ ### Sample for Video Dataset with Metadata JSONL File
141
+
142
+ ```toml
143
+ # resolution, target_frames, frame_extraction, frame_stride, frame_sample,
144
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
145
+ # caption_extension is not required for metadata jsonl file
146
+ # cache_directory is required for each dataset with metadata jsonl file
147
+
148
+ # general configurations
149
+ [general]
150
+ resolution = [960, 544]
151
+ batch_size = 1
152
+ enable_bucket = true
153
+ bucket_no_upscale = false
154
+
155
+ [[datasets]]
156
+ video_jsonl_file = "/path/to/metadata.jsonl"
157
+ target_frames = [1, 25, 45]
158
+ frame_extraction = "head"
159
+ cache_directory = "/path/to/cache_directory_head"
160
+
161
+ # same metadata jsonl file can be used for multiple datasets
162
+ [[datasets]]
163
+ video_jsonl_file = "/path/to/metadata.jsonl"
164
+ target_frames = [1]
165
+ frame_stride = 10
166
+ cache_directory = "/path/to/cache_directory_stride"
167
+
168
+ # other datasets can be added here. each dataset can have different configurations
169
+ ```
170
+
171
+ JSONL file format for metadata:
172
+
173
+ ```json
174
+ {"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
175
+ {"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
176
+ ```
177
+
178
+ <details>
179
+ <summary>日本語</summary>
180
+
181
+ resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。
182
+
183
+ metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
184
+
185
+ 他の注意事項は今までのデータセットと同様です。
186
+ </details>
187
+
188
+ ### frame_extraction Options
189
+
190
+ <details>
191
+ <summary>English</summary>
192
+
193
+ - `head`: Extract the first N frames from the video.
194
+ - `chunk`: Extract frames by splitting the video into chunks of N frames.
195
+ - `slide`: Extract frames from the video with a stride of `frame_stride`.
196
+ - `uniform`: Extract `frame_sample` samples uniformly from the video.
197
+
198
+ For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
199
+ </details>
200
+
201
+ <details>
202
+ <summary>日本語</summary>
203
+
204
+ - `head`: 動画から最初のNフレームを抽出します。
205
+ - `chunk`: 動画をNフレームずつに分割してフレームを抽出します。
206
+ - `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。
207
+ - `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。
208
+
209
+ 例えば、40フレームの動画を例とした抽出について、以下の図で説明します。
210
+ </details>
211
+
212
+ ```
213
+ Original Video, 40 frames: x = frame, o = no frame
214
+ oooooooooooooooooooooooooooooooooooooooo
215
+
216
+ head, target_frames = [1, 13, 25] -> extract head frames:
217
+ xooooooooooooooooooooooooooooooooooooooo
218
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
219
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
220
+
221
+ chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
222
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
223
+ oooooooooooooxxxxxxxxxxxxxoooooooooooooo
224
+ ooooooooooooooooooooooooooxxxxxxxxxxxxxo
225
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
226
+
227
+ NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
228
+ 注: frame_extraction "chunk" を使用する場合、target_frames に 1 を含めないでください。全てのフレームが抽出されてしまいます。
229
+
230
+ slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
231
+ xooooooooooooooooooooooooooooooooooooooo
232
+ ooooooooooxooooooooooooooooooooooooooooo
233
+ ooooooooooooooooooooxooooooooooooooooooo
234
+ ooooooooooooooooooooooooooooooxooooooooo
235
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
236
+ ooooooooooxxxxxxxxxxxxxooooooooooooooooo
237
+ ooooooooooooooooooooxxxxxxxxxxxxxooooooo
238
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
239
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
240
+
241
+ uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
242
+ xooooooooooooooooooooooooooooooooooooooo
243
+ oooooooooooooxoooooooooooooooooooooooooo
244
+ oooooooooooooooooooooooooxoooooooooooooo
245
+ ooooooooooooooooooooooooooooooooooooooox
246
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
247
+ oooooooooxxxxxxxxxxxxxoooooooooooooooooo
248
+ ooooooooooooooooooxxxxxxxxxxxxxooooooooo
249
+ oooooooooooooooooooooooooooxxxxxxxxxxxxx
250
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
251
+ oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
252
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
253
+ oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
254
+ ```
255
+
256
+ ## Specifications
257
+
258
+ ```toml
259
+ # general configurations
260
+ [general]
261
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
262
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
263
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
264
+ num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
265
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
266
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
267
+
268
+ ### Image Dataset
269
+
270
+ # sample image dataset with caption text files
271
+ [[datasets]]
272
+ image_directory = "/path/to/image_dir"
273
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
274
+ resolution = [960, 544] # required if general resolution is not set
275
+ batch_size = 4 # optional, overwrite the default batch size
276
+ num_repeats = 1 # optional, overwrite the default num_repeats
277
+ enable_bucket = false # optional, overwrite the default bucketing setting
278
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
279
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
280
+
281
+ # sample image dataset with metadata **jsonl** file
282
+ [[datasets]]
283
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
284
+ resolution = [960, 544] # required if general resolution is not set
285
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
286
+ # caption_extension is not required for metadata jsonl file
287
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
288
+
289
+ ### Video Dataset
290
+
291
+ # sample video dataset with caption text files
292
+ [[datasets]]
293
+ video_directory = "/path/to/video_dir"
294
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
295
+ resolution = [960, 544] # required if general resolution is not set
296
+
297
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
298
+
299
+ # NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
300
+
301
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
302
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
303
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
304
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
305
+
306
+ # sample video dataset with metadata jsonl file
307
+ [[datasets]]
308
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
309
+
310
+ target_frames = [1, 79]
311
+
312
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
313
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
314
+ ```
315
+
316
+ <!--
317
+ # sample image dataset with lance
318
+ [[datasets]]
319
+ image_lance_dataset = "/path/to/lance_dataset"
320
+ resolution = [960, 544] # required if general resolution is not set
321
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
322
+ -->
323
+
324
+ The metadata with .json file will be supported in the near future.
325
+
326
+
327
+
328
+ <!--
329
+
330
+ ```toml
331
+ # general configurations
332
+ [general]
333
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
334
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
335
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
336
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
337
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
338
+
339
+ # sample image dataset with caption text files
340
+ [[datasets]]
341
+ image_directory = "/path/to/image_dir"
342
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
343
+ resolution = [960, 544] # required if general resolution is not set
344
+ batch_size = 4 # optional, overwrite the default batch size
345
+ enable_bucket = false # optional, overwrite the default bucketing setting
346
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
347
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
348
+
349
+ # sample image dataset with metadata **jsonl** file
350
+ [[datasets]]
351
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
352
+ resolution = [960, 544] # required if general resolution is not set
353
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
354
+ # caption_extension is not required for metadata jsonl file
355
+ # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
356
+
357
+ # sample video dataset with caption text files
358
+ [[datasets]]
359
+ video_directory = "/path/to/video_dir"
360
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
361
+ resolution = [960, 544] # required if general resolution is not set
362
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
363
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
364
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
365
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
366
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
367
+
368
+ # sample video dataset with metadata jsonl file
369
+ [[datasets]]
370
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
371
+ target_frames = [1, 79]
372
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
373
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
374
+ ```
375
+
376
+ # sample image dataset with lance
377
+ [[datasets]]
378
+ image_lance_dataset = "/path/to/lance_dataset"
379
+ resolution = [960, 544] # required if general resolution is not set
380
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
381
+
382
+ The metadata with .json file will be supported in the near future.
383
+
384
+
385
+
386
+
387
+ -->
dataset/image_video_dataset.py ADDED
@@ -0,0 +1,1400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ import time
8
+ from typing import Optional, Sequence, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from safetensors.torch import save_file, load_file
13
+ from safetensors import safe_open
14
+ from PIL import Image
15
+ import cv2
16
+ import av
17
+
18
+ from utils import safetensors_utils
19
+ from utils.model_utils import dtype_to_str
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
28
+
29
+ try:
30
+ import pillow_avif
31
+
32
+ IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
33
+ except:
34
+ pass
35
+
36
+ # JPEG-XL on Linux
37
+ try:
38
+ from jxlpy import JXLImagePlugin
39
+
40
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
41
+ except:
42
+ pass
43
+
44
+ # JPEG-XL on Windows
45
+ try:
46
+ import pillow_jxl
47
+
48
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
49
+ except:
50
+ pass
51
+
52
+ VIDEO_EXTENSIONS = [
53
+ ".mp4",
54
+ ".webm",
55
+ ".avi",
56
+ ".mkv",
57
+ ".mov",
58
+ ".flv",
59
+ ".wmv",
60
+ ".m4v",
61
+ ".mpg",
62
+ ".mpeg",
63
+ ".MP4",
64
+ ".WEBM",
65
+ ".AVI",
66
+ ".MKV",
67
+ ".MOV",
68
+ ".FLV",
69
+ ".WMV",
70
+ ".M4V",
71
+ ".MPG",
72
+ ".MPEG",
73
+ ] # some of them are not tested
74
+
75
+ ARCHITECTURE_HUNYUAN_VIDEO = "hv"
76
+ ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video"
77
+ ARCHITECTURE_WAN = "wan"
78
+ ARCHITECTURE_WAN_FULL = "wan"
79
+
80
+
81
+ def glob_images(directory, base="*"):
82
+ img_paths = []
83
+ for ext in IMAGE_EXTENSIONS:
84
+ if base == "*":
85
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
86
+ else:
87
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
88
+ img_paths = list(set(img_paths)) # remove duplicates
89
+ img_paths.sort()
90
+ return img_paths
91
+
92
+
93
+ def glob_videos(directory, base="*"):
94
+ video_paths = []
95
+ for ext in VIDEO_EXTENSIONS:
96
+ if base == "*":
97
+ video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
98
+ else:
99
+ video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
100
+ video_paths = list(set(video_paths)) # remove duplicates
101
+ video_paths.sort()
102
+ return video_paths
103
+
104
+
105
+ def divisible_by(num: int, divisor: int) -> int:
106
+ return num - num % divisor
107
+
108
+
109
+ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
110
+ """
111
+ Resize the image to the bucket resolution.
112
+ """
113
+ is_pil_image = isinstance(image, Image.Image)
114
+ if is_pil_image:
115
+ image_width, image_height = image.size
116
+ else:
117
+ image_height, image_width = image.shape[:2]
118
+
119
+ if bucket_reso == (image_width, image_height):
120
+ return np.array(image) if is_pil_image else image
121
+
122
+ bucket_width, bucket_height = bucket_reso
123
+ if bucket_width == image_width or bucket_height == image_height:
124
+ image = np.array(image) if is_pil_image else image
125
+ else:
126
+ # resize the image to the bucket resolution to match the short side
127
+ scale_width = bucket_width / image_width
128
+ scale_height = bucket_height / image_height
129
+ scale = max(scale_width, scale_height)
130
+ image_width = int(image_width * scale + 0.5)
131
+ image_height = int(image_height * scale + 0.5)
132
+
133
+ if scale > 1:
134
+ image = Image.fromarray(image) if not is_pil_image else image
135
+ image = image.resize((image_width, image_height), Image.LANCZOS)
136
+ image = np.array(image)
137
+ else:
138
+ image = np.array(image) if is_pil_image else image
139
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
140
+
141
+ # crop the image to the bucket resolution
142
+ crop_left = (image_width - bucket_width) // 2
143
+ crop_top = (image_height - bucket_height) // 2
144
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
145
+ return image
146
+
147
+
148
+ class ItemInfo:
149
+ def __init__(
150
+ self,
151
+ item_key: str,
152
+ caption: str,
153
+ original_size: tuple[int, int],
154
+ bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
155
+ frame_count: Optional[int] = None,
156
+ content: Optional[np.ndarray] = None,
157
+ latent_cache_path: Optional[str] = None,
158
+ ) -> None:
159
+ self.item_key = item_key
160
+ self.caption = caption
161
+ self.original_size = original_size
162
+ self.bucket_size = bucket_size
163
+ self.frame_count = frame_count
164
+ self.content = content
165
+ self.latent_cache_path = latent_cache_path
166
+ self.text_encoder_output_cache_path: Optional[str] = None
167
+
168
+ def __str__(self) -> str:
169
+ return (
170
+ f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
171
+ + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
172
+ + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
173
+ )
174
+
175
+
176
+ # We use simple if-else approach to support multiple architectures.
177
+ # Maybe we can use a plugin system in the future.
178
+
179
+ # the keys of the dict are `<content_type>_FxHxW_<dtype>` for latents
180
+ # and `<content_type>_<dtype|mask>` for other tensors
181
+
182
+
183
+ def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
184
+ """HunyuanVideo architecture only"""
185
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
186
+
187
+ _, F, H, W = latent.shape
188
+ dtype_str = dtype_to_str(latent.dtype)
189
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
190
+
191
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
192
+
193
+
194
+ def save_latent_cache_wan(
195
+ item_info: ItemInfo, latent: torch.Tensor, clip_embed: Optional[torch.Tensor], image_latent: Optional[torch.Tensor]
196
+ ):
197
+ """Wan architecture only"""
198
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
199
+
200
+ _, F, H, W = latent.shape
201
+ dtype_str = dtype_to_str(latent.dtype)
202
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
203
+
204
+ if clip_embed is not None:
205
+ sd[f"clip_{dtype_str}"] = clip_embed.detach().cpu()
206
+
207
+ if image_latent is not None:
208
+ sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu()
209
+
210
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
211
+
212
+
213
+ def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
214
+ metadata = {
215
+ "architecture": arch_fullname,
216
+ "width": f"{item_info.original_size[0]}",
217
+ "height": f"{item_info.original_size[1]}",
218
+ "format_version": "1.0.1",
219
+ }
220
+ if item_info.frame_count is not None:
221
+ metadata["frame_count"] = f"{item_info.frame_count}"
222
+
223
+ for key, value in sd.items():
224
+ # NaN check and show warning, replace NaN with 0
225
+ if torch.isnan(value).any():
226
+ logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
227
+ value[torch.isnan(value)] = 0
228
+
229
+ latent_dir = os.path.dirname(item_info.latent_cache_path)
230
+ os.makedirs(latent_dir, exist_ok=True)
231
+
232
+ save_file(sd, item_info.latent_cache_path, metadata=metadata)
233
+
234
+
235
+ def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
236
+ """HunyuanVideo architecture only"""
237
+ assert (
238
+ embed.dim() == 1 or embed.dim() == 2
239
+ ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
240
+ assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
241
+
242
+ sd = {}
243
+ dtype_str = dtype_to_str(embed.dtype)
244
+ text_encoder_type = "llm" if is_llm else "clipL"
245
+ sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
246
+ if mask is not None:
247
+ sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
248
+
249
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
250
+
251
+
252
+ def save_text_encoder_output_cache_wan(item_info: ItemInfo, embed: torch.Tensor):
253
+ """Wan architecture only. Wan2.1 only has a single text encoder"""
254
+
255
+ sd = {}
256
+ dtype_str = dtype_to_str(embed.dtype)
257
+ text_encoder_type = "t5"
258
+ sd[f"varlen_{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
259
+
260
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
261
+
262
+
263
+ def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
264
+ for key, value in sd.items():
265
+ # NaN check and show warning, replace NaN with 0
266
+ if torch.isnan(value).any():
267
+ logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
268
+ value[torch.isnan(value)] = 0
269
+
270
+ metadata = {
271
+ "architecture": arch_fullname,
272
+ "caption1": item_info.caption,
273
+ "format_version": "1.0.1",
274
+ }
275
+
276
+ if os.path.exists(item_info.text_encoder_output_cache_path):
277
+ # load existing cache and update metadata
278
+ with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
279
+ existing_metadata = f.metadata()
280
+ for key in f.keys():
281
+ if key not in sd: # avoid overwriting by existing cache, we keep the new one
282
+ sd[key] = f.get_tensor(key)
283
+
284
+ assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
285
+ if existing_metadata["caption1"] != metadata["caption1"]:
286
+ logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
287
+ # TODO verify format_version
288
+
289
+ existing_metadata.pop("caption1", None)
290
+ existing_metadata.pop("format_version", None)
291
+ metadata.update(existing_metadata) # copy existing metadata except caption and format_version
292
+ else:
293
+ text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
294
+ os.makedirs(text_encoder_output_dir, exist_ok=True)
295
+
296
+ safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
297
+
298
+
299
+ class BucketSelector:
300
+ RESOLUTION_STEPS_HUNYUAN = 16
301
+ RESOLUTION_STEPS_WAN = 16
302
+
303
+ def __init__(
304
+ self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default"
305
+ ):
306
+ self.resolution = resolution
307
+ self.bucket_area = resolution[0] * resolution[1]
308
+ self.architecture = architecture
309
+
310
+ if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
311
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
312
+ elif self.architecture == ARCHITECTURE_WAN:
313
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN
314
+ else:
315
+ raise ValueError(f"Invalid architecture: {self.architecture}")
316
+
317
+ if not enable_bucket:
318
+ # only define one bucket
319
+ self.bucket_resolutions = [resolution]
320
+ self.no_upscale = False
321
+ else:
322
+ # prepare bucket resolution
323
+ self.no_upscale = no_upscale
324
+ sqrt_size = int(math.sqrt(self.bucket_area))
325
+ min_size = divisible_by(sqrt_size // 2, self.reso_steps)
326
+ self.bucket_resolutions = []
327
+ for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
328
+ h = divisible_by(self.bucket_area // w, self.reso_steps)
329
+ self.bucket_resolutions.append((w, h))
330
+ self.bucket_resolutions.append((h, w))
331
+
332
+ self.bucket_resolutions = list(set(self.bucket_resolutions))
333
+ self.bucket_resolutions.sort()
334
+
335
+ # calculate aspect ratio to find the nearest resolution
336
+ self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
337
+
338
+ def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
339
+ """
340
+ return the bucket resolution for the given image size, (width, height)
341
+ """
342
+ area = image_size[0] * image_size[1]
343
+ if self.no_upscale and area <= self.bucket_area:
344
+ w, h = image_size
345
+ w = divisible_by(w, self.reso_steps)
346
+ h = divisible_by(h, self.reso_steps)
347
+ return w, h
348
+
349
+ aspect_ratio = image_size[0] / image_size[1]
350
+ ar_errors = self.aspect_ratios - aspect_ratio
351
+ bucket_id = np.abs(ar_errors).argmin()
352
+ return self.bucket_resolutions[bucket_id]
353
+
354
+
355
+ def load_video(
356
+ video_path: str,
357
+ start_frame: Optional[int] = None,
358
+ end_frame: Optional[int] = None,
359
+ bucket_selector: Optional[BucketSelector] = None,
360
+ bucket_reso: Optional[tuple[int, int]] = None,
361
+ ) -> list[np.ndarray]:
362
+ """
363
+ bucket_reso: if given, resize the video to the bucket resolution, (width, height)
364
+ """
365
+ container = av.open(video_path)
366
+ video = []
367
+ for i, frame in enumerate(container.decode(video=0)):
368
+ if start_frame is not None and i < start_frame:
369
+ continue
370
+ if end_frame is not None and i >= end_frame:
371
+ break
372
+ frame = frame.to_image()
373
+
374
+ if bucket_selector is not None and bucket_reso is None:
375
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
376
+
377
+ if bucket_reso is not None:
378
+ frame = resize_image_to_bucket(frame, bucket_reso)
379
+ else:
380
+ frame = np.array(frame)
381
+
382
+ video.append(frame)
383
+ container.close()
384
+ return video
385
+
386
+
387
+ class BucketBatchManager:
388
+
389
+ def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
390
+ self.batch_size = batch_size
391
+ self.buckets = bucketed_item_info
392
+ self.bucket_resos = list(self.buckets.keys())
393
+ self.bucket_resos.sort()
394
+
395
+ self.bucket_batch_indices = []
396
+ for bucket_reso in self.bucket_resos:
397
+ bucket = self.buckets[bucket_reso]
398
+ num_batches = math.ceil(len(bucket) / self.batch_size)
399
+ for i in range(num_batches):
400
+ self.bucket_batch_indices.append((bucket_reso, i))
401
+
402
+ self.shuffle()
403
+
404
+ def show_bucket_info(self):
405
+ for bucket_reso in self.bucket_resos:
406
+ bucket = self.buckets[bucket_reso]
407
+ logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
408
+
409
+ logger.info(f"total batches: {len(self)}")
410
+
411
+ def shuffle(self):
412
+ for bucket in self.buckets.values():
413
+ random.shuffle(bucket)
414
+ random.shuffle(self.bucket_batch_indices)
415
+
416
+ def __len__(self):
417
+ return len(self.bucket_batch_indices)
418
+
419
+ def __getitem__(self, idx):
420
+ bucket_reso, batch_idx = self.bucket_batch_indices[idx]
421
+ bucket = self.buckets[bucket_reso]
422
+ start = batch_idx * self.batch_size
423
+ end = min(start + self.batch_size, len(bucket))
424
+
425
+ batch_tensor_data = {}
426
+ varlen_keys = set()
427
+ for item_info in bucket[start:end]:
428
+ sd_latent = load_file(item_info.latent_cache_path)
429
+ sd_te = load_file(item_info.text_encoder_output_cache_path)
430
+ sd = {**sd_latent, **sd_te}
431
+
432
+ # TODO refactor this
433
+ for key in sd.keys():
434
+ is_varlen_key = key.startswith("varlen_") # varlen keys are not stacked
435
+ content_key = key
436
+
437
+ if is_varlen_key:
438
+ content_key = content_key.replace("varlen_", "")
439
+
440
+ if content_key.endswith("_mask"):
441
+ pass
442
+ else:
443
+ content_key = content_key.rsplit("_", 1)[0] # remove dtype
444
+ if content_key.startswith("latents_"):
445
+ content_key = content_key.rsplit("_", 1)[0] # remove FxHxW
446
+
447
+ if content_key not in batch_tensor_data:
448
+ batch_tensor_data[content_key] = []
449
+ batch_tensor_data[content_key].append(sd[key])
450
+
451
+ if is_varlen_key:
452
+ varlen_keys.add(content_key)
453
+
454
+ for key in batch_tensor_data.keys():
455
+ if key not in varlen_keys:
456
+ batch_tensor_data[key] = torch.stack(batch_tensor_data[key])
457
+
458
+ return batch_tensor_data
459
+
460
+
461
+ class ContentDatasource:
462
+ def __init__(self):
463
+ self.caption_only = False
464
+
465
+ def set_caption_only(self, caption_only: bool):
466
+ self.caption_only = caption_only
467
+
468
+ def is_indexable(self):
469
+ return False
470
+
471
+ def get_caption(self, idx: int) -> tuple[str, str]:
472
+ """
473
+ Returns caption. May not be called if is_indexable() returns False.
474
+ """
475
+ raise NotImplementedError
476
+
477
+ def __len__(self):
478
+ raise NotImplementedError
479
+
480
+ def __iter__(self):
481
+ raise NotImplementedError
482
+
483
+ def __next__(self):
484
+ raise NotImplementedError
485
+
486
+
487
+ class ImageDatasource(ContentDatasource):
488
+ def __init__(self):
489
+ super().__init__()
490
+
491
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
492
+ """
493
+ Returns image data as a tuple of image path, image, and caption for the given index.
494
+ Key must be unique and valid as a file name.
495
+ May not be called if is_indexable() returns False.
496
+ """
497
+ raise NotImplementedError
498
+
499
+
500
+ class ImageDirectoryDatasource(ImageDatasource):
501
+ def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
502
+ super().__init__()
503
+ self.image_directory = image_directory
504
+ self.caption_extension = caption_extension
505
+ self.current_idx = 0
506
+
507
+ # glob images
508
+ logger.info(f"glob images in {self.image_directory}")
509
+ self.image_paths = glob_images(self.image_directory)
510
+ logger.info(f"found {len(self.image_paths)} images")
511
+
512
+ def is_indexable(self):
513
+ return True
514
+
515
+ def __len__(self):
516
+ return len(self.image_paths)
517
+
518
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
519
+ image_path = self.image_paths[idx]
520
+ image = Image.open(image_path).convert("RGB")
521
+
522
+ _, caption = self.get_caption(idx)
523
+
524
+ return image_path, image, caption
525
+
526
+ def get_caption(self, idx: int) -> tuple[str, str]:
527
+ image_path = self.image_paths[idx]
528
+ caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
529
+ with open(caption_path, "r", encoding="utf-8") as f:
530
+ caption = f.read().strip()
531
+ return image_path, caption
532
+
533
+ def __iter__(self):
534
+ self.current_idx = 0
535
+ return self
536
+
537
+ def __next__(self) -> callable:
538
+ """
539
+ Returns a fetcher function that returns image data.
540
+ """
541
+ if self.current_idx >= len(self.image_paths):
542
+ raise StopIteration
543
+
544
+ if self.caption_only:
545
+
546
+ def create_caption_fetcher(index):
547
+ return lambda: self.get_caption(index)
548
+
549
+ fetcher = create_caption_fetcher(self.current_idx)
550
+ else:
551
+
552
+ def create_image_fetcher(index):
553
+ return lambda: self.get_image_data(index)
554
+
555
+ fetcher = create_image_fetcher(self.current_idx)
556
+
557
+ self.current_idx += 1
558
+ return fetcher
559
+
560
+
561
+ class ImageJsonlDatasource(ImageDatasource):
562
+ def __init__(self, image_jsonl_file: str):
563
+ super().__init__()
564
+ self.image_jsonl_file = image_jsonl_file
565
+ self.current_idx = 0
566
+
567
+ # load jsonl
568
+ logger.info(f"load image jsonl from {self.image_jsonl_file}")
569
+ self.data = []
570
+ with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
571
+ for line in f:
572
+ try:
573
+ data = json.loads(line)
574
+ except json.JSONDecodeError:
575
+ logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}")
576
+ raise
577
+ self.data.append(data)
578
+ logger.info(f"loaded {len(self.data)} images")
579
+
580
+ def is_indexable(self):
581
+ return True
582
+
583
+ def __len__(self):
584
+ return len(self.data)
585
+
586
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
587
+ data = self.data[idx]
588
+ image_path = data["image_path"]
589
+ image = Image.open(image_path).convert("RGB")
590
+
591
+ caption = data["caption"]
592
+
593
+ return image_path, image, caption
594
+
595
+ def get_caption(self, idx: int) -> tuple[str, str]:
596
+ data = self.data[idx]
597
+ image_path = data["image_path"]
598
+ caption = data["caption"]
599
+ return image_path, caption
600
+
601
+ def __iter__(self):
602
+ self.current_idx = 0
603
+ return self
604
+
605
+ def __next__(self) -> callable:
606
+ if self.current_idx >= len(self.data):
607
+ raise StopIteration
608
+
609
+ if self.caption_only:
610
+
611
+ def create_caption_fetcher(index):
612
+ return lambda: self.get_caption(index)
613
+
614
+ fetcher = create_caption_fetcher(self.current_idx)
615
+
616
+ else:
617
+
618
+ def create_fetcher(index):
619
+ return lambda: self.get_image_data(index)
620
+
621
+ fetcher = create_fetcher(self.current_idx)
622
+
623
+ self.current_idx += 1
624
+ return fetcher
625
+
626
+
627
+ class VideoDatasource(ContentDatasource):
628
+ def __init__(self):
629
+ super().__init__()
630
+
631
+ # None means all frames
632
+ self.start_frame = None
633
+ self.end_frame = None
634
+
635
+ self.bucket_selector = None
636
+
637
+ def __len__(self):
638
+ raise NotImplementedError
639
+
640
+ def get_video_data_from_path(
641
+ self,
642
+ video_path: str,
643
+ start_frame: Optional[int] = None,
644
+ end_frame: Optional[int] = None,
645
+ bucket_selector: Optional[BucketSelector] = None,
646
+ ) -> tuple[str, list[Image.Image], str]:
647
+ # this method can resize the video if bucket_selector is given to reduce the memory usage
648
+
649
+ start_frame = start_frame if start_frame is not None else self.start_frame
650
+ end_frame = end_frame if end_frame is not None else self.end_frame
651
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
652
+
653
+ video = load_video(video_path, start_frame, end_frame, bucket_selector)
654
+ return video
655
+
656
+ def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
657
+ self.start_frame = start_frame
658
+ self.end_frame = end_frame
659
+
660
+ def set_bucket_selector(self, bucket_selector: BucketSelector):
661
+ self.bucket_selector = bucket_selector
662
+
663
+ def __iter__(self):
664
+ raise NotImplementedError
665
+
666
+ def __next__(self):
667
+ raise NotImplementedError
668
+
669
+
670
+ class VideoDirectoryDatasource(VideoDatasource):
671
+ def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
672
+ super().__init__()
673
+ self.video_directory = video_directory
674
+ self.caption_extension = caption_extension
675
+ self.current_idx = 0
676
+
677
+ # glob images
678
+ logger.info(f"glob images in {self.video_directory}")
679
+ self.video_paths = glob_videos(self.video_directory)
680
+ logger.info(f"found {len(self.video_paths)} videos")
681
+
682
+ def is_indexable(self):
683
+ return True
684
+
685
+ def __len__(self):
686
+ return len(self.video_paths)
687
+
688
+ def get_video_data(
689
+ self,
690
+ idx: int,
691
+ start_frame: Optional[int] = None,
692
+ end_frame: Optional[int] = None,
693
+ bucket_selector: Optional[BucketSelector] = None,
694
+ ) -> tuple[str, list[Image.Image], str]:
695
+ video_path = self.video_paths[idx]
696
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
697
+
698
+ _, caption = self.get_caption(idx)
699
+
700
+ return video_path, video, caption
701
+
702
+ def get_caption(self, idx: int) -> tuple[str, str]:
703
+ video_path = self.video_paths[idx]
704
+ caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
705
+ with open(caption_path, "r", encoding="utf-8") as f:
706
+ caption = f.read().strip()
707
+ return video_path, caption
708
+
709
+ def __iter__(self):
710
+ self.current_idx = 0
711
+ return self
712
+
713
+ def __next__(self):
714
+ if self.current_idx >= len(self.video_paths):
715
+ raise StopIteration
716
+
717
+ if self.caption_only:
718
+
719
+ def create_caption_fetcher(index):
720
+ return lambda: self.get_caption(index)
721
+
722
+ fetcher = create_caption_fetcher(self.current_idx)
723
+
724
+ else:
725
+
726
+ def create_fetcher(index):
727
+ return lambda: self.get_video_data(index)
728
+
729
+ fetcher = create_fetcher(self.current_idx)
730
+
731
+ self.current_idx += 1
732
+ return fetcher
733
+
734
+
735
+ class VideoJsonlDatasource(VideoDatasource):
736
+ def __init__(self, video_jsonl_file: str):
737
+ super().__init__()
738
+ self.video_jsonl_file = video_jsonl_file
739
+ self.current_idx = 0
740
+
741
+ # load jsonl
742
+ logger.info(f"load video jsonl from {self.video_jsonl_file}")
743
+ self.data = []
744
+ with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
745
+ for line in f:
746
+ data = json.loads(line)
747
+ self.data.append(data)
748
+ logger.info(f"loaded {len(self.data)} videos")
749
+
750
+ def is_indexable(self):
751
+ return True
752
+
753
+ def __len__(self):
754
+ return len(self.data)
755
+
756
+ def get_video_data(
757
+ self,
758
+ idx: int,
759
+ start_frame: Optional[int] = None,
760
+ end_frame: Optional[int] = None,
761
+ bucket_selector: Optional[BucketSelector] = None,
762
+ ) -> tuple[str, list[Image.Image], str]:
763
+ data = self.data[idx]
764
+ video_path = data["video_path"]
765
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
766
+
767
+ caption = data["caption"]
768
+
769
+ return video_path, video, caption
770
+
771
+ def get_caption(self, idx: int) -> tuple[str, str]:
772
+ data = self.data[idx]
773
+ video_path = data["video_path"]
774
+ caption = data["caption"]
775
+ return video_path, caption
776
+
777
+ def __iter__(self):
778
+ self.current_idx = 0
779
+ return self
780
+
781
+ def __next__(self):
782
+ if self.current_idx >= len(self.data):
783
+ raise StopIteration
784
+
785
+ if self.caption_only:
786
+
787
+ def create_caption_fetcher(index):
788
+ return lambda: self.get_caption(index)
789
+
790
+ fetcher = create_caption_fetcher(self.current_idx)
791
+
792
+ else:
793
+
794
+ def create_fetcher(index):
795
+ return lambda: self.get_video_data(index)
796
+
797
+ fetcher = create_fetcher(self.current_idx)
798
+
799
+ self.current_idx += 1
800
+ return fetcher
801
+
802
+
803
+ class BaseDataset(torch.utils.data.Dataset):
804
+ def __init__(
805
+ self,
806
+ resolution: Tuple[int, int] = (960, 544),
807
+ caption_extension: Optional[str] = None,
808
+ batch_size: int = 1,
809
+ num_repeats: int = 1,
810
+ enable_bucket: bool = False,
811
+ bucket_no_upscale: bool = False,
812
+ cache_directory: Optional[str] = None,
813
+ debug_dataset: bool = False,
814
+ architecture: str = "no_default",
815
+ ):
816
+ self.resolution = resolution
817
+ self.caption_extension = caption_extension
818
+ self.batch_size = batch_size
819
+ self.num_repeats = num_repeats
820
+ self.enable_bucket = enable_bucket
821
+ self.bucket_no_upscale = bucket_no_upscale
822
+ self.cache_directory = cache_directory
823
+ self.debug_dataset = debug_dataset
824
+ self.architecture = architecture
825
+ self.seed = None
826
+ self.current_epoch = 0
827
+
828
+ if not self.enable_bucket:
829
+ self.bucket_no_upscale = False
830
+
831
+ def get_metadata(self) -> dict:
832
+ metadata = {
833
+ "resolution": self.resolution,
834
+ "caption_extension": self.caption_extension,
835
+ "batch_size_per_device": self.batch_size,
836
+ "num_repeats": self.num_repeats,
837
+ "enable_bucket": bool(self.enable_bucket),
838
+ "bucket_no_upscale": bool(self.bucket_no_upscale),
839
+ }
840
+ return metadata
841
+
842
+ def get_all_latent_cache_files(self):
843
+ return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
844
+
845
+ def get_all_text_encoder_output_cache_files(self):
846
+ return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}_te.safetensors"))
847
+
848
+ def get_latent_cache_path(self, item_info: ItemInfo) -> str:
849
+ """
850
+ Returns the cache path for the latent tensor.
851
+
852
+ item_info: ItemInfo object
853
+
854
+ Returns:
855
+ str: cache path
856
+
857
+ cache_path is based on the item_key and the resolution.
858
+ """
859
+ w, h = item_info.original_size
860
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
861
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
862
+ return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors")
863
+
864
+ def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
865
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
866
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
867
+ return os.path.join(self.cache_directory, f"{basename}_{self.architecture}_te.safetensors")
868
+
869
+ def retrieve_latent_cache_batches(self, num_workers: int):
870
+ raise NotImplementedError
871
+
872
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
873
+ raise NotImplementedError
874
+
875
+ def prepare_for_training(self):
876
+ pass
877
+
878
+ def set_seed(self, seed: int):
879
+ self.seed = seed
880
+
881
+ def set_current_epoch(self, epoch):
882
+ if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
883
+ if epoch > self.current_epoch:
884
+ logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
885
+ num_epochs = epoch - self.current_epoch
886
+ for _ in range(num_epochs):
887
+ self.current_epoch += 1
888
+ self.shuffle_buckets()
889
+ # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
890
+ else:
891
+ logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
892
+ self.current_epoch = epoch
893
+
894
+ def set_current_step(self, step):
895
+ self.current_step = step
896
+
897
+ def set_max_train_steps(self, max_train_steps):
898
+ self.max_train_steps = max_train_steps
899
+
900
+ def shuffle_buckets(self):
901
+ raise NotImplementedError
902
+
903
+ def __len__(self):
904
+ return NotImplementedError
905
+
906
+ def __getitem__(self, idx):
907
+ raise NotImplementedError
908
+
909
+ def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
910
+ datasource.set_caption_only(True)
911
+ executor = ThreadPoolExecutor(max_workers=num_workers)
912
+
913
+ data: list[ItemInfo] = []
914
+ futures = []
915
+
916
+ def aggregate_future(consume_all: bool = False):
917
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
918
+ completed_futures = [future for future in futures if future.done()]
919
+ if len(completed_futures) == 0:
920
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
921
+ time.sleep(0.1)
922
+ continue
923
+ else:
924
+ break # submit batch if possible
925
+
926
+ for future in completed_futures:
927
+ item_key, caption = future.result()
928
+ item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
929
+ item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
930
+ data.append(item_info)
931
+
932
+ futures.remove(future)
933
+
934
+ def submit_batch(flush: bool = False):
935
+ nonlocal data
936
+ if len(data) >= batch_size or (len(data) > 0 and flush):
937
+ batch = data[0:batch_size]
938
+ if len(data) > batch_size:
939
+ data = data[batch_size:]
940
+ else:
941
+ data = []
942
+ return batch
943
+ return None
944
+
945
+ for fetch_op in datasource:
946
+ future = executor.submit(fetch_op)
947
+ futures.append(future)
948
+ aggregate_future()
949
+ while True:
950
+ batch = submit_batch()
951
+ if batch is None:
952
+ break
953
+ yield batch
954
+
955
+ aggregate_future(consume_all=True)
956
+ while True:
957
+ batch = submit_batch(flush=True)
958
+ if batch is None:
959
+ break
960
+ yield batch
961
+
962
+ executor.shutdown()
963
+
964
+
965
+ class ImageDataset(BaseDataset):
966
+ def __init__(
967
+ self,
968
+ resolution: Tuple[int, int],
969
+ caption_extension: Optional[str],
970
+ batch_size: int,
971
+ num_repeats: int,
972
+ enable_bucket: bool,
973
+ bucket_no_upscale: bool,
974
+ image_directory: Optional[str] = None,
975
+ image_jsonl_file: Optional[str] = None,
976
+ cache_directory: Optional[str] = None,
977
+ debug_dataset: bool = False,
978
+ architecture: str = "no_default",
979
+ ):
980
+ super(ImageDataset, self).__init__(
981
+ resolution,
982
+ caption_extension,
983
+ batch_size,
984
+ num_repeats,
985
+ enable_bucket,
986
+ bucket_no_upscale,
987
+ cache_directory,
988
+ debug_dataset,
989
+ architecture,
990
+ )
991
+ self.image_directory = image_directory
992
+ self.image_jsonl_file = image_jsonl_file
993
+ if image_directory is not None:
994
+ self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
995
+ elif image_jsonl_file is not None:
996
+ self.datasource = ImageJsonlDatasource(image_jsonl_file)
997
+ else:
998
+ raise ValueError("image_directory or image_jsonl_file must be specified")
999
+
1000
+ if self.cache_directory is None:
1001
+ self.cache_directory = self.image_directory
1002
+
1003
+ self.batch_manager = None
1004
+ self.num_train_items = 0
1005
+
1006
+ def get_metadata(self):
1007
+ metadata = super().get_metadata()
1008
+ if self.image_directory is not None:
1009
+ metadata["image_directory"] = os.path.basename(self.image_directory)
1010
+ if self.image_jsonl_file is not None:
1011
+ metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
1012
+ return metadata
1013
+
1014
+ def get_total_image_count(self):
1015
+ return len(self.datasource) if self.datasource.is_indexable() else None
1016
+
1017
+ def retrieve_latent_cache_batches(self, num_workers: int):
1018
+ buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
1019
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1020
+
1021
+ batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
1022
+ futures = []
1023
+
1024
+ # aggregate futures and sort by bucket resolution
1025
+ def aggregate_future(consume_all: bool = False):
1026
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1027
+ completed_futures = [future for future in futures if future.done()]
1028
+ if len(completed_futures) == 0:
1029
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1030
+ time.sleep(0.1)
1031
+ continue
1032
+ else:
1033
+ break # submit batch if possible
1034
+
1035
+ for future in completed_futures:
1036
+ original_size, item_key, image, caption = future.result()
1037
+ bucket_height, bucket_width = image.shape[:2]
1038
+ bucket_reso = (bucket_width, bucket_height)
1039
+
1040
+ item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
1041
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1042
+
1043
+ if bucket_reso not in batches:
1044
+ batches[bucket_reso] = []
1045
+ batches[bucket_reso].append(item_info)
1046
+
1047
+ futures.remove(future)
1048
+
1049
+ # submit batch if some bucket has enough items
1050
+ def submit_batch(flush: bool = False):
1051
+ for key in batches:
1052
+ if len(batches[key]) >= self.batch_size or flush:
1053
+ batch = batches[key][0 : self.batch_size]
1054
+ if len(batches[key]) > self.batch_size:
1055
+ batches[key] = batches[key][self.batch_size :]
1056
+ else:
1057
+ del batches[key]
1058
+ return key, batch
1059
+ return None, None
1060
+
1061
+ for fetch_op in self.datasource:
1062
+
1063
+ # fetch and resize image in a separate thread
1064
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
1065
+ image_key, image, caption = op()
1066
+ image: Image.Image
1067
+ image_size = image.size
1068
+
1069
+ bucket_reso = buckset_selector.get_bucket_resolution(image_size)
1070
+ image = resize_image_to_bucket(image, bucket_reso)
1071
+ return image_size, image_key, image, caption
1072
+
1073
+ future = executor.submit(fetch_and_resize, fetch_op)
1074
+ futures.append(future)
1075
+ aggregate_future()
1076
+ while True:
1077
+ key, batch = submit_batch()
1078
+ if key is None:
1079
+ break
1080
+ yield key, batch
1081
+
1082
+ aggregate_future(consume_all=True)
1083
+ while True:
1084
+ key, batch = submit_batch(flush=True)
1085
+ if key is None:
1086
+ break
1087
+ yield key, batch
1088
+
1089
+ executor.shutdown()
1090
+
1091
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1092
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
1093
+
1094
+ def prepare_for_training(self):
1095
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
1096
+
1097
+ # glob cache files
1098
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
1099
+
1100
+ # assign cache files to item info
1101
+ bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
1102
+ for cache_file in latent_cache_files:
1103
+ tokens = os.path.basename(cache_file).split("_")
1104
+
1105
+ image_size = tokens[-2] # 0000x0000
1106
+ image_width, image_height = map(int, image_size.split("x"))
1107
+ image_size = (image_width, image_height)
1108
+
1109
+ item_key = "_".join(tokens[:-2])
1110
+ text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
1111
+ if not os.path.exists(text_encoder_output_cache_file):
1112
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
1113
+ continue
1114
+
1115
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1116
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
1117
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1118
+
1119
+ bucket = bucketed_item_info.get(bucket_reso, [])
1120
+ for _ in range(self.num_repeats):
1121
+ bucket.append(item_info)
1122
+ bucketed_item_info[bucket_reso] = bucket
1123
+
1124
+ # prepare batch manager
1125
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
1126
+ self.batch_manager.show_bucket_info()
1127
+
1128
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
1129
+
1130
+ def shuffle_buckets(self):
1131
+ # set random seed for this epoch
1132
+ random.seed(self.seed + self.current_epoch)
1133
+ self.batch_manager.shuffle()
1134
+
1135
+ def __len__(self):
1136
+ if self.batch_manager is None:
1137
+ return 100 # dummy value
1138
+ return len(self.batch_manager)
1139
+
1140
+ def __getitem__(self, idx):
1141
+ return self.batch_manager[idx]
1142
+
1143
+
1144
+ class VideoDataset(BaseDataset):
1145
+ def __init__(
1146
+ self,
1147
+ resolution: Tuple[int, int],
1148
+ caption_extension: Optional[str],
1149
+ batch_size: int,
1150
+ num_repeats: int,
1151
+ enable_bucket: bool,
1152
+ bucket_no_upscale: bool,
1153
+ frame_extraction: Optional[str] = "head",
1154
+ frame_stride: Optional[int] = 1,
1155
+ frame_sample: Optional[int] = 1,
1156
+ target_frames: Optional[list[int]] = None,
1157
+ video_directory: Optional[str] = None,
1158
+ video_jsonl_file: Optional[str] = None,
1159
+ cache_directory: Optional[str] = None,
1160
+ debug_dataset: bool = False,
1161
+ architecture: str = "no_default",
1162
+ ):
1163
+ super(VideoDataset, self).__init__(
1164
+ resolution,
1165
+ caption_extension,
1166
+ batch_size,
1167
+ num_repeats,
1168
+ enable_bucket,
1169
+ bucket_no_upscale,
1170
+ cache_directory,
1171
+ debug_dataset,
1172
+ architecture,
1173
+ )
1174
+ self.video_directory = video_directory
1175
+ self.video_jsonl_file = video_jsonl_file
1176
+ self.target_frames = target_frames
1177
+ self.frame_extraction = frame_extraction
1178
+ self.frame_stride = frame_stride
1179
+ self.frame_sample = frame_sample
1180
+
1181
+ if video_directory is not None:
1182
+ self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
1183
+ elif video_jsonl_file is not None:
1184
+ self.datasource = VideoJsonlDatasource(video_jsonl_file)
1185
+
1186
+ if self.frame_extraction == "uniform" and self.frame_sample == 1:
1187
+ self.frame_extraction = "head"
1188
+ logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
1189
+ if self.frame_extraction == "head":
1190
+ # head extraction. we can limit the number of frames to be extracted
1191
+ self.datasource.set_start_and_end_frame(0, max(self.target_frames))
1192
+
1193
+ if self.cache_directory is None:
1194
+ self.cache_directory = self.video_directory
1195
+
1196
+ self.batch_manager = None
1197
+ self.num_train_items = 0
1198
+
1199
+ def get_metadata(self):
1200
+ metadata = super().get_metadata()
1201
+ if self.video_directory is not None:
1202
+ metadata["video_directory"] = os.path.basename(self.video_directory)
1203
+ if self.video_jsonl_file is not None:
1204
+ metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
1205
+ metadata["frame_extraction"] = self.frame_extraction
1206
+ metadata["frame_stride"] = self.frame_stride
1207
+ metadata["frame_sample"] = self.frame_sample
1208
+ metadata["target_frames"] = self.target_frames
1209
+ return metadata
1210
+
1211
+ def retrieve_latent_cache_batches(self, num_workers: int):
1212
+ buckset_selector = BucketSelector(self.resolution, architecture=self.architecture)
1213
+ self.datasource.set_bucket_selector(buckset_selector)
1214
+
1215
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1216
+
1217
+ # key: (width, height, frame_count), value: [ItemInfo]
1218
+ batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
1219
+ futures = []
1220
+
1221
+ def aggregate_future(consume_all: bool = False):
1222
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1223
+ completed_futures = [future for future in futures if future.done()]
1224
+ if len(completed_futures) == 0:
1225
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1226
+ time.sleep(0.1)
1227
+ continue
1228
+ else:
1229
+ break # submit batch if possible
1230
+
1231
+ for future in completed_futures:
1232
+ original_frame_size, video_key, video, caption = future.result()
1233
+
1234
+ frame_count = len(video)
1235
+ video = np.stack(video, axis=0)
1236
+ height, width = video.shape[1:3]
1237
+ bucket_reso = (width, height) # already resized
1238
+
1239
+ crop_pos_and_frames = []
1240
+ if self.frame_extraction == "head":
1241
+ for target_frame in self.target_frames:
1242
+ if frame_count >= target_frame:
1243
+ crop_pos_and_frames.append((0, target_frame))
1244
+ elif self.frame_extraction == "chunk":
1245
+ # split by target_frames
1246
+ for target_frame in self.target_frames:
1247
+ for i in range(0, frame_count, target_frame):
1248
+ if i + target_frame <= frame_count:
1249
+ crop_pos_and_frames.append((i, target_frame))
1250
+ elif self.frame_extraction == "slide":
1251
+ # slide window
1252
+ for target_frame in self.target_frames:
1253
+ if frame_count >= target_frame:
1254
+ for i in range(0, frame_count - target_frame + 1, self.frame_stride):
1255
+ crop_pos_and_frames.append((i, target_frame))
1256
+ elif self.frame_extraction == "uniform":
1257
+ # select N frames uniformly
1258
+ for target_frame in self.target_frames:
1259
+ if frame_count >= target_frame:
1260
+ frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
1261
+ for i in frame_indices:
1262
+ crop_pos_and_frames.append((i, target_frame))
1263
+ else:
1264
+ raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
1265
+
1266
+ for crop_pos, target_frame in crop_pos_and_frames:
1267
+ cropped_video = video[crop_pos : crop_pos + target_frame]
1268
+ body, ext = os.path.splitext(video_key)
1269
+ item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
1270
+ batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
1271
+
1272
+ item_info = ItemInfo(
1273
+ item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
1274
+ )
1275
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1276
+
1277
+ batch = batches.get(batch_key, [])
1278
+ batch.append(item_info)
1279
+ batches[batch_key] = batch
1280
+
1281
+ futures.remove(future)
1282
+
1283
+ def submit_batch(flush: bool = False):
1284
+ for key in batches:
1285
+ if len(batches[key]) >= self.batch_size or flush:
1286
+ batch = batches[key][0 : self.batch_size]
1287
+ if len(batches[key]) > self.batch_size:
1288
+ batches[key] = batches[key][self.batch_size :]
1289
+ else:
1290
+ del batches[key]
1291
+ return key, batch
1292
+ return None, None
1293
+
1294
+ for operator in self.datasource:
1295
+
1296
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
1297
+ video_key, video, caption = op()
1298
+ video: list[np.ndarray]
1299
+ frame_size = (video[0].shape[1], video[0].shape[0])
1300
+
1301
+ # resize if necessary
1302
+ bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
1303
+ video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
1304
+
1305
+ return frame_size, video_key, video, caption
1306
+
1307
+ future = executor.submit(fetch_and_resize, operator)
1308
+ futures.append(future)
1309
+ aggregate_future()
1310
+ while True:
1311
+ key, batch = submit_batch()
1312
+ if key is None:
1313
+ break
1314
+ yield key, batch
1315
+
1316
+ aggregate_future(consume_all=True)
1317
+ while True:
1318
+ key, batch = submit_batch(flush=True)
1319
+ if key is None:
1320
+ break
1321
+ yield key, batch
1322
+
1323
+ executor.shutdown()
1324
+
1325
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1326
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
1327
+
1328
+ def prepare_for_training(self):
1329
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
1330
+
1331
+ # glob cache files
1332
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
1333
+
1334
+ # assign cache files to item info
1335
+ bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
1336
+ for cache_file in latent_cache_files:
1337
+ tokens = os.path.basename(cache_file).split("_")
1338
+
1339
+ image_size = tokens[-2] # 0000x0000
1340
+ image_width, image_height = map(int, image_size.split("x"))
1341
+ image_size = (image_width, image_height)
1342
+
1343
+ frame_pos, frame_count = tokens[-3].split("-")
1344
+ frame_pos, frame_count = int(frame_pos), int(frame_count)
1345
+
1346
+ item_key = "_".join(tokens[:-3])
1347
+ text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
1348
+ if not os.path.exists(text_encoder_output_cache_file):
1349
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
1350
+ continue
1351
+
1352
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1353
+ bucket_reso = (*bucket_reso, frame_count)
1354
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
1355
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1356
+
1357
+ bucket = bucketed_item_info.get(bucket_reso, [])
1358
+ for _ in range(self.num_repeats):
1359
+ bucket.append(item_info)
1360
+ bucketed_item_info[bucket_reso] = bucket
1361
+
1362
+ # prepare batch manager
1363
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
1364
+ self.batch_manager.show_bucket_info()
1365
+
1366
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
1367
+
1368
+ def shuffle_buckets(self):
1369
+ # set random seed for this epoch
1370
+ random.seed(self.seed + self.current_epoch)
1371
+ self.batch_manager.shuffle()
1372
+
1373
+ def __len__(self):
1374
+ if self.batch_manager is None:
1375
+ return 100 # dummy value
1376
+ return len(self.batch_manager)
1377
+
1378
+ def __getitem__(self, idx):
1379
+ return self.batch_manager[idx]
1380
+
1381
+
1382
+ class DatasetGroup(torch.utils.data.ConcatDataset):
1383
+ def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
1384
+ super().__init__(datasets)
1385
+ self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
1386
+ self.num_train_items = 0
1387
+ for dataset in self.datasets:
1388
+ self.num_train_items += dataset.num_train_items
1389
+
1390
+ def set_current_epoch(self, epoch):
1391
+ for dataset in self.datasets:
1392
+ dataset.set_current_epoch(epoch)
1393
+
1394
+ def set_current_step(self, step):
1395
+ for dataset in self.datasets:
1396
+ dataset.set_current_step(step)
1397
+
1398
+ def set_max_train_steps(self, max_train_steps):
1399
+ for dataset in self.datasets:
1400
+ dataset.set_max_train_steps(max_train_steps)
docs/advanced_config.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > 📝 Click on the language section to expand / 言語をクリックして展開
2
+
3
+ # Advanced configuration / 高度な設定
4
+
5
+ ## How to specify `network_args` / `network_args`の指定方法
6
+
7
+ The `--network_args` option is an option for specifying detailed arguments to LoRA. Specify the arguments in the form of `key=value` in `--network_args`.
8
+
9
+ <details>
10
+ <summary>日本語</summary>
11
+ `--network_args`オプションは、LoRAへの詳細な引数を指定するためのオプションです。`--network_args`には、`key=value`の形式で引数を指定します。
12
+ </details>
13
+
14
+ ### Example / 記述例
15
+
16
+ If you specify it on the command line, write as follows. / コマンドラインで指定する場合は以下のように記述します。
17
+
18
+ ```bash
19
+ accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
20
+ --network_module networks.lora --network_dim 32
21
+ --network_args "key1=value1" "key2=value2" ...
22
+ ```
23
+
24
+ If you specify it in the configuration file, write as follows. / 設定ファイルで指定する場合は以下のように記述します。
25
+
26
+ ```toml
27
+ network_args = ["key1=value1", "key2=value2", ...]
28
+ ```
29
+
30
+ If you specify `"verbose=True"`, detailed information of LoRA will be displayed. / `"verbose=True"`を指定するとLoRAの詳細な情報が表示されます。
31
+
32
+ ```bash
33
+ --network_args "verbose=True" "key1=value1" "key2=value2" ...
34
+ ```
35
+
36
+ ## LoRA+
37
+
38
+ LoRA+ is a method to improve the training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiplier for the learning rate. The original paper recommends 16, but adjust as needed. It seems to be good to start from around 4. For details, please refer to the [related PR of sd-scripts](https://github.com/kohya-ss/sd-scripts/pull/1233).
39
+
40
+ Specify `loraplus_lr_ratio` with `--network_args`.
41
+
42
+ <details>
43
+ <summary>日本語</summary>
44
+
45
+ LoRA+は、LoRAのUP側(LoRA-B)の学習率を上げることで学習速度を向上させる手法です。学習率に対する倍率を指定します。元論文では16を推奨していますが、必要に応じて調整してください。4程度から始めるとよいようです。詳細は[sd-scriptsの関連PR]https://github.com/kohya-ss/sd-scripts/pull/1233)を参照してください。
46
+
47
+ `--network_args`で`loraplus_lr_ratio`を指定します。
48
+ </details>
49
+
50
+ ### Example / 記述例
51
+
52
+ ```bash
53
+ accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
54
+ --network_module networks.lora --network_dim 32 --network_args "loraplus_lr_ratio=4" ...
55
+ ```
56
+
57
+ ## Select the target modules of LoRA / LoRAの対象モジュールを選択する
58
+
59
+ *This feature is highly experimental and the specification may change. / この機能は特に実験的なもので、仕様は変更される可能性があります。*
60
+
61
+ By specifying `exclude_patterns` and `include_patterns` with `--network_args`, you can select the target modules of LoRA.
62
+
63
+ `exclude_patterns` excludes modules that match the specified pattern. `include_patterns` targets only modules that match the specified pattern.
64
+
65
+ Specify the values as a list. For example, `"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`.
66
+
67
+ The pattern is a regular expression for the module name. The module name is in the form of `double_blocks.0.img_mod.linear` or `single_blocks.39.modulation.linear`. The regular expression is not a partial match but a complete match.
68
+
69
+ The patterns are applied in the order of `exclude_patterns`→`include_patterns`. By default, the Linear layers of `img_mod`, `txt_mod`, and `modulation` of double blocks and single blocks are excluded.
70
+
71
+ (`.*(img_mod|txt_mod|modulation).*` is specified.)
72
+
73
+ <details>
74
+ <summary>日本語</summary>
75
+
76
+ `--network_args`で`exclude_patterns`と`include_patterns`を指定することで、LoRAの対象モジュールを選択することができます。
77
+
78
+ `exclude_patterns`は、指定したパターンに一致するモジュールを除外します。`include_patterns`は、指定したパターンに一致するモジュールのみを対象とします。
79
+
80
+ 値は、リストで指定します。`"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`のようになります。
81
+
82
+ パターンは、モジュール名に対する正規表現です。モジュール名は、たとえば`double_blocks.0.img_mod.linear`や`single_blocks.39.modulation.linear`のような形式です。正規表現は部分一致ではなく完全一致です。
83
+
84
+ パターンは、`exclude_patterns`→`include_patterns`の順で適用されます。デフォルトは、double blocksとsingle blocksのLinear層のうち、`img_mod`、`txt_mod`、`modulation`が除外されています。
85
+
86
+ (`.*(img_mod|txt_mod|modulation).*`が指定されています。)
87
+ </details>
88
+
89
+ ### Example / 記述例
90
+
91
+ Only the modules of double blocks / double blocksのモジュールのみを対象とする場合:
92
+
93
+ ```bash
94
+ --network_args "exclude_patterns=[r'.*single_blocks.*']"
95
+ ```
96
+
97
+ Only the modules of single blocks from the 10th / single blocksの10番目以降のLinearモジュールのみを対象とする場合:
98
+
99
+ ```bash
100
+ --network_args "exclude_patterns=[r'.*']" "include_patterns=[r'.*single_blocks\.\d{2}\.linear.*']"
101
+ ```
102
+
103
+ ## Save and view logs in TensorBoard format / TensorBoard形式のログの保存と参照
104
+
105
+ Specify the folder to save the logs with the `--logging_dir` option. Logs in TensorBoard format will be saved.
106
+
107
+ For example, if you specify `--logging_dir=logs`, a `logs` folder will be created in the working folder, and logs will be saved in the date folder inside it.
108
+
109
+ Also, if you specify the `--log_prefix` option, the specified string will be added before the date. For example, use `--logging_dir=logs --log_prefix=lora_setting1_` for identification.
110
+
111
+ To view logs in TensorBoard, open another command prompt and activate the virtual environment. Then enter the following in the working folder.
112
+
113
+ ```powershell
114
+ tensorboard --logdir=logs
115
+ ```
116
+
117
+ (tensorboard installation is required.)
118
+
119
+ Then open a browser and access http://localhost:6006/ to display it.
120
+
121
+ <details>
122
+ <summary>日本語</summary>
123
+ `--logging_dir`オプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
124
+
125
+ たとえば`--logging_dir=logs`と指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。
126
+
127
+ また`--log_prefix`オプションを指定すると、日時の前に指定した文字列が追加されます。`--logging_dir=logs --log_prefix=lora_setting1_`などとして識別用にお使いください。
128
+
129
+ TensorBoardでログを確認するには、別のコマンドプロンプトを開き、仮想環境を有効にしてから、作業フォルダで以下のように入力します。
130
+
131
+ ```powershell
132
+ tensorboard --logdir=logs
133
+ ```
134
+
135
+ (tensorboardのインストールが必要です。)
136
+
137
+ その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
138
+ </details>
139
+
140
+ ## Save and view logs in wandb / wandbでログの保存と参照
141
+
142
+ `--log_with wandb` option is available to save logs in wandb format. `tensorboard` or `all` is also available. The default is `tensorboard`.
143
+
144
+ Specify the project name with `--log_tracker_name` when using wandb.
145
+
146
+ <details>
147
+ <summary>日本語</summary>
148
+ `--log_with wandb`オプションを指定するとwandb形式でログを保存することができます。`tensorboard`や`all`も指定可能です。デフォルトは`tensorboard`です。
149
+
150
+ wandbを使用する場合は、`--log_tracker_name`でプロジェクト名を指定してください。
151
+ </details>
docs/sampling_during_training.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > 📝 Click on the language section to expand / 言語をクリックして展開
2
+
3
+ # Sampling during training / 学習中のサンプル画像生成
4
+
5
+ By preparing a prompt file, you can generate sample images during training.
6
+
7
+ Please be aware that it consumes a considerable amount of VRAM, so be careful when generating sample images for videos with a large number of frames. Also, since it takes time to generate, adjust the frequency of sample image generation as needed.
8
+
9
+ <details>
10
+ <summary>日本語</summary>
11
+
12
+ プロンプトファイルを用意することで、学習中にサンプル画像を生成することができます。
13
+
14
+ VRAMをそれなりに消費しますので、特にフレーム数が多い動画を生成する場合は注意してください。また生成には時間がかかりますので、サンプル画像生成の頻度は適宜調整してください。
15
+ </details>
16
+
17
+ ## How to use / 使い方
18
+
19
+ ### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
20
+
21
+ Example of command line options for training with sampling / 記述例:
22
+
23
+ ```bash
24
+ --vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt
25
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128
26
+ --text_encoder1 path/to/ckpts/text_encoder
27
+ --text_encoder2 path/to/ckpts/text_encoder_2
28
+ --sample_prompts /path/to/prompt_file.txt
29
+ --sample_every_n_epochs 1 --sample_every_n_steps 1000 --sample_at_first
30
+ ```
31
+
32
+ `--vae`, `--vae_chunk_size`, `--vae_spatial_tile_sample_min_size`, `--text_encoder1`, `--text_encoder2` are the same as when generating images, so please refer to [here](/README.md#inference) for details. `--fp8_llm` can also be specified.
33
+
34
+ `--sample_prompts` specifies the path to the prompt file used for sample image generation. Details are described below.
35
+
36
+ `--sample_every_n_epochs` specifies how often to generate sample images in epochs, and `--sample_every_n_steps` specifies how often to generate sample images in steps.
37
+
38
+ `--sample_at_first` is specified when generating sample images at the beginning of training.
39
+
40
+ Sample images and videos are saved in the `sample` directory in the directory specified by `--output_dir`. They are saved as `.png` for still images and `.mp4` for videos.
41
+
42
+ <details>
43
+ <summary>日本語</summary>
44
+
45
+ `--vae`、`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`、`--text_encoder1`、`--text_encoder2`は、画像生成時と同様ですので、詳細は[こちら](/README.ja.md#推論)を参照してください。`--fp8_llm`も指定可能です。
46
+
47
+ `--sample_prompts`は、サンプル画像生成に使用するプロンプトファイルのパスを指定します。詳細は後述します。
48
+
49
+ `--sample_every_n_epochs`は、何エポックごとにサンプル画像を生成するかを、`--sample_every_n_steps`は、何ステップごとにサンプル画像を生成するかを指定します。
50
+
51
+ `--sample_at_first`は、学習開始時にサンプル画像を生成する場合に指定します。
52
+
53
+ サンプル画像、動画は、`--output_dir`で指定したディレクトリ内の、`sample`ディレクトリに保存されます。静止画の場合は`.png`、動画の場合は`.mp4`で保存されます。
54
+ </details>
55
+
56
+ ### Prompt file / プロンプトファイル
57
+
58
+ The prompt file is a text file that contains the prompts for generating sample images. The example is as follows. / プロンプトファイルは、サンプル画像生成のためのプロンプトを記述したテキストファイルです。例は以下の通りです。
59
+
60
+ ```
61
+ # prompt 1: for generating a cat video
62
+ A cat walks on the grass, realistic style. --w 640 --h 480 --f 25 --d 1 --s 20
63
+
64
+ # prompt 2: for generating a dog image
65
+ A dog runs on the beach, realistic style. --w 960 --h 544 --f 1 --d 2 --s 20
66
+ ```
67
+
68
+ A line starting with `#` is a comment.
69
+
70
+ * `--w` specifies the width of the generated image or video. The default is 256.
71
+ * `--h` specifies the height. The default is 256.
72
+ * `--f` specifies the number of frames. The default is 1, which generates a still image.
73
+ * `--d` specifies the seed. The default is random.
74
+ * `--s` specifies the number of steps in generation. The default is 20.
75
+ * `--g` specifies the guidance scale. The default is 6.0, which is the default value during inference of HunyuanVideo. Specify 1.0 for SkyReels V1 models. Ignore this option for Wan2.1 models.
76
+ * `--fs` specifies the discrete flow shift. The default is 14.5, which corresponds to the number of steps 20. In the HunyuanVideo paper, 7.0 is recommended for 50 steps, and 17.0 is recommended for less than 20 steps (e.g. 10).
77
+
78
+ If you train I2V models, you can use the additional options below.
79
+
80
+ * `--i path/to/image.png`: the image path for image2video inference.
81
+
82
+ If you train the model with classifier free guidance, you can use the additional options below.
83
+
84
+ *`--n negative prompt...`: the negative prompt for the classifier free guidance.
85
+ *`--l 6.0`: the classifier free guidance scale. Should be set to 6.0 for SkyReels V1 models. 5.0 is the default value for Wan2.1 (if omitted).
86
+
87
+ <details>
88
+ <summary>日本語</summary>
89
+
90
+ `#` で始まる行はコメントです。
91
+
92
+ * `--w` 生成画像、動画の幅を指定します。省略時は256です。
93
+ * `--h` 高さを指定します。省略時は256です。
94
+ * `--f` フレーム数を指定します。省略時は1で、静止画を生成します。
95
+ * `--d` シードを指定します。省略時はランダムです。
96
+ * `--s` 生成におけるステップ数を指定します。省略時は20です。
97
+ * `--g` guidance scaleを指定します。省略時は6.0で、HunyuanVideoの推論時のデフォルト値です。
98
+ * `--fs` discrete flow shiftを指定します。省略時は14.5で、ステップ数20の場合に対応した値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。
99
+
100
+ I2Vモデルを学習する場合、以下の追加オプションを使用できます。
101
+
102
+ * `--i path/to/image.png`: image2video推論用の画像パス。
103
+
104
+ classifier free guidance(ネガティブプロンプト)を必要とするモデルを学習する場合、以下の追加オプションを使用できます。
105
+
106
+ *`--n negative prompt...`: classifier free guidance用のネガティブプロンプト。
107
+ *`--l 6.0`: classifier free guidance scale。SkyReels V1モデルの場合は6.0に設定してください。Wan2.1の場合はデフォルト値が5.0です(省略時)。
108
+ </details>
docs/wan.md ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > 📝 Click on the language section to expand / 言語をクリックして展開
2
+
3
+ # Wan 2.1
4
+
5
+ ## Overview / 概要
6
+
7
+ This is an unofficial training and inference script for [Wan2.1](https://github.com/Wan-Video/Wan2.1). The features are as follows.
8
+
9
+ - fp8 support and memory reduction by block swap: Inference of a 720x1280x81frames videos with 24GB VRAM, training with 720x1280 images with 24GB VRAM
10
+ - Inference without installing Flash attention (using PyTorch's scaled dot product attention)
11
+ - Supports xformers and Sage attention
12
+
13
+ This feature is experimental.
14
+
15
+ <details>
16
+ <summary>日本語</summary>
17
+ [Wan2.1](https://github.com/Wan-Video/Wan2.1) の非公式の学習および推論スクリプトです。
18
+
19
+ 以下の特徴があります。
20
+
21
+ - fp8対応およびblock swapによる省メモリ化:720x1280x81framesの動画を24GB VRAMで推論可能、720x1280の画像での学習が24GB VRAMで可能
22
+ - Flash attentionのインストールなしでの実行(PyTorchのscaled dot product attentionを使用)
23
+ - xformersおよびSage attention対応
24
+
25
+ この機能は実験的なものです。
26
+ </details>
27
+
28
+ ## Download the model / モデルのダウンロード
29
+
30
+ Download the T5 `models_t5_umt5-xxl-enc-bf16.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
31
+
32
+ Download the VAE from the above page `Wan2.1_VAE.pth` or download `split_files/vae/wan_2.1_vae.safetensors` from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
33
+
34
+ Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
35
+
36
+ Please select the appropriate weights according to T2V, I2V, resolution, model size, etc. fp8 models can be used if `--fp8` is specified.
37
+
38
+ (Thanks to Comfy-Org for providing the repackaged weights.)
39
+ <details>
40
+ <summary>日本語</summary>
41
+ T5 `models_t5_umt5-xxl-enc-bf16.pth` およびCLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を、次のページからダウンロードしてください:https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
42
+
43
+ VAEは上のページから `Wan2.1_VAE.pth` をダウンロードするか、次のページから `split_files/vae/wan_2.1_vae.safetensors` をダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
44
+
45
+ DiTの重みを次のページからダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
46
+
47
+ T2VやI2V、解像度、モデルサイズなどにより適切な重みを選択してください。`--fp8`指定時はfp8モデルも使用できます。
48
+
49
+ (repackaged版の重みを提供してくださっているComfy-Orgに感謝いたします。)
50
+ </details>
51
+
52
+ ## Pre-caching / 事前キャッシュ
53
+
54
+ ### Latent Pre-caching
55
+
56
+ Latent pre-caching is almost the same as in HunyuanVideo. Create the cache using the following command:
57
+
58
+ ```bash
59
+ python wan_cache_latents.py --dataset_config path/to/toml --vae path/to/wan_2.1_vae.safetensors
60
+ ```
61
+
62
+ If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model. If not specified, the training will raise an error.
63
+
64
+ If you're running low on VRAM, specify `--vae_cache_cpu` to use the CPU for the VAE internal cache, which will reduce VRAM usage somewhat.
65
+
66
+ <details>
67
+ <summary>日本語</summary>
68
+ latentの事前キャッシングはHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
69
+
70
+ I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。指定しないと学習時にエラーが発生します。
71
+
72
+ VRAMが不足している場合は、`--vae_cache_cpu` を指定するとVAEの内部キャッシュにCPUを使うことで、使用VRAMを多少削減できます。
73
+ </details>
74
+
75
+ ### Text Encoder Output Pre-caching
76
+
77
+ Text encoder output pre-caching is also almost the same as in HunyuanVideo. Create the cache using the following command:
78
+
79
+ ```bash
80
+ python wan_cache_text_encoder_outputs.py --dataset_config path/to/toml --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --batch_size 16
81
+ ```
82
+
83
+ Adjust `--batch_size` according to your available VRAM.
84
+
85
+ For systems with limited VRAM (less than ~16GB), use `--fp8_t5` to run the T5 in fp8 mode.
86
+
87
+ <details>
88
+ <summary>日本語</summary>
89
+ テキストエンコーダ出力の事前キャッシングもHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
90
+
91
+ 使用可能なVRAMに合わせて `--batch_size` を調整してください。
92
+
93
+ VRAMが限られているシステム(約16GB未満)の場合は、T5をfp8モードで実行するために `--fp8_t5` を使用してください。
94
+ </details>
95
+
96
+ ## Training / 学習
97
+
98
+ ### Training
99
+
100
+ Start training using the following command (input as a single line):
101
+
102
+ ```bash
103
+ accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 wan_train_network.py
104
+ --task t2v-1.3B
105
+ --dit path/to/wan2.1_xxx_bf16.safetensors
106
+ --dataset_config path/to/toml --sdpa --mixed_precision bf16 --fp8_base
107
+ --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing
108
+ --max_data_loader_n_workers 2 --persistent_data_loader_workers
109
+ --network_module networks.lora_wan --network_dim 32
110
+ --timestep_sampling shift --discrete_flow_shift 3.0
111
+ --max_train_epochs 16 --save_every_n_epochs 1 --seed 42
112
+ --output_dir path/to/output_dir --output_name name-of-lora
113
+ ```
114
+ The above is an example. The appropriate values for `timestep_sampling` and `discrete_flow_shift` need to be determined by experimentation.
115
+
116
+ For additional options, use `python wan_train_network.py --help` (note that many options are unverified).
117
+
118
+ `--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`. Specify the DiT weights for the task with `--dit`.
119
+
120
+ Don't forget to specify `--network_module networks.lora_wan`.
121
+
122
+ Other options are mostly the same as `hv_train_network.py`.
123
+
124
+ Use `convert_lora.py` for converting the LoRA weights after training, as in HunyuanVideo.
125
+
126
+ <details>
127
+ <summary>日本語</summary>
128
+ `timestep_sampling`や`discrete_flow_shift`は一例です。どのような値が適切かは実験が必要です。
129
+
130
+ その他のオプションについては `python wan_train_network.py --help` を使用してください(多くのオプションは未検証です)。
131
+
132
+ `--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。`--dit`に、taskに応じたDiTの重みを指定してください。
133
+
134
+ `--network_module` に `networks.lora_wan` を指定することを忘れないでください。
135
+
136
+ その他のオプションは、ほぼ`hv_train_network.py`と同様です。
137
+
138
+ 学習後のLoRAの重みの変換は、HunyuanVideoと同様に`convert_lora.py`を使用してください。
139
+ </details>
140
+
141
+ ### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
142
+
143
+ Example of command line options for training with sampling / 記述例:
144
+
145
+ ```bash
146
+ --vae path/to/wan_2.1_vae.safetensors
147
+ --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
148
+ --sample_prompts /path/to/prompt_file.txt
149
+ --sample_every_n_epochs 1 --sample_every_n_steps 1000 -- sample_at_first
150
+ ```
151
+ Each option is the same as when generating images or as HunyuanVideo. Please refer to [here](/docs/sampling_during_training.md) for details.
152
+
153
+ If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model.
154
+
155
+ You can specify the initial image and negative prompts in the prompt file. Please refer to [here](/docs/sampling_during_training.md#prompt-file--プロンプトファイル).
156
+
157
+ <details>
158
+ <summary>日本語</summary>
159
+ 各オプションは推論時、およびHunyuanVideoの場合と同様です。[こちら](/docs/sampling_during_training.md)を参照してください。
160
+
161
+ I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。
162
+
163
+ プロンプトファイルで、初期画像やネガティブプロンプト等を指定できます。[こちら](/docs/sampling_during_training.md#prompt-file--プロンプトファイル)を参照してください。
164
+ </details>
165
+
166
+
167
+ ## Inference / 推論
168
+
169
+ ### T2V Inference / T2V推論
170
+
171
+ The following is an example of T2V inference (input as a single line):
172
+
173
+ ```bash
174
+ python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 832 480 --video_length 81 --infer_steps 20
175
+ --prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
176
+ --dit path/to/wan2.1_t2v_1.3B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
177
+ --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
178
+ --attn_mode torch
179
+ ```
180
+
181
+ `--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`.
182
+
183
+ `--attn_mode` is `torch`, `sdpa` (same as `torch`), `xformers`, `sageattn`,`flash2`, `flash` (same as `flash2`) or `flash3`. `torch` is the default. Other options require the corresponding library to be installed. `flash3` (Flash attention 3) is not tested.
184
+
185
+ `--fp8_t5` can be used to specify the T5 model in fp8 format. This option reduces memory usage for the T5 model.
186
+
187
+ `--negative_prompt` can be used to specify a negative prompt. If omitted, the default negative prompt is used.
188
+
189
+ ` --flow_shift` can be used to specify the flow shift (default 3.0 for I2V with 480p, 5.0 for others).
190
+
191
+ `--guidance_scale` can be used to specify the guidance scale for classifier free guiance (default 5.0).
192
+
193
+ `--blocks_to_swap` is the number of blocks to swap during inference. The default value is None (no block swap). The maximum value is 39 for 14B model and 29 for 1.3B model.
194
+
195
+ `--vae_cache_cpu` enables VAE cache in main memory. This reduces VRAM usage slightly but processing is slower.
196
+
197
+ Other options are same as `hv_generate_video.py` (some options are not supported, please check the help).
198
+
199
+ <details>
200
+ <summary>日本語</summary>
201
+ `--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。
202
+
203
+ `--attn_mode` には `torch`, `sdpa`(`torch`と同じ)、`xformers`, `sageattn`, `flash2`, `flash`(`flash2`と同じ), `flash3` のいずれかを指定します。デフォルトは `torch` です。その他のオプションを使用する場合は、対応するライブラリをインストールする必要があります。`flash3`(Flash attention 3)は未テストです。
204
+
205
+ `--fp8_t5` を指定するとT5モデルをfp8形式で実行します。T5モデル呼び出し時のメモリ使用量を削減します。
206
+
207
+ `--negative_prompt` でネガティブプロンプトを指定できます。省略した場合はデフォルトのネガティブプロンプトが使用されます。
208
+
209
+ `--flow_shift` でflow shiftを指定できます(480pのI2Vの場合はデフォルト3.0、それ以外は5.0)。
210
+
211
+ `--guidance_scale` でclassifier free guianceのガイダンススケールを指定できます(デフォルト5.0)。
212
+
213
+ `--blocks_to_swap` は推論時のblock swapの数です。デフォルト値はNone(block swapなし)です。最大値は14Bモデルの場合39、1.3Bモデルの場合29です。
214
+
215
+ `--vae_cache_cpu` を有効にすると、VAEのキャッシュをメインメモリに保持します。VRAM使用量が多少減りますが、処理は遅くなります。
216
+
217
+ その他のオプションは `hv_generate_video.py` と同じです(一部のオプションはサポートされていないため、ヘルプを確認してください)。
218
+ </details>
219
+
220
+ ### I2V Inference / I2V推論
221
+
222
+ The following is an example of I2V inference (input as a single line):
223
+
224
+ ```bash
225
+ python wan_generate_video.py --fp8 --task i2v-14B --video_size 832 480 --video_length 81 --infer_steps 20
226
+ --prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
227
+ --dit path/to/wan2.1_i2v_480p_14B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
228
+ --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
229
+ --attn_mode torch --image_path path/to/image.jpg
230
+ ```
231
+
232
+ Add `--clip` to specify the CLIP model. `--image_path` is the path to the image to be used as the initial frame.
233
+
234
+ Other options are same as T2V inference.
235
+
236
+ <details>
237
+ <summary>日本語</summary>
238
+ `--clip` を追加してCLIPモデルを指定します。`--image_path` は初期フレームとして使用する画像のパスです。
239
+
240
+ その他のオプションはT2V推論と同じです。
241
+ </details>
hunyuan_model/__init__.py ADDED
File without changes
hunyuan_model/activation_layers.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation_layer(act_type):
5
+ """get activation layer
6
+
7
+ Args:
8
+ act_type (str): the activation type
9
+
10
+ Returns:
11
+ torch.nn.functional: the activation layer
12
+ """
13
+ if act_type == "gelu":
14
+ return lambda: nn.GELU()
15
+ elif act_type == "gelu_tanh":
16
+ # Approximate `tanh` requires torch >= 1.13
17
+ return lambda: nn.GELU(approximate="tanh")
18
+ elif act_type == "relu":
19
+ return nn.ReLU
20
+ elif act_type == "silu":
21
+ return nn.SiLU
22
+ else:
23
+ raise ValueError(f"Unknown activation type: {act_type}")
hunyuan_model/attention.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ from flash_attn.flash_attn_interface import flash_attn_func
13
+ except ImportError:
14
+ flash_attn = None
15
+ flash_attn_varlen_func = None
16
+ _flash_attn_forward = None
17
+ flash_attn_func = None
18
+
19
+ try:
20
+ print(f"Trying to import sageattention")
21
+ from sageattention import sageattn_varlen, sageattn
22
+
23
+ print("Successfully imported sageattention")
24
+ except ImportError:
25
+ print(f"Failed to import sageattention")
26
+ sageattn_varlen = None
27
+ sageattn = None
28
+
29
+ try:
30
+ import xformers.ops as xops
31
+ except ImportError:
32
+ xops = None
33
+
34
+ MEMORY_LAYOUT = {
35
+ "flash": (
36
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
37
+ lambda x: x,
38
+ ),
39
+ "flash_fixlen": (
40
+ lambda x: x,
41
+ lambda x: x,
42
+ ),
43
+ "sageattn": (
44
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
45
+ lambda x: x,
46
+ ),
47
+ "sageattn_fixlen": (
48
+ lambda x: x.transpose(1, 2),
49
+ lambda x: x.transpose(1, 2),
50
+ ),
51
+ "torch": (
52
+ lambda x: x.transpose(1, 2),
53
+ lambda x: x.transpose(1, 2),
54
+ ),
55
+ "xformers": (
56
+ lambda x: x,
57
+ lambda x: x,
58
+ ),
59
+ "vanilla": (
60
+ lambda x: x.transpose(1, 2),
61
+ lambda x: x.transpose(1, 2),
62
+ ),
63
+ }
64
+
65
+
66
+ def get_cu_seqlens(text_mask, img_len):
67
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
68
+
69
+ Args:
70
+ text_mask (torch.Tensor): the mask of text
71
+ img_len (int): the length of image
72
+
73
+ Returns:
74
+ torch.Tensor: the calculated cu_seqlens for flash attention
75
+ """
76
+ batch_size = text_mask.shape[0]
77
+ text_len = text_mask.sum(dim=1)
78
+ max_len = text_mask.shape[1] + img_len
79
+
80
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
81
+
82
+ for i in range(batch_size):
83
+ s = text_len[i] + img_len
84
+ s1 = i * max_len + s
85
+ s2 = (i + 1) * max_len
86
+ cu_seqlens[2 * i + 1] = s1
87
+ cu_seqlens[2 * i + 2] = s2
88
+
89
+ return cu_seqlens
90
+
91
+
92
+ def attention(
93
+ q_or_qkv_list,
94
+ k=None,
95
+ v=None,
96
+ mode="flash",
97
+ drop_rate=0,
98
+ attn_mask=None,
99
+ total_len=None,
100
+ causal=False,
101
+ cu_seqlens_q=None,
102
+ cu_seqlens_kv=None,
103
+ max_seqlen_q=None,
104
+ max_seqlen_kv=None,
105
+ batch_size=1,
106
+ ):
107
+ """
108
+ Perform QKV self attention.
109
+
110
+ Args:
111
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
112
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
113
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
114
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
115
+ drop_rate (float): Dropout rate in attention map. (default: 0)
116
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
117
+ (default: None)
118
+ causal (bool): Whether to use causal attention. (default: False)
119
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
120
+ used to index into q.
121
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
122
+ used to index into kv.
123
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
124
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
125
+
126
+ Returns:
127
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
128
+ """
129
+ q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
130
+ if type(q_or_qkv_list) == list:
131
+ q_or_qkv_list.clear()
132
+ split_attn = total_len is not None
133
+ if split_attn and mode == "sageattn":
134
+ mode = "sageattn_fixlen"
135
+ elif split_attn and mode == "flash":
136
+ mode = "flash_fixlen"
137
+ # print(f"Attention mode: {mode}, split_attn: {split_attn}")
138
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
139
+
140
+ # trim the sequence length to the actual length instead of attn_mask
141
+ if split_attn:
142
+ trimmed_len = q.shape[1] - total_len
143
+ q = [q[i : i + 1, : total_len[i]] for i in range(len(q))]
144
+ k = [k[i : i + 1, : total_len[i]] for i in range(len(k))]
145
+ v = [v[i : i + 1, : total_len[i]] for i in range(len(v))]
146
+ q = [pre_attn_layout(q_i) for q_i in q]
147
+ k = [pre_attn_layout(k_i) for k_i in k]
148
+ v = [pre_attn_layout(v_i) for v_i in v]
149
+ # print(
150
+ # f"Trimming the sequence length to {total_len},trimmed_len: {trimmed_len}, q.shape: {[q_i.shape for q_i in q]}, mode: {mode}"
151
+ # )
152
+ else:
153
+ q = pre_attn_layout(q)
154
+ k = pre_attn_layout(k)
155
+ v = pre_attn_layout(v)
156
+
157
+ if mode == "torch":
158
+ if split_attn:
159
+ x = []
160
+ for i in range(len(q)):
161
+ x_i = F.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate, is_causal=causal)
162
+ q[i], k[i], v[i] = None, None, None
163
+ x.append(x_i)
164
+ del q, k, v
165
+ else:
166
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
167
+ attn_mask = attn_mask.to(q.dtype)
168
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
169
+ del q, k, v
170
+ del attn_mask
171
+
172
+ elif mode == "xformers":
173
+ # B, M, H, K: M is the sequence length, H is the number of heads, K is the dimension of the heads -> it is same as input dimension
174
+ # currently only support batch_size = 1
175
+ assert split_attn, "Xformers only supports splitting"
176
+ x = []
177
+ for i in range(len(q)):
178
+ x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) # , causal=causal)
179
+ q[i], k[i], v[i] = None, None, None
180
+ x.append(x_i)
181
+ del q, k, v
182
+
183
+ elif mode == "flash":
184
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
185
+ del q, k, v
186
+ # x with shape [(bxs), a, d]
187
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
188
+ elif mode == "flash_fixlen":
189
+ x = []
190
+ for i in range(len(q)):
191
+ # q: (batch_size, seqlen, nheads, headdim), k: (batch_size, seqlen, nheads_k, headdim), v: (batch_size, seqlen, nheads_k, headdim)
192
+ x_i = flash_attn_func(q[i], k[i], v[i], dropout_p=drop_rate, causal=causal)
193
+ q[i], k[i], v[i] = None, None, None
194
+ x.append(x_i)
195
+ del q, k, v
196
+ elif mode == "sageattn":
197
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
198
+ del q, k, v
199
+ # x with shape [(bxs), a, d]
200
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
201
+ elif mode == "sageattn_fixlen":
202
+ x = []
203
+ for i in range(len(q)):
204
+ # HND seems to cause an error
205
+ x_i = sageattn(q[i], k[i], v[i]) # (batch_size, seq_len, head_num, head_dim)
206
+ q[i], k[i], v[i] = None, None, None
207
+ x.append(x_i)
208
+ del q, k, v
209
+ elif mode == "vanilla":
210
+ assert not split_attn, "Vanilla attention does not support trimming"
211
+ scale_factor = 1 / math.sqrt(q.size(-1))
212
+
213
+ b, a, s, _ = q.shape
214
+ s1 = k.size(2)
215
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
216
+ if causal:
217
+ # Only applied to self attention
218
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
219
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
220
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
221
+ attn_bias.to(q.dtype)
222
+
223
+ if attn_mask is not None:
224
+ if attn_mask.dtype == torch.bool:
225
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
226
+ else:
227
+ attn_bias += attn_mask
228
+
229
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
230
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
231
+ attn += attn_bias
232
+ attn = attn.softmax(dim=-1)
233
+ attn = torch.dropout(attn, p=drop_rate, train=True)
234
+ x = attn @ v
235
+ else:
236
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
237
+
238
+ if split_attn:
239
+ x = [post_attn_layout(x_i) for x_i in x]
240
+ for i in range(len(x)):
241
+ x[i] = F.pad(x[i], (0, 0, 0, 0, 0, trimmed_len[i]))
242
+ x = torch.cat(x, dim=0)
243
+ else:
244
+ x = post_attn_layout(x)
245
+
246
+ b, s, a, d = x.shape
247
+ out = x.reshape(b, s, -1)
248
+ return out
249
+
250
+
251
+ def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
252
+ attn1 = hybrid_seq_parallel_attn(
253
+ None,
254
+ q[:, :img_q_len, :, :],
255
+ k[:, :img_kv_len, :, :],
256
+ v[:, :img_kv_len, :, :],
257
+ dropout_p=0.0,
258
+ causal=False,
259
+ joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
260
+ joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
261
+ joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
262
+ joint_strategy="rear",
263
+ )
264
+ if flash_attn.__version__ >= "2.7.0":
265
+ attn2, *_ = _flash_attn_forward(
266
+ q[:, cu_seqlens_q[1] :],
267
+ k[:, cu_seqlens_kv[1] :],
268
+ v[:, cu_seqlens_kv[1] :],
269
+ dropout_p=0.0,
270
+ softmax_scale=q.shape[-1] ** (-0.5),
271
+ causal=False,
272
+ window_size_left=-1,
273
+ window_size_right=-1,
274
+ softcap=0.0,
275
+ alibi_slopes=None,
276
+ return_softmax=False,
277
+ )
278
+ else:
279
+ attn2, *_ = _flash_attn_forward(
280
+ q[:, cu_seqlens_q[1] :],
281
+ k[:, cu_seqlens_kv[1] :],
282
+ v[:, cu_seqlens_kv[1] :],
283
+ dropout_p=0.0,
284
+ softmax_scale=q.shape[-1] ** (-0.5),
285
+ causal=False,
286
+ window_size=(-1, -1),
287
+ softcap=0.0,
288
+ alibi_slopes=None,
289
+ return_softmax=False,
290
+ )
291
+ attn = torch.cat([attn1, attn2], dim=1)
292
+ b, s, a, d = attn.shape
293
+ attn = attn.reshape(b, s, -1)
294
+
295
+ return attn
hunyuan_model/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ from typing import Dict, Optional, Tuple, Union
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+
27
+ # try:
28
+ # # This diffusers is modified and packed in the mirror.
29
+ # from diffusers.loaders import FromOriginalVAEMixin
30
+ # except ImportError:
31
+ # # Use this to be compatible with the original diffusers.
32
+ # from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
33
+ from diffusers.utils.accelerate_utils import apply_forward_hook
34
+ from diffusers.models.attention_processor import (
35
+ ADDED_KV_ATTENTION_PROCESSORS,
36
+ CROSS_ATTENTION_PROCESSORS,
37
+ Attention,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
45
+
46
+
47
+ @dataclass
48
+ class DecoderOutput2(BaseOutput):
49
+ sample: torch.FloatTensor
50
+ posterior: Optional[DiagonalGaussianDistribution] = None
51
+
52
+
53
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin):
54
+ r"""
55
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
56
+
57
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
58
+ for all models (such as downloading or saving).
59
+ """
60
+
61
+ _supports_gradient_checkpointing = True
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
69
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
70
+ block_out_channels: Tuple[int] = (64,),
71
+ layers_per_block: int = 1,
72
+ act_fn: str = "silu",
73
+ latent_channels: int = 4,
74
+ norm_num_groups: int = 32,
75
+ sample_size: int = 32,
76
+ sample_tsize: int = 64,
77
+ scaling_factor: float = 0.18215,
78
+ force_upcast: float = True,
79
+ spatial_compression_ratio: int = 8,
80
+ time_compression_ratio: int = 4,
81
+ mid_block_add_attention: bool = True,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.time_compression_ratio = time_compression_ratio
86
+
87
+ self.encoder = EncoderCausal3D(
88
+ in_channels=in_channels,
89
+ out_channels=latent_channels,
90
+ down_block_types=down_block_types,
91
+ block_out_channels=block_out_channels,
92
+ layers_per_block=layers_per_block,
93
+ act_fn=act_fn,
94
+ norm_num_groups=norm_num_groups,
95
+ double_z=True,
96
+ time_compression_ratio=time_compression_ratio,
97
+ spatial_compression_ratio=spatial_compression_ratio,
98
+ mid_block_add_attention=mid_block_add_attention,
99
+ )
100
+
101
+ self.decoder = DecoderCausal3D(
102
+ in_channels=latent_channels,
103
+ out_channels=out_channels,
104
+ up_block_types=up_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ norm_num_groups=norm_num_groups,
108
+ act_fn=act_fn,
109
+ time_compression_ratio=time_compression_ratio,
110
+ spatial_compression_ratio=spatial_compression_ratio,
111
+ mid_block_add_attention=mid_block_add_attention,
112
+ )
113
+
114
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
115
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
116
+
117
+ self.use_slicing = False
118
+ self.use_spatial_tiling = False
119
+ self.use_temporal_tiling = False
120
+
121
+ # only relevant if vae tiling is enabled
122
+ self.tile_sample_min_tsize = sample_tsize
123
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
124
+
125
+ self.tile_sample_min_size = self.config.sample_size
126
+ sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
127
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
128
+ self.tile_overlap_factor = 0.25
129
+
130
+ def _set_gradient_checkpointing(self, module, value=False):
131
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
132
+ module.gradient_checkpointing = value
133
+
134
+ def enable_temporal_tiling(self, use_tiling: bool = True):
135
+ self.use_temporal_tiling = use_tiling
136
+
137
+ def disable_temporal_tiling(self):
138
+ self.enable_temporal_tiling(False)
139
+
140
+ def enable_spatial_tiling(self, use_tiling: bool = True):
141
+ self.use_spatial_tiling = use_tiling
142
+
143
+ def disable_spatial_tiling(self):
144
+ self.enable_spatial_tiling(False)
145
+
146
+ def enable_tiling(self, use_tiling: bool = True):
147
+ r"""
148
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
149
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
150
+ processing larger videos.
151
+ """
152
+ self.enable_spatial_tiling(use_tiling)
153
+ self.enable_temporal_tiling(use_tiling)
154
+
155
+ def disable_tiling(self):
156
+ r"""
157
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
158
+ decoding in one step.
159
+ """
160
+ self.disable_spatial_tiling()
161
+ self.disable_temporal_tiling()
162
+
163
+ def enable_slicing(self):
164
+ r"""
165
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
166
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
167
+ """
168
+ self.use_slicing = True
169
+
170
+ def disable_slicing(self):
171
+ r"""
172
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
173
+ decoding in one step.
174
+ """
175
+ self.use_slicing = False
176
+
177
+ def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
178
+ # set chunk_size to CausalConv3d recursively
179
+ def set_chunk_size(module):
180
+ if hasattr(module, "chunk_size"):
181
+ module.chunk_size = chunk_size
182
+
183
+ self.apply(set_chunk_size)
184
+
185
+ @property
186
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
187
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
188
+ r"""
189
+ Returns:
190
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
191
+ indexed by its weight name.
192
+ """
193
+ # set recursively
194
+ processors = {}
195
+
196
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
197
+ if hasattr(module, "get_processor"):
198
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
199
+
200
+ for sub_name, child in module.named_children():
201
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
202
+
203
+ return processors
204
+
205
+ for name, module in self.named_children():
206
+ fn_recursive_add_processors(name, module, processors)
207
+
208
+ return processors
209
+
210
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
211
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
212
+ r"""
213
+ Sets the attention processor to use to compute attention.
214
+
215
+ Parameters:
216
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
217
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
218
+ for **all** `Attention` layers.
219
+
220
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
221
+ processor. This is strongly recommended when setting trainable attention processors.
222
+
223
+ """
224
+ count = len(self.attn_processors.keys())
225
+
226
+ if isinstance(processor, dict) and len(processor) != count:
227
+ raise ValueError(
228
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
229
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
230
+ )
231
+
232
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
233
+ if hasattr(module, "set_processor"):
234
+ if not isinstance(processor, dict):
235
+ module.set_processor(processor, _remove_lora=_remove_lora)
236
+ else:
237
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
238
+
239
+ for sub_name, child in module.named_children():
240
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
241
+
242
+ for name, module in self.named_children():
243
+ fn_recursive_attn_processor(name, module, processor)
244
+
245
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
246
+ def set_default_attn_processor(self):
247
+ """
248
+ Disables custom attention processors and sets the default attention implementation.
249
+ """
250
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
251
+ processor = AttnAddedKVProcessor()
252
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
253
+ processor = AttnProcessor()
254
+ else:
255
+ raise ValueError(
256
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
257
+ )
258
+
259
+ self.set_attn_processor(processor, _remove_lora=True)
260
+
261
+ @apply_forward_hook
262
+ def encode(
263
+ self, x: torch.FloatTensor, return_dict: bool = True
264
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
265
+ """
266
+ Encode a batch of images/videos into latents.
267
+
268
+ Args:
269
+ x (`torch.FloatTensor`): Input batch of images/videos.
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
272
+
273
+ Returns:
274
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
275
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
276
+ """
277
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
278
+
279
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
280
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
281
+
282
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
283
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
284
+
285
+ if self.use_slicing and x.shape[0] > 1:
286
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
287
+ h = torch.cat(encoded_slices)
288
+ else:
289
+ h = self.encoder(x)
290
+
291
+ moments = self.quant_conv(h)
292
+ posterior = DiagonalGaussianDistribution(moments)
293
+
294
+ if not return_dict:
295
+ return (posterior,)
296
+
297
+ return AutoencoderKLOutput(latent_dist=posterior)
298
+
299
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
300
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
301
+
302
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
303
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
304
+
305
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
306
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
307
+
308
+ z = self.post_quant_conv(z)
309
+ dec = self.decoder(z)
310
+
311
+ if not return_dict:
312
+ return (dec,)
313
+
314
+ return DecoderOutput(sample=dec)
315
+
316
+ @apply_forward_hook
317
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
318
+ """
319
+ Decode a batch of images/videos.
320
+
321
+ Args:
322
+ z (`torch.FloatTensor`): Input batch of latent vectors.
323
+ return_dict (`bool`, *optional*, defaults to `True`):
324
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
325
+
326
+ Returns:
327
+ [`~models.vae.DecoderOutput`] or `tuple`:
328
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
329
+ returned.
330
+
331
+ """
332
+ if self.use_slicing and z.shape[0] > 1:
333
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
334
+ decoded = torch.cat(decoded_slices)
335
+ else:
336
+ decoded = self._decode(z).sample
337
+
338
+ if not return_dict:
339
+ return (decoded,)
340
+
341
+ return DecoderOutput(sample=decoded)
342
+
343
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
345
+ for y in range(blend_extent):
346
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
347
+ return b
348
+
349
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
351
+ for x in range(blend_extent):
352
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
353
+ return b
354
+
355
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
357
+ for x in range(blend_extent):
358
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
359
+ return b
360
+
361
+ def spatial_tiled_encode(
362
+ self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
363
+ ) -> AutoencoderKLOutput:
364
+ r"""Encode a batch of images/videos using a tiled encoder.
365
+
366
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
367
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
368
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
369
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
370
+ output, but they should be much less noticeable.
371
+
372
+ Args:
373
+ x (`torch.FloatTensor`): Input batch of images/videos.
374
+ return_dict (`bool`, *optional*, defaults to `True`):
375
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
376
+
377
+ Returns:
378
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
379
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
380
+ `tuple` is returned.
381
+ """
382
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
383
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
384
+ row_limit = self.tile_latent_min_size - blend_extent
385
+
386
+ # Split video into tiles and encode them separately.
387
+ rows = []
388
+ for i in range(0, x.shape[-2], overlap_size):
389
+ row = []
390
+ for j in range(0, x.shape[-1], overlap_size):
391
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
392
+ tile = self.encoder(tile)
393
+ tile = self.quant_conv(tile)
394
+ row.append(tile)
395
+ rows.append(row)
396
+ result_rows = []
397
+ for i, row in enumerate(rows):
398
+ result_row = []
399
+ for j, tile in enumerate(row):
400
+ # blend the above tile and the left tile
401
+ # to the current tile and add the current tile to the result row
402
+ if i > 0:
403
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
404
+ if j > 0:
405
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
406
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
407
+ result_rows.append(torch.cat(result_row, dim=-1))
408
+
409
+ moments = torch.cat(result_rows, dim=-2)
410
+ if return_moments:
411
+ return moments
412
+
413
+ posterior = DiagonalGaussianDistribution(moments)
414
+ if not return_dict:
415
+ return (posterior,)
416
+
417
+ return AutoencoderKLOutput(latent_dist=posterior)
418
+
419
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
420
+ r"""
421
+ Decode a batch of images/videos using a tiled decoder.
422
+
423
+ Args:
424
+ z (`torch.FloatTensor`): Input batch of latent vectors.
425
+ return_dict (`bool`, *optional*, defaults to `True`):
426
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
427
+
428
+ Returns:
429
+ [`~models.vae.DecoderOutput`] or `tuple`:
430
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
431
+ returned.
432
+ """
433
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
434
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
435
+ row_limit = self.tile_sample_min_size - blend_extent
436
+
437
+ # Split z into overlapping tiles and decode them separately.
438
+ # The tiles have an overlap to avoid seams between tiles.
439
+ rows = []
440
+ for i in range(0, z.shape[-2], overlap_size):
441
+ row = []
442
+ for j in range(0, z.shape[-1], overlap_size):
443
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
444
+ tile = self.post_quant_conv(tile)
445
+ decoded = self.decoder(tile)
446
+ row.append(decoded)
447
+ rows.append(row)
448
+ result_rows = []
449
+ for i, row in enumerate(rows):
450
+ result_row = []
451
+ for j, tile in enumerate(row):
452
+ # blend the above tile and the left tile
453
+ # to the current tile and add the current tile to the result row
454
+ if i > 0:
455
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
456
+ if j > 0:
457
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
458
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
459
+ result_rows.append(torch.cat(result_row, dim=-1))
460
+
461
+ dec = torch.cat(result_rows, dim=-2)
462
+ if not return_dict:
463
+ return (dec,)
464
+
465
+ return DecoderOutput(sample=dec)
466
+
467
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
468
+
469
+ B, C, T, H, W = x.shape
470
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
471
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
472
+ t_limit = self.tile_latent_min_tsize - blend_extent
473
+
474
+ # Split the video into tiles and encode them separately.
475
+ row = []
476
+ for i in range(0, T, overlap_size):
477
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
478
+ if self.use_spatial_tiling and (
479
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
480
+ ):
481
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
482
+ else:
483
+ tile = self.encoder(tile)
484
+ tile = self.quant_conv(tile)
485
+ if i > 0:
486
+ tile = tile[:, :, 1:, :, :]
487
+ row.append(tile)
488
+ result_row = []
489
+ for i, tile in enumerate(row):
490
+ if i > 0:
491
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
492
+ result_row.append(tile[:, :, :t_limit, :, :])
493
+ else:
494
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
495
+
496
+ moments = torch.cat(result_row, dim=2)
497
+ posterior = DiagonalGaussianDistribution(moments)
498
+
499
+ if not return_dict:
500
+ return (posterior,)
501
+
502
+ return AutoencoderKLOutput(latent_dist=posterior)
503
+
504
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
505
+ # Split z into overlapping tiles and decode them separately.
506
+
507
+ B, C, T, H, W = z.shape
508
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
509
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
510
+ t_limit = self.tile_sample_min_tsize - blend_extent
511
+
512
+ row = []
513
+ for i in range(0, T, overlap_size):
514
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
515
+ if self.use_spatial_tiling and (
516
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
517
+ ):
518
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
519
+ else:
520
+ tile = self.post_quant_conv(tile)
521
+ decoded = self.decoder(tile)
522
+ if i > 0:
523
+ decoded = decoded[:, :, 1:, :, :]
524
+ row.append(decoded)
525
+ result_row = []
526
+ for i, tile in enumerate(row):
527
+ if i > 0:
528
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
529
+ result_row.append(tile[:, :, :t_limit, :, :])
530
+ else:
531
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
532
+
533
+ dec = torch.cat(result_row, dim=2)
534
+ if not return_dict:
535
+ return (dec,)
536
+
537
+ return DecoderOutput(sample=dec)
538
+
539
+ def forward(
540
+ self,
541
+ sample: torch.FloatTensor,
542
+ sample_posterior: bool = False,
543
+ return_dict: bool = True,
544
+ return_posterior: bool = False,
545
+ generator: Optional[torch.Generator] = None,
546
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
547
+ r"""
548
+ Args:
549
+ sample (`torch.FloatTensor`): Input sample.
550
+ sample_posterior (`bool`, *optional*, defaults to `False`):
551
+ Whether to sample from the posterior.
552
+ return_dict (`bool`, *optional*, defaults to `True`):
553
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
554
+ """
555
+ x = sample
556
+ posterior = self.encode(x).latent_dist
557
+ if sample_posterior:
558
+ z = posterior.sample(generator=generator)
559
+ else:
560
+ z = posterior.mode()
561
+ dec = self.decode(z).sample
562
+
563
+ if not return_dict:
564
+ if return_posterior:
565
+ return (dec, posterior)
566
+ else:
567
+ return (dec,)
568
+ if return_posterior:
569
+ return DecoderOutput2(sample=dec, posterior=posterior)
570
+ else:
571
+ return DecoderOutput2(sample=dec)
572
+
573
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
574
+ def fuse_qkv_projections(self):
575
+ """
576
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
577
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
578
+
579
+ <Tip warning={true}>
580
+
581
+ This API is 🧪 experimental.
582
+
583
+ </Tip>
584
+ """
585
+ self.original_attn_processors = None
586
+
587
+ for _, attn_processor in self.attn_processors.items():
588
+ if "Added" in str(attn_processor.__class__.__name__):
589
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
590
+
591
+ self.original_attn_processors = self.attn_processors
592
+
593
+ for module in self.modules():
594
+ if isinstance(module, Attention):
595
+ module.fuse_projections(fuse=True)
596
+
597
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
598
+ def unfuse_qkv_projections(self):
599
+ """Disables the fused QKV projection if enabled.
600
+
601
+ <Tip warning={true}>
602
+
603
+ This API is 🧪 experimental.
604
+
605
+ </Tip>
606
+
607
+ """
608
+ if self.original_attn_processors is not None:
609
+ self.set_attn_processor(self.original_attn_processors)
hunyuan_model/embed_layers.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange, repeat
6
+
7
+ from .helpers import to_2tuple
8
+
9
+ class PatchEmbed(nn.Module):
10
+ """2D Image to Patch Embedding
11
+
12
+ Image to Patch Embedding using Conv2d
13
+
14
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
15
+
16
+ Based on the impl in https://github.com/google-research/vision_transformer
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+
20
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ patch_size=16,
26
+ in_chans=3,
27
+ embed_dim=768,
28
+ norm_layer=None,
29
+ flatten=True,
30
+ bias=True,
31
+ dtype=None,
32
+ device=None,
33
+ ):
34
+ factory_kwargs = {"dtype": dtype, "device": device}
35
+ super().__init__()
36
+ patch_size = to_2tuple(patch_size)
37
+ self.patch_size = patch_size
38
+ self.flatten = flatten
39
+
40
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
41
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
42
+ if bias:
43
+ nn.init.zeros_(self.proj.bias)
44
+
45
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
46
+
47
+ def forward(self, x):
48
+ x = self.proj(x)
49
+ if self.flatten:
50
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
51
+ x = self.norm(x)
52
+ return x
53
+
54
+
55
+ class TextProjection(nn.Module):
56
+ """
57
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
58
+
59
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60
+ """
61
+
62
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63
+ factory_kwargs = {"dtype": dtype, "device": device}
64
+ super().__init__()
65
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66
+ self.act_1 = act_layer()
67
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68
+
69
+ def forward(self, caption):
70
+ hidden_states = self.linear_1(caption)
71
+ hidden_states = self.act_1(hidden_states)
72
+ hidden_states = self.linear_2(hidden_states)
73
+ return hidden_states
74
+
75
+
76
+ def timestep_embedding(t, dim, max_period=10000):
77
+ """
78
+ Create sinusoidal timestep embeddings.
79
+
80
+ Args:
81
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82
+ dim (int): the dimension of the output.
83
+ max_period (int): controls the minimum frequency of the embeddings.
84
+
85
+ Returns:
86
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87
+
88
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89
+ """
90
+ half = dim // 2
91
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
92
+ args = t[:, None].float() * freqs[None]
93
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
94
+ if dim % 2:
95
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
96
+ return embedding
97
+
98
+
99
+ class TimestepEmbedder(nn.Module):
100
+ """
101
+ Embeds scalar timesteps into vector representations.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ hidden_size,
107
+ act_layer,
108
+ frequency_embedding_size=256,
109
+ max_period=10000,
110
+ out_size=None,
111
+ dtype=None,
112
+ device=None,
113
+ ):
114
+ factory_kwargs = {"dtype": dtype, "device": device}
115
+ super().__init__()
116
+ self.frequency_embedding_size = frequency_embedding_size
117
+ self.max_period = max_period
118
+ if out_size is None:
119
+ out_size = hidden_size
120
+
121
+ self.mlp = nn.Sequential(
122
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
123
+ act_layer(),
124
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
125
+ )
126
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
127
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
128
+
129
+ def forward(self, t):
130
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
131
+ t_emb = self.mlp(t_freq)
132
+ return t_emb
hunyuan_model/helpers.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+
3
+ from itertools import repeat
4
+
5
+
6
+ def _ntuple(n):
7
+ def parse(x):
8
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
9
+ x = tuple(x)
10
+ if len(x) == 1:
11
+ x = tuple(repeat(x[0], n))
12
+ return x
13
+ return tuple(repeat(x, n))
14
+ return parse
15
+
16
+
17
+ to_1tuple = _ntuple(1)
18
+ to_2tuple = _ntuple(2)
19
+ to_3tuple = _ntuple(3)
20
+ to_4tuple = _ntuple(4)
21
+
22
+
23
+ def as_tuple(x):
24
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
+ return tuple(x)
26
+ if x is None or isinstance(x, (int, float, str)):
27
+ return (x,)
28
+ else:
29
+ raise ValueError(f"Unknown type {type(x)}")
30
+
31
+
32
+ def as_list_of_2tuple(x):
33
+ x = as_tuple(x)
34
+ if len(x) == 1:
35
+ x = (x[0], x[0])
36
+ assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
37
+ lst = []
38
+ for i in range(0, len(x), 2):
39
+ lst.append((x[i], x[i + 1]))
40
+ return lst
hunyuan_model/mlp_layers.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .modulate_layers import modulate
10
+ from .helpers import to_2tuple
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
+
37
+ self.fc1 = linear_layer(
38
+ in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39
+ )
40
+ self.act = act_layer()
41
+ self.drop1 = nn.Dropout(drop_probs[0])
42
+ self.norm = (
43
+ norm_layer(hidden_channels, **factory_kwargs)
44
+ if norm_layer is not None
45
+ else nn.Identity()
46
+ )
47
+ self.fc2 = linear_layer(
48
+ hidden_channels, out_features, bias=bias[1], **factory_kwargs
49
+ )
50
+ self.drop2 = nn.Dropout(drop_probs[1])
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop1(x)
56
+ x = self.norm(x)
57
+ x = self.fc2(x)
58
+ x = self.drop2(x)
59
+ return x
60
+
61
+
62
+ #
63
+ class MLPEmbedder(nn.Module):
64
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
66
+ factory_kwargs = {"device": device, "dtype": dtype}
67
+ super().__init__()
68
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
69
+ self.silu = nn.SiLU()
70
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.out_layer(self.silu(self.in_layer(x)))
74
+
75
+
76
+ class FinalLayer(nn.Module):
77
+ """The final layer of DiT."""
78
+
79
+ def __init__(
80
+ self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
81
+ ):
82
+ factory_kwargs = {"device": device, "dtype": dtype}
83
+ super().__init__()
84
+
85
+ # Just use LayerNorm for the final layer
86
+ self.norm_final = nn.LayerNorm(
87
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
88
+ )
89
+ if isinstance(patch_size, int):
90
+ self.linear = nn.Linear(
91
+ hidden_size,
92
+ patch_size * patch_size * out_channels,
93
+ bias=True,
94
+ **factory_kwargs
95
+ )
96
+ else:
97
+ self.linear = nn.Linear(
98
+ hidden_size,
99
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
100
+ bias=True,
101
+ )
102
+ nn.init.zeros_(self.linear.weight)
103
+ nn.init.zeros_(self.linear.bias)
104
+
105
+ # Here we don't distinguish between the modulate types. Just use the simple one.
106
+ self.adaLN_modulation = nn.Sequential(
107
+ act_layer(),
108
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
109
+ )
110
+ # Zero-initialize the modulation
111
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
112
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
113
+
114
+ def forward(self, x, c):
115
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
116
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
117
+ x = self.linear(x)
118
+ return x
hunyuan_model/models.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, List, Tuple, Optional, Union, Dict
3
+ import accelerate
4
+ from einops import rearrange
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from .activation_layers import get_activation_layer
11
+ from .norm_layers import get_norm_layer
12
+ from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
13
+ from .attention import attention, parallel_attention, get_cu_seqlens
14
+ from .posemb_layers import apply_rotary_emb
15
+ from .mlp_layers import MLP, MLPEmbedder, FinalLayer
16
+ from .modulate_layers import ModulateDiT, modulate, apply_gate
17
+ from .token_refiner import SingleTokenRefiner
18
+ from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
19
+ from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
20
+
21
+ from utils.safetensors_utils import MemoryEfficientSafeOpen
22
+
23
+
24
+ class MMDoubleStreamBlock(nn.Module):
25
+ """
26
+ A multimodal dit block with seperate modulation for
27
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
28
+ (Flux.1): https://github.com/black-forest-labs/flux
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ hidden_size: int,
34
+ heads_num: int,
35
+ mlp_width_ratio: float,
36
+ mlp_act_type: str = "gelu_tanh",
37
+ qk_norm: bool = True,
38
+ qk_norm_type: str = "rms",
39
+ qkv_bias: bool = False,
40
+ dtype: Optional[torch.dtype] = None,
41
+ device: Optional[torch.device] = None,
42
+ attn_mode: str = "flash",
43
+ split_attn: bool = False,
44
+ ):
45
+ factory_kwargs = {"device": device, "dtype": dtype}
46
+ super().__init__()
47
+ self.attn_mode = attn_mode
48
+ self.split_attn = split_attn
49
+
50
+ self.deterministic = False
51
+ self.heads_num = heads_num
52
+ head_dim = hidden_size // heads_num
53
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
54
+
55
+ self.img_mod = ModulateDiT(
56
+ hidden_size,
57
+ factor=6,
58
+ act_layer=get_activation_layer("silu"),
59
+ **factory_kwargs,
60
+ )
61
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
62
+
63
+ self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
64
+ qk_norm_layer = get_norm_layer(qk_norm_type)
65
+ self.img_attn_q_norm = (
66
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
67
+ )
68
+ self.img_attn_k_norm = (
69
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
70
+ )
71
+ self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
72
+
73
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
74
+ self.img_mlp = MLP(
75
+ hidden_size,
76
+ mlp_hidden_dim,
77
+ act_layer=get_activation_layer(mlp_act_type),
78
+ bias=True,
79
+ **factory_kwargs,
80
+ )
81
+
82
+ self.txt_mod = ModulateDiT(
83
+ hidden_size,
84
+ factor=6,
85
+ act_layer=get_activation_layer("silu"),
86
+ **factory_kwargs,
87
+ )
88
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
89
+
90
+ self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
91
+ self.txt_attn_q_norm = (
92
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
93
+ )
94
+ self.txt_attn_k_norm = (
95
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
96
+ )
97
+ self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
98
+
99
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
100
+ self.txt_mlp = MLP(
101
+ hidden_size,
102
+ mlp_hidden_dim,
103
+ act_layer=get_activation_layer(mlp_act_type),
104
+ bias=True,
105
+ **factory_kwargs,
106
+ )
107
+ self.hybrid_seq_parallel_attn = None
108
+
109
+ self.gradient_checkpointing = False
110
+
111
+ def enable_deterministic(self):
112
+ self.deterministic = True
113
+
114
+ def disable_deterministic(self):
115
+ self.deterministic = False
116
+
117
+ def enable_gradient_checkpointing(self):
118
+ self.gradient_checkpointing = True
119
+
120
+ def disable_gradient_checkpointing(self):
121
+ self.gradient_checkpointing = False
122
+
123
+ def _forward(
124
+ self,
125
+ img: torch.Tensor,
126
+ txt: torch.Tensor,
127
+ vec: torch.Tensor,
128
+ attn_mask: Optional[torch.Tensor] = None,
129
+ total_len: Optional[torch.Tensor] = None,
130
+ cu_seqlens_q: Optional[torch.Tensor] = None,
131
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
132
+ max_seqlen_q: Optional[int] = None,
133
+ max_seqlen_kv: Optional[int] = None,
134
+ freqs_cis: tuple = None,
135
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
136
+ (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
137
+ 6, dim=-1
138
+ )
139
+ (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
140
+ 6, dim=-1
141
+ )
142
+
143
+ # Prepare image for attention.
144
+ img_modulated = self.img_norm1(img)
145
+ img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
146
+ img_qkv = self.img_attn_qkv(img_modulated)
147
+ img_modulated = None
148
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
149
+ img_qkv = None
150
+ # Apply QK-Norm if needed
151
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
152
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
153
+
154
+ # Apply RoPE if needed.
155
+ if freqs_cis is not None:
156
+ img_q_shape = img_q.shape
157
+ img_k_shape = img_k.shape
158
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
159
+ assert (
160
+ img_q.shape == img_q_shape and img_k.shape == img_k_shape
161
+ ), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
162
+ # img_q, img_k = img_qq, img_kk
163
+
164
+ # Prepare txt for attention.
165
+ txt_modulated = self.txt_norm1(txt)
166
+ txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
167
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
168
+ txt_modulated = None
169
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
170
+ txt_qkv = None
171
+ # Apply QK-Norm if needed.
172
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
173
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
174
+
175
+ # Run actual attention.
176
+ img_q_len = img_q.shape[1]
177
+ img_kv_len = img_k.shape[1]
178
+ batch_size = img_k.shape[0]
179
+ q = torch.cat((img_q, txt_q), dim=1)
180
+ img_q = txt_q = None
181
+ k = torch.cat((img_k, txt_k), dim=1)
182
+ img_k = txt_k = None
183
+ v = torch.cat((img_v, txt_v), dim=1)
184
+ img_v = txt_v = None
185
+
186
+ assert (
187
+ cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
188
+ ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
189
+
190
+ # attention computation start
191
+ if not self.hybrid_seq_parallel_attn:
192
+ l = [q, k, v]
193
+ q = k = v = None
194
+ attn = attention(
195
+ l,
196
+ mode=self.attn_mode,
197
+ attn_mask=attn_mask,
198
+ total_len=total_len,
199
+ cu_seqlens_q=cu_seqlens_q,
200
+ cu_seqlens_kv=cu_seqlens_kv,
201
+ max_seqlen_q=max_seqlen_q,
202
+ max_seqlen_kv=max_seqlen_kv,
203
+ batch_size=batch_size,
204
+ )
205
+ else:
206
+ attn = parallel_attention(
207
+ self.hybrid_seq_parallel_attn,
208
+ q,
209
+ k,
210
+ v,
211
+ img_q_len=img_q_len,
212
+ img_kv_len=img_kv_len,
213
+ cu_seqlens_q=cu_seqlens_q,
214
+ cu_seqlens_kv=cu_seqlens_kv,
215
+ )
216
+
217
+ # attention computation end
218
+
219
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
220
+ attn = None
221
+
222
+ # Calculate the img bloks.
223
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
224
+ img_attn = None
225
+ img = img + apply_gate(
226
+ self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
227
+ gate=img_mod2_gate,
228
+ )
229
+
230
+ # Calculate the txt bloks.
231
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
232
+ txt_attn = None
233
+ txt = txt + apply_gate(
234
+ self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
235
+ gate=txt_mod2_gate,
236
+ )
237
+
238
+ return img, txt
239
+
240
+ # def forward(
241
+ # self,
242
+ # img: torch.Tensor,
243
+ # txt: torch.Tensor,
244
+ # vec: torch.Tensor,
245
+ # attn_mask: Optional[torch.Tensor] = None,
246
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
247
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
248
+ # max_seqlen_q: Optional[int] = None,
249
+ # max_seqlen_kv: Optional[int] = None,
250
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
251
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ def forward(self, *args, **kwargs):
253
+ if self.training and self.gradient_checkpointing:
254
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
255
+ else:
256
+ return self._forward(*args, **kwargs)
257
+
258
+
259
+ class MMSingleStreamBlock(nn.Module):
260
+ """
261
+ A DiT block with parallel linear layers as described in
262
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
263
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
264
+ (Flux.1): https://github.com/black-forest-labs/flux
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ hidden_size: int,
270
+ heads_num: int,
271
+ mlp_width_ratio: float = 4.0,
272
+ mlp_act_type: str = "gelu_tanh",
273
+ qk_norm: bool = True,
274
+ qk_norm_type: str = "rms",
275
+ qk_scale: float = None,
276
+ dtype: Optional[torch.dtype] = None,
277
+ device: Optional[torch.device] = None,
278
+ attn_mode: str = "flash",
279
+ split_attn: bool = False,
280
+ ):
281
+ factory_kwargs = {"device": device, "dtype": dtype}
282
+ super().__init__()
283
+ self.attn_mode = attn_mode
284
+ self.split_attn = split_attn
285
+
286
+ self.deterministic = False
287
+ self.hidden_size = hidden_size
288
+ self.heads_num = heads_num
289
+ head_dim = hidden_size // heads_num
290
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
291
+ self.mlp_hidden_dim = mlp_hidden_dim
292
+ self.scale = qk_scale or head_dim**-0.5
293
+
294
+ # qkv and mlp_in
295
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
296
+ # proj and mlp_out
297
+ self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
298
+
299
+ qk_norm_layer = get_norm_layer(qk_norm_type)
300
+ self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
301
+ self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
302
+
303
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
304
+
305
+ self.mlp_act = get_activation_layer(mlp_act_type)()
306
+ self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
307
+ self.hybrid_seq_parallel_attn = None
308
+
309
+ self.gradient_checkpointing = False
310
+
311
+ def enable_deterministic(self):
312
+ self.deterministic = True
313
+
314
+ def disable_deterministic(self):
315
+ self.deterministic = False
316
+
317
+ def enable_gradient_checkpointing(self):
318
+ self.gradient_checkpointing = True
319
+
320
+ def disable_gradient_checkpointing(self):
321
+ self.gradient_checkpointing = False
322
+
323
+ def _forward(
324
+ self,
325
+ x: torch.Tensor,
326
+ vec: torch.Tensor,
327
+ txt_len: int,
328
+ attn_mask: Optional[torch.Tensor] = None,
329
+ total_len: Optional[torch.Tensor] = None,
330
+ cu_seqlens_q: Optional[torch.Tensor] = None,
331
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
332
+ max_seqlen_q: Optional[int] = None,
333
+ max_seqlen_kv: Optional[int] = None,
334
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
335
+ ) -> torch.Tensor:
336
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
337
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
338
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
339
+ x_mod = None
340
+ # mlp = mlp.to("cpu", non_blocking=True)
341
+ # clean_memory_on_device(x.device)
342
+
343
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
344
+ qkv = None
345
+
346
+ # Apply QK-Norm if needed.
347
+ q = self.q_norm(q).to(v)
348
+ k = self.k_norm(k).to(v)
349
+
350
+ # Apply RoPE if needed.
351
+ if freqs_cis is not None:
352
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
353
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
354
+ q = k = None
355
+ img_q_shape = img_q.shape
356
+ img_k_shape = img_k.shape
357
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
358
+ assert (
359
+ img_q.shape == img_q_shape and img_k_shape == img_k.shape
360
+ ), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
361
+ # img_q, img_k = img_qq, img_kk
362
+ # del img_qq, img_kk
363
+ q = torch.cat((img_q, txt_q), dim=1)
364
+ k = torch.cat((img_k, txt_k), dim=1)
365
+ del img_q, txt_q, img_k, txt_k
366
+
367
+ # Compute attention.
368
+ assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
369
+
370
+ # attention computation start
371
+ if not self.hybrid_seq_parallel_attn:
372
+ l = [q, k, v]
373
+ q = k = v = None
374
+ attn = attention(
375
+ l,
376
+ mode=self.attn_mode,
377
+ attn_mask=attn_mask,
378
+ total_len=total_len,
379
+ cu_seqlens_q=cu_seqlens_q,
380
+ cu_seqlens_kv=cu_seqlens_kv,
381
+ max_seqlen_q=max_seqlen_q,
382
+ max_seqlen_kv=max_seqlen_kv,
383
+ batch_size=x.shape[0],
384
+ )
385
+ else:
386
+ attn = parallel_attention(
387
+ self.hybrid_seq_parallel_attn,
388
+ q,
389
+ k,
390
+ v,
391
+ img_q_len=img_q.shape[1],
392
+ img_kv_len=img_k.shape[1],
393
+ cu_seqlens_q=cu_seqlens_q,
394
+ cu_seqlens_kv=cu_seqlens_kv,
395
+ )
396
+ # attention computation end
397
+
398
+ # Compute activation in mlp stream, cat again and run second linear layer.
399
+ # mlp = mlp.to(x.device)
400
+ mlp = self.mlp_act(mlp)
401
+ attn_mlp = torch.cat((attn, mlp), 2)
402
+ attn = None
403
+ mlp = None
404
+ output = self.linear2(attn_mlp)
405
+ attn_mlp = None
406
+ return x + apply_gate(output, gate=mod_gate)
407
+
408
+ # def forward(
409
+ # self,
410
+ # x: torch.Tensor,
411
+ # vec: torch.Tensor,
412
+ # txt_len: int,
413
+ # attn_mask: Optional[torch.Tensor] = None,
414
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
415
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
416
+ # max_seqlen_q: Optional[int] = None,
417
+ # max_seqlen_kv: Optional[int] = None,
418
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
419
+ # ) -> torch.Tensor:
420
+ def forward(self, *args, **kwargs):
421
+ if self.training and self.gradient_checkpointing:
422
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
423
+ else:
424
+ return self._forward(*args, **kwargs)
425
+
426
+
427
+ class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
428
+ """
429
+ HunyuanVideo Transformer backbone
430
+
431
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
432
+
433
+ Reference:
434
+ [1] Flux.1: https://github.com/black-forest-labs/flux
435
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
436
+
437
+ Parameters
438
+ ----------
439
+ args: argparse.Namespace
440
+ The arguments parsed by argparse.
441
+ patch_size: list
442
+ The size of the patch.
443
+ in_channels: int
444
+ The number of input channels.
445
+ out_channels: int
446
+ The number of output channels.
447
+ hidden_size: int
448
+ The hidden size of the transformer backbone.
449
+ heads_num: int
450
+ The number of attention heads.
451
+ mlp_width_ratio: float
452
+ The ratio of the hidden size of the MLP in the transformer block.
453
+ mlp_act_type: str
454
+ The activation function of the MLP in the transformer block.
455
+ depth_double_blocks: int
456
+ The number of transformer blocks in the double blocks.
457
+ depth_single_blocks: int
458
+ The number of transformer blocks in the single blocks.
459
+ rope_dim_list: list
460
+ The dimension of the rotary embedding for t, h, w.
461
+ qkv_bias: bool
462
+ Whether to use bias in the qkv linear layer.
463
+ qk_norm: bool
464
+ Whether to use qk norm.
465
+ qk_norm_type: str
466
+ The type of qk norm.
467
+ guidance_embed: bool
468
+ Whether to use guidance embedding for distillation.
469
+ text_projection: str
470
+ The type of the text projection, default is single_refiner.
471
+ use_attention_mask: bool
472
+ Whether to use attention mask for text encoder.
473
+ dtype: torch.dtype
474
+ The dtype of the model.
475
+ device: torch.device
476
+ The device of the model.
477
+ attn_mode: str
478
+ The mode of the attention, default is flash.
479
+ split_attn: bool
480
+ Whether to use split attention (make attention as batch size 1).
481
+ """
482
+
483
+ # @register_to_config
484
+ def __init__(
485
+ self,
486
+ text_states_dim: int,
487
+ text_states_dim_2: int,
488
+ patch_size: list = [1, 2, 2],
489
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
490
+ out_channels: int = None,
491
+ hidden_size: int = 3072,
492
+ heads_num: int = 24,
493
+ mlp_width_ratio: float = 4.0,
494
+ mlp_act_type: str = "gelu_tanh",
495
+ mm_double_blocks_depth: int = 20,
496
+ mm_single_blocks_depth: int = 40,
497
+ rope_dim_list: List[int] = [16, 56, 56],
498
+ qkv_bias: bool = True,
499
+ qk_norm: bool = True,
500
+ qk_norm_type: str = "rms",
501
+ guidance_embed: bool = False, # For modulation.
502
+ text_projection: str = "single_refiner",
503
+ use_attention_mask: bool = True,
504
+ dtype: Optional[torch.dtype] = None,
505
+ device: Optional[torch.device] = None,
506
+ attn_mode: str = "flash",
507
+ split_attn: bool = False,
508
+ ):
509
+ factory_kwargs = {"device": device, "dtype": dtype}
510
+ super().__init__()
511
+
512
+ self.patch_size = patch_size
513
+ self.in_channels = in_channels
514
+ self.out_channels = in_channels if out_channels is None else out_channels
515
+ self.unpatchify_channels = self.out_channels
516
+ self.guidance_embed = guidance_embed
517
+ self.rope_dim_list = rope_dim_list
518
+
519
+ # Text projection. Default to linear projection.
520
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
521
+ self.use_attention_mask = use_attention_mask
522
+ self.text_projection = text_projection
523
+
524
+ self.text_states_dim = text_states_dim
525
+ self.text_states_dim_2 = text_states_dim_2
526
+
527
+ if hidden_size % heads_num != 0:
528
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
529
+ pe_dim = hidden_size // heads_num
530
+ if sum(rope_dim_list) != pe_dim:
531
+ raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
532
+ self.hidden_size = hidden_size
533
+ self.heads_num = heads_num
534
+
535
+ self.attn_mode = attn_mode
536
+ self.split_attn = split_attn
537
+ print(f"Using {self.attn_mode} attention mode, split_attn: {self.split_attn}")
538
+
539
+ # image projection
540
+ self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
541
+
542
+ # text projection
543
+ if self.text_projection == "linear":
544
+ self.txt_in = TextProjection(
545
+ self.text_states_dim,
546
+ self.hidden_size,
547
+ get_activation_layer("silu"),
548
+ **factory_kwargs,
549
+ )
550
+ elif self.text_projection == "single_refiner":
551
+ self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
552
+ else:
553
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
554
+
555
+ # time modulation
556
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
557
+
558
+ # text modulation
559
+ self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
560
+
561
+ # guidance modulation
562
+ self.guidance_in = (
563
+ TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
564
+ )
565
+
566
+ # double blocks
567
+ self.double_blocks = nn.ModuleList(
568
+ [
569
+ MMDoubleStreamBlock(
570
+ self.hidden_size,
571
+ self.heads_num,
572
+ mlp_width_ratio=mlp_width_ratio,
573
+ mlp_act_type=mlp_act_type,
574
+ qk_norm=qk_norm,
575
+ qk_norm_type=qk_norm_type,
576
+ qkv_bias=qkv_bias,
577
+ attn_mode=attn_mode,
578
+ split_attn=split_attn,
579
+ **factory_kwargs,
580
+ )
581
+ for _ in range(mm_double_blocks_depth)
582
+ ]
583
+ )
584
+
585
+ # single blocks
586
+ self.single_blocks = nn.ModuleList(
587
+ [
588
+ MMSingleStreamBlock(
589
+ self.hidden_size,
590
+ self.heads_num,
591
+ mlp_width_ratio=mlp_width_ratio,
592
+ mlp_act_type=mlp_act_type,
593
+ qk_norm=qk_norm,
594
+ qk_norm_type=qk_norm_type,
595
+ attn_mode=attn_mode,
596
+ split_attn=split_attn,
597
+ **factory_kwargs,
598
+ )
599
+ for _ in range(mm_single_blocks_depth)
600
+ ]
601
+ )
602
+
603
+ self.final_layer = FinalLayer(
604
+ self.hidden_size,
605
+ self.patch_size,
606
+ self.out_channels,
607
+ get_activation_layer("silu"),
608
+ **factory_kwargs,
609
+ )
610
+
611
+ self.gradient_checkpointing = False
612
+ self.blocks_to_swap = None
613
+ self.offloader_double = None
614
+ self.offloader_single = None
615
+ self._enable_img_in_txt_in_offloading = False
616
+
617
+ @property
618
+ def device(self):
619
+ return next(self.parameters()).device
620
+
621
+ @property
622
+ def dtype(self):
623
+ return next(self.parameters()).dtype
624
+
625
+ def enable_gradient_checkpointing(self):
626
+ self.gradient_checkpointing = True
627
+
628
+ self.txt_in.enable_gradient_checkpointing()
629
+
630
+ for block in self.double_blocks + self.single_blocks:
631
+ block.enable_gradient_checkpointing()
632
+
633
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
634
+
635
+ def disable_gradient_checkpointing(self):
636
+ self.gradient_checkpointing = False
637
+
638
+ self.txt_in.disable_gradient_checkpointing()
639
+
640
+ for block in self.double_blocks + self.single_blocks:
641
+ block.disable_gradient_checkpointing()
642
+
643
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing disabled.")
644
+
645
+ def enable_img_in_txt_in_offloading(self):
646
+ self._enable_img_in_txt_in_offloading = True
647
+
648
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
649
+ self.blocks_to_swap = num_blocks
650
+ self.num_double_blocks = len(self.double_blocks)
651
+ self.num_single_blocks = len(self.single_blocks)
652
+ double_blocks_to_swap = num_blocks // 2
653
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
654
+
655
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
656
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
657
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
658
+ )
659
+
660
+ self.offloader_double = ModelOffloader(
661
+ "double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
662
+ )
663
+ self.offloader_single = ModelOffloader(
664
+ "single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
665
+ )
666
+ print(
667
+ f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
668
+ )
669
+
670
+ def switch_block_swap_for_inference(self):
671
+ if self.blocks_to_swap:
672
+ self.offloader_double.set_forward_only(True)
673
+ self.offloader_single.set_forward_only(True)
674
+ self.prepare_block_swap_before_forward()
675
+ print(f"HYVideoDiffusionTransformer: Block swap set to forward only.")
676
+
677
+ def switch_block_swap_for_training(self):
678
+ if self.blocks_to_swap:
679
+ self.offloader_double.set_forward_only(False)
680
+ self.offloader_single.set_forward_only(False)
681
+ self.prepare_block_swap_before_forward()
682
+ print(f"HYVideoDiffusionTransformer: Block swap set to forward and backward.")
683
+
684
+ def move_to_device_except_swap_blocks(self, device: torch.device):
685
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
686
+ if self.blocks_to_swap:
687
+ save_double_blocks = self.double_blocks
688
+ save_single_blocks = self.single_blocks
689
+ self.double_blocks = None
690
+ self.single_blocks = None
691
+
692
+ self.to(device)
693
+
694
+ if self.blocks_to_swap:
695
+ self.double_blocks = save_double_blocks
696
+ self.single_blocks = save_single_blocks
697
+
698
+ def prepare_block_swap_before_forward(self):
699
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
700
+ return
701
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
702
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
703
+
704
+ def enable_deterministic(self):
705
+ for block in self.double_blocks:
706
+ block.enable_deterministic()
707
+ for block in self.single_blocks:
708
+ block.enable_deterministic()
709
+
710
+ def disable_deterministic(self):
711
+ for block in self.double_blocks:
712
+ block.disable_deterministic()
713
+ for block in self.single_blocks:
714
+ block.disable_deterministic()
715
+
716
+ def forward(
717
+ self,
718
+ x: torch.Tensor,
719
+ t: torch.Tensor, # Should be in range(0, 1000).
720
+ text_states: torch.Tensor = None,
721
+ text_mask: torch.Tensor = None, # Now we don't use it.
722
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
723
+ freqs_cos: Optional[torch.Tensor] = None,
724
+ freqs_sin: Optional[torch.Tensor] = None,
725
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
726
+ return_dict: bool = True,
727
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
728
+ out = {}
729
+ img = x
730
+ txt = text_states
731
+ _, _, ot, oh, ow = x.shape
732
+ tt, th, tw = (
733
+ ot // self.patch_size[0],
734
+ oh // self.patch_size[1],
735
+ ow // self.patch_size[2],
736
+ )
737
+
738
+ # Prepare modulation vectors.
739
+ vec = self.time_in(t)
740
+
741
+ # text modulation
742
+ vec = vec + self.vector_in(text_states_2)
743
+
744
+ # guidance modulation
745
+ if self.guidance_embed:
746
+ if guidance is None:
747
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
748
+
749
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
750
+ vec = vec + self.guidance_in(guidance)
751
+
752
+ # Embed image and text.
753
+ if self._enable_img_in_txt_in_offloading:
754
+ self.img_in.to(x.device, non_blocking=True)
755
+ self.txt_in.to(x.device, non_blocking=True)
756
+ synchronize_device(x.device)
757
+
758
+ img = self.img_in(img)
759
+ if self.text_projection == "linear":
760
+ txt = self.txt_in(txt)
761
+ elif self.text_projection == "single_refiner":
762
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
763
+ else:
764
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
765
+
766
+ if self._enable_img_in_txt_in_offloading:
767
+ self.img_in.to(torch.device("cpu"), non_blocking=True)
768
+ self.txt_in.to(torch.device("cpu"), non_blocking=True)
769
+ synchronize_device(x.device)
770
+ clean_memory_on_device(x.device)
771
+
772
+ txt_seq_len = txt.shape[1]
773
+ img_seq_len = img.shape[1]
774
+
775
+ # Compute cu_squlens and max_seqlen for flash attention
776
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
777
+ cu_seqlens_kv = cu_seqlens_q
778
+ max_seqlen_q = img_seq_len + txt_seq_len
779
+ max_seqlen_kv = max_seqlen_q
780
+
781
+ attn_mask = total_len = None
782
+ if self.split_attn or self.attn_mode == "torch":
783
+ # calculate text length and total length
784
+ text_len = text_mask.sum(dim=1) # (bs, )
785
+ total_len = img_seq_len + text_len # (bs, )
786
+ if self.attn_mode == "torch" and not self.split_attn:
787
+ # initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
788
+ bs = img.shape[0]
789
+ attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
790
+
791
+ # set attention mask with total_len
792
+ for i in range(bs):
793
+ attn_mask[i, :, : total_len[i], : total_len[i]] = True
794
+ total_len = None # means we don't use split_attn
795
+
796
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
797
+ # --------------------- Pass through DiT blocks ------------------------
798
+ for block_idx, block in enumerate(self.double_blocks):
799
+ double_block_args = [
800
+ img,
801
+ txt,
802
+ vec,
803
+ attn_mask,
804
+ total_len,
805
+ cu_seqlens_q,
806
+ cu_seqlens_kv,
807
+ max_seqlen_q,
808
+ max_seqlen_kv,
809
+ freqs_cis,
810
+ ]
811
+
812
+ if self.blocks_to_swap:
813
+ self.offloader_double.wait_for_block(block_idx)
814
+
815
+ img, txt = block(*double_block_args)
816
+
817
+ if self.blocks_to_swap:
818
+ self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
819
+
820
+ # Merge txt and img to pass through single stream blocks.
821
+ x = torch.cat((img, txt), 1)
822
+ if self.blocks_to_swap:
823
+ # delete img, txt to reduce memory usage
824
+ del img, txt
825
+ clean_memory_on_device(x.device)
826
+
827
+ if len(self.single_blocks) > 0:
828
+ for block_idx, block in enumerate(self.single_blocks):
829
+ single_block_args = [
830
+ x,
831
+ vec,
832
+ txt_seq_len,
833
+ attn_mask,
834
+ total_len,
835
+ cu_seqlens_q,
836
+ cu_seqlens_kv,
837
+ max_seqlen_q,
838
+ max_seqlen_kv,
839
+ freqs_cis,
840
+ ]
841
+ if self.blocks_to_swap:
842
+ self.offloader_single.wait_for_block(block_idx)
843
+
844
+ x = block(*single_block_args)
845
+
846
+ if self.blocks_to_swap:
847
+ self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
848
+
849
+ img = x[:, :img_seq_len, ...]
850
+ x = None
851
+
852
+ # ---------------------------- Final layer ------------------------------
853
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
854
+
855
+ img = self.unpatchify(img, tt, th, tw)
856
+ if return_dict:
857
+ out["x"] = img
858
+ return out
859
+ return img
860
+
861
+ def unpatchify(self, x, t, h, w):
862
+ """
863
+ x: (N, T, patch_size**2 * C)
864
+ imgs: (N, H, W, C)
865
+ """
866
+ c = self.unpatchify_channels
867
+ pt, ph, pw = self.patch_size
868
+ assert t * h * w == x.shape[1]
869
+
870
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
871
+ x = torch.einsum("nthwcopq->nctohpwq", x)
872
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
873
+
874
+ return imgs
875
+
876
+ def params_count(self):
877
+ counts = {
878
+ "double": sum(
879
+ [
880
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
881
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
882
+ + sum(p.numel() for p in block.img_mlp.parameters())
883
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
884
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
885
+ + sum(p.numel() for p in block.txt_mlp.parameters())
886
+ for block in self.double_blocks
887
+ ]
888
+ ),
889
+ "single": sum(
890
+ [
891
+ sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
892
+ for block in self.single_blocks
893
+ ]
894
+ ),
895
+ "total": sum(p.numel() for p in self.parameters()),
896
+ }
897
+ counts["attn+mlp"] = counts["double"] + counts["single"]
898
+ return counts
899
+
900
+
901
+ #################################################################################
902
+ # HunyuanVideo Configs #
903
+ #################################################################################
904
+
905
+ HUNYUAN_VIDEO_CONFIG = {
906
+ "HYVideo-T/2": {
907
+ "mm_double_blocks_depth": 20,
908
+ "mm_single_blocks_depth": 40,
909
+ "rope_dim_list": [16, 56, 56],
910
+ "hidden_size": 3072,
911
+ "heads_num": 24,
912
+ "mlp_width_ratio": 4,
913
+ },
914
+ "HYVideo-T/2-cfgdistill": {
915
+ "mm_double_blocks_depth": 20,
916
+ "mm_single_blocks_depth": 40,
917
+ "rope_dim_list": [16, 56, 56],
918
+ "hidden_size": 3072,
919
+ "heads_num": 24,
920
+ "mlp_width_ratio": 4,
921
+ "guidance_embed": True,
922
+ },
923
+ }
924
+
925
+
926
+ def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
927
+ """load hunyuan video model
928
+
929
+ NOTE: Only support HYVideo-T/2-cfgdistill now.
930
+
931
+ Args:
932
+ text_state_dim (int): text state dimension
933
+ text_state_dim_2 (int): text state dimension 2
934
+ in_channels (int): input channels number
935
+ out_channels (int): output channels number
936
+ factor_kwargs (dict): factor kwargs
937
+
938
+ Returns:
939
+ model (nn.Module): The hunyuan video model
940
+ """
941
+ # if args.model in HUNYUAN_VIDEO_CONFIG.keys():
942
+ model = HYVideoDiffusionTransformer(
943
+ text_states_dim=text_states_dim,
944
+ text_states_dim_2=text_states_dim_2,
945
+ in_channels=in_channels,
946
+ out_channels=out_channels,
947
+ **HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
948
+ **factor_kwargs,
949
+ )
950
+ return model
951
+ # else:
952
+ # raise NotImplementedError()
953
+
954
+
955
+ def load_state_dict(model, model_path):
956
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
957
+
958
+ load_key = "module"
959
+ if load_key in state_dict:
960
+ state_dict = state_dict[load_key]
961
+ else:
962
+ raise KeyError(
963
+ f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
964
+ f"are: {list(state_dict.keys())}."
965
+ )
966
+ model.load_state_dict(state_dict, strict=True, assign=True)
967
+ return model
968
+
969
+
970
+ def load_transformer(dit_path, attn_mode, split_attn, device, dtype, in_channels=16) -> HYVideoDiffusionTransformer:
971
+ # =========================== Build main model ===========================
972
+ factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode, "split_attn": split_attn}
973
+ latent_channels = 16
974
+ out_channels = latent_channels
975
+
976
+ with accelerate.init_empty_weights():
977
+ transformer = load_dit_model(
978
+ text_states_dim=4096,
979
+ text_states_dim_2=768,
980
+ in_channels=in_channels,
981
+ out_channels=out_channels,
982
+ factor_kwargs=factor_kwargs,
983
+ )
984
+
985
+ if os.path.splitext(dit_path)[-1] == ".safetensors":
986
+ # loading safetensors: may be already fp8
987
+ with MemoryEfficientSafeOpen(dit_path) as f:
988
+ state_dict = {}
989
+ for k in f.keys():
990
+ tensor = f.get_tensor(k)
991
+ tensor = tensor.to(device=device, dtype=dtype)
992
+ # TODO support comfy model
993
+ # if k.startswith("model.model."):
994
+ # k = convert_comfy_model_key(k)
995
+ state_dict[k] = tensor
996
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
997
+ else:
998
+ transformer = load_state_dict(transformer, dit_path)
999
+
1000
+ return transformer
1001
+
1002
+
1003
+ def get_rotary_pos_embed_by_shape(model, latents_size):
1004
+ target_ndim = 3
1005
+ ndim = 5 - 2
1006
+
1007
+ if isinstance(model.patch_size, int):
1008
+ assert all(s % model.patch_size == 0 for s in latents_size), (
1009
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
1010
+ f"but got {latents_size}."
1011
+ )
1012
+ rope_sizes = [s // model.patch_size for s in latents_size]
1013
+ elif isinstance(model.patch_size, list):
1014
+ assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
1015
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
1016
+ f"but got {latents_size}."
1017
+ )
1018
+ rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
1019
+
1020
+ if len(rope_sizes) != target_ndim:
1021
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
1022
+ head_dim = model.hidden_size // model.heads_num
1023
+ rope_dim_list = model.rope_dim_list
1024
+ if rope_dim_list is None:
1025
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
1026
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
1027
+
1028
+ rope_theta = 256
1029
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
1030
+ rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
1031
+ )
1032
+ return freqs_cos, freqs_sin
1033
+
1034
+
1035
+ def get_rotary_pos_embed(vae_name, model, video_length, height, width):
1036
+ # 884
1037
+ if "884" in vae_name:
1038
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
1039
+ elif "888" in vae_name:
1040
+ latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
1041
+ else:
1042
+ latents_size = [video_length, height // 8, width // 8]
1043
+
1044
+ return get_rotary_pos_embed_by_shape(model, latents_size)
hunyuan_model/modulate_layers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ModulateDiT(nn.Module):
8
+ """Modulation layer for DiT."""
9
+ def __init__(
10
+ self,
11
+ hidden_size: int,
12
+ factor: int,
13
+ act_layer: Callable,
14
+ dtype=None,
15
+ device=None,
16
+ ):
17
+ factory_kwargs = {"dtype": dtype, "device": device}
18
+ super().__init__()
19
+ self.act = act_layer()
20
+ self.linear = nn.Linear(
21
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22
+ )
23
+ # Zero-initialize the modulation
24
+ nn.init.zeros_(self.linear.weight)
25
+ nn.init.zeros_(self.linear.bias)
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ return self.linear(self.act(x))
29
+
30
+
31
+ def modulate(x, shift=None, scale=None):
32
+ """modulate by shift and scale
33
+
34
+ Args:
35
+ x (torch.Tensor): input tensor.
36
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
37
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
38
+
39
+ Returns:
40
+ torch.Tensor: the output tensor after modulate.
41
+ """
42
+ if scale is None and shift is None:
43
+ return x
44
+ elif shift is None:
45
+ return x * (1 + scale.unsqueeze(1))
46
+ elif scale is None:
47
+ return x + shift.unsqueeze(1)
48
+ else:
49
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50
+
51
+
52
+ def apply_gate(x, gate=None, tanh=False):
53
+ """AI is creating summary for apply_gate
54
+
55
+ Args:
56
+ x (torch.Tensor): input tensor.
57
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
58
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
59
+
60
+ Returns:
61
+ torch.Tensor: the output tensor after apply gate.
62
+ """
63
+ if gate is None:
64
+ return x
65
+ if tanh:
66
+ return x * gate.unsqueeze(1).tanh()
67
+ else:
68
+ return x * gate.unsqueeze(1)
69
+
70
+
71
+ def ckpt_wrapper(module):
72
+ def ckpt_forward(*inputs):
73
+ outputs = module(*inputs)
74
+ return outputs
75
+
76
+ return ckpt_forward