Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +5 -0
- .ipynb_checkpoints/README-checkpoint.md +147 -0
- .python-version +1 -0
- Gundam_outputs/Gundam_w1_3_lora-000001.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000002.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000003.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000004.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000005.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000006.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000007.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000008.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000009.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000010.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000011.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000012.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000013.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000014.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000015.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000016.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000017.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000018.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000019.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000020.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000021.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000022.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000023.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000024.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000025.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000026.safetensors +3 -0
- Gundam_outputs/Gundam_w1_3_lora-000027.safetensors +3 -0
- README.md +147 -0
- cache_latents.py +281 -0
- cache_text_encoder_outputs.py +214 -0
- convert_lora.py +135 -0
- dataset/__init__.py +0 -0
- dataset/config_utils.py +372 -0
- dataset/dataset_config.md +387 -0
- dataset/image_video_dataset.py +1400 -0
- docs/advanced_config.md +151 -0
- docs/sampling_during_training.md +108 -0
- docs/wan.md +241 -0
- hunyuan_model/__init__.py +0 -0
- hunyuan_model/activation_layers.py +23 -0
- hunyuan_model/attention.py +295 -0
- hunyuan_model/autoencoder_kl_causal_3d.py +609 -0
- hunyuan_model/embed_layers.py +132 -0
- hunyuan_model/helpers.py +40 -0
- hunyuan_model/mlp_layers.py +118 -0
- hunyuan_model/models.py +1044 -0
- hunyuan_model/modulate_layers.py +76 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.venv
|
3 |
+
venv/
|
4 |
+
logs/
|
5 |
+
uv.lock
|
.ipynb_checkpoints/README-checkpoint.md
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gundam Text-to-Video Generation
|
2 |
+
|
3 |
+
This repository contains the necessary steps and scripts to generate videos using the Gundam text-to-video model. The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create high-quality anime-style videos based on textual prompts.
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
Before proceeding, ensure that you have the following installed on your system:
|
8 |
+
|
9 |
+
• **Ubuntu** (or a compatible Linux distribution)
|
10 |
+
• **Python 3.x**
|
11 |
+
• **pip** (Python package manager)
|
12 |
+
• **Git**
|
13 |
+
• **Git LFS** (Git Large File Storage)
|
14 |
+
• **FFmpeg**
|
15 |
+
|
16 |
+
## Installation
|
17 |
+
|
18 |
+
1. **Update and Install Dependencies**
|
19 |
+
|
20 |
+
```bash
|
21 |
+
sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
|
22 |
+
```
|
23 |
+
|
24 |
+
2. **Clone the Repository**
|
25 |
+
|
26 |
+
```bash
|
27 |
+
git clone https://huggingface.co/svjack/Gundam_wan_2_1_1_3_B_text2video_lora
|
28 |
+
cd Gundam_wan_2_1_1_3_B_text2video_lora
|
29 |
+
```
|
30 |
+
|
31 |
+
3. **Install Python Dependencies**
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pip install torch torchvision
|
35 |
+
pip install -r requirements.txt
|
36 |
+
pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
|
37 |
+
pip install moviepy==1.0.3
|
38 |
+
pip install sageattention==1.0.6
|
39 |
+
```
|
40 |
+
|
41 |
+
4. **Download Model Weights**
|
42 |
+
|
43 |
+
```bash
|
44 |
+
wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
|
45 |
+
wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
46 |
+
wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
|
47 |
+
wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
|
48 |
+
wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
|
49 |
+
```
|
50 |
+
|
51 |
+
## Usage
|
52 |
+
|
53 |
+
To generate a video, use the `wan_generate_video.py` script with the appropriate parameters. Below are examples of how to generate videos using the Gundam model.
|
54 |
+
|
55 |
+
#### Gundam Moon Background
|
56 |
+
|
57 |
+
```bash
|
58 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
|
59 |
+
--save_path save --output_type both \
|
60 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
61 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
62 |
+
--attn_mode torch \
|
63 |
+
--lora_weight Gundam_outputs/Gundam_w1_3_lora-000027.safetensors \
|
64 |
+
--lora_multiplier 1.0 \
|
65 |
+
--prompt "In the style of Gundam , The video features a large, black robot with a predominantly humanoid form. walk one step at a time forward, with the moon as the background"
|
66 |
+
|
67 |
+
```
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
#### Pink Butterfly Gundam
|
72 |
+
|
73 |
+
```bash
|
74 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
|
75 |
+
--save_path save --output_type both \
|
76 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
77 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
78 |
+
--attn_mode torch \
|
79 |
+
--lora_weight Gundam_outputs/Gundam_w1_3_lora-000027.safetensors \
|
80 |
+
--lora_multiplier 1.0 \
|
81 |
+
--prompt "In the style of Gundam , The video features a large, pink robot with a butterfly form. suggesting smooth and cute. The background is plain and light-colored, ensuring the focus remains on the robot."
|
82 |
+
|
83 |
+
```
|
84 |
+
|
85 |
+
|
86 |
+
#### Yellow Gundam
|
87 |
+
|
88 |
+
```bash
|
89 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
|
90 |
+
--save_path save --output_type both \
|
91 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
92 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
93 |
+
--attn_mode torch \
|
94 |
+
--lora_weight Gundam_outputs/Gundam_w1_3_lora-000025.safetensors \
|
95 |
+
--lora_multiplier 1.0 \
|
96 |
+
--prompt "In the style of Gundam , The video features a large, yellow robot with two slender arms suggesting strength and speed. The background is plain and light-colored, ensuring the focus remains on the robot."
|
97 |
+
|
98 |
+
|
99 |
+
```
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
## Parameters
|
104 |
+
|
105 |
+
* `--fp8`: Enable FP8 precision (optional).
|
106 |
+
* `--task`: Specify the task (e.g., `t2v-1.3B`).
|
107 |
+
* `--video_size`: Set the resolution of the generated video (e.g., `1024 1024`).
|
108 |
+
* `--video_length`: Define the length of the video in frames.
|
109 |
+
* `--infer_steps`: Number of inference steps.
|
110 |
+
* `--save_path`: Directory to save the generated video.
|
111 |
+
* `--output_type`: Output type (e.g., `both` for video and frames).
|
112 |
+
* `--dit`: Path to the diffusion model weights.
|
113 |
+
* `--vae`: Path to the VAE model weights.
|
114 |
+
* `--t5`: Path to the T5 model weights.
|
115 |
+
* `--attn_mode`: Attention mode (e.g., `torch`).
|
116 |
+
* `--lora_weight`: Path to the LoRA weights.
|
117 |
+
* `--lora_multiplier`: Multiplier for LoRA weights.
|
118 |
+
* `--prompt`: Textual prompt for video generation.
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
## Output
|
123 |
+
|
124 |
+
The generated video and frames will be saved in the specified `save_path` directory.
|
125 |
+
|
126 |
+
## Troubleshooting
|
127 |
+
|
128 |
+
• Ensure all dependencies are correctly installed.
|
129 |
+
• Verify that the model weights are downloaded and placed in the correct locations.
|
130 |
+
• Check for any missing Python packages and install them using `pip`.
|
131 |
+
|
132 |
+
## License
|
133 |
+
|
134 |
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
135 |
+
|
136 |
+
## Acknowledgments
|
137 |
+
|
138 |
+
• **Hugging Face** for hosting the model weights.
|
139 |
+
• **Wan-AI** for providing the pre-trained models.
|
140 |
+
• **DeepBeepMeep** for contributing to the model weights.
|
141 |
+
|
142 |
+
## Contact
|
143 |
+
|
144 |
+
For any questions or issues, please open an issue on the repository or contact the maintainer.
|
145 |
+
|
146 |
+
---
|
147 |
+
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.10
|
Gundam_outputs/Gundam_w1_3_lora-000001.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4bc379f7399e80d4ff4c1ea763b1abe8d06a8b368ebbf058151a88fabd50d29d
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000002.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cd616ff62de984e314cb00e2813fed91680f330dc3f29cd72c7dc0895b48d084
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10b9fd3544961a69e630ff3a9e3e42c6ea55c9d68c9b0c3d0bbc0aed64218b61
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83dd9043e893384c048eaaa4ce614ed20fd8c74908e767aa94df9b2f6148e32b
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000005.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b3c9e2a14d9587b7b7df26e839afcbdbc51df03c5744db181c36ed4441f7300
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000006.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e98b6d3d424de07a26a0cf4e3ddf8c9bea53f77dc2f0a687d5ea8dee528eb5bd
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90b40c48bef2fed6cb15225da1c2852fde299b5dfea1beed9868a0c1dc07569d
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000008.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc4a71c6cb63d54c9b03a9b015b3b23fef4887fb10425988cc536fbd6967230e
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000009.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8ae9e2d10ff2822e93daa67fe7db4cdcba7747744c3c97a5feee3fa55dc8b9c1
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000010.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dba6602ad083af939ba6133e9cb3c6bdb5a9332a79f6fa27d4a291600730bc9a
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000011.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ce0c3acb45bb190331dae7b90b1fbc90b0aa5dbea2a44a489ae8e352937d23a
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000012.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e71353404f12ce7487ba4b392198afbabe69f5eaba2fee5d1b5a10ec3bee22eb
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000013.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5178a67e3f6865d4d8a70755da4c40bf82efff8c668967f15f80290f0a343db0
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000014.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e673899a7bd91058d92add206f39ffb497b961a7f2414fc121bed93ee68b230
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000015.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:44f25284fa970b0a247412ccc93c627f9cb6027bfb70fafc94510c15d2041f78
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000016.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a1b03ab55fc6fe90e7251784f293344c59f3a06dc2891bf964d4d9ab7d6a194a
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000017.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:89902b3ec65576f80c8c3eefd6c69e4e485f2a78d1a040ecd7b15d60e91eafa8
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000018.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dbd0c4f04b15dad756e85a4624849b13402e8f6bb1889f144928c9bf23cb0b8e
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000019.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f47dd80d1ecb4b98c1b43f05fb34b2327cd015127446d53b65770a66c66109b
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000020.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50efde2fd8b78d0d83ae24c059834464a38a2b5d695f70e1cec41093c1e1959f
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000021.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a580ac6b1c8e8c5ab8a028ce26e262f2c6af55e46639cd6c59088a6f9e984619
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000022.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a38e2e7d6ab1272da8dc8911df28e1a49dbbd602490828b7dfeacda135e0d96c
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000023.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c35d3c93ba0cf30b62dcd3956702dd462be9af09727fbe3316ae4b4b5ded2642
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000024.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a203c9a6106795e6895d16dc11cca70f62e094b82b09f9bda3e0d2129da0efe
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000025.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ee416f97d83fa66a95d64aa0995b13ae0bfd640a6591baf17b6f2dbb18624da9
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000026.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d546a7a0714094d1620fb11d5a6e3bbc6d42456f57c0778e710ac77c12e51b9b
|
3 |
+
size 87594680
|
Gundam_outputs/Gundam_w1_3_lora-000027.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ceaa7bd1f4ac61bcbf3a36163524d2ff6b9854e2d095940c98f79d73b4bcf778
|
3 |
+
size 87594680
|
README.md
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gundam Text-to-Video Generation
|
2 |
+
|
3 |
+
This repository contains the necessary steps and scripts to generate videos using the Gundam text-to-video model. The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create high-quality anime-style videos based on textual prompts.
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
Before proceeding, ensure that you have the following installed on your system:
|
8 |
+
|
9 |
+
• **Ubuntu** (or a compatible Linux distribution)
|
10 |
+
• **Python 3.x**
|
11 |
+
• **pip** (Python package manager)
|
12 |
+
• **Git**
|
13 |
+
• **Git LFS** (Git Large File Storage)
|
14 |
+
• **FFmpeg**
|
15 |
+
|
16 |
+
## Installation
|
17 |
+
|
18 |
+
1. **Update and Install Dependencies**
|
19 |
+
|
20 |
+
```bash
|
21 |
+
sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
|
22 |
+
```
|
23 |
+
|
24 |
+
2. **Clone the Repository**
|
25 |
+
|
26 |
+
```bash
|
27 |
+
git clone https://huggingface.co/svjack/Gundam_wan_2_1_1_3_B_text2video_lora
|
28 |
+
cd Gundam_wan_2_1_1_3_B_text2video_lora
|
29 |
+
```
|
30 |
+
|
31 |
+
3. **Install Python Dependencies**
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pip install torch torchvision
|
35 |
+
pip install -r requirements.txt
|
36 |
+
pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
|
37 |
+
pip install moviepy==1.0.3
|
38 |
+
pip install sageattention==1.0.6
|
39 |
+
```
|
40 |
+
|
41 |
+
4. **Download Model Weights**
|
42 |
+
|
43 |
+
```bash
|
44 |
+
wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/models_t5_umt5-xxl-enc-bf16.pth
|
45 |
+
wget https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
46 |
+
wget https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/resolve/main/Wan2.1_VAE.pth
|
47 |
+
wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors
|
48 |
+
wget https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_t2v_14B_bf16.safetensors
|
49 |
+
```
|
50 |
+
|
51 |
+
## Usage
|
52 |
+
|
53 |
+
To generate a video, use the `wan_generate_video.py` script with the appropriate parameters. Below are examples of how to generate videos using the Gundam model.
|
54 |
+
|
55 |
+
#### Gundam Moon Background
|
56 |
+
|
57 |
+
```bash
|
58 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
|
59 |
+
--save_path save --output_type both \
|
60 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
61 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
62 |
+
--attn_mode torch \
|
63 |
+
--lora_weight Gundam_outputs/Gundam_w1_3_lora-000027.safetensors \
|
64 |
+
--lora_multiplier 1.0 \
|
65 |
+
--prompt "In the style of Gundam , The video features a large, black robot with a predominantly humanoid form. walk one step at a time forward, with the moon as the background"
|
66 |
+
|
67 |
+
```
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
#### Pink Butterfly Gundam
|
72 |
+
|
73 |
+
```bash
|
74 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
|
75 |
+
--save_path save --output_type both \
|
76 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
77 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
78 |
+
--attn_mode torch \
|
79 |
+
--lora_weight Gundam_outputs/Gundam_w1_3_lora-000027.safetensors \
|
80 |
+
--lora_multiplier 1.0 \
|
81 |
+
--prompt "In the style of Gundam , The video features a large, pink robot with a butterfly form. suggesting smooth and cute. The background is plain and light-colored, ensuring the focus remains on the robot."
|
82 |
+
|
83 |
+
```
|
84 |
+
|
85 |
+
|
86 |
+
#### Yellow Gundam
|
87 |
+
|
88 |
+
```bash
|
89 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 480 832 --video_length 81 --infer_steps 20 \
|
90 |
+
--save_path save --output_type both \
|
91 |
+
--dit wan2.1_t2v_1.3B_bf16.safetensors --vae Wan2.1_VAE.pth \
|
92 |
+
--t5 models_t5_umt5-xxl-enc-bf16.pth \
|
93 |
+
--attn_mode torch \
|
94 |
+
--lora_weight Gundam_outputs/Gundam_w1_3_lora-000025.safetensors \
|
95 |
+
--lora_multiplier 1.0 \
|
96 |
+
--prompt "In the style of Gundam , The video features a large, yellow robot with two slender arms suggesting strength and speed. The background is plain and light-colored, ensuring the focus remains on the robot."
|
97 |
+
|
98 |
+
|
99 |
+
```
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
## Parameters
|
104 |
+
|
105 |
+
* `--fp8`: Enable FP8 precision (optional).
|
106 |
+
* `--task`: Specify the task (e.g., `t2v-1.3B`).
|
107 |
+
* `--video_size`: Set the resolution of the generated video (e.g., `1024 1024`).
|
108 |
+
* `--video_length`: Define the length of the video in frames.
|
109 |
+
* `--infer_steps`: Number of inference steps.
|
110 |
+
* `--save_path`: Directory to save the generated video.
|
111 |
+
* `--output_type`: Output type (e.g., `both` for video and frames).
|
112 |
+
* `--dit`: Path to the diffusion model weights.
|
113 |
+
* `--vae`: Path to the VAE model weights.
|
114 |
+
* `--t5`: Path to the T5 model weights.
|
115 |
+
* `--attn_mode`: Attention mode (e.g., `torch`).
|
116 |
+
* `--lora_weight`: Path to the LoRA weights.
|
117 |
+
* `--lora_multiplier`: Multiplier for LoRA weights.
|
118 |
+
* `--prompt`: Textual prompt for video generation.
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
## Output
|
123 |
+
|
124 |
+
The generated video and frames will be saved in the specified `save_path` directory.
|
125 |
+
|
126 |
+
## Troubleshooting
|
127 |
+
|
128 |
+
• Ensure all dependencies are correctly installed.
|
129 |
+
• Verify that the model weights are downloaded and placed in the correct locations.
|
130 |
+
• Check for any missing Python packages and install them using `pip`.
|
131 |
+
|
132 |
+
## License
|
133 |
+
|
134 |
+
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
|
135 |
+
|
136 |
+
## Acknowledgments
|
137 |
+
|
138 |
+
• **Hugging Face** for hosting the model weights.
|
139 |
+
• **Wan-AI** for providing the pre-trained models.
|
140 |
+
• **DeepBeepMeep** for contributing to the model weights.
|
141 |
+
|
142 |
+
## Contact
|
143 |
+
|
144 |
+
For any questions or issues, please open an issue on the repository or contact the maintainer.
|
145 |
+
|
146 |
+
---
|
147 |
+
|
cache_latents.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from dataset import config_utils
|
11 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
import logging
|
15 |
+
|
16 |
+
from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache, ARCHITECTURE_HUNYUAN_VIDEO
|
17 |
+
from hunyuan_model.vae import load_vae
|
18 |
+
from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
19 |
+
from utils.model_utils import str_to_dtype
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
|
24 |
+
|
25 |
+
def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
|
26 |
+
import cv2
|
27 |
+
|
28 |
+
imgs = (
|
29 |
+
[image]
|
30 |
+
if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
|
31 |
+
else [image[0], image[-1]]
|
32 |
+
)
|
33 |
+
if len(imgs) > 1:
|
34 |
+
print(f"Number of images: {len(image)}")
|
35 |
+
for i, img in enumerate(imgs):
|
36 |
+
if len(imgs) > 1:
|
37 |
+
print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
|
38 |
+
else:
|
39 |
+
print(f"Image: {img.shape}")
|
40 |
+
cv2_img = np.array(img) if isinstance(img, Image.Image) else img
|
41 |
+
cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
|
42 |
+
cv2.imshow("image", cv2_img)
|
43 |
+
k = cv2.waitKey(0)
|
44 |
+
cv2.destroyAllWindows()
|
45 |
+
if k == ord("q") or k == ord("d"):
|
46 |
+
return k
|
47 |
+
return k
|
48 |
+
|
49 |
+
|
50 |
+
def show_console(
|
51 |
+
image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
|
52 |
+
width: int,
|
53 |
+
back: str,
|
54 |
+
interactive: bool = False,
|
55 |
+
) -> int:
|
56 |
+
from ascii_magic import from_pillow_image, Back
|
57 |
+
|
58 |
+
back = None
|
59 |
+
if back is not None:
|
60 |
+
back = getattr(Back, back.upper())
|
61 |
+
|
62 |
+
k = None
|
63 |
+
imgs = (
|
64 |
+
[image]
|
65 |
+
if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
|
66 |
+
else [image[0], image[-1]]
|
67 |
+
)
|
68 |
+
if len(imgs) > 1:
|
69 |
+
print(f"Number of images: {len(image)}")
|
70 |
+
for i, img in enumerate(imgs):
|
71 |
+
if len(imgs) > 1:
|
72 |
+
print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
|
73 |
+
else:
|
74 |
+
print(f"Image: {img.shape}")
|
75 |
+
pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
|
76 |
+
ascii_img = from_pillow_image(pil_img)
|
77 |
+
ascii_img.to_terminal(columns=width, back=back)
|
78 |
+
|
79 |
+
if interactive:
|
80 |
+
k = input("Press q to quit, d to next dataset, other key to next: ")
|
81 |
+
if k == "q" or k == "d":
|
82 |
+
return ord(k)
|
83 |
+
|
84 |
+
if not interactive:
|
85 |
+
return ord(" ")
|
86 |
+
return ord(k) if k else ord(" ")
|
87 |
+
|
88 |
+
|
89 |
+
def show_datasets(
|
90 |
+
datasets: list[BaseDataset], debug_mode: str, console_width: int, console_back: str, console_num_images: Optional[int]
|
91 |
+
):
|
92 |
+
print(f"d: next dataset, q: quit")
|
93 |
+
|
94 |
+
num_workers = max(1, os.cpu_count() - 1)
|
95 |
+
for i, dataset in enumerate(datasets):
|
96 |
+
print(f"Dataset [{i}]")
|
97 |
+
batch_index = 0
|
98 |
+
num_images_to_show = console_num_images
|
99 |
+
k = None
|
100 |
+
for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
|
101 |
+
print(f"bucket resolution: {key}, count: {len(batch)}")
|
102 |
+
for j, item_info in enumerate(batch):
|
103 |
+
item_info: ItemInfo
|
104 |
+
print(f"{batch_index}-{j}: {item_info}")
|
105 |
+
if debug_mode == "image":
|
106 |
+
k = show_image(item_info.content)
|
107 |
+
elif debug_mode == "console":
|
108 |
+
k = show_console(item_info.content, console_width, console_back, console_num_images is None)
|
109 |
+
if num_images_to_show is not None:
|
110 |
+
num_images_to_show -= 1
|
111 |
+
if num_images_to_show == 0:
|
112 |
+
k = ord("d") # next dataset
|
113 |
+
|
114 |
+
if k == ord("q"):
|
115 |
+
return
|
116 |
+
elif k == ord("d"):
|
117 |
+
break
|
118 |
+
if k == ord("d"):
|
119 |
+
break
|
120 |
+
batch_index += 1
|
121 |
+
|
122 |
+
|
123 |
+
def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
|
124 |
+
contents = torch.stack([torch.from_numpy(item.content) for item in batch])
|
125 |
+
if len(contents.shape) == 4:
|
126 |
+
contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
|
127 |
+
|
128 |
+
contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
|
129 |
+
contents = contents.to(vae.device, dtype=vae.dtype)
|
130 |
+
contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
|
131 |
+
|
132 |
+
h, w = contents.shape[3], contents.shape[4]
|
133 |
+
if h < 8 or w < 8:
|
134 |
+
item = batch[0] # other items should have the same size
|
135 |
+
raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
|
136 |
+
|
137 |
+
# print(f"encode batch: {contents.shape}")
|
138 |
+
with torch.no_grad():
|
139 |
+
latent = vae.encode(contents).latent_dist.sample()
|
140 |
+
# latent = latent * vae.config.scaling_factor
|
141 |
+
|
142 |
+
# # debug: decode and save
|
143 |
+
# with torch.no_grad():
|
144 |
+
# latent_to_decode = latent / vae.config.scaling_factor
|
145 |
+
# images = vae.decode(latent_to_decode, return_dict=False)[0]
|
146 |
+
# images = (images / 2 + 0.5).clamp(0, 1)
|
147 |
+
# images = images.cpu().float().numpy()
|
148 |
+
# images = (images * 255).astype(np.uint8)
|
149 |
+
# images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
|
150 |
+
# for b in range(images.shape[0]):
|
151 |
+
# for f in range(images.shape[1]):
|
152 |
+
# fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
|
153 |
+
# img = Image.fromarray(images[b, f])
|
154 |
+
# img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
|
155 |
+
|
156 |
+
for item, l in zip(batch, latent):
|
157 |
+
# print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
|
158 |
+
save_latent_cache(item, l)
|
159 |
+
|
160 |
+
|
161 |
+
def encode_datasets(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
|
162 |
+
num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
|
163 |
+
for i, dataset in enumerate(datasets):
|
164 |
+
logger.info(f"Encoding dataset [{i}]")
|
165 |
+
all_latent_cache_paths = []
|
166 |
+
for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
|
167 |
+
all_latent_cache_paths.extend([item.latent_cache_path for item in batch])
|
168 |
+
|
169 |
+
if args.skip_existing:
|
170 |
+
filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
|
171 |
+
if len(filtered_batch) == 0:
|
172 |
+
continue
|
173 |
+
batch = filtered_batch
|
174 |
+
|
175 |
+
bs = args.batch_size if args.batch_size is not None else len(batch)
|
176 |
+
for i in range(0, len(batch), bs):
|
177 |
+
encode(batch[i : i + bs])
|
178 |
+
|
179 |
+
# normalize paths
|
180 |
+
all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
|
181 |
+
all_latent_cache_paths = set(all_latent_cache_paths)
|
182 |
+
|
183 |
+
# remove old cache files not in the dataset
|
184 |
+
all_cache_files = dataset.get_all_latent_cache_files()
|
185 |
+
for cache_file in all_cache_files:
|
186 |
+
if os.path.normpath(cache_file) not in all_latent_cache_paths:
|
187 |
+
if args.keep_cache:
|
188 |
+
logger.info(f"Keep cache file not in the dataset: {cache_file}")
|
189 |
+
else:
|
190 |
+
os.remove(cache_file)
|
191 |
+
logger.info(f"Removed old cache file: {cache_file}")
|
192 |
+
|
193 |
+
|
194 |
+
def main(args):
|
195 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
196 |
+
device = torch.device(device)
|
197 |
+
|
198 |
+
# Load dataset config
|
199 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
200 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
201 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
202 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
|
203 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
204 |
+
|
205 |
+
datasets = train_dataset_group.datasets
|
206 |
+
|
207 |
+
if args.debug_mode is not None:
|
208 |
+
show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
|
209 |
+
return
|
210 |
+
|
211 |
+
assert args.vae is not None, "vae checkpoint is required"
|
212 |
+
|
213 |
+
# Load VAE model: HunyuanVideo VAE model is float16
|
214 |
+
vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
|
215 |
+
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
|
216 |
+
vae.eval()
|
217 |
+
logger.info(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
|
218 |
+
|
219 |
+
if args.vae_chunk_size is not None:
|
220 |
+
vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
|
221 |
+
logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
|
222 |
+
if args.vae_spatial_tile_sample_min_size is not None:
|
223 |
+
vae.enable_spatial_tiling(True)
|
224 |
+
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
|
225 |
+
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
|
226 |
+
elif args.vae_tiling:
|
227 |
+
vae.enable_spatial_tiling(True)
|
228 |
+
|
229 |
+
# Encode images
|
230 |
+
def encode(one_batch: list[ItemInfo]):
|
231 |
+
encode_and_save_batch(vae, one_batch)
|
232 |
+
|
233 |
+
encode_datasets(datasets, encode, args)
|
234 |
+
|
235 |
+
|
236 |
+
def setup_parser_common() -> argparse.ArgumentParser:
|
237 |
+
parser = argparse.ArgumentParser()
|
238 |
+
|
239 |
+
parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
|
240 |
+
parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
|
241 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
|
242 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
|
243 |
+
parser.add_argument(
|
244 |
+
"--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
|
245 |
+
)
|
246 |
+
parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
|
247 |
+
parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
|
248 |
+
parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
|
249 |
+
parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console"], help="debug mode")
|
250 |
+
parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
|
251 |
+
parser.add_argument(
|
252 |
+
"--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--console_num_images",
|
256 |
+
type=int,
|
257 |
+
default=None,
|
258 |
+
help="debug mode: not interactive, number of images to show for each dataset",
|
259 |
+
)
|
260 |
+
return parser
|
261 |
+
|
262 |
+
|
263 |
+
def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
264 |
+
parser.add_argument(
|
265 |
+
"--vae_tiling",
|
266 |
+
action="store_true",
|
267 |
+
help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
|
268 |
+
)
|
269 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
270 |
+
parser.add_argument(
|
271 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
272 |
+
)
|
273 |
+
return parser
|
274 |
+
|
275 |
+
|
276 |
+
if __name__ == "__main__":
|
277 |
+
parser = setup_parser_common()
|
278 |
+
parser = hv_setup_parser(parser)
|
279 |
+
|
280 |
+
args = parser.parse_args()
|
281 |
+
main(args)
|
cache_text_encoder_outputs.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from dataset import config_utils
|
10 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
11 |
+
import accelerate
|
12 |
+
|
13 |
+
from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, BaseDataset, ItemInfo, save_text_encoder_output_cache
|
14 |
+
from hunyuan_model import text_encoder as text_encoder_module
|
15 |
+
from hunyuan_model.text_encoder import TextEncoder
|
16 |
+
|
17 |
+
import logging
|
18 |
+
|
19 |
+
from utils.model_utils import str_to_dtype
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
|
24 |
+
|
25 |
+
def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
|
26 |
+
data_type = "video" # video only, image is not supported
|
27 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
|
31 |
+
|
32 |
+
return prompt_outputs.hidden_state, prompt_outputs.attention_mask
|
33 |
+
|
34 |
+
|
35 |
+
def encode_and_save_batch(
|
36 |
+
text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
|
37 |
+
):
|
38 |
+
prompts = [item.caption for item in batch]
|
39 |
+
# print(prompts)
|
40 |
+
|
41 |
+
# encode prompt
|
42 |
+
if accelerator is not None:
|
43 |
+
with accelerator.autocast():
|
44 |
+
prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
|
45 |
+
else:
|
46 |
+
prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
|
47 |
+
|
48 |
+
# # convert to fp16 if needed
|
49 |
+
# if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
|
50 |
+
# prompt_embeds = prompt_embeds.to(text_encoder.dtype)
|
51 |
+
|
52 |
+
# save prompt cache
|
53 |
+
for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
|
54 |
+
save_text_encoder_output_cache(item, embed, mask, is_llm)
|
55 |
+
|
56 |
+
|
57 |
+
def prepare_cache_files_and_paths(datasets: list[BaseDataset]):
|
58 |
+
all_cache_files_for_dataset = [] # exisiting cache files
|
59 |
+
all_cache_paths_for_dataset = [] # all cache paths in the dataset
|
60 |
+
for dataset in datasets:
|
61 |
+
all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()]
|
62 |
+
all_cache_files = set(all_cache_files)
|
63 |
+
all_cache_files_for_dataset.append(all_cache_files)
|
64 |
+
|
65 |
+
all_cache_paths_for_dataset.append(set())
|
66 |
+
return all_cache_files_for_dataset, all_cache_paths_for_dataset
|
67 |
+
|
68 |
+
|
69 |
+
def process_text_encoder_batches(
|
70 |
+
num_workers: Optional[int],
|
71 |
+
skip_existing: bool,
|
72 |
+
batch_size: int,
|
73 |
+
datasets: list[BaseDataset],
|
74 |
+
all_cache_files_for_dataset: list[set],
|
75 |
+
all_cache_paths_for_dataset: list[set],
|
76 |
+
encode: callable,
|
77 |
+
):
|
78 |
+
num_workers = num_workers if num_workers is not None else max(1, os.cpu_count() - 1)
|
79 |
+
for i, dataset in enumerate(datasets):
|
80 |
+
logger.info(f"Encoding dataset [{i}]")
|
81 |
+
all_cache_files = all_cache_files_for_dataset[i]
|
82 |
+
all_cache_paths = all_cache_paths_for_dataset[i]
|
83 |
+
for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
|
84 |
+
# update cache files (it's ok if we update it multiple times)
|
85 |
+
all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch])
|
86 |
+
|
87 |
+
# skip existing cache files
|
88 |
+
if skip_existing:
|
89 |
+
filtered_batch = [
|
90 |
+
item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files
|
91 |
+
]
|
92 |
+
# print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files")
|
93 |
+
if len(filtered_batch) == 0:
|
94 |
+
continue
|
95 |
+
batch = filtered_batch
|
96 |
+
|
97 |
+
bs = batch_size if batch_size is not None else len(batch)
|
98 |
+
for i in range(0, len(batch), bs):
|
99 |
+
encode(batch[i : i + bs])
|
100 |
+
|
101 |
+
|
102 |
+
def post_process_cache_files(
|
103 |
+
datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set]
|
104 |
+
):
|
105 |
+
for i, dataset in enumerate(datasets):
|
106 |
+
all_cache_files = all_cache_files_for_dataset[i]
|
107 |
+
all_cache_paths = all_cache_paths_for_dataset[i]
|
108 |
+
for cache_file in all_cache_files:
|
109 |
+
if cache_file not in all_cache_paths:
|
110 |
+
if args.keep_cache:
|
111 |
+
logger.info(f"Keep cache file not in the dataset: {cache_file}")
|
112 |
+
else:
|
113 |
+
os.remove(cache_file)
|
114 |
+
logger.info(f"Removed old cache file: {cache_file}")
|
115 |
+
|
116 |
+
|
117 |
+
def main(args):
|
118 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
119 |
+
device = torch.device(device)
|
120 |
+
|
121 |
+
# Load dataset config
|
122 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
123 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
124 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
125 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
|
126 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
127 |
+
|
128 |
+
datasets = train_dataset_group.datasets
|
129 |
+
|
130 |
+
# define accelerator for fp8 inference
|
131 |
+
accelerator = None
|
132 |
+
if args.fp8_llm:
|
133 |
+
accelerator = accelerate.Accelerator(mixed_precision="fp16")
|
134 |
+
|
135 |
+
# prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
|
136 |
+
all_cache_files_for_dataset, all_cache_paths_for_dataset = prepare_cache_files_and_paths(datasets)
|
137 |
+
|
138 |
+
# Load Text Encoder 1
|
139 |
+
text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
|
140 |
+
logger.info(f"loading text encoder 1: {args.text_encoder1}")
|
141 |
+
text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
|
142 |
+
text_encoder_1.to(device=device)
|
143 |
+
|
144 |
+
# Encode with Text Encoder 1 (LLM)
|
145 |
+
logger.info("Encoding with Text Encoder 1")
|
146 |
+
|
147 |
+
def encode_for_text_encoder_1(batch: list[ItemInfo]):
|
148 |
+
encode_and_save_batch(text_encoder_1, batch, is_llm=True, accelerator=accelerator)
|
149 |
+
|
150 |
+
process_text_encoder_batches(
|
151 |
+
args.num_workers,
|
152 |
+
args.skip_existing,
|
153 |
+
args.batch_size,
|
154 |
+
datasets,
|
155 |
+
all_cache_files_for_dataset,
|
156 |
+
all_cache_paths_for_dataset,
|
157 |
+
encode_for_text_encoder_1,
|
158 |
+
)
|
159 |
+
del text_encoder_1
|
160 |
+
|
161 |
+
# Load Text Encoder 2
|
162 |
+
logger.info(f"loading text encoder 2: {args.text_encoder2}")
|
163 |
+
text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
|
164 |
+
text_encoder_2.to(device=device)
|
165 |
+
|
166 |
+
# Encode with Text Encoder 2
|
167 |
+
logger.info("Encoding with Text Encoder 2")
|
168 |
+
|
169 |
+
def encode_for_text_encoder_2(batch: list[ItemInfo]):
|
170 |
+
encode_and_save_batch(text_encoder_2, batch, is_llm=False, accelerator=None)
|
171 |
+
|
172 |
+
process_text_encoder_batches(
|
173 |
+
args.num_workers,
|
174 |
+
args.skip_existing,
|
175 |
+
args.batch_size,
|
176 |
+
datasets,
|
177 |
+
all_cache_files_for_dataset,
|
178 |
+
all_cache_paths_for_dataset,
|
179 |
+
encode_for_text_encoder_2,
|
180 |
+
)
|
181 |
+
del text_encoder_2
|
182 |
+
|
183 |
+
# remove cache files not in dataset
|
184 |
+
post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset)
|
185 |
+
|
186 |
+
|
187 |
+
def setup_parser_common():
|
188 |
+
parser = argparse.ArgumentParser()
|
189 |
+
|
190 |
+
parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
|
191 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
|
192 |
+
parser.add_argument(
|
193 |
+
"--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
|
194 |
+
)
|
195 |
+
parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
|
196 |
+
parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
|
197 |
+
parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
|
198 |
+
return parser
|
199 |
+
|
200 |
+
|
201 |
+
def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
202 |
+
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
|
203 |
+
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
|
204 |
+
parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
|
205 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
206 |
+
return parser
|
207 |
+
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
parser = setup_parser_common()
|
211 |
+
parser = hv_setup_parser(parser)
|
212 |
+
|
213 |
+
args = parser.parse_args()
|
214 |
+
main(args)
|
convert_lora.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from safetensors.torch import load_file, save_file
|
5 |
+
from safetensors import safe_open
|
6 |
+
from utils import model_utils
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
|
14 |
+
|
15 |
+
def convert_from_diffusers(prefix, weights_sd):
|
16 |
+
# convert from diffusers(?) to default LoRA
|
17 |
+
# Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
|
18 |
+
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
|
19 |
+
|
20 |
+
# note: Diffusers has no alpha, so alpha is set to rank
|
21 |
+
new_weights_sd = {}
|
22 |
+
lora_dims = {}
|
23 |
+
for key, weight in weights_sd.items():
|
24 |
+
diffusers_prefix, key_body = key.split(".", 1)
|
25 |
+
if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
|
26 |
+
logger.warning(f"unexpected key: {key} in diffusers format")
|
27 |
+
continue
|
28 |
+
|
29 |
+
new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
|
30 |
+
new_weights_sd[new_key] = weight
|
31 |
+
|
32 |
+
lora_name = new_key.split(".")[0] # before first dot
|
33 |
+
if lora_name not in lora_dims and "lora_down" in new_key:
|
34 |
+
lora_dims[lora_name] = weight.shape[0]
|
35 |
+
|
36 |
+
# add alpha with rank
|
37 |
+
for lora_name, dim in lora_dims.items():
|
38 |
+
new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
|
39 |
+
|
40 |
+
return new_weights_sd
|
41 |
+
|
42 |
+
|
43 |
+
def convert_to_diffusers(prefix, weights_sd):
|
44 |
+
# convert from default LoRA to diffusers
|
45 |
+
|
46 |
+
# get alphas
|
47 |
+
lora_alphas = {}
|
48 |
+
for key, weight in weights_sd.items():
|
49 |
+
if key.startswith(prefix):
|
50 |
+
lora_name = key.split(".", 1)[0] # before first dot
|
51 |
+
if lora_name not in lora_alphas and "alpha" in key:
|
52 |
+
lora_alphas[lora_name] = weight
|
53 |
+
|
54 |
+
new_weights_sd = {}
|
55 |
+
for key, weight in weights_sd.items():
|
56 |
+
if key.startswith(prefix):
|
57 |
+
if "alpha" in key:
|
58 |
+
continue
|
59 |
+
|
60 |
+
lora_name = key.split(".", 1)[0] # before first dot
|
61 |
+
|
62 |
+
module_name = lora_name[len(prefix) :] # remove "lora_unet_"
|
63 |
+
module_name = module_name.replace("_", ".") # replace "_" with "."
|
64 |
+
if ".cross.attn." in module_name or ".self.attn." in module_name:
|
65 |
+
# Wan2.1 lora name to module name: ugly but works
|
66 |
+
module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn
|
67 |
+
module_name = module_name.replace("self.attn", "self_attn") # fix self attn
|
68 |
+
else:
|
69 |
+
# HunyuanVideo lora name to module name: ugly but works
|
70 |
+
module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
|
71 |
+
module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
|
72 |
+
module_name = module_name.replace("img.", "img_") # fix img
|
73 |
+
module_name = module_name.replace("txt.", "txt_") # fix txt
|
74 |
+
module_name = module_name.replace("attn.", "attn_") # fix attn
|
75 |
+
|
76 |
+
diffusers_prefix = "diffusion_model"
|
77 |
+
if "lora_down" in key:
|
78 |
+
new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
|
79 |
+
dim = weight.shape[0]
|
80 |
+
elif "lora_up" in key:
|
81 |
+
new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
|
82 |
+
dim = weight.shape[1]
|
83 |
+
else:
|
84 |
+
logger.warning(f"unexpected key: {key} in default LoRA format")
|
85 |
+
continue
|
86 |
+
|
87 |
+
# scale weight by alpha
|
88 |
+
if lora_name in lora_alphas:
|
89 |
+
# we scale both down and up, so scale is sqrt
|
90 |
+
scale = lora_alphas[lora_name] / dim
|
91 |
+
scale = scale.sqrt()
|
92 |
+
weight = weight * scale
|
93 |
+
else:
|
94 |
+
logger.warning(f"missing alpha for {lora_name}")
|
95 |
+
|
96 |
+
new_weights_sd[new_key] = weight
|
97 |
+
|
98 |
+
return new_weights_sd
|
99 |
+
|
100 |
+
|
101 |
+
def convert(input_file, output_file, target_format):
|
102 |
+
logger.info(f"loading {input_file}")
|
103 |
+
weights_sd = load_file(input_file)
|
104 |
+
with safe_open(input_file, framework="pt") as f:
|
105 |
+
metadata = f.metadata()
|
106 |
+
|
107 |
+
logger.info(f"converting to {target_format}")
|
108 |
+
prefix = "lora_unet_"
|
109 |
+
if target_format == "default":
|
110 |
+
new_weights_sd = convert_from_diffusers(prefix, weights_sd)
|
111 |
+
metadata = metadata or {}
|
112 |
+
model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
|
113 |
+
elif target_format == "other":
|
114 |
+
new_weights_sd = convert_to_diffusers(prefix, weights_sd)
|
115 |
+
else:
|
116 |
+
raise ValueError(f"unknown target format: {target_format}")
|
117 |
+
|
118 |
+
logger.info(f"saving to {output_file}")
|
119 |
+
save_file(new_weights_sd, output_file, metadata=metadata)
|
120 |
+
|
121 |
+
logger.info("done")
|
122 |
+
|
123 |
+
|
124 |
+
def parse_args():
|
125 |
+
parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
|
126 |
+
parser.add_argument("--input", type=str, required=True, help="input model file")
|
127 |
+
parser.add_argument("--output", type=str, required=True, help="output model file")
|
128 |
+
parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
|
129 |
+
args = parser.parse_args()
|
130 |
+
return args
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
args = parse_args()
|
135 |
+
convert(args.input, args.output, args.target)
|
dataset/__init__.py
ADDED
File without changes
|
dataset/config_utils.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from dataclasses import (
|
3 |
+
asdict,
|
4 |
+
dataclass,
|
5 |
+
)
|
6 |
+
import functools
|
7 |
+
import random
|
8 |
+
from textwrap import dedent, indent
|
9 |
+
import json
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
# from toolz import curry
|
13 |
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
14 |
+
|
15 |
+
import toml
|
16 |
+
import voluptuous
|
17 |
+
from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
|
18 |
+
|
19 |
+
from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
|
20 |
+
|
21 |
+
import logging
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
logging.basicConfig(level=logging.INFO)
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class BaseDatasetParams:
|
29 |
+
resolution: Tuple[int, int] = (960, 544)
|
30 |
+
enable_bucket: bool = False
|
31 |
+
bucket_no_upscale: bool = False
|
32 |
+
caption_extension: Optional[str] = None
|
33 |
+
batch_size: int = 1
|
34 |
+
num_repeats: int = 1
|
35 |
+
cache_directory: Optional[str] = None
|
36 |
+
debug_dataset: bool = False
|
37 |
+
architecture: str = "no_default" # short style like "hv" or "wan"
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class ImageDatasetParams(BaseDatasetParams):
|
42 |
+
image_directory: Optional[str] = None
|
43 |
+
image_jsonl_file: Optional[str] = None
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class VideoDatasetParams(BaseDatasetParams):
|
48 |
+
video_directory: Optional[str] = None
|
49 |
+
video_jsonl_file: Optional[str] = None
|
50 |
+
target_frames: Sequence[int] = (1,)
|
51 |
+
frame_extraction: Optional[str] = "head"
|
52 |
+
frame_stride: Optional[int] = 1
|
53 |
+
frame_sample: Optional[int] = 1
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class DatasetBlueprint:
|
58 |
+
is_image_dataset: bool
|
59 |
+
params: Union[ImageDatasetParams, VideoDatasetParams]
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class DatasetGroupBlueprint:
|
64 |
+
datasets: Sequence[DatasetBlueprint]
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class Blueprint:
|
69 |
+
dataset_group: DatasetGroupBlueprint
|
70 |
+
|
71 |
+
|
72 |
+
class ConfigSanitizer:
|
73 |
+
# @curry
|
74 |
+
@staticmethod
|
75 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
76 |
+
Schema(ExactSequence([klass, klass]))(value)
|
77 |
+
return tuple(value)
|
78 |
+
|
79 |
+
# @curry
|
80 |
+
@staticmethod
|
81 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
82 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
83 |
+
try:
|
84 |
+
Schema(klass)(value)
|
85 |
+
return (value, value)
|
86 |
+
except:
|
87 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
88 |
+
|
89 |
+
# datasets schema
|
90 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
91 |
+
"caption_extension": str,
|
92 |
+
"batch_size": int,
|
93 |
+
"num_repeats": int,
|
94 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
95 |
+
"enable_bucket": bool,
|
96 |
+
"bucket_no_upscale": bool,
|
97 |
+
}
|
98 |
+
IMAGE_DATASET_DISTINCT_SCHEMA = {
|
99 |
+
"image_directory": str,
|
100 |
+
"image_jsonl_file": str,
|
101 |
+
"cache_directory": str,
|
102 |
+
}
|
103 |
+
VIDEO_DATASET_DISTINCT_SCHEMA = {
|
104 |
+
"video_directory": str,
|
105 |
+
"video_jsonl_file": str,
|
106 |
+
"target_frames": [int],
|
107 |
+
"frame_extraction": str,
|
108 |
+
"frame_stride": int,
|
109 |
+
"frame_sample": int,
|
110 |
+
"cache_directory": str,
|
111 |
+
}
|
112 |
+
|
113 |
+
# options handled by argparse but not handled by user config
|
114 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
115 |
+
"debug_dataset": bool,
|
116 |
+
}
|
117 |
+
|
118 |
+
def __init__(self) -> None:
|
119 |
+
self.image_dataset_schema = self.__merge_dict(
|
120 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
121 |
+
self.IMAGE_DATASET_DISTINCT_SCHEMA,
|
122 |
+
)
|
123 |
+
self.video_dataset_schema = self.__merge_dict(
|
124 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
125 |
+
self.VIDEO_DATASET_DISTINCT_SCHEMA,
|
126 |
+
)
|
127 |
+
|
128 |
+
def validate_flex_dataset(dataset_config: dict):
|
129 |
+
if "target_frames" in dataset_config:
|
130 |
+
return Schema(self.video_dataset_schema)(dataset_config)
|
131 |
+
else:
|
132 |
+
return Schema(self.image_dataset_schema)(dataset_config)
|
133 |
+
|
134 |
+
self.dataset_schema = validate_flex_dataset
|
135 |
+
|
136 |
+
self.general_schema = self.__merge_dict(
|
137 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
138 |
+
)
|
139 |
+
self.user_config_validator = Schema(
|
140 |
+
{
|
141 |
+
"general": self.general_schema,
|
142 |
+
"datasets": [self.dataset_schema],
|
143 |
+
}
|
144 |
+
)
|
145 |
+
self.argparse_schema = self.__merge_dict(
|
146 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
147 |
+
)
|
148 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
149 |
+
|
150 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
151 |
+
try:
|
152 |
+
return self.user_config_validator(user_config)
|
153 |
+
except MultipleInvalid:
|
154 |
+
# TODO: clarify the error message
|
155 |
+
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
156 |
+
raise
|
157 |
+
|
158 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
159 |
+
# However this will help us to detect program bug
|
160 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
161 |
+
try:
|
162 |
+
return self.argparse_config_validator(argparse_namespace)
|
163 |
+
except MultipleInvalid:
|
164 |
+
# XXX: this should be a bug
|
165 |
+
logger.error(
|
166 |
+
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
167 |
+
)
|
168 |
+
raise
|
169 |
+
|
170 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
171 |
+
@staticmethod
|
172 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
173 |
+
merged = {}
|
174 |
+
for schema in dict_list:
|
175 |
+
# merged |= schema
|
176 |
+
for k, v in schema.items():
|
177 |
+
merged[k] = v
|
178 |
+
return merged
|
179 |
+
|
180 |
+
|
181 |
+
class BlueprintGenerator:
|
182 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
|
183 |
+
|
184 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
185 |
+
self.sanitizer = sanitizer
|
186 |
+
|
187 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
188 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
189 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
190 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
191 |
+
|
192 |
+
argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
|
193 |
+
general_config = sanitized_user_config.get("general", {})
|
194 |
+
|
195 |
+
dataset_blueprints = []
|
196 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
197 |
+
is_image_dataset = "target_frames" not in dataset_config
|
198 |
+
if is_image_dataset:
|
199 |
+
dataset_params_klass = ImageDatasetParams
|
200 |
+
else:
|
201 |
+
dataset_params_klass = VideoDatasetParams
|
202 |
+
|
203 |
+
params = self.generate_params_by_fallbacks(
|
204 |
+
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
|
205 |
+
)
|
206 |
+
dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
|
207 |
+
|
208 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
209 |
+
|
210 |
+
return Blueprint(dataset_group_blueprint)
|
211 |
+
|
212 |
+
@staticmethod
|
213 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
214 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
215 |
+
search_value = BlueprintGenerator.search_value
|
216 |
+
default_params = asdict(param_klass())
|
217 |
+
param_names = default_params.keys()
|
218 |
+
|
219 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
220 |
+
|
221 |
+
return param_klass(**params)
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
|
225 |
+
for cand in fallbacks:
|
226 |
+
value = cand.get(key)
|
227 |
+
if value is not None:
|
228 |
+
return value
|
229 |
+
|
230 |
+
return default_value
|
231 |
+
|
232 |
+
|
233 |
+
# if training is True, it will return a dataset group for training, otherwise for caching
|
234 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
|
235 |
+
datasets: List[Union[ImageDataset, VideoDataset]] = []
|
236 |
+
|
237 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
238 |
+
if dataset_blueprint.is_image_dataset:
|
239 |
+
dataset_klass = ImageDataset
|
240 |
+
else:
|
241 |
+
dataset_klass = VideoDataset
|
242 |
+
|
243 |
+
dataset = dataset_klass(**asdict(dataset_blueprint.params))
|
244 |
+
datasets.append(dataset)
|
245 |
+
|
246 |
+
# assertion
|
247 |
+
cache_directories = [dataset.cache_directory for dataset in datasets]
|
248 |
+
num_of_unique_cache_directories = len(set(cache_directories))
|
249 |
+
if num_of_unique_cache_directories != len(cache_directories):
|
250 |
+
raise ValueError(
|
251 |
+
"cache directory should be unique for each dataset (note that cache directory is image/video directory if not specified)"
|
252 |
+
+ " / cache directory は各データセットごとに異なる必要があります(指定されていない場合はimage/video directoryが使われるので注意)"
|
253 |
+
)
|
254 |
+
|
255 |
+
# print info
|
256 |
+
info = ""
|
257 |
+
for i, dataset in enumerate(datasets):
|
258 |
+
is_image_dataset = isinstance(dataset, ImageDataset)
|
259 |
+
info += dedent(
|
260 |
+
f"""\
|
261 |
+
[Dataset {i}]
|
262 |
+
is_image_dataset: {is_image_dataset}
|
263 |
+
resolution: {dataset.resolution}
|
264 |
+
batch_size: {dataset.batch_size}
|
265 |
+
num_repeats: {dataset.num_repeats}
|
266 |
+
caption_extension: "{dataset.caption_extension}"
|
267 |
+
enable_bucket: {dataset.enable_bucket}
|
268 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
269 |
+
cache_directory: "{dataset.cache_directory}"
|
270 |
+
debug_dataset: {dataset.debug_dataset}
|
271 |
+
"""
|
272 |
+
)
|
273 |
+
|
274 |
+
if is_image_dataset:
|
275 |
+
info += indent(
|
276 |
+
dedent(
|
277 |
+
f"""\
|
278 |
+
image_directory: "{dataset.image_directory}"
|
279 |
+
image_jsonl_file: "{dataset.image_jsonl_file}"
|
280 |
+
\n"""
|
281 |
+
),
|
282 |
+
" ",
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
info += indent(
|
286 |
+
dedent(
|
287 |
+
f"""\
|
288 |
+
video_directory: "{dataset.video_directory}"
|
289 |
+
video_jsonl_file: "{dataset.video_jsonl_file}"
|
290 |
+
target_frames: {dataset.target_frames}
|
291 |
+
frame_extraction: {dataset.frame_extraction}
|
292 |
+
frame_stride: {dataset.frame_stride}
|
293 |
+
frame_sample: {dataset.frame_sample}
|
294 |
+
\n"""
|
295 |
+
),
|
296 |
+
" ",
|
297 |
+
)
|
298 |
+
logger.info(f"{info}")
|
299 |
+
|
300 |
+
# make buckets first because it determines the length of dataset
|
301 |
+
# and set the same seed for all datasets
|
302 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
303 |
+
for i, dataset in enumerate(datasets):
|
304 |
+
# logger.info(f"[Dataset {i}]")
|
305 |
+
dataset.set_seed(seed)
|
306 |
+
if training:
|
307 |
+
dataset.prepare_for_training()
|
308 |
+
|
309 |
+
return DatasetGroup(datasets)
|
310 |
+
|
311 |
+
|
312 |
+
def load_user_config(file: str) -> dict:
|
313 |
+
file: Path = Path(file)
|
314 |
+
if not file.is_file():
|
315 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
316 |
+
|
317 |
+
if file.name.lower().endswith(".json"):
|
318 |
+
try:
|
319 |
+
with open(file, "r", encoding="utf-8") as f:
|
320 |
+
config = json.load(f)
|
321 |
+
except Exception:
|
322 |
+
logger.error(
|
323 |
+
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
324 |
+
)
|
325 |
+
raise
|
326 |
+
elif file.name.lower().endswith(".toml"):
|
327 |
+
try:
|
328 |
+
config = toml.load(file)
|
329 |
+
except Exception:
|
330 |
+
logger.error(
|
331 |
+
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
332 |
+
)
|
333 |
+
raise
|
334 |
+
else:
|
335 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
336 |
+
|
337 |
+
return config
|
338 |
+
|
339 |
+
|
340 |
+
# for config test
|
341 |
+
if __name__ == "__main__":
|
342 |
+
parser = argparse.ArgumentParser()
|
343 |
+
parser.add_argument("dataset_config")
|
344 |
+
config_args, remain = parser.parse_known_args()
|
345 |
+
|
346 |
+
parser = argparse.ArgumentParser()
|
347 |
+
parser.add_argument("--debug_dataset", action="store_true")
|
348 |
+
argparse_namespace = parser.parse_args(remain)
|
349 |
+
|
350 |
+
logger.info("[argparse_namespace]")
|
351 |
+
logger.info(f"{vars(argparse_namespace)}")
|
352 |
+
|
353 |
+
user_config = load_user_config(config_args.dataset_config)
|
354 |
+
|
355 |
+
logger.info("")
|
356 |
+
logger.info("[user_config]")
|
357 |
+
logger.info(f"{user_config}")
|
358 |
+
|
359 |
+
sanitizer = ConfigSanitizer()
|
360 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
361 |
+
|
362 |
+
logger.info("")
|
363 |
+
logger.info("[sanitized_user_config]")
|
364 |
+
logger.info(f"{sanitized_user_config}")
|
365 |
+
|
366 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
367 |
+
|
368 |
+
logger.info("")
|
369 |
+
logger.info("[blueprint]")
|
370 |
+
logger.info(f"{blueprint}")
|
371 |
+
|
372 |
+
dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
dataset/dataset_config.md
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> 📝 Click on the language section to expand / 言語をクリックして展開
|
2 |
+
|
3 |
+
## Dataset Configuration
|
4 |
+
|
5 |
+
<details>
|
6 |
+
<summary>English</summary>
|
7 |
+
|
8 |
+
Please create a TOML file for dataset configuration.
|
9 |
+
|
10 |
+
Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
|
11 |
+
|
12 |
+
The cache directory must be different for each dataset.
|
13 |
+
</details>
|
14 |
+
|
15 |
+
<details>
|
16 |
+
<summary>日本語</summary>
|
17 |
+
|
18 |
+
データセットの設定を行うためのTOMLファイルを作成してください。
|
19 |
+
|
20 |
+
画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。
|
21 |
+
|
22 |
+
キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。
|
23 |
+
</details>
|
24 |
+
|
25 |
+
### Sample for Image Dataset with Caption Text Files
|
26 |
+
|
27 |
+
```toml
|
28 |
+
# resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
29 |
+
# otherwise, the default values will be used for each item
|
30 |
+
|
31 |
+
# general configurations
|
32 |
+
[general]
|
33 |
+
resolution = [960, 544]
|
34 |
+
caption_extension = ".txt"
|
35 |
+
batch_size = 1
|
36 |
+
enable_bucket = true
|
37 |
+
bucket_no_upscale = false
|
38 |
+
|
39 |
+
[[datasets]]
|
40 |
+
image_directory = "/path/to/image_dir"
|
41 |
+
cache_directory = "/path/to/cache_directory"
|
42 |
+
num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
|
43 |
+
|
44 |
+
# other datasets can be added here. each dataset can have different configurations
|
45 |
+
```
|
46 |
+
|
47 |
+
<details>
|
48 |
+
<summary>English</summary>
|
49 |
+
|
50 |
+
`cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets.
|
51 |
+
|
52 |
+
`num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes.
|
53 |
+
|
54 |
+
</details>
|
55 |
+
|
56 |
+
<details>
|
57 |
+
<summary>日本語</summary>
|
58 |
+
|
59 |
+
`cache_directory` はオプションです。デフォルトは画像ディレクトリと同じディレクトリに設定されます。ただし、異なるデータセット間でキャッシュファイルが共有されるのを防ぐために、明示的に別のキャッシュディレクトリを設定することをお勧めします。
|
60 |
+
|
61 |
+
`num_repeats` はオプションで、デフォルトは 1 です(繰り返しなし)。画像(や動画)を、その回数だけ単純に繰り返してデータセットを拡張します。たとえば`num_repeats = 2`としたとき、画像20枚のデータセットなら、各画像が2枚ずつ(同一のキャプションで)計40枚存在した場合と同じになります。異なるデータ数のデータセット間でバランスを取るために使用可能です。
|
62 |
+
|
63 |
+
resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
|
64 |
+
|
65 |
+
`[[datasets]]`以下を追加することで、他のデータセットを追加できます。各データセットには異なる設定を持てます。
|
66 |
+
</details>
|
67 |
+
|
68 |
+
### Sample for Image Dataset with Metadata JSONL File
|
69 |
+
|
70 |
+
```toml
|
71 |
+
# resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
72 |
+
# caption_extension is not required for metadata jsonl file
|
73 |
+
# cache_directory is required for each dataset with metadata jsonl file
|
74 |
+
|
75 |
+
# general configurations
|
76 |
+
[general]
|
77 |
+
resolution = [960, 544]
|
78 |
+
batch_size = 1
|
79 |
+
enable_bucket = true
|
80 |
+
bucket_no_upscale = false
|
81 |
+
|
82 |
+
[[datasets]]
|
83 |
+
image_jsonl_file = "/path/to/metadata.jsonl"
|
84 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
85 |
+
num_repeats = 1 # optional, default is 1. Same as above.
|
86 |
+
|
87 |
+
# other datasets can be added here. each dataset can have different configurations
|
88 |
+
```
|
89 |
+
|
90 |
+
JSONL file format for metadata:
|
91 |
+
|
92 |
+
```json
|
93 |
+
{"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
|
94 |
+
{"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
|
95 |
+
```
|
96 |
+
|
97 |
+
<details>
|
98 |
+
<summary>日本語</summary>
|
99 |
+
|
100 |
+
resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
|
101 |
+
|
102 |
+
metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須��す。
|
103 |
+
|
104 |
+
キャプションによるデータセットと同様に、複数のデータセットを追加できます。各データセットには異なる設定を持てます。
|
105 |
+
</details>
|
106 |
+
|
107 |
+
|
108 |
+
### Sample for Video Dataset with Caption Text Files
|
109 |
+
|
110 |
+
```toml
|
111 |
+
# resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample,
|
112 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
113 |
+
# num_repeats is also available for video dataset, example is not shown here
|
114 |
+
|
115 |
+
# general configurations
|
116 |
+
[general]
|
117 |
+
resolution = [960, 544]
|
118 |
+
caption_extension = ".txt"
|
119 |
+
batch_size = 1
|
120 |
+
enable_bucket = true
|
121 |
+
bucket_no_upscale = false
|
122 |
+
|
123 |
+
[[datasets]]
|
124 |
+
video_directory = "/path/to/video_dir"
|
125 |
+
cache_directory = "/path/to/cache_directory" # recommended to set cache directory
|
126 |
+
target_frames = [1, 25, 45]
|
127 |
+
frame_extraction = "head"
|
128 |
+
|
129 |
+
# other datasets can be added here. each dataset can have different configurations
|
130 |
+
```
|
131 |
+
|
132 |
+
<details>
|
133 |
+
<summary>日本語</summary>
|
134 |
+
|
135 |
+
resolution, caption_extension, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。
|
136 |
+
|
137 |
+
他の注意事項は画像データセットと同様です。
|
138 |
+
</details>
|
139 |
+
|
140 |
+
### Sample for Video Dataset with Metadata JSONL File
|
141 |
+
|
142 |
+
```toml
|
143 |
+
# resolution, target_frames, frame_extraction, frame_stride, frame_sample,
|
144 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
145 |
+
# caption_extension is not required for metadata jsonl file
|
146 |
+
# cache_directory is required for each dataset with metadata jsonl file
|
147 |
+
|
148 |
+
# general configurations
|
149 |
+
[general]
|
150 |
+
resolution = [960, 544]
|
151 |
+
batch_size = 1
|
152 |
+
enable_bucket = true
|
153 |
+
bucket_no_upscale = false
|
154 |
+
|
155 |
+
[[datasets]]
|
156 |
+
video_jsonl_file = "/path/to/metadata.jsonl"
|
157 |
+
target_frames = [1, 25, 45]
|
158 |
+
frame_extraction = "head"
|
159 |
+
cache_directory = "/path/to/cache_directory_head"
|
160 |
+
|
161 |
+
# same metadata jsonl file can be used for multiple datasets
|
162 |
+
[[datasets]]
|
163 |
+
video_jsonl_file = "/path/to/metadata.jsonl"
|
164 |
+
target_frames = [1]
|
165 |
+
frame_stride = 10
|
166 |
+
cache_directory = "/path/to/cache_directory_stride"
|
167 |
+
|
168 |
+
# other datasets can be added here. each dataset can have different configurations
|
169 |
+
```
|
170 |
+
|
171 |
+
JSONL file format for metadata:
|
172 |
+
|
173 |
+
```json
|
174 |
+
{"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
|
175 |
+
{"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
|
176 |
+
```
|
177 |
+
|
178 |
+
<details>
|
179 |
+
<summary>日本語</summary>
|
180 |
+
|
181 |
+
resolution, target_frames, frame_extraction, frame_stride, frame_sample, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。
|
182 |
+
|
183 |
+
metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
|
184 |
+
|
185 |
+
他の注意事項は今までのデータセットと同様です。
|
186 |
+
</details>
|
187 |
+
|
188 |
+
### frame_extraction Options
|
189 |
+
|
190 |
+
<details>
|
191 |
+
<summary>English</summary>
|
192 |
+
|
193 |
+
- `head`: Extract the first N frames from the video.
|
194 |
+
- `chunk`: Extract frames by splitting the video into chunks of N frames.
|
195 |
+
- `slide`: Extract frames from the video with a stride of `frame_stride`.
|
196 |
+
- `uniform`: Extract `frame_sample` samples uniformly from the video.
|
197 |
+
|
198 |
+
For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
|
199 |
+
</details>
|
200 |
+
|
201 |
+
<details>
|
202 |
+
<summary>日本語</summary>
|
203 |
+
|
204 |
+
- `head`: 動画から最初のNフレームを抽出します。
|
205 |
+
- `chunk`: 動画をNフレームずつに分割してフレームを抽出します。
|
206 |
+
- `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。
|
207 |
+
- `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。
|
208 |
+
|
209 |
+
例えば、40フレームの動画を例とした抽出について、以下の図で説明します。
|
210 |
+
</details>
|
211 |
+
|
212 |
+
```
|
213 |
+
Original Video, 40 frames: x = frame, o = no frame
|
214 |
+
oooooooooooooooooooooooooooooooooooooooo
|
215 |
+
|
216 |
+
head, target_frames = [1, 13, 25] -> extract head frames:
|
217 |
+
xooooooooooooooooooooooooooooooooooooooo
|
218 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
219 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
220 |
+
|
221 |
+
chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
|
222 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
223 |
+
oooooooooooooxxxxxxxxxxxxxoooooooooooooo
|
224 |
+
ooooooooooooooooooooooooooxxxxxxxxxxxxxo
|
225 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
226 |
+
|
227 |
+
NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
|
228 |
+
注: frame_extraction "chunk" を使用する場合、target_frames に 1 を含めないでください。全てのフレームが抽出されてしまいます。
|
229 |
+
|
230 |
+
slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
|
231 |
+
xooooooooooooooooooooooooooooooooooooooo
|
232 |
+
ooooooooooxooooooooooooooooooooooooooooo
|
233 |
+
ooooooooooooooooooooxooooooooooooooooooo
|
234 |
+
ooooooooooooooooooooooooooooooxooooooooo
|
235 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
236 |
+
ooooooooooxxxxxxxxxxxxxooooooooooooooooo
|
237 |
+
ooooooooooooooooooooxxxxxxxxxxxxxooooooo
|
238 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
239 |
+
ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
|
240 |
+
|
241 |
+
uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
|
242 |
+
xooooooooooooooooooooooooooooooooooooooo
|
243 |
+
oooooooooooooxoooooooooooooooooooooooooo
|
244 |
+
oooooooooooooooooooooooooxoooooooooooooo
|
245 |
+
ooooooooooooooooooooooooooooooooooooooox
|
246 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
247 |
+
oooooooooxxxxxxxxxxxxxoooooooooooooooooo
|
248 |
+
ooooooooooooooooooxxxxxxxxxxxxxooooooooo
|
249 |
+
oooooooooooooooooooooooooooxxxxxxxxxxxxx
|
250 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
251 |
+
oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
|
252 |
+
ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
|
253 |
+
oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
|
254 |
+
```
|
255 |
+
|
256 |
+
## Specifications
|
257 |
+
|
258 |
+
```toml
|
259 |
+
# general configurations
|
260 |
+
[general]
|
261 |
+
resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
|
262 |
+
caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
|
263 |
+
batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
|
264 |
+
num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
|
265 |
+
enable_bucket = true # optional, default is false. Enable bucketing for datasets
|
266 |
+
bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
|
267 |
+
|
268 |
+
### Image Dataset
|
269 |
+
|
270 |
+
# sample image dataset with caption text files
|
271 |
+
[[datasets]]
|
272 |
+
image_directory = "/path/to/image_dir"
|
273 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
274 |
+
resolution = [960, 544] # required if general resolution is not set
|
275 |
+
batch_size = 4 # optional, overwrite the default batch size
|
276 |
+
num_repeats = 1 # optional, overwrite the default num_repeats
|
277 |
+
enable_bucket = false # optional, overwrite the default bucketing setting
|
278 |
+
bucket_no_upscale = true # optional, overwrite the default bucketing setting
|
279 |
+
cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
|
280 |
+
|
281 |
+
# sample image dataset with metadata **jsonl** file
|
282 |
+
[[datasets]]
|
283 |
+
image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
|
284 |
+
resolution = [960, 544] # required if general resolution is not set
|
285 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
286 |
+
# caption_extension is not required for metadata jsonl file
|
287 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
|
288 |
+
|
289 |
+
### Video Dataset
|
290 |
+
|
291 |
+
# sample video dataset with caption text files
|
292 |
+
[[datasets]]
|
293 |
+
video_directory = "/path/to/video_dir"
|
294 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
295 |
+
resolution = [960, 544] # required if general resolution is not set
|
296 |
+
|
297 |
+
target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
|
298 |
+
|
299 |
+
# NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
|
300 |
+
|
301 |
+
frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
|
302 |
+
frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
|
303 |
+
frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
|
304 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
|
305 |
+
|
306 |
+
# sample video dataset with metadata jsonl file
|
307 |
+
[[datasets]]
|
308 |
+
video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
|
309 |
+
|
310 |
+
target_frames = [1, 79]
|
311 |
+
|
312 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
313 |
+
# frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
|
314 |
+
```
|
315 |
+
|
316 |
+
<!--
|
317 |
+
# sample image dataset with lance
|
318 |
+
[[datasets]]
|
319 |
+
image_lance_dataset = "/path/to/lance_dataset"
|
320 |
+
resolution = [960, 544] # required if general resolution is not set
|
321 |
+
# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
|
322 |
+
-->
|
323 |
+
|
324 |
+
The metadata with .json file will be supported in the near future.
|
325 |
+
|
326 |
+
|
327 |
+
|
328 |
+
<!--
|
329 |
+
|
330 |
+
```toml
|
331 |
+
# general configurations
|
332 |
+
[general]
|
333 |
+
resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
|
334 |
+
caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
|
335 |
+
batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
|
336 |
+
enable_bucket = true # optional, default is false. Enable bucketing for datasets
|
337 |
+
bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
|
338 |
+
|
339 |
+
# sample image dataset with caption text files
|
340 |
+
[[datasets]]
|
341 |
+
image_directory = "/path/to/image_dir"
|
342 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
343 |
+
resolution = [960, 544] # required if general resolution is not set
|
344 |
+
batch_size = 4 # optional, overwrite the default batch size
|
345 |
+
enable_bucket = false # optional, overwrite the default bucketing setting
|
346 |
+
bucket_no_upscale = true # optional, overwrite the default bucketing setting
|
347 |
+
cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
|
348 |
+
|
349 |
+
# sample image dataset with metadata **jsonl** file
|
350 |
+
[[datasets]]
|
351 |
+
image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
|
352 |
+
resolution = [960, 544] # required if general resolution is not set
|
353 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
354 |
+
# caption_extension is not required for metadata jsonl file
|
355 |
+
# batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
|
356 |
+
|
357 |
+
# sample video dataset with caption text files
|
358 |
+
[[datasets]]
|
359 |
+
video_directory = "/path/to/video_dir"
|
360 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
361 |
+
resolution = [960, 544] # required if general resolution is not set
|
362 |
+
target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
|
363 |
+
frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
|
364 |
+
frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
|
365 |
+
frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
|
366 |
+
# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
|
367 |
+
|
368 |
+
# sample video dataset with metadata jsonl file
|
369 |
+
[[datasets]]
|
370 |
+
video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
|
371 |
+
target_frames = [1, 79]
|
372 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
373 |
+
# frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
|
374 |
+
```
|
375 |
+
|
376 |
+
# sample image dataset with lance
|
377 |
+
[[datasets]]
|
378 |
+
image_lance_dataset = "/path/to/lance_dataset"
|
379 |
+
resolution = [960, 544] # required if general resolution is not set
|
380 |
+
# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
|
381 |
+
|
382 |
+
The metadata with .json file will be supported in the near future.
|
383 |
+
|
384 |
+
|
385 |
+
|
386 |
+
|
387 |
+
-->
|
dataset/image_video_dataset.py
ADDED
@@ -0,0 +1,1400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concurrent.futures import ThreadPoolExecutor
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
from typing import Optional, Sequence, Tuple, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from safetensors.torch import save_file, load_file
|
13 |
+
from safetensors import safe_open
|
14 |
+
from PIL import Image
|
15 |
+
import cv2
|
16 |
+
import av
|
17 |
+
|
18 |
+
from utils import safetensors_utils
|
19 |
+
from utils.model_utils import dtype_to_str
|
20 |
+
|
21 |
+
import logging
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
logging.basicConfig(level=logging.INFO)
|
25 |
+
|
26 |
+
|
27 |
+
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
|
28 |
+
|
29 |
+
try:
|
30 |
+
import pillow_avif
|
31 |
+
|
32 |
+
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
|
33 |
+
except:
|
34 |
+
pass
|
35 |
+
|
36 |
+
# JPEG-XL on Linux
|
37 |
+
try:
|
38 |
+
from jxlpy import JXLImagePlugin
|
39 |
+
|
40 |
+
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
41 |
+
except:
|
42 |
+
pass
|
43 |
+
|
44 |
+
# JPEG-XL on Windows
|
45 |
+
try:
|
46 |
+
import pillow_jxl
|
47 |
+
|
48 |
+
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
49 |
+
except:
|
50 |
+
pass
|
51 |
+
|
52 |
+
VIDEO_EXTENSIONS = [
|
53 |
+
".mp4",
|
54 |
+
".webm",
|
55 |
+
".avi",
|
56 |
+
".mkv",
|
57 |
+
".mov",
|
58 |
+
".flv",
|
59 |
+
".wmv",
|
60 |
+
".m4v",
|
61 |
+
".mpg",
|
62 |
+
".mpeg",
|
63 |
+
".MP4",
|
64 |
+
".WEBM",
|
65 |
+
".AVI",
|
66 |
+
".MKV",
|
67 |
+
".MOV",
|
68 |
+
".FLV",
|
69 |
+
".WMV",
|
70 |
+
".M4V",
|
71 |
+
".MPG",
|
72 |
+
".MPEG",
|
73 |
+
] # some of them are not tested
|
74 |
+
|
75 |
+
ARCHITECTURE_HUNYUAN_VIDEO = "hv"
|
76 |
+
ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video"
|
77 |
+
ARCHITECTURE_WAN = "wan"
|
78 |
+
ARCHITECTURE_WAN_FULL = "wan"
|
79 |
+
|
80 |
+
|
81 |
+
def glob_images(directory, base="*"):
|
82 |
+
img_paths = []
|
83 |
+
for ext in IMAGE_EXTENSIONS:
|
84 |
+
if base == "*":
|
85 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
86 |
+
else:
|
87 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
88 |
+
img_paths = list(set(img_paths)) # remove duplicates
|
89 |
+
img_paths.sort()
|
90 |
+
return img_paths
|
91 |
+
|
92 |
+
|
93 |
+
def glob_videos(directory, base="*"):
|
94 |
+
video_paths = []
|
95 |
+
for ext in VIDEO_EXTENSIONS:
|
96 |
+
if base == "*":
|
97 |
+
video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
98 |
+
else:
|
99 |
+
video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
100 |
+
video_paths = list(set(video_paths)) # remove duplicates
|
101 |
+
video_paths.sort()
|
102 |
+
return video_paths
|
103 |
+
|
104 |
+
|
105 |
+
def divisible_by(num: int, divisor: int) -> int:
|
106 |
+
return num - num % divisor
|
107 |
+
|
108 |
+
|
109 |
+
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
|
110 |
+
"""
|
111 |
+
Resize the image to the bucket resolution.
|
112 |
+
"""
|
113 |
+
is_pil_image = isinstance(image, Image.Image)
|
114 |
+
if is_pil_image:
|
115 |
+
image_width, image_height = image.size
|
116 |
+
else:
|
117 |
+
image_height, image_width = image.shape[:2]
|
118 |
+
|
119 |
+
if bucket_reso == (image_width, image_height):
|
120 |
+
return np.array(image) if is_pil_image else image
|
121 |
+
|
122 |
+
bucket_width, bucket_height = bucket_reso
|
123 |
+
if bucket_width == image_width or bucket_height == image_height:
|
124 |
+
image = np.array(image) if is_pil_image else image
|
125 |
+
else:
|
126 |
+
# resize the image to the bucket resolution to match the short side
|
127 |
+
scale_width = bucket_width / image_width
|
128 |
+
scale_height = bucket_height / image_height
|
129 |
+
scale = max(scale_width, scale_height)
|
130 |
+
image_width = int(image_width * scale + 0.5)
|
131 |
+
image_height = int(image_height * scale + 0.5)
|
132 |
+
|
133 |
+
if scale > 1:
|
134 |
+
image = Image.fromarray(image) if not is_pil_image else image
|
135 |
+
image = image.resize((image_width, image_height), Image.LANCZOS)
|
136 |
+
image = np.array(image)
|
137 |
+
else:
|
138 |
+
image = np.array(image) if is_pil_image else image
|
139 |
+
image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
|
140 |
+
|
141 |
+
# crop the image to the bucket resolution
|
142 |
+
crop_left = (image_width - bucket_width) // 2
|
143 |
+
crop_top = (image_height - bucket_height) // 2
|
144 |
+
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
|
145 |
+
return image
|
146 |
+
|
147 |
+
|
148 |
+
class ItemInfo:
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
item_key: str,
|
152 |
+
caption: str,
|
153 |
+
original_size: tuple[int, int],
|
154 |
+
bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
|
155 |
+
frame_count: Optional[int] = None,
|
156 |
+
content: Optional[np.ndarray] = None,
|
157 |
+
latent_cache_path: Optional[str] = None,
|
158 |
+
) -> None:
|
159 |
+
self.item_key = item_key
|
160 |
+
self.caption = caption
|
161 |
+
self.original_size = original_size
|
162 |
+
self.bucket_size = bucket_size
|
163 |
+
self.frame_count = frame_count
|
164 |
+
self.content = content
|
165 |
+
self.latent_cache_path = latent_cache_path
|
166 |
+
self.text_encoder_output_cache_path: Optional[str] = None
|
167 |
+
|
168 |
+
def __str__(self) -> str:
|
169 |
+
return (
|
170 |
+
f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
|
171 |
+
+ f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
|
172 |
+
+ f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})"
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
# We use simple if-else approach to support multiple architectures.
|
177 |
+
# Maybe we can use a plugin system in the future.
|
178 |
+
|
179 |
+
# the keys of the dict are `<content_type>_FxHxW_<dtype>` for latents
|
180 |
+
# and `<content_type>_<dtype|mask>` for other tensors
|
181 |
+
|
182 |
+
|
183 |
+
def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
|
184 |
+
"""HunyuanVideo architecture only"""
|
185 |
+
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
|
186 |
+
|
187 |
+
_, F, H, W = latent.shape
|
188 |
+
dtype_str = dtype_to_str(latent.dtype)
|
189 |
+
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
|
190 |
+
|
191 |
+
save_latent_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
|
192 |
+
|
193 |
+
|
194 |
+
def save_latent_cache_wan(
|
195 |
+
item_info: ItemInfo, latent: torch.Tensor, clip_embed: Optional[torch.Tensor], image_latent: Optional[torch.Tensor]
|
196 |
+
):
|
197 |
+
"""Wan architecture only"""
|
198 |
+
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
|
199 |
+
|
200 |
+
_, F, H, W = latent.shape
|
201 |
+
dtype_str = dtype_to_str(latent.dtype)
|
202 |
+
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
|
203 |
+
|
204 |
+
if clip_embed is not None:
|
205 |
+
sd[f"clip_{dtype_str}"] = clip_embed.detach().cpu()
|
206 |
+
|
207 |
+
if image_latent is not None:
|
208 |
+
sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu()
|
209 |
+
|
210 |
+
save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
|
211 |
+
|
212 |
+
|
213 |
+
def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
|
214 |
+
metadata = {
|
215 |
+
"architecture": arch_fullname,
|
216 |
+
"width": f"{item_info.original_size[0]}",
|
217 |
+
"height": f"{item_info.original_size[1]}",
|
218 |
+
"format_version": "1.0.1",
|
219 |
+
}
|
220 |
+
if item_info.frame_count is not None:
|
221 |
+
metadata["frame_count"] = f"{item_info.frame_count}"
|
222 |
+
|
223 |
+
for key, value in sd.items():
|
224 |
+
# NaN check and show warning, replace NaN with 0
|
225 |
+
if torch.isnan(value).any():
|
226 |
+
logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
|
227 |
+
value[torch.isnan(value)] = 0
|
228 |
+
|
229 |
+
latent_dir = os.path.dirname(item_info.latent_cache_path)
|
230 |
+
os.makedirs(latent_dir, exist_ok=True)
|
231 |
+
|
232 |
+
save_file(sd, item_info.latent_cache_path, metadata=metadata)
|
233 |
+
|
234 |
+
|
235 |
+
def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
|
236 |
+
"""HunyuanVideo architecture only"""
|
237 |
+
assert (
|
238 |
+
embed.dim() == 1 or embed.dim() == 2
|
239 |
+
), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
|
240 |
+
assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
|
241 |
+
|
242 |
+
sd = {}
|
243 |
+
dtype_str = dtype_to_str(embed.dtype)
|
244 |
+
text_encoder_type = "llm" if is_llm else "clipL"
|
245 |
+
sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
|
246 |
+
if mask is not None:
|
247 |
+
sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
|
248 |
+
|
249 |
+
save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
|
250 |
+
|
251 |
+
|
252 |
+
def save_text_encoder_output_cache_wan(item_info: ItemInfo, embed: torch.Tensor):
|
253 |
+
"""Wan architecture only. Wan2.1 only has a single text encoder"""
|
254 |
+
|
255 |
+
sd = {}
|
256 |
+
dtype_str = dtype_to_str(embed.dtype)
|
257 |
+
text_encoder_type = "t5"
|
258 |
+
sd[f"varlen_{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
|
259 |
+
|
260 |
+
save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
|
261 |
+
|
262 |
+
|
263 |
+
def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
|
264 |
+
for key, value in sd.items():
|
265 |
+
# NaN check and show warning, replace NaN with 0
|
266 |
+
if torch.isnan(value).any():
|
267 |
+
logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
|
268 |
+
value[torch.isnan(value)] = 0
|
269 |
+
|
270 |
+
metadata = {
|
271 |
+
"architecture": arch_fullname,
|
272 |
+
"caption1": item_info.caption,
|
273 |
+
"format_version": "1.0.1",
|
274 |
+
}
|
275 |
+
|
276 |
+
if os.path.exists(item_info.text_encoder_output_cache_path):
|
277 |
+
# load existing cache and update metadata
|
278 |
+
with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
|
279 |
+
existing_metadata = f.metadata()
|
280 |
+
for key in f.keys():
|
281 |
+
if key not in sd: # avoid overwriting by existing cache, we keep the new one
|
282 |
+
sd[key] = f.get_tensor(key)
|
283 |
+
|
284 |
+
assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
|
285 |
+
if existing_metadata["caption1"] != metadata["caption1"]:
|
286 |
+
logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
|
287 |
+
# TODO verify format_version
|
288 |
+
|
289 |
+
existing_metadata.pop("caption1", None)
|
290 |
+
existing_metadata.pop("format_version", None)
|
291 |
+
metadata.update(existing_metadata) # copy existing metadata except caption and format_version
|
292 |
+
else:
|
293 |
+
text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
|
294 |
+
os.makedirs(text_encoder_output_dir, exist_ok=True)
|
295 |
+
|
296 |
+
safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
|
297 |
+
|
298 |
+
|
299 |
+
class BucketSelector:
|
300 |
+
RESOLUTION_STEPS_HUNYUAN = 16
|
301 |
+
RESOLUTION_STEPS_WAN = 16
|
302 |
+
|
303 |
+
def __init__(
|
304 |
+
self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default"
|
305 |
+
):
|
306 |
+
self.resolution = resolution
|
307 |
+
self.bucket_area = resolution[0] * resolution[1]
|
308 |
+
self.architecture = architecture
|
309 |
+
|
310 |
+
if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
|
311 |
+
self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
|
312 |
+
elif self.architecture == ARCHITECTURE_WAN:
|
313 |
+
self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN
|
314 |
+
else:
|
315 |
+
raise ValueError(f"Invalid architecture: {self.architecture}")
|
316 |
+
|
317 |
+
if not enable_bucket:
|
318 |
+
# only define one bucket
|
319 |
+
self.bucket_resolutions = [resolution]
|
320 |
+
self.no_upscale = False
|
321 |
+
else:
|
322 |
+
# prepare bucket resolution
|
323 |
+
self.no_upscale = no_upscale
|
324 |
+
sqrt_size = int(math.sqrt(self.bucket_area))
|
325 |
+
min_size = divisible_by(sqrt_size // 2, self.reso_steps)
|
326 |
+
self.bucket_resolutions = []
|
327 |
+
for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
|
328 |
+
h = divisible_by(self.bucket_area // w, self.reso_steps)
|
329 |
+
self.bucket_resolutions.append((w, h))
|
330 |
+
self.bucket_resolutions.append((h, w))
|
331 |
+
|
332 |
+
self.bucket_resolutions = list(set(self.bucket_resolutions))
|
333 |
+
self.bucket_resolutions.sort()
|
334 |
+
|
335 |
+
# calculate aspect ratio to find the nearest resolution
|
336 |
+
self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
|
337 |
+
|
338 |
+
def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
|
339 |
+
"""
|
340 |
+
return the bucket resolution for the given image size, (width, height)
|
341 |
+
"""
|
342 |
+
area = image_size[0] * image_size[1]
|
343 |
+
if self.no_upscale and area <= self.bucket_area:
|
344 |
+
w, h = image_size
|
345 |
+
w = divisible_by(w, self.reso_steps)
|
346 |
+
h = divisible_by(h, self.reso_steps)
|
347 |
+
return w, h
|
348 |
+
|
349 |
+
aspect_ratio = image_size[0] / image_size[1]
|
350 |
+
ar_errors = self.aspect_ratios - aspect_ratio
|
351 |
+
bucket_id = np.abs(ar_errors).argmin()
|
352 |
+
return self.bucket_resolutions[bucket_id]
|
353 |
+
|
354 |
+
|
355 |
+
def load_video(
|
356 |
+
video_path: str,
|
357 |
+
start_frame: Optional[int] = None,
|
358 |
+
end_frame: Optional[int] = None,
|
359 |
+
bucket_selector: Optional[BucketSelector] = None,
|
360 |
+
bucket_reso: Optional[tuple[int, int]] = None,
|
361 |
+
) -> list[np.ndarray]:
|
362 |
+
"""
|
363 |
+
bucket_reso: if given, resize the video to the bucket resolution, (width, height)
|
364 |
+
"""
|
365 |
+
container = av.open(video_path)
|
366 |
+
video = []
|
367 |
+
for i, frame in enumerate(container.decode(video=0)):
|
368 |
+
if start_frame is not None and i < start_frame:
|
369 |
+
continue
|
370 |
+
if end_frame is not None and i >= end_frame:
|
371 |
+
break
|
372 |
+
frame = frame.to_image()
|
373 |
+
|
374 |
+
if bucket_selector is not None and bucket_reso is None:
|
375 |
+
bucket_reso = bucket_selector.get_bucket_resolution(frame.size)
|
376 |
+
|
377 |
+
if bucket_reso is not None:
|
378 |
+
frame = resize_image_to_bucket(frame, bucket_reso)
|
379 |
+
else:
|
380 |
+
frame = np.array(frame)
|
381 |
+
|
382 |
+
video.append(frame)
|
383 |
+
container.close()
|
384 |
+
return video
|
385 |
+
|
386 |
+
|
387 |
+
class BucketBatchManager:
|
388 |
+
|
389 |
+
def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int):
|
390 |
+
self.batch_size = batch_size
|
391 |
+
self.buckets = bucketed_item_info
|
392 |
+
self.bucket_resos = list(self.buckets.keys())
|
393 |
+
self.bucket_resos.sort()
|
394 |
+
|
395 |
+
self.bucket_batch_indices = []
|
396 |
+
for bucket_reso in self.bucket_resos:
|
397 |
+
bucket = self.buckets[bucket_reso]
|
398 |
+
num_batches = math.ceil(len(bucket) / self.batch_size)
|
399 |
+
for i in range(num_batches):
|
400 |
+
self.bucket_batch_indices.append((bucket_reso, i))
|
401 |
+
|
402 |
+
self.shuffle()
|
403 |
+
|
404 |
+
def show_bucket_info(self):
|
405 |
+
for bucket_reso in self.bucket_resos:
|
406 |
+
bucket = self.buckets[bucket_reso]
|
407 |
+
logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
|
408 |
+
|
409 |
+
logger.info(f"total batches: {len(self)}")
|
410 |
+
|
411 |
+
def shuffle(self):
|
412 |
+
for bucket in self.buckets.values():
|
413 |
+
random.shuffle(bucket)
|
414 |
+
random.shuffle(self.bucket_batch_indices)
|
415 |
+
|
416 |
+
def __len__(self):
|
417 |
+
return len(self.bucket_batch_indices)
|
418 |
+
|
419 |
+
def __getitem__(self, idx):
|
420 |
+
bucket_reso, batch_idx = self.bucket_batch_indices[idx]
|
421 |
+
bucket = self.buckets[bucket_reso]
|
422 |
+
start = batch_idx * self.batch_size
|
423 |
+
end = min(start + self.batch_size, len(bucket))
|
424 |
+
|
425 |
+
batch_tensor_data = {}
|
426 |
+
varlen_keys = set()
|
427 |
+
for item_info in bucket[start:end]:
|
428 |
+
sd_latent = load_file(item_info.latent_cache_path)
|
429 |
+
sd_te = load_file(item_info.text_encoder_output_cache_path)
|
430 |
+
sd = {**sd_latent, **sd_te}
|
431 |
+
|
432 |
+
# TODO refactor this
|
433 |
+
for key in sd.keys():
|
434 |
+
is_varlen_key = key.startswith("varlen_") # varlen keys are not stacked
|
435 |
+
content_key = key
|
436 |
+
|
437 |
+
if is_varlen_key:
|
438 |
+
content_key = content_key.replace("varlen_", "")
|
439 |
+
|
440 |
+
if content_key.endswith("_mask"):
|
441 |
+
pass
|
442 |
+
else:
|
443 |
+
content_key = content_key.rsplit("_", 1)[0] # remove dtype
|
444 |
+
if content_key.startswith("latents_"):
|
445 |
+
content_key = content_key.rsplit("_", 1)[0] # remove FxHxW
|
446 |
+
|
447 |
+
if content_key not in batch_tensor_data:
|
448 |
+
batch_tensor_data[content_key] = []
|
449 |
+
batch_tensor_data[content_key].append(sd[key])
|
450 |
+
|
451 |
+
if is_varlen_key:
|
452 |
+
varlen_keys.add(content_key)
|
453 |
+
|
454 |
+
for key in batch_tensor_data.keys():
|
455 |
+
if key not in varlen_keys:
|
456 |
+
batch_tensor_data[key] = torch.stack(batch_tensor_data[key])
|
457 |
+
|
458 |
+
return batch_tensor_data
|
459 |
+
|
460 |
+
|
461 |
+
class ContentDatasource:
|
462 |
+
def __init__(self):
|
463 |
+
self.caption_only = False
|
464 |
+
|
465 |
+
def set_caption_only(self, caption_only: bool):
|
466 |
+
self.caption_only = caption_only
|
467 |
+
|
468 |
+
def is_indexable(self):
|
469 |
+
return False
|
470 |
+
|
471 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
472 |
+
"""
|
473 |
+
Returns caption. May not be called if is_indexable() returns False.
|
474 |
+
"""
|
475 |
+
raise NotImplementedError
|
476 |
+
|
477 |
+
def __len__(self):
|
478 |
+
raise NotImplementedError
|
479 |
+
|
480 |
+
def __iter__(self):
|
481 |
+
raise NotImplementedError
|
482 |
+
|
483 |
+
def __next__(self):
|
484 |
+
raise NotImplementedError
|
485 |
+
|
486 |
+
|
487 |
+
class ImageDatasource(ContentDatasource):
|
488 |
+
def __init__(self):
|
489 |
+
super().__init__()
|
490 |
+
|
491 |
+
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
|
492 |
+
"""
|
493 |
+
Returns image data as a tuple of image path, image, and caption for the given index.
|
494 |
+
Key must be unique and valid as a file name.
|
495 |
+
May not be called if is_indexable() returns False.
|
496 |
+
"""
|
497 |
+
raise NotImplementedError
|
498 |
+
|
499 |
+
|
500 |
+
class ImageDirectoryDatasource(ImageDatasource):
|
501 |
+
def __init__(self, image_directory: str, caption_extension: Optional[str] = None):
|
502 |
+
super().__init__()
|
503 |
+
self.image_directory = image_directory
|
504 |
+
self.caption_extension = caption_extension
|
505 |
+
self.current_idx = 0
|
506 |
+
|
507 |
+
# glob images
|
508 |
+
logger.info(f"glob images in {self.image_directory}")
|
509 |
+
self.image_paths = glob_images(self.image_directory)
|
510 |
+
logger.info(f"found {len(self.image_paths)} images")
|
511 |
+
|
512 |
+
def is_indexable(self):
|
513 |
+
return True
|
514 |
+
|
515 |
+
def __len__(self):
|
516 |
+
return len(self.image_paths)
|
517 |
+
|
518 |
+
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
|
519 |
+
image_path = self.image_paths[idx]
|
520 |
+
image = Image.open(image_path).convert("RGB")
|
521 |
+
|
522 |
+
_, caption = self.get_caption(idx)
|
523 |
+
|
524 |
+
return image_path, image, caption
|
525 |
+
|
526 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
527 |
+
image_path = self.image_paths[idx]
|
528 |
+
caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
|
529 |
+
with open(caption_path, "r", encoding="utf-8") as f:
|
530 |
+
caption = f.read().strip()
|
531 |
+
return image_path, caption
|
532 |
+
|
533 |
+
def __iter__(self):
|
534 |
+
self.current_idx = 0
|
535 |
+
return self
|
536 |
+
|
537 |
+
def __next__(self) -> callable:
|
538 |
+
"""
|
539 |
+
Returns a fetcher function that returns image data.
|
540 |
+
"""
|
541 |
+
if self.current_idx >= len(self.image_paths):
|
542 |
+
raise StopIteration
|
543 |
+
|
544 |
+
if self.caption_only:
|
545 |
+
|
546 |
+
def create_caption_fetcher(index):
|
547 |
+
return lambda: self.get_caption(index)
|
548 |
+
|
549 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
550 |
+
else:
|
551 |
+
|
552 |
+
def create_image_fetcher(index):
|
553 |
+
return lambda: self.get_image_data(index)
|
554 |
+
|
555 |
+
fetcher = create_image_fetcher(self.current_idx)
|
556 |
+
|
557 |
+
self.current_idx += 1
|
558 |
+
return fetcher
|
559 |
+
|
560 |
+
|
561 |
+
class ImageJsonlDatasource(ImageDatasource):
|
562 |
+
def __init__(self, image_jsonl_file: str):
|
563 |
+
super().__init__()
|
564 |
+
self.image_jsonl_file = image_jsonl_file
|
565 |
+
self.current_idx = 0
|
566 |
+
|
567 |
+
# load jsonl
|
568 |
+
logger.info(f"load image jsonl from {self.image_jsonl_file}")
|
569 |
+
self.data = []
|
570 |
+
with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
|
571 |
+
for line in f:
|
572 |
+
try:
|
573 |
+
data = json.loads(line)
|
574 |
+
except json.JSONDecodeError:
|
575 |
+
logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}")
|
576 |
+
raise
|
577 |
+
self.data.append(data)
|
578 |
+
logger.info(f"loaded {len(self.data)} images")
|
579 |
+
|
580 |
+
def is_indexable(self):
|
581 |
+
return True
|
582 |
+
|
583 |
+
def __len__(self):
|
584 |
+
return len(self.data)
|
585 |
+
|
586 |
+
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
|
587 |
+
data = self.data[idx]
|
588 |
+
image_path = data["image_path"]
|
589 |
+
image = Image.open(image_path).convert("RGB")
|
590 |
+
|
591 |
+
caption = data["caption"]
|
592 |
+
|
593 |
+
return image_path, image, caption
|
594 |
+
|
595 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
596 |
+
data = self.data[idx]
|
597 |
+
image_path = data["image_path"]
|
598 |
+
caption = data["caption"]
|
599 |
+
return image_path, caption
|
600 |
+
|
601 |
+
def __iter__(self):
|
602 |
+
self.current_idx = 0
|
603 |
+
return self
|
604 |
+
|
605 |
+
def __next__(self) -> callable:
|
606 |
+
if self.current_idx >= len(self.data):
|
607 |
+
raise StopIteration
|
608 |
+
|
609 |
+
if self.caption_only:
|
610 |
+
|
611 |
+
def create_caption_fetcher(index):
|
612 |
+
return lambda: self.get_caption(index)
|
613 |
+
|
614 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
615 |
+
|
616 |
+
else:
|
617 |
+
|
618 |
+
def create_fetcher(index):
|
619 |
+
return lambda: self.get_image_data(index)
|
620 |
+
|
621 |
+
fetcher = create_fetcher(self.current_idx)
|
622 |
+
|
623 |
+
self.current_idx += 1
|
624 |
+
return fetcher
|
625 |
+
|
626 |
+
|
627 |
+
class VideoDatasource(ContentDatasource):
|
628 |
+
def __init__(self):
|
629 |
+
super().__init__()
|
630 |
+
|
631 |
+
# None means all frames
|
632 |
+
self.start_frame = None
|
633 |
+
self.end_frame = None
|
634 |
+
|
635 |
+
self.bucket_selector = None
|
636 |
+
|
637 |
+
def __len__(self):
|
638 |
+
raise NotImplementedError
|
639 |
+
|
640 |
+
def get_video_data_from_path(
|
641 |
+
self,
|
642 |
+
video_path: str,
|
643 |
+
start_frame: Optional[int] = None,
|
644 |
+
end_frame: Optional[int] = None,
|
645 |
+
bucket_selector: Optional[BucketSelector] = None,
|
646 |
+
) -> tuple[str, list[Image.Image], str]:
|
647 |
+
# this method can resize the video if bucket_selector is given to reduce the memory usage
|
648 |
+
|
649 |
+
start_frame = start_frame if start_frame is not None else self.start_frame
|
650 |
+
end_frame = end_frame if end_frame is not None else self.end_frame
|
651 |
+
bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
|
652 |
+
|
653 |
+
video = load_video(video_path, start_frame, end_frame, bucket_selector)
|
654 |
+
return video
|
655 |
+
|
656 |
+
def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
|
657 |
+
self.start_frame = start_frame
|
658 |
+
self.end_frame = end_frame
|
659 |
+
|
660 |
+
def set_bucket_selector(self, bucket_selector: BucketSelector):
|
661 |
+
self.bucket_selector = bucket_selector
|
662 |
+
|
663 |
+
def __iter__(self):
|
664 |
+
raise NotImplementedError
|
665 |
+
|
666 |
+
def __next__(self):
|
667 |
+
raise NotImplementedError
|
668 |
+
|
669 |
+
|
670 |
+
class VideoDirectoryDatasource(VideoDatasource):
|
671 |
+
def __init__(self, video_directory: str, caption_extension: Optional[str] = None):
|
672 |
+
super().__init__()
|
673 |
+
self.video_directory = video_directory
|
674 |
+
self.caption_extension = caption_extension
|
675 |
+
self.current_idx = 0
|
676 |
+
|
677 |
+
# glob images
|
678 |
+
logger.info(f"glob images in {self.video_directory}")
|
679 |
+
self.video_paths = glob_videos(self.video_directory)
|
680 |
+
logger.info(f"found {len(self.video_paths)} videos")
|
681 |
+
|
682 |
+
def is_indexable(self):
|
683 |
+
return True
|
684 |
+
|
685 |
+
def __len__(self):
|
686 |
+
return len(self.video_paths)
|
687 |
+
|
688 |
+
def get_video_data(
|
689 |
+
self,
|
690 |
+
idx: int,
|
691 |
+
start_frame: Optional[int] = None,
|
692 |
+
end_frame: Optional[int] = None,
|
693 |
+
bucket_selector: Optional[BucketSelector] = None,
|
694 |
+
) -> tuple[str, list[Image.Image], str]:
|
695 |
+
video_path = self.video_paths[idx]
|
696 |
+
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
|
697 |
+
|
698 |
+
_, caption = self.get_caption(idx)
|
699 |
+
|
700 |
+
return video_path, video, caption
|
701 |
+
|
702 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
703 |
+
video_path = self.video_paths[idx]
|
704 |
+
caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
|
705 |
+
with open(caption_path, "r", encoding="utf-8") as f:
|
706 |
+
caption = f.read().strip()
|
707 |
+
return video_path, caption
|
708 |
+
|
709 |
+
def __iter__(self):
|
710 |
+
self.current_idx = 0
|
711 |
+
return self
|
712 |
+
|
713 |
+
def __next__(self):
|
714 |
+
if self.current_idx >= len(self.video_paths):
|
715 |
+
raise StopIteration
|
716 |
+
|
717 |
+
if self.caption_only:
|
718 |
+
|
719 |
+
def create_caption_fetcher(index):
|
720 |
+
return lambda: self.get_caption(index)
|
721 |
+
|
722 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
723 |
+
|
724 |
+
else:
|
725 |
+
|
726 |
+
def create_fetcher(index):
|
727 |
+
return lambda: self.get_video_data(index)
|
728 |
+
|
729 |
+
fetcher = create_fetcher(self.current_idx)
|
730 |
+
|
731 |
+
self.current_idx += 1
|
732 |
+
return fetcher
|
733 |
+
|
734 |
+
|
735 |
+
class VideoJsonlDatasource(VideoDatasource):
|
736 |
+
def __init__(self, video_jsonl_file: str):
|
737 |
+
super().__init__()
|
738 |
+
self.video_jsonl_file = video_jsonl_file
|
739 |
+
self.current_idx = 0
|
740 |
+
|
741 |
+
# load jsonl
|
742 |
+
logger.info(f"load video jsonl from {self.video_jsonl_file}")
|
743 |
+
self.data = []
|
744 |
+
with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
|
745 |
+
for line in f:
|
746 |
+
data = json.loads(line)
|
747 |
+
self.data.append(data)
|
748 |
+
logger.info(f"loaded {len(self.data)} videos")
|
749 |
+
|
750 |
+
def is_indexable(self):
|
751 |
+
return True
|
752 |
+
|
753 |
+
def __len__(self):
|
754 |
+
return len(self.data)
|
755 |
+
|
756 |
+
def get_video_data(
|
757 |
+
self,
|
758 |
+
idx: int,
|
759 |
+
start_frame: Optional[int] = None,
|
760 |
+
end_frame: Optional[int] = None,
|
761 |
+
bucket_selector: Optional[BucketSelector] = None,
|
762 |
+
) -> tuple[str, list[Image.Image], str]:
|
763 |
+
data = self.data[idx]
|
764 |
+
video_path = data["video_path"]
|
765 |
+
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
|
766 |
+
|
767 |
+
caption = data["caption"]
|
768 |
+
|
769 |
+
return video_path, video, caption
|
770 |
+
|
771 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
772 |
+
data = self.data[idx]
|
773 |
+
video_path = data["video_path"]
|
774 |
+
caption = data["caption"]
|
775 |
+
return video_path, caption
|
776 |
+
|
777 |
+
def __iter__(self):
|
778 |
+
self.current_idx = 0
|
779 |
+
return self
|
780 |
+
|
781 |
+
def __next__(self):
|
782 |
+
if self.current_idx >= len(self.data):
|
783 |
+
raise StopIteration
|
784 |
+
|
785 |
+
if self.caption_only:
|
786 |
+
|
787 |
+
def create_caption_fetcher(index):
|
788 |
+
return lambda: self.get_caption(index)
|
789 |
+
|
790 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
791 |
+
|
792 |
+
else:
|
793 |
+
|
794 |
+
def create_fetcher(index):
|
795 |
+
return lambda: self.get_video_data(index)
|
796 |
+
|
797 |
+
fetcher = create_fetcher(self.current_idx)
|
798 |
+
|
799 |
+
self.current_idx += 1
|
800 |
+
return fetcher
|
801 |
+
|
802 |
+
|
803 |
+
class BaseDataset(torch.utils.data.Dataset):
|
804 |
+
def __init__(
|
805 |
+
self,
|
806 |
+
resolution: Tuple[int, int] = (960, 544),
|
807 |
+
caption_extension: Optional[str] = None,
|
808 |
+
batch_size: int = 1,
|
809 |
+
num_repeats: int = 1,
|
810 |
+
enable_bucket: bool = False,
|
811 |
+
bucket_no_upscale: bool = False,
|
812 |
+
cache_directory: Optional[str] = None,
|
813 |
+
debug_dataset: bool = False,
|
814 |
+
architecture: str = "no_default",
|
815 |
+
):
|
816 |
+
self.resolution = resolution
|
817 |
+
self.caption_extension = caption_extension
|
818 |
+
self.batch_size = batch_size
|
819 |
+
self.num_repeats = num_repeats
|
820 |
+
self.enable_bucket = enable_bucket
|
821 |
+
self.bucket_no_upscale = bucket_no_upscale
|
822 |
+
self.cache_directory = cache_directory
|
823 |
+
self.debug_dataset = debug_dataset
|
824 |
+
self.architecture = architecture
|
825 |
+
self.seed = None
|
826 |
+
self.current_epoch = 0
|
827 |
+
|
828 |
+
if not self.enable_bucket:
|
829 |
+
self.bucket_no_upscale = False
|
830 |
+
|
831 |
+
def get_metadata(self) -> dict:
|
832 |
+
metadata = {
|
833 |
+
"resolution": self.resolution,
|
834 |
+
"caption_extension": self.caption_extension,
|
835 |
+
"batch_size_per_device": self.batch_size,
|
836 |
+
"num_repeats": self.num_repeats,
|
837 |
+
"enable_bucket": bool(self.enable_bucket),
|
838 |
+
"bucket_no_upscale": bool(self.bucket_no_upscale),
|
839 |
+
}
|
840 |
+
return metadata
|
841 |
+
|
842 |
+
def get_all_latent_cache_files(self):
|
843 |
+
return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
|
844 |
+
|
845 |
+
def get_all_text_encoder_output_cache_files(self):
|
846 |
+
return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}_te.safetensors"))
|
847 |
+
|
848 |
+
def get_latent_cache_path(self, item_info: ItemInfo) -> str:
|
849 |
+
"""
|
850 |
+
Returns the cache path for the latent tensor.
|
851 |
+
|
852 |
+
item_info: ItemInfo object
|
853 |
+
|
854 |
+
Returns:
|
855 |
+
str: cache path
|
856 |
+
|
857 |
+
cache_path is based on the item_key and the resolution.
|
858 |
+
"""
|
859 |
+
w, h = item_info.original_size
|
860 |
+
basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
|
861 |
+
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
|
862 |
+
return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors")
|
863 |
+
|
864 |
+
def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
|
865 |
+
basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
|
866 |
+
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
|
867 |
+
return os.path.join(self.cache_directory, f"{basename}_{self.architecture}_te.safetensors")
|
868 |
+
|
869 |
+
def retrieve_latent_cache_batches(self, num_workers: int):
|
870 |
+
raise NotImplementedError
|
871 |
+
|
872 |
+
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
|
873 |
+
raise NotImplementedError
|
874 |
+
|
875 |
+
def prepare_for_training(self):
|
876 |
+
pass
|
877 |
+
|
878 |
+
def set_seed(self, seed: int):
|
879 |
+
self.seed = seed
|
880 |
+
|
881 |
+
def set_current_epoch(self, epoch):
|
882 |
+
if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
|
883 |
+
if epoch > self.current_epoch:
|
884 |
+
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
885 |
+
num_epochs = epoch - self.current_epoch
|
886 |
+
for _ in range(num_epochs):
|
887 |
+
self.current_epoch += 1
|
888 |
+
self.shuffle_buckets()
|
889 |
+
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
|
890 |
+
else:
|
891 |
+
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
892 |
+
self.current_epoch = epoch
|
893 |
+
|
894 |
+
def set_current_step(self, step):
|
895 |
+
self.current_step = step
|
896 |
+
|
897 |
+
def set_max_train_steps(self, max_train_steps):
|
898 |
+
self.max_train_steps = max_train_steps
|
899 |
+
|
900 |
+
def shuffle_buckets(self):
|
901 |
+
raise NotImplementedError
|
902 |
+
|
903 |
+
def __len__(self):
|
904 |
+
return NotImplementedError
|
905 |
+
|
906 |
+
def __getitem__(self, idx):
|
907 |
+
raise NotImplementedError
|
908 |
+
|
909 |
+
def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
|
910 |
+
datasource.set_caption_only(True)
|
911 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
912 |
+
|
913 |
+
data: list[ItemInfo] = []
|
914 |
+
futures = []
|
915 |
+
|
916 |
+
def aggregate_future(consume_all: bool = False):
|
917 |
+
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
|
918 |
+
completed_futures = [future for future in futures if future.done()]
|
919 |
+
if len(completed_futures) == 0:
|
920 |
+
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
|
921 |
+
time.sleep(0.1)
|
922 |
+
continue
|
923 |
+
else:
|
924 |
+
break # submit batch if possible
|
925 |
+
|
926 |
+
for future in completed_futures:
|
927 |
+
item_key, caption = future.result()
|
928 |
+
item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
|
929 |
+
item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
|
930 |
+
data.append(item_info)
|
931 |
+
|
932 |
+
futures.remove(future)
|
933 |
+
|
934 |
+
def submit_batch(flush: bool = False):
|
935 |
+
nonlocal data
|
936 |
+
if len(data) >= batch_size or (len(data) > 0 and flush):
|
937 |
+
batch = data[0:batch_size]
|
938 |
+
if len(data) > batch_size:
|
939 |
+
data = data[batch_size:]
|
940 |
+
else:
|
941 |
+
data = []
|
942 |
+
return batch
|
943 |
+
return None
|
944 |
+
|
945 |
+
for fetch_op in datasource:
|
946 |
+
future = executor.submit(fetch_op)
|
947 |
+
futures.append(future)
|
948 |
+
aggregate_future()
|
949 |
+
while True:
|
950 |
+
batch = submit_batch()
|
951 |
+
if batch is None:
|
952 |
+
break
|
953 |
+
yield batch
|
954 |
+
|
955 |
+
aggregate_future(consume_all=True)
|
956 |
+
while True:
|
957 |
+
batch = submit_batch(flush=True)
|
958 |
+
if batch is None:
|
959 |
+
break
|
960 |
+
yield batch
|
961 |
+
|
962 |
+
executor.shutdown()
|
963 |
+
|
964 |
+
|
965 |
+
class ImageDataset(BaseDataset):
|
966 |
+
def __init__(
|
967 |
+
self,
|
968 |
+
resolution: Tuple[int, int],
|
969 |
+
caption_extension: Optional[str],
|
970 |
+
batch_size: int,
|
971 |
+
num_repeats: int,
|
972 |
+
enable_bucket: bool,
|
973 |
+
bucket_no_upscale: bool,
|
974 |
+
image_directory: Optional[str] = None,
|
975 |
+
image_jsonl_file: Optional[str] = None,
|
976 |
+
cache_directory: Optional[str] = None,
|
977 |
+
debug_dataset: bool = False,
|
978 |
+
architecture: str = "no_default",
|
979 |
+
):
|
980 |
+
super(ImageDataset, self).__init__(
|
981 |
+
resolution,
|
982 |
+
caption_extension,
|
983 |
+
batch_size,
|
984 |
+
num_repeats,
|
985 |
+
enable_bucket,
|
986 |
+
bucket_no_upscale,
|
987 |
+
cache_directory,
|
988 |
+
debug_dataset,
|
989 |
+
architecture,
|
990 |
+
)
|
991 |
+
self.image_directory = image_directory
|
992 |
+
self.image_jsonl_file = image_jsonl_file
|
993 |
+
if image_directory is not None:
|
994 |
+
self.datasource = ImageDirectoryDatasource(image_directory, caption_extension)
|
995 |
+
elif image_jsonl_file is not None:
|
996 |
+
self.datasource = ImageJsonlDatasource(image_jsonl_file)
|
997 |
+
else:
|
998 |
+
raise ValueError("image_directory or image_jsonl_file must be specified")
|
999 |
+
|
1000 |
+
if self.cache_directory is None:
|
1001 |
+
self.cache_directory = self.image_directory
|
1002 |
+
|
1003 |
+
self.batch_manager = None
|
1004 |
+
self.num_train_items = 0
|
1005 |
+
|
1006 |
+
def get_metadata(self):
|
1007 |
+
metadata = super().get_metadata()
|
1008 |
+
if self.image_directory is not None:
|
1009 |
+
metadata["image_directory"] = os.path.basename(self.image_directory)
|
1010 |
+
if self.image_jsonl_file is not None:
|
1011 |
+
metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
|
1012 |
+
return metadata
|
1013 |
+
|
1014 |
+
def get_total_image_count(self):
|
1015 |
+
return len(self.datasource) if self.datasource.is_indexable() else None
|
1016 |
+
|
1017 |
+
def retrieve_latent_cache_batches(self, num_workers: int):
|
1018 |
+
buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
|
1019 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
1020 |
+
|
1021 |
+
batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
|
1022 |
+
futures = []
|
1023 |
+
|
1024 |
+
# aggregate futures and sort by bucket resolution
|
1025 |
+
def aggregate_future(consume_all: bool = False):
|
1026 |
+
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
|
1027 |
+
completed_futures = [future for future in futures if future.done()]
|
1028 |
+
if len(completed_futures) == 0:
|
1029 |
+
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
|
1030 |
+
time.sleep(0.1)
|
1031 |
+
continue
|
1032 |
+
else:
|
1033 |
+
break # submit batch if possible
|
1034 |
+
|
1035 |
+
for future in completed_futures:
|
1036 |
+
original_size, item_key, image, caption = future.result()
|
1037 |
+
bucket_height, bucket_width = image.shape[:2]
|
1038 |
+
bucket_reso = (bucket_width, bucket_height)
|
1039 |
+
|
1040 |
+
item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
|
1041 |
+
item_info.latent_cache_path = self.get_latent_cache_path(item_info)
|
1042 |
+
|
1043 |
+
if bucket_reso not in batches:
|
1044 |
+
batches[bucket_reso] = []
|
1045 |
+
batches[bucket_reso].append(item_info)
|
1046 |
+
|
1047 |
+
futures.remove(future)
|
1048 |
+
|
1049 |
+
# submit batch if some bucket has enough items
|
1050 |
+
def submit_batch(flush: bool = False):
|
1051 |
+
for key in batches:
|
1052 |
+
if len(batches[key]) >= self.batch_size or flush:
|
1053 |
+
batch = batches[key][0 : self.batch_size]
|
1054 |
+
if len(batches[key]) > self.batch_size:
|
1055 |
+
batches[key] = batches[key][self.batch_size :]
|
1056 |
+
else:
|
1057 |
+
del batches[key]
|
1058 |
+
return key, batch
|
1059 |
+
return None, None
|
1060 |
+
|
1061 |
+
for fetch_op in self.datasource:
|
1062 |
+
|
1063 |
+
# fetch and resize image in a separate thread
|
1064 |
+
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]:
|
1065 |
+
image_key, image, caption = op()
|
1066 |
+
image: Image.Image
|
1067 |
+
image_size = image.size
|
1068 |
+
|
1069 |
+
bucket_reso = buckset_selector.get_bucket_resolution(image_size)
|
1070 |
+
image = resize_image_to_bucket(image, bucket_reso)
|
1071 |
+
return image_size, image_key, image, caption
|
1072 |
+
|
1073 |
+
future = executor.submit(fetch_and_resize, fetch_op)
|
1074 |
+
futures.append(future)
|
1075 |
+
aggregate_future()
|
1076 |
+
while True:
|
1077 |
+
key, batch = submit_batch()
|
1078 |
+
if key is None:
|
1079 |
+
break
|
1080 |
+
yield key, batch
|
1081 |
+
|
1082 |
+
aggregate_future(consume_all=True)
|
1083 |
+
while True:
|
1084 |
+
key, batch = submit_batch(flush=True)
|
1085 |
+
if key is None:
|
1086 |
+
break
|
1087 |
+
yield key, batch
|
1088 |
+
|
1089 |
+
executor.shutdown()
|
1090 |
+
|
1091 |
+
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
|
1092 |
+
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
|
1093 |
+
|
1094 |
+
def prepare_for_training(self):
|
1095 |
+
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
|
1096 |
+
|
1097 |
+
# glob cache files
|
1098 |
+
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
|
1099 |
+
|
1100 |
+
# assign cache files to item info
|
1101 |
+
bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
|
1102 |
+
for cache_file in latent_cache_files:
|
1103 |
+
tokens = os.path.basename(cache_file).split("_")
|
1104 |
+
|
1105 |
+
image_size = tokens[-2] # 0000x0000
|
1106 |
+
image_width, image_height = map(int, image_size.split("x"))
|
1107 |
+
image_size = (image_width, image_height)
|
1108 |
+
|
1109 |
+
item_key = "_".join(tokens[:-2])
|
1110 |
+
text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
|
1111 |
+
if not os.path.exists(text_encoder_output_cache_file):
|
1112 |
+
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
|
1113 |
+
continue
|
1114 |
+
|
1115 |
+
bucket_reso = bucket_selector.get_bucket_resolution(image_size)
|
1116 |
+
item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
|
1117 |
+
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
|
1118 |
+
|
1119 |
+
bucket = bucketed_item_info.get(bucket_reso, [])
|
1120 |
+
for _ in range(self.num_repeats):
|
1121 |
+
bucket.append(item_info)
|
1122 |
+
bucketed_item_info[bucket_reso] = bucket
|
1123 |
+
|
1124 |
+
# prepare batch manager
|
1125 |
+
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
|
1126 |
+
self.batch_manager.show_bucket_info()
|
1127 |
+
|
1128 |
+
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
|
1129 |
+
|
1130 |
+
def shuffle_buckets(self):
|
1131 |
+
# set random seed for this epoch
|
1132 |
+
random.seed(self.seed + self.current_epoch)
|
1133 |
+
self.batch_manager.shuffle()
|
1134 |
+
|
1135 |
+
def __len__(self):
|
1136 |
+
if self.batch_manager is None:
|
1137 |
+
return 100 # dummy value
|
1138 |
+
return len(self.batch_manager)
|
1139 |
+
|
1140 |
+
def __getitem__(self, idx):
|
1141 |
+
return self.batch_manager[idx]
|
1142 |
+
|
1143 |
+
|
1144 |
+
class VideoDataset(BaseDataset):
|
1145 |
+
def __init__(
|
1146 |
+
self,
|
1147 |
+
resolution: Tuple[int, int],
|
1148 |
+
caption_extension: Optional[str],
|
1149 |
+
batch_size: int,
|
1150 |
+
num_repeats: int,
|
1151 |
+
enable_bucket: bool,
|
1152 |
+
bucket_no_upscale: bool,
|
1153 |
+
frame_extraction: Optional[str] = "head",
|
1154 |
+
frame_stride: Optional[int] = 1,
|
1155 |
+
frame_sample: Optional[int] = 1,
|
1156 |
+
target_frames: Optional[list[int]] = None,
|
1157 |
+
video_directory: Optional[str] = None,
|
1158 |
+
video_jsonl_file: Optional[str] = None,
|
1159 |
+
cache_directory: Optional[str] = None,
|
1160 |
+
debug_dataset: bool = False,
|
1161 |
+
architecture: str = "no_default",
|
1162 |
+
):
|
1163 |
+
super(VideoDataset, self).__init__(
|
1164 |
+
resolution,
|
1165 |
+
caption_extension,
|
1166 |
+
batch_size,
|
1167 |
+
num_repeats,
|
1168 |
+
enable_bucket,
|
1169 |
+
bucket_no_upscale,
|
1170 |
+
cache_directory,
|
1171 |
+
debug_dataset,
|
1172 |
+
architecture,
|
1173 |
+
)
|
1174 |
+
self.video_directory = video_directory
|
1175 |
+
self.video_jsonl_file = video_jsonl_file
|
1176 |
+
self.target_frames = target_frames
|
1177 |
+
self.frame_extraction = frame_extraction
|
1178 |
+
self.frame_stride = frame_stride
|
1179 |
+
self.frame_sample = frame_sample
|
1180 |
+
|
1181 |
+
if video_directory is not None:
|
1182 |
+
self.datasource = VideoDirectoryDatasource(video_directory, caption_extension)
|
1183 |
+
elif video_jsonl_file is not None:
|
1184 |
+
self.datasource = VideoJsonlDatasource(video_jsonl_file)
|
1185 |
+
|
1186 |
+
if self.frame_extraction == "uniform" and self.frame_sample == 1:
|
1187 |
+
self.frame_extraction = "head"
|
1188 |
+
logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
|
1189 |
+
if self.frame_extraction == "head":
|
1190 |
+
# head extraction. we can limit the number of frames to be extracted
|
1191 |
+
self.datasource.set_start_and_end_frame(0, max(self.target_frames))
|
1192 |
+
|
1193 |
+
if self.cache_directory is None:
|
1194 |
+
self.cache_directory = self.video_directory
|
1195 |
+
|
1196 |
+
self.batch_manager = None
|
1197 |
+
self.num_train_items = 0
|
1198 |
+
|
1199 |
+
def get_metadata(self):
|
1200 |
+
metadata = super().get_metadata()
|
1201 |
+
if self.video_directory is not None:
|
1202 |
+
metadata["video_directory"] = os.path.basename(self.video_directory)
|
1203 |
+
if self.video_jsonl_file is not None:
|
1204 |
+
metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
|
1205 |
+
metadata["frame_extraction"] = self.frame_extraction
|
1206 |
+
metadata["frame_stride"] = self.frame_stride
|
1207 |
+
metadata["frame_sample"] = self.frame_sample
|
1208 |
+
metadata["target_frames"] = self.target_frames
|
1209 |
+
return metadata
|
1210 |
+
|
1211 |
+
def retrieve_latent_cache_batches(self, num_workers: int):
|
1212 |
+
buckset_selector = BucketSelector(self.resolution, architecture=self.architecture)
|
1213 |
+
self.datasource.set_bucket_selector(buckset_selector)
|
1214 |
+
|
1215 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
1216 |
+
|
1217 |
+
# key: (width, height, frame_count), value: [ItemInfo]
|
1218 |
+
batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
|
1219 |
+
futures = []
|
1220 |
+
|
1221 |
+
def aggregate_future(consume_all: bool = False):
|
1222 |
+
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
|
1223 |
+
completed_futures = [future for future in futures if future.done()]
|
1224 |
+
if len(completed_futures) == 0:
|
1225 |
+
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
|
1226 |
+
time.sleep(0.1)
|
1227 |
+
continue
|
1228 |
+
else:
|
1229 |
+
break # submit batch if possible
|
1230 |
+
|
1231 |
+
for future in completed_futures:
|
1232 |
+
original_frame_size, video_key, video, caption = future.result()
|
1233 |
+
|
1234 |
+
frame_count = len(video)
|
1235 |
+
video = np.stack(video, axis=0)
|
1236 |
+
height, width = video.shape[1:3]
|
1237 |
+
bucket_reso = (width, height) # already resized
|
1238 |
+
|
1239 |
+
crop_pos_and_frames = []
|
1240 |
+
if self.frame_extraction == "head":
|
1241 |
+
for target_frame in self.target_frames:
|
1242 |
+
if frame_count >= target_frame:
|
1243 |
+
crop_pos_and_frames.append((0, target_frame))
|
1244 |
+
elif self.frame_extraction == "chunk":
|
1245 |
+
# split by target_frames
|
1246 |
+
for target_frame in self.target_frames:
|
1247 |
+
for i in range(0, frame_count, target_frame):
|
1248 |
+
if i + target_frame <= frame_count:
|
1249 |
+
crop_pos_and_frames.append((i, target_frame))
|
1250 |
+
elif self.frame_extraction == "slide":
|
1251 |
+
# slide window
|
1252 |
+
for target_frame in self.target_frames:
|
1253 |
+
if frame_count >= target_frame:
|
1254 |
+
for i in range(0, frame_count - target_frame + 1, self.frame_stride):
|
1255 |
+
crop_pos_and_frames.append((i, target_frame))
|
1256 |
+
elif self.frame_extraction == "uniform":
|
1257 |
+
# select N frames uniformly
|
1258 |
+
for target_frame in self.target_frames:
|
1259 |
+
if frame_count >= target_frame:
|
1260 |
+
frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
|
1261 |
+
for i in frame_indices:
|
1262 |
+
crop_pos_and_frames.append((i, target_frame))
|
1263 |
+
else:
|
1264 |
+
raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
|
1265 |
+
|
1266 |
+
for crop_pos, target_frame in crop_pos_and_frames:
|
1267 |
+
cropped_video = video[crop_pos : crop_pos + target_frame]
|
1268 |
+
body, ext = os.path.splitext(video_key)
|
1269 |
+
item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
|
1270 |
+
batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
|
1271 |
+
|
1272 |
+
item_info = ItemInfo(
|
1273 |
+
item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
|
1274 |
+
)
|
1275 |
+
item_info.latent_cache_path = self.get_latent_cache_path(item_info)
|
1276 |
+
|
1277 |
+
batch = batches.get(batch_key, [])
|
1278 |
+
batch.append(item_info)
|
1279 |
+
batches[batch_key] = batch
|
1280 |
+
|
1281 |
+
futures.remove(future)
|
1282 |
+
|
1283 |
+
def submit_batch(flush: bool = False):
|
1284 |
+
for key in batches:
|
1285 |
+
if len(batches[key]) >= self.batch_size or flush:
|
1286 |
+
batch = batches[key][0 : self.batch_size]
|
1287 |
+
if len(batches[key]) > self.batch_size:
|
1288 |
+
batches[key] = batches[key][self.batch_size :]
|
1289 |
+
else:
|
1290 |
+
del batches[key]
|
1291 |
+
return key, batch
|
1292 |
+
return None, None
|
1293 |
+
|
1294 |
+
for operator in self.datasource:
|
1295 |
+
|
1296 |
+
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]:
|
1297 |
+
video_key, video, caption = op()
|
1298 |
+
video: list[np.ndarray]
|
1299 |
+
frame_size = (video[0].shape[1], video[0].shape[0])
|
1300 |
+
|
1301 |
+
# resize if necessary
|
1302 |
+
bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
|
1303 |
+
video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
|
1304 |
+
|
1305 |
+
return frame_size, video_key, video, caption
|
1306 |
+
|
1307 |
+
future = executor.submit(fetch_and_resize, operator)
|
1308 |
+
futures.append(future)
|
1309 |
+
aggregate_future()
|
1310 |
+
while True:
|
1311 |
+
key, batch = submit_batch()
|
1312 |
+
if key is None:
|
1313 |
+
break
|
1314 |
+
yield key, batch
|
1315 |
+
|
1316 |
+
aggregate_future(consume_all=True)
|
1317 |
+
while True:
|
1318 |
+
key, batch = submit_batch(flush=True)
|
1319 |
+
if key is None:
|
1320 |
+
break
|
1321 |
+
yield key, batch
|
1322 |
+
|
1323 |
+
executor.shutdown()
|
1324 |
+
|
1325 |
+
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
|
1326 |
+
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
|
1327 |
+
|
1328 |
+
def prepare_for_training(self):
|
1329 |
+
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
|
1330 |
+
|
1331 |
+
# glob cache files
|
1332 |
+
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
|
1333 |
+
|
1334 |
+
# assign cache files to item info
|
1335 |
+
bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
|
1336 |
+
for cache_file in latent_cache_files:
|
1337 |
+
tokens = os.path.basename(cache_file).split("_")
|
1338 |
+
|
1339 |
+
image_size = tokens[-2] # 0000x0000
|
1340 |
+
image_width, image_height = map(int, image_size.split("x"))
|
1341 |
+
image_size = (image_width, image_height)
|
1342 |
+
|
1343 |
+
frame_pos, frame_count = tokens[-3].split("-")
|
1344 |
+
frame_pos, frame_count = int(frame_pos), int(frame_count)
|
1345 |
+
|
1346 |
+
item_key = "_".join(tokens[:-3])
|
1347 |
+
text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
|
1348 |
+
if not os.path.exists(text_encoder_output_cache_file):
|
1349 |
+
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
|
1350 |
+
continue
|
1351 |
+
|
1352 |
+
bucket_reso = bucket_selector.get_bucket_resolution(image_size)
|
1353 |
+
bucket_reso = (*bucket_reso, frame_count)
|
1354 |
+
item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
|
1355 |
+
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
|
1356 |
+
|
1357 |
+
bucket = bucketed_item_info.get(bucket_reso, [])
|
1358 |
+
for _ in range(self.num_repeats):
|
1359 |
+
bucket.append(item_info)
|
1360 |
+
bucketed_item_info[bucket_reso] = bucket
|
1361 |
+
|
1362 |
+
# prepare batch manager
|
1363 |
+
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
|
1364 |
+
self.batch_manager.show_bucket_info()
|
1365 |
+
|
1366 |
+
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
|
1367 |
+
|
1368 |
+
def shuffle_buckets(self):
|
1369 |
+
# set random seed for this epoch
|
1370 |
+
random.seed(self.seed + self.current_epoch)
|
1371 |
+
self.batch_manager.shuffle()
|
1372 |
+
|
1373 |
+
def __len__(self):
|
1374 |
+
if self.batch_manager is None:
|
1375 |
+
return 100 # dummy value
|
1376 |
+
return len(self.batch_manager)
|
1377 |
+
|
1378 |
+
def __getitem__(self, idx):
|
1379 |
+
return self.batch_manager[idx]
|
1380 |
+
|
1381 |
+
|
1382 |
+
class DatasetGroup(torch.utils.data.ConcatDataset):
|
1383 |
+
def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
|
1384 |
+
super().__init__(datasets)
|
1385 |
+
self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
|
1386 |
+
self.num_train_items = 0
|
1387 |
+
for dataset in self.datasets:
|
1388 |
+
self.num_train_items += dataset.num_train_items
|
1389 |
+
|
1390 |
+
def set_current_epoch(self, epoch):
|
1391 |
+
for dataset in self.datasets:
|
1392 |
+
dataset.set_current_epoch(epoch)
|
1393 |
+
|
1394 |
+
def set_current_step(self, step):
|
1395 |
+
for dataset in self.datasets:
|
1396 |
+
dataset.set_current_step(step)
|
1397 |
+
|
1398 |
+
def set_max_train_steps(self, max_train_steps):
|
1399 |
+
for dataset in self.datasets:
|
1400 |
+
dataset.set_max_train_steps(max_train_steps)
|
docs/advanced_config.md
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> 📝 Click on the language section to expand / 言語をクリックして展開
|
2 |
+
|
3 |
+
# Advanced configuration / 高度な設定
|
4 |
+
|
5 |
+
## How to specify `network_args` / `network_args`の指定方法
|
6 |
+
|
7 |
+
The `--network_args` option is an option for specifying detailed arguments to LoRA. Specify the arguments in the form of `key=value` in `--network_args`.
|
8 |
+
|
9 |
+
<details>
|
10 |
+
<summary>日本語</summary>
|
11 |
+
`--network_args`オプションは、LoRAへの詳細な引数を指定するためのオプションです。`--network_args`には、`key=value`の形式で引数を指定します。
|
12 |
+
</details>
|
13 |
+
|
14 |
+
### Example / 記述例
|
15 |
+
|
16 |
+
If you specify it on the command line, write as follows. / コマンドラインで指定する場合は以下のように記述します。
|
17 |
+
|
18 |
+
```bash
|
19 |
+
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
|
20 |
+
--network_module networks.lora --network_dim 32
|
21 |
+
--network_args "key1=value1" "key2=value2" ...
|
22 |
+
```
|
23 |
+
|
24 |
+
If you specify it in the configuration file, write as follows. / 設定ファイルで指定する場合は以下のように記述します。
|
25 |
+
|
26 |
+
```toml
|
27 |
+
network_args = ["key1=value1", "key2=value2", ...]
|
28 |
+
```
|
29 |
+
|
30 |
+
If you specify `"verbose=True"`, detailed information of LoRA will be displayed. / `"verbose=True"`を指定するとLoRAの詳細な情報が表示されます。
|
31 |
+
|
32 |
+
```bash
|
33 |
+
--network_args "verbose=True" "key1=value1" "key2=value2" ...
|
34 |
+
```
|
35 |
+
|
36 |
+
## LoRA+
|
37 |
+
|
38 |
+
LoRA+ is a method to improve the training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiplier for the learning rate. The original paper recommends 16, but adjust as needed. It seems to be good to start from around 4. For details, please refer to the [related PR of sd-scripts](https://github.com/kohya-ss/sd-scripts/pull/1233).
|
39 |
+
|
40 |
+
Specify `loraplus_lr_ratio` with `--network_args`.
|
41 |
+
|
42 |
+
<details>
|
43 |
+
<summary>日本語</summary>
|
44 |
+
|
45 |
+
LoRA+は、LoRAのUP側(LoRA-B)の学習率を上げることで学習速度を向上させる手法です。学習率に対する倍率を指定します。元論文では16を推奨していますが、必要に応じて調整してください。4程度から始めるとよいようです。詳細は[sd-scriptsの関連PR]https://github.com/kohya-ss/sd-scripts/pull/1233)を参照してください。
|
46 |
+
|
47 |
+
`--network_args`で`loraplus_lr_ratio`を指定します。
|
48 |
+
</details>
|
49 |
+
|
50 |
+
### Example / 記述例
|
51 |
+
|
52 |
+
```bash
|
53 |
+
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 hv_train_network.py --dit ...
|
54 |
+
--network_module networks.lora --network_dim 32 --network_args "loraplus_lr_ratio=4" ...
|
55 |
+
```
|
56 |
+
|
57 |
+
## Select the target modules of LoRA / LoRAの対象モジュールを選択する
|
58 |
+
|
59 |
+
*This feature is highly experimental and the specification may change. / この機能は特に実験的なもので、仕様は変更される可能性があります。*
|
60 |
+
|
61 |
+
By specifying `exclude_patterns` and `include_patterns` with `--network_args`, you can select the target modules of LoRA.
|
62 |
+
|
63 |
+
`exclude_patterns` excludes modules that match the specified pattern. `include_patterns` targets only modules that match the specified pattern.
|
64 |
+
|
65 |
+
Specify the values as a list. For example, `"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`.
|
66 |
+
|
67 |
+
The pattern is a regular expression for the module name. The module name is in the form of `double_blocks.0.img_mod.linear` or `single_blocks.39.modulation.linear`. The regular expression is not a partial match but a complete match.
|
68 |
+
|
69 |
+
The patterns are applied in the order of `exclude_patterns`→`include_patterns`. By default, the Linear layers of `img_mod`, `txt_mod`, and `modulation` of double blocks and single blocks are excluded.
|
70 |
+
|
71 |
+
(`.*(img_mod|txt_mod|modulation).*` is specified.)
|
72 |
+
|
73 |
+
<details>
|
74 |
+
<summary>日本語</summary>
|
75 |
+
|
76 |
+
`--network_args`で`exclude_patterns`と`include_patterns`を指定することで、LoRAの対象モジュールを選択することができます。
|
77 |
+
|
78 |
+
`exclude_patterns`は、指定したパターンに一致するモジュールを除外します。`include_patterns`は、指定したパターンに一致するモジュールのみを対象とします。
|
79 |
+
|
80 |
+
値は、リストで指定します。`"exclude_patterns=[r'.*single_blocks.*', r'.*double_blocks\.[0-9]\..*']"`のようになります。
|
81 |
+
|
82 |
+
パターンは、モジュール名に対する正規表現です。モジュール名は、たとえば`double_blocks.0.img_mod.linear`や`single_blocks.39.modulation.linear`のような形式です。正規表現は部分一致ではなく完全一致です。
|
83 |
+
|
84 |
+
パターンは、`exclude_patterns`→`include_patterns`の順で適用されます。デフォルトは、double blocksとsingle blocksのLinear層のうち、`img_mod`、`txt_mod`、`modulation`が除外されています。
|
85 |
+
|
86 |
+
(`.*(img_mod|txt_mod|modulation).*`が指定されています。)
|
87 |
+
</details>
|
88 |
+
|
89 |
+
### Example / 記述例
|
90 |
+
|
91 |
+
Only the modules of double blocks / double blocksのモジュールのみを対象とする場合:
|
92 |
+
|
93 |
+
```bash
|
94 |
+
--network_args "exclude_patterns=[r'.*single_blocks.*']"
|
95 |
+
```
|
96 |
+
|
97 |
+
Only the modules of single blocks from the 10th / single blocksの10番目以降のLinearモジュールのみを対象とする場合:
|
98 |
+
|
99 |
+
```bash
|
100 |
+
--network_args "exclude_patterns=[r'.*']" "include_patterns=[r'.*single_blocks\.\d{2}\.linear.*']"
|
101 |
+
```
|
102 |
+
|
103 |
+
## Save and view logs in TensorBoard format / TensorBoard形式のログの保存と参照
|
104 |
+
|
105 |
+
Specify the folder to save the logs with the `--logging_dir` option. Logs in TensorBoard format will be saved.
|
106 |
+
|
107 |
+
For example, if you specify `--logging_dir=logs`, a `logs` folder will be created in the working folder, and logs will be saved in the date folder inside it.
|
108 |
+
|
109 |
+
Also, if you specify the `--log_prefix` option, the specified string will be added before the date. For example, use `--logging_dir=logs --log_prefix=lora_setting1_` for identification.
|
110 |
+
|
111 |
+
To view logs in TensorBoard, open another command prompt and activate the virtual environment. Then enter the following in the working folder.
|
112 |
+
|
113 |
+
```powershell
|
114 |
+
tensorboard --logdir=logs
|
115 |
+
```
|
116 |
+
|
117 |
+
(tensorboard installation is required.)
|
118 |
+
|
119 |
+
Then open a browser and access http://localhost:6006/ to display it.
|
120 |
+
|
121 |
+
<details>
|
122 |
+
<summary>日本語</summary>
|
123 |
+
`--logging_dir`オプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。
|
124 |
+
|
125 |
+
たとえば`--logging_dir=logs`と指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。
|
126 |
+
|
127 |
+
また`--log_prefix`オプションを指定すると、日時の前に指定した文字列が追加されます。`--logging_dir=logs --log_prefix=lora_setting1_`などとして識別用にお使いください。
|
128 |
+
|
129 |
+
TensorBoardでログを確認するには、別のコマンドプロンプトを開き、仮想環境を有効にしてから、作業フォルダで以下のように入力します。
|
130 |
+
|
131 |
+
```powershell
|
132 |
+
tensorboard --logdir=logs
|
133 |
+
```
|
134 |
+
|
135 |
+
(tensorboardのインストールが必要です。)
|
136 |
+
|
137 |
+
その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
|
138 |
+
</details>
|
139 |
+
|
140 |
+
## Save and view logs in wandb / wandbでログの保存と参照
|
141 |
+
|
142 |
+
`--log_with wandb` option is available to save logs in wandb format. `tensorboard` or `all` is also available. The default is `tensorboard`.
|
143 |
+
|
144 |
+
Specify the project name with `--log_tracker_name` when using wandb.
|
145 |
+
|
146 |
+
<details>
|
147 |
+
<summary>日本語</summary>
|
148 |
+
`--log_with wandb`オプションを指定するとwandb形式でログを保存することができます。`tensorboard`や`all`も指定可能です。デフォルトは`tensorboard`です。
|
149 |
+
|
150 |
+
wandbを使用する場合は、`--log_tracker_name`でプロジェクト名を指定してください。
|
151 |
+
</details>
|
docs/sampling_during_training.md
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> 📝 Click on the language section to expand / 言語をクリックして展開
|
2 |
+
|
3 |
+
# Sampling during training / 学習中のサンプル画像生成
|
4 |
+
|
5 |
+
By preparing a prompt file, you can generate sample images during training.
|
6 |
+
|
7 |
+
Please be aware that it consumes a considerable amount of VRAM, so be careful when generating sample images for videos with a large number of frames. Also, since it takes time to generate, adjust the frequency of sample image generation as needed.
|
8 |
+
|
9 |
+
<details>
|
10 |
+
<summary>日本語</summary>
|
11 |
+
|
12 |
+
プロンプトファイルを用意することで、学習中にサンプル画像を生成することができます。
|
13 |
+
|
14 |
+
VRAMをそれなりに消費しますので、特にフレーム数が多い動画を生成する場合は注意してください。また生成には時間がかかりますので、サンプル画像生成の頻度は適宜調整してください。
|
15 |
+
</details>
|
16 |
+
|
17 |
+
## How to use / 使い方
|
18 |
+
|
19 |
+
### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
|
20 |
+
|
21 |
+
Example of command line options for training with sampling / 記述例:
|
22 |
+
|
23 |
+
```bash
|
24 |
+
--vae path/to/ckpts/hunyuan-video-t2v-720p/vae/pytorch_model.pt
|
25 |
+
--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128
|
26 |
+
--text_encoder1 path/to/ckpts/text_encoder
|
27 |
+
--text_encoder2 path/to/ckpts/text_encoder_2
|
28 |
+
--sample_prompts /path/to/prompt_file.txt
|
29 |
+
--sample_every_n_epochs 1 --sample_every_n_steps 1000 --sample_at_first
|
30 |
+
```
|
31 |
+
|
32 |
+
`--vae`, `--vae_chunk_size`, `--vae_spatial_tile_sample_min_size`, `--text_encoder1`, `--text_encoder2` are the same as when generating images, so please refer to [here](/README.md#inference) for details. `--fp8_llm` can also be specified.
|
33 |
+
|
34 |
+
`--sample_prompts` specifies the path to the prompt file used for sample image generation. Details are described below.
|
35 |
+
|
36 |
+
`--sample_every_n_epochs` specifies how often to generate sample images in epochs, and `--sample_every_n_steps` specifies how often to generate sample images in steps.
|
37 |
+
|
38 |
+
`--sample_at_first` is specified when generating sample images at the beginning of training.
|
39 |
+
|
40 |
+
Sample images and videos are saved in the `sample` directory in the directory specified by `--output_dir`. They are saved as `.png` for still images and `.mp4` for videos.
|
41 |
+
|
42 |
+
<details>
|
43 |
+
<summary>日本語</summary>
|
44 |
+
|
45 |
+
`--vae`、`--vae_chunk_size`、`--vae_spatial_tile_sample_min_size`、`--text_encoder1`、`--text_encoder2`は、画像生成時と同様ですので、詳細は[こちら](/README.ja.md#推論)を参照してください。`--fp8_llm`も指定可能です。
|
46 |
+
|
47 |
+
`--sample_prompts`は、サンプル画像生成に使用するプロンプトファイルのパスを指定します。詳細は後述します。
|
48 |
+
|
49 |
+
`--sample_every_n_epochs`は、何エポックごとにサンプル画像を生成するかを、`--sample_every_n_steps`は、何ステップごとにサンプル画像を生成するかを指定します。
|
50 |
+
|
51 |
+
`--sample_at_first`は、学習開始時にサンプル画像を生成する場合に指定します。
|
52 |
+
|
53 |
+
サンプル画像、動画は、`--output_dir`で指定したディレクトリ内の、`sample`ディレクトリに保存されます。静止画の場合は`.png`、動画の場合は`.mp4`で保存されます。
|
54 |
+
</details>
|
55 |
+
|
56 |
+
### Prompt file / プロンプトファイル
|
57 |
+
|
58 |
+
The prompt file is a text file that contains the prompts for generating sample images. The example is as follows. / プロンプトファイルは、サンプル画像生成のためのプロンプトを記述したテキストファイルです。例は以下の通りです。
|
59 |
+
|
60 |
+
```
|
61 |
+
# prompt 1: for generating a cat video
|
62 |
+
A cat walks on the grass, realistic style. --w 640 --h 480 --f 25 --d 1 --s 20
|
63 |
+
|
64 |
+
# prompt 2: for generating a dog image
|
65 |
+
A dog runs on the beach, realistic style. --w 960 --h 544 --f 1 --d 2 --s 20
|
66 |
+
```
|
67 |
+
|
68 |
+
A line starting with `#` is a comment.
|
69 |
+
|
70 |
+
* `--w` specifies the width of the generated image or video. The default is 256.
|
71 |
+
* `--h` specifies the height. The default is 256.
|
72 |
+
* `--f` specifies the number of frames. The default is 1, which generates a still image.
|
73 |
+
* `--d` specifies the seed. The default is random.
|
74 |
+
* `--s` specifies the number of steps in generation. The default is 20.
|
75 |
+
* `--g` specifies the guidance scale. The default is 6.0, which is the default value during inference of HunyuanVideo. Specify 1.0 for SkyReels V1 models. Ignore this option for Wan2.1 models.
|
76 |
+
* `--fs` specifies the discrete flow shift. The default is 14.5, which corresponds to the number of steps 20. In the HunyuanVideo paper, 7.0 is recommended for 50 steps, and 17.0 is recommended for less than 20 steps (e.g. 10).
|
77 |
+
|
78 |
+
If you train I2V models, you can use the additional options below.
|
79 |
+
|
80 |
+
* `--i path/to/image.png`: the image path for image2video inference.
|
81 |
+
|
82 |
+
If you train the model with classifier free guidance, you can use the additional options below.
|
83 |
+
|
84 |
+
*`--n negative prompt...`: the negative prompt for the classifier free guidance.
|
85 |
+
*`--l 6.0`: the classifier free guidance scale. Should be set to 6.0 for SkyReels V1 models. 5.0 is the default value for Wan2.1 (if omitted).
|
86 |
+
|
87 |
+
<details>
|
88 |
+
<summary>日本語</summary>
|
89 |
+
|
90 |
+
`#` で始まる行はコメントです。
|
91 |
+
|
92 |
+
* `--w` 生成画像、動画の幅を指定します。省略時は256です。
|
93 |
+
* `--h` 高さを指定します。省略時は256です。
|
94 |
+
* `--f` フレーム数を指定します。省略時は1で、静止画を生成します。
|
95 |
+
* `--d` シードを指定します。省略時はランダムです。
|
96 |
+
* `--s` 生成におけるステップ数を指定します。省略時は20です。
|
97 |
+
* `--g` guidance scaleを指定します。省略時は6.0で、HunyuanVideoの推論時のデフォルト値です。
|
98 |
+
* `--fs` discrete flow shiftを指定します。省略時は14.5で、ステップ数20の場合に対応した値です。HunyuanVideoの論文では、ステップ数50の場合は7.0、ステップ数20未満(10など)で17.0が推奨されています。
|
99 |
+
|
100 |
+
I2Vモデルを学習する場合、以下の追加オプションを使用できます。
|
101 |
+
|
102 |
+
* `--i path/to/image.png`: image2video推論用の画像パス。
|
103 |
+
|
104 |
+
classifier free guidance(ネガティブプロンプト)を必要とするモデルを学習する場合、以下の追加オプションを使用できます。
|
105 |
+
|
106 |
+
*`--n negative prompt...`: classifier free guidance用のネガティブプロンプト。
|
107 |
+
*`--l 6.0`: classifier free guidance scale。SkyReels V1モデルの場合は6.0に設定してください。Wan2.1の場合はデフォルト値が5.0です(省略時)。
|
108 |
+
</details>
|
docs/wan.md
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
> 📝 Click on the language section to expand / 言語をクリックして展開
|
2 |
+
|
3 |
+
# Wan 2.1
|
4 |
+
|
5 |
+
## Overview / 概要
|
6 |
+
|
7 |
+
This is an unofficial training and inference script for [Wan2.1](https://github.com/Wan-Video/Wan2.1). The features are as follows.
|
8 |
+
|
9 |
+
- fp8 support and memory reduction by block swap: Inference of a 720x1280x81frames videos with 24GB VRAM, training with 720x1280 images with 24GB VRAM
|
10 |
+
- Inference without installing Flash attention (using PyTorch's scaled dot product attention)
|
11 |
+
- Supports xformers and Sage attention
|
12 |
+
|
13 |
+
This feature is experimental.
|
14 |
+
|
15 |
+
<details>
|
16 |
+
<summary>日本語</summary>
|
17 |
+
[Wan2.1](https://github.com/Wan-Video/Wan2.1) の非公式の学習および推論スクリプトです。
|
18 |
+
|
19 |
+
以下の特徴があります。
|
20 |
+
|
21 |
+
- fp8対応およびblock swapによる省メモリ化:720x1280x81framesの動画を24GB VRAMで推論可能、720x1280の画像での学習が24GB VRAMで可能
|
22 |
+
- Flash attentionのインストールなしでの実行(PyTorchのscaled dot product attentionを使用)
|
23 |
+
- xformersおよびSage attention対応
|
24 |
+
|
25 |
+
この機能は実験的なものです。
|
26 |
+
</details>
|
27 |
+
|
28 |
+
## Download the model / モデルのダウンロード
|
29 |
+
|
30 |
+
Download the T5 `models_t5_umt5-xxl-enc-bf16.pth` and CLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` from the following page: https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
|
31 |
+
|
32 |
+
Download the VAE from the above page `Wan2.1_VAE.pth` or download `split_files/vae/wan_2.1_vae.safetensors` from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
|
33 |
+
|
34 |
+
Download the DiT weights from the following page: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
35 |
+
|
36 |
+
Please select the appropriate weights according to T2V, I2V, resolution, model size, etc. fp8 models can be used if `--fp8` is specified.
|
37 |
+
|
38 |
+
(Thanks to Comfy-Org for providing the repackaged weights.)
|
39 |
+
<details>
|
40 |
+
<summary>日本語</summary>
|
41 |
+
T5 `models_t5_umt5-xxl-enc-bf16.pth` およびCLIP `models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を、次のページからダウンロードしてください:https://huggingface.co/Wan-AI/Wan2.1-T2V-14B/tree/main
|
42 |
+
|
43 |
+
VAEは上のページから `Wan2.1_VAE.pth` をダウンロードするか、次のページから `split_files/vae/wan_2.1_vae.safetensors` をダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/vae
|
44 |
+
|
45 |
+
DiTの重みを次のページからダウンロードしてください:https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models
|
46 |
+
|
47 |
+
T2VやI2V、解像度、モデルサイズなどにより適切な重みを選択してください。`--fp8`指定時はfp8モデルも使用できます。
|
48 |
+
|
49 |
+
(repackaged版の重みを提供してくださっているComfy-Orgに感謝いたします。)
|
50 |
+
</details>
|
51 |
+
|
52 |
+
## Pre-caching / 事前キャッシュ
|
53 |
+
|
54 |
+
### Latent Pre-caching
|
55 |
+
|
56 |
+
Latent pre-caching is almost the same as in HunyuanVideo. Create the cache using the following command:
|
57 |
+
|
58 |
+
```bash
|
59 |
+
python wan_cache_latents.py --dataset_config path/to/toml --vae path/to/wan_2.1_vae.safetensors
|
60 |
+
```
|
61 |
+
|
62 |
+
If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model. If not specified, the training will raise an error.
|
63 |
+
|
64 |
+
If you're running low on VRAM, specify `--vae_cache_cpu` to use the CPU for the VAE internal cache, which will reduce VRAM usage somewhat.
|
65 |
+
|
66 |
+
<details>
|
67 |
+
<summary>日本語</summary>
|
68 |
+
latentの事前キャッシングはHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
|
69 |
+
|
70 |
+
I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。指定しないと学習時にエラーが発生します。
|
71 |
+
|
72 |
+
VRAMが不足している場合は、`--vae_cache_cpu` を指定するとVAEの内部キャッシュにCPUを使うことで、使用VRAMを多少削減できます。
|
73 |
+
</details>
|
74 |
+
|
75 |
+
### Text Encoder Output Pre-caching
|
76 |
+
|
77 |
+
Text encoder output pre-caching is also almost the same as in HunyuanVideo. Create the cache using the following command:
|
78 |
+
|
79 |
+
```bash
|
80 |
+
python wan_cache_text_encoder_outputs.py --dataset_config path/to/toml --t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --batch_size 16
|
81 |
+
```
|
82 |
+
|
83 |
+
Adjust `--batch_size` according to your available VRAM.
|
84 |
+
|
85 |
+
For systems with limited VRAM (less than ~16GB), use `--fp8_t5` to run the T5 in fp8 mode.
|
86 |
+
|
87 |
+
<details>
|
88 |
+
<summary>日本語</summary>
|
89 |
+
テキストエンコーダ出力の事前キャッシングもHunyuanVideoとほぼ同じです。上のコマンド例を使用してキャッシュを作成してください。
|
90 |
+
|
91 |
+
使用可能なVRAMに合わせて `--batch_size` を調整してください。
|
92 |
+
|
93 |
+
VRAMが限られているシステム(約16GB未満)の場合は、T5をfp8モードで実行するために `--fp8_t5` を使用してください。
|
94 |
+
</details>
|
95 |
+
|
96 |
+
## Training / 学習
|
97 |
+
|
98 |
+
### Training
|
99 |
+
|
100 |
+
Start training using the following command (input as a single line):
|
101 |
+
|
102 |
+
```bash
|
103 |
+
accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 wan_train_network.py
|
104 |
+
--task t2v-1.3B
|
105 |
+
--dit path/to/wan2.1_xxx_bf16.safetensors
|
106 |
+
--dataset_config path/to/toml --sdpa --mixed_precision bf16 --fp8_base
|
107 |
+
--optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing
|
108 |
+
--max_data_loader_n_workers 2 --persistent_data_loader_workers
|
109 |
+
--network_module networks.lora_wan --network_dim 32
|
110 |
+
--timestep_sampling shift --discrete_flow_shift 3.0
|
111 |
+
--max_train_epochs 16 --save_every_n_epochs 1 --seed 42
|
112 |
+
--output_dir path/to/output_dir --output_name name-of-lora
|
113 |
+
```
|
114 |
+
The above is an example. The appropriate values for `timestep_sampling` and `discrete_flow_shift` need to be determined by experimentation.
|
115 |
+
|
116 |
+
For additional options, use `python wan_train_network.py --help` (note that many options are unverified).
|
117 |
+
|
118 |
+
`--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`. Specify the DiT weights for the task with `--dit`.
|
119 |
+
|
120 |
+
Don't forget to specify `--network_module networks.lora_wan`.
|
121 |
+
|
122 |
+
Other options are mostly the same as `hv_train_network.py`.
|
123 |
+
|
124 |
+
Use `convert_lora.py` for converting the LoRA weights after training, as in HunyuanVideo.
|
125 |
+
|
126 |
+
<details>
|
127 |
+
<summary>日本語</summary>
|
128 |
+
`timestep_sampling`や`discrete_flow_shift`は一例です。どのような値が適切かは実験が必要です。
|
129 |
+
|
130 |
+
その他のオプションについては `python wan_train_network.py --help` を使用してください(多くのオプションは未検証です)。
|
131 |
+
|
132 |
+
`--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。`--dit`に、taskに応じたDiTの重みを指定してください。
|
133 |
+
|
134 |
+
`--network_module` に `networks.lora_wan` を指定することを忘れないでください。
|
135 |
+
|
136 |
+
その他のオプションは、ほぼ`hv_train_network.py`と同様です。
|
137 |
+
|
138 |
+
学習後のLoRAの重みの変換は、HunyuanVideoと同様に`convert_lora.py`を使用してください。
|
139 |
+
</details>
|
140 |
+
|
141 |
+
### Command line options for training with sampling / サンプル画像生成に関連する学習時のコマンドラインオプション
|
142 |
+
|
143 |
+
Example of command line options for training with sampling / 記述例:
|
144 |
+
|
145 |
+
```bash
|
146 |
+
--vae path/to/wan_2.1_vae.safetensors
|
147 |
+
--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
|
148 |
+
--sample_prompts /path/to/prompt_file.txt
|
149 |
+
--sample_every_n_epochs 1 --sample_every_n_steps 1000 -- sample_at_first
|
150 |
+
```
|
151 |
+
Each option is the same as when generating images or as HunyuanVideo. Please refer to [here](/docs/sampling_during_training.md) for details.
|
152 |
+
|
153 |
+
If you train I2V models, add `--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` to specify the CLIP model.
|
154 |
+
|
155 |
+
You can specify the initial image and negative prompts in the prompt file. Please refer to [here](/docs/sampling_during_training.md#prompt-file--プロンプトファイル).
|
156 |
+
|
157 |
+
<details>
|
158 |
+
<summary>日本語</summary>
|
159 |
+
各オプションは推論時、およびHunyuanVideoの場合と同様です。[こちら](/docs/sampling_during_training.md)を参照してください。
|
160 |
+
|
161 |
+
I2Vモデルを学習する場合は、`--clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth` を追加してCLIPモデルを指定してください。
|
162 |
+
|
163 |
+
プロンプトファイルで、初期画像やネガティブプロンプト等を指定できます。[こちら](/docs/sampling_during_training.md#prompt-file--プロンプトファイル)を参照してください。
|
164 |
+
</details>
|
165 |
+
|
166 |
+
|
167 |
+
## Inference / 推論
|
168 |
+
|
169 |
+
### T2V Inference / T2V推論
|
170 |
+
|
171 |
+
The following is an example of T2V inference (input as a single line):
|
172 |
+
|
173 |
+
```bash
|
174 |
+
python wan_generate_video.py --fp8 --task t2v-1.3B --video_size 832 480 --video_length 81 --infer_steps 20
|
175 |
+
--prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
|
176 |
+
--dit path/to/wan2.1_t2v_1.3B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
|
177 |
+
--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth
|
178 |
+
--attn_mode torch
|
179 |
+
```
|
180 |
+
|
181 |
+
`--task` is one of `t2v-1.3B`, `t2v-14B`, `i2v-14B` and `t2i-14B`.
|
182 |
+
|
183 |
+
`--attn_mode` is `torch`, `sdpa` (same as `torch`), `xformers`, `sageattn`,`flash2`, `flash` (same as `flash2`) or `flash3`. `torch` is the default. Other options require the corresponding library to be installed. `flash3` (Flash attention 3) is not tested.
|
184 |
+
|
185 |
+
`--fp8_t5` can be used to specify the T5 model in fp8 format. This option reduces memory usage for the T5 model.
|
186 |
+
|
187 |
+
`--negative_prompt` can be used to specify a negative prompt. If omitted, the default negative prompt is used.
|
188 |
+
|
189 |
+
` --flow_shift` can be used to specify the flow shift (default 3.0 for I2V with 480p, 5.0 for others).
|
190 |
+
|
191 |
+
`--guidance_scale` can be used to specify the guidance scale for classifier free guiance (default 5.0).
|
192 |
+
|
193 |
+
`--blocks_to_swap` is the number of blocks to swap during inference. The default value is None (no block swap). The maximum value is 39 for 14B model and 29 for 1.3B model.
|
194 |
+
|
195 |
+
`--vae_cache_cpu` enables VAE cache in main memory. This reduces VRAM usage slightly but processing is slower.
|
196 |
+
|
197 |
+
Other options are same as `hv_generate_video.py` (some options are not supported, please check the help).
|
198 |
+
|
199 |
+
<details>
|
200 |
+
<summary>日本語</summary>
|
201 |
+
`--task` には `t2v-1.3B`, `t2v-14B`, `i2v-14B`, `t2i-14B` のいずれかを指定します。
|
202 |
+
|
203 |
+
`--attn_mode` には `torch`, `sdpa`(`torch`と同じ)、`xformers`, `sageattn`, `flash2`, `flash`(`flash2`と同じ), `flash3` のいずれかを指定します。デフォルトは `torch` です。その他のオプションを使用する場合は、対応するライブラリをインストールする必要があります。`flash3`(Flash attention 3)は未テストです。
|
204 |
+
|
205 |
+
`--fp8_t5` を指定するとT5モデルをfp8形式で実行します。T5モデル呼び出し時のメモリ使用量を削減します。
|
206 |
+
|
207 |
+
`--negative_prompt` でネガティブプロンプトを指定できます。省略した場合はデフォルトのネガティブプロンプトが使用されます。
|
208 |
+
|
209 |
+
`--flow_shift` でflow shiftを指定できます(480pのI2Vの場合はデフォルト3.0、それ以外は5.0)。
|
210 |
+
|
211 |
+
`--guidance_scale` でclassifier free guianceのガイダンススケールを指定できます(デフォルト5.0)。
|
212 |
+
|
213 |
+
`--blocks_to_swap` は推論時のblock swapの数です。デフォルト値はNone(block swapなし)です。最大値は14Bモデルの場合39、1.3Bモデルの場合29です。
|
214 |
+
|
215 |
+
`--vae_cache_cpu` を有効にすると、VAEのキャッシュをメインメモリに保持します。VRAM使用量が多少減りますが、処理は遅くなります。
|
216 |
+
|
217 |
+
その他のオプションは `hv_generate_video.py` と同じです(一部のオプションはサポートされていないため、ヘルプを確認してください)。
|
218 |
+
</details>
|
219 |
+
|
220 |
+
### I2V Inference / I2V推論
|
221 |
+
|
222 |
+
The following is an example of I2V inference (input as a single line):
|
223 |
+
|
224 |
+
```bash
|
225 |
+
python wan_generate_video.py --fp8 --task i2v-14B --video_size 832 480 --video_length 81 --infer_steps 20
|
226 |
+
--prompt "prompt for the video" --save_path path/to/save.mp4 --output_type both
|
227 |
+
--dit path/to/wan2.1_i2v_480p_14B_bf16_etc.safetensors --vae path/to/wan_2.1_vae.safetensors
|
228 |
+
--t5 path/to/models_t5_umt5-xxl-enc-bf16.pth --clip path/to/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
|
229 |
+
--attn_mode torch --image_path path/to/image.jpg
|
230 |
+
```
|
231 |
+
|
232 |
+
Add `--clip` to specify the CLIP model. `--image_path` is the path to the image to be used as the initial frame.
|
233 |
+
|
234 |
+
Other options are same as T2V inference.
|
235 |
+
|
236 |
+
<details>
|
237 |
+
<summary>日本語</summary>
|
238 |
+
`--clip` を追加してCLIPモデルを指定します。`--image_path` は初期フレームとして使用する画像のパスです。
|
239 |
+
|
240 |
+
その他のオプションはT2V推論と同じです。
|
241 |
+
</details>
|
hunyuan_model/__init__.py
ADDED
File without changes
|
hunyuan_model/activation_layers.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_activation_layer(act_type):
|
5 |
+
"""get activation layer
|
6 |
+
|
7 |
+
Args:
|
8 |
+
act_type (str): the activation type
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
torch.nn.functional: the activation layer
|
12 |
+
"""
|
13 |
+
if act_type == "gelu":
|
14 |
+
return lambda: nn.GELU()
|
15 |
+
elif act_type == "gelu_tanh":
|
16 |
+
# Approximate `tanh` requires torch >= 1.13
|
17 |
+
return lambda: nn.GELU(approximate="tanh")
|
18 |
+
elif act_type == "relu":
|
19 |
+
return nn.ReLU
|
20 |
+
elif act_type == "silu":
|
21 |
+
return nn.SiLU
|
22 |
+
else:
|
23 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
hunyuan_model/attention.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.metadata
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
try:
|
9 |
+
import flash_attn
|
10 |
+
from flash_attn.flash_attn_interface import _flash_attn_forward
|
11 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
12 |
+
from flash_attn.flash_attn_interface import flash_attn_func
|
13 |
+
except ImportError:
|
14 |
+
flash_attn = None
|
15 |
+
flash_attn_varlen_func = None
|
16 |
+
_flash_attn_forward = None
|
17 |
+
flash_attn_func = None
|
18 |
+
|
19 |
+
try:
|
20 |
+
print(f"Trying to import sageattention")
|
21 |
+
from sageattention import sageattn_varlen, sageattn
|
22 |
+
|
23 |
+
print("Successfully imported sageattention")
|
24 |
+
except ImportError:
|
25 |
+
print(f"Failed to import sageattention")
|
26 |
+
sageattn_varlen = None
|
27 |
+
sageattn = None
|
28 |
+
|
29 |
+
try:
|
30 |
+
import xformers.ops as xops
|
31 |
+
except ImportError:
|
32 |
+
xops = None
|
33 |
+
|
34 |
+
MEMORY_LAYOUT = {
|
35 |
+
"flash": (
|
36 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
37 |
+
lambda x: x,
|
38 |
+
),
|
39 |
+
"flash_fixlen": (
|
40 |
+
lambda x: x,
|
41 |
+
lambda x: x,
|
42 |
+
),
|
43 |
+
"sageattn": (
|
44 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
45 |
+
lambda x: x,
|
46 |
+
),
|
47 |
+
"sageattn_fixlen": (
|
48 |
+
lambda x: x.transpose(1, 2),
|
49 |
+
lambda x: x.transpose(1, 2),
|
50 |
+
),
|
51 |
+
"torch": (
|
52 |
+
lambda x: x.transpose(1, 2),
|
53 |
+
lambda x: x.transpose(1, 2),
|
54 |
+
),
|
55 |
+
"xformers": (
|
56 |
+
lambda x: x,
|
57 |
+
lambda x: x,
|
58 |
+
),
|
59 |
+
"vanilla": (
|
60 |
+
lambda x: x.transpose(1, 2),
|
61 |
+
lambda x: x.transpose(1, 2),
|
62 |
+
),
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def get_cu_seqlens(text_mask, img_len):
|
67 |
+
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
|
68 |
+
|
69 |
+
Args:
|
70 |
+
text_mask (torch.Tensor): the mask of text
|
71 |
+
img_len (int): the length of image
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
torch.Tensor: the calculated cu_seqlens for flash attention
|
75 |
+
"""
|
76 |
+
batch_size = text_mask.shape[0]
|
77 |
+
text_len = text_mask.sum(dim=1)
|
78 |
+
max_len = text_mask.shape[1] + img_len
|
79 |
+
|
80 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
|
81 |
+
|
82 |
+
for i in range(batch_size):
|
83 |
+
s = text_len[i] + img_len
|
84 |
+
s1 = i * max_len + s
|
85 |
+
s2 = (i + 1) * max_len
|
86 |
+
cu_seqlens[2 * i + 1] = s1
|
87 |
+
cu_seqlens[2 * i + 2] = s2
|
88 |
+
|
89 |
+
return cu_seqlens
|
90 |
+
|
91 |
+
|
92 |
+
def attention(
|
93 |
+
q_or_qkv_list,
|
94 |
+
k=None,
|
95 |
+
v=None,
|
96 |
+
mode="flash",
|
97 |
+
drop_rate=0,
|
98 |
+
attn_mask=None,
|
99 |
+
total_len=None,
|
100 |
+
causal=False,
|
101 |
+
cu_seqlens_q=None,
|
102 |
+
cu_seqlens_kv=None,
|
103 |
+
max_seqlen_q=None,
|
104 |
+
max_seqlen_kv=None,
|
105 |
+
batch_size=1,
|
106 |
+
):
|
107 |
+
"""
|
108 |
+
Perform QKV self attention.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
112 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
113 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
114 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
115 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
116 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
117 |
+
(default: None)
|
118 |
+
causal (bool): Whether to use causal attention. (default: False)
|
119 |
+
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
120 |
+
used to index into q.
|
121 |
+
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
122 |
+
used to index into kv.
|
123 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
124 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
128 |
+
"""
|
129 |
+
q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
|
130 |
+
if type(q_or_qkv_list) == list:
|
131 |
+
q_or_qkv_list.clear()
|
132 |
+
split_attn = total_len is not None
|
133 |
+
if split_attn and mode == "sageattn":
|
134 |
+
mode = "sageattn_fixlen"
|
135 |
+
elif split_attn and mode == "flash":
|
136 |
+
mode = "flash_fixlen"
|
137 |
+
# print(f"Attention mode: {mode}, split_attn: {split_attn}")
|
138 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
139 |
+
|
140 |
+
# trim the sequence length to the actual length instead of attn_mask
|
141 |
+
if split_attn:
|
142 |
+
trimmed_len = q.shape[1] - total_len
|
143 |
+
q = [q[i : i + 1, : total_len[i]] for i in range(len(q))]
|
144 |
+
k = [k[i : i + 1, : total_len[i]] for i in range(len(k))]
|
145 |
+
v = [v[i : i + 1, : total_len[i]] for i in range(len(v))]
|
146 |
+
q = [pre_attn_layout(q_i) for q_i in q]
|
147 |
+
k = [pre_attn_layout(k_i) for k_i in k]
|
148 |
+
v = [pre_attn_layout(v_i) for v_i in v]
|
149 |
+
# print(
|
150 |
+
# f"Trimming the sequence length to {total_len},trimmed_len: {trimmed_len}, q.shape: {[q_i.shape for q_i in q]}, mode: {mode}"
|
151 |
+
# )
|
152 |
+
else:
|
153 |
+
q = pre_attn_layout(q)
|
154 |
+
k = pre_attn_layout(k)
|
155 |
+
v = pre_attn_layout(v)
|
156 |
+
|
157 |
+
if mode == "torch":
|
158 |
+
if split_attn:
|
159 |
+
x = []
|
160 |
+
for i in range(len(q)):
|
161 |
+
x_i = F.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate, is_causal=causal)
|
162 |
+
q[i], k[i], v[i] = None, None, None
|
163 |
+
x.append(x_i)
|
164 |
+
del q, k, v
|
165 |
+
else:
|
166 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
167 |
+
attn_mask = attn_mask.to(q.dtype)
|
168 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
169 |
+
del q, k, v
|
170 |
+
del attn_mask
|
171 |
+
|
172 |
+
elif mode == "xformers":
|
173 |
+
# B, M, H, K: M is the sequence length, H is the number of heads, K is the dimension of the heads -> it is same as input dimension
|
174 |
+
# currently only support batch_size = 1
|
175 |
+
assert split_attn, "Xformers only supports splitting"
|
176 |
+
x = []
|
177 |
+
for i in range(len(q)):
|
178 |
+
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) # , causal=causal)
|
179 |
+
q[i], k[i], v[i] = None, None, None
|
180 |
+
x.append(x_i)
|
181 |
+
del q, k, v
|
182 |
+
|
183 |
+
elif mode == "flash":
|
184 |
+
x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
185 |
+
del q, k, v
|
186 |
+
# x with shape [(bxs), a, d]
|
187 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
188 |
+
elif mode == "flash_fixlen":
|
189 |
+
x = []
|
190 |
+
for i in range(len(q)):
|
191 |
+
# q: (batch_size, seqlen, nheads, headdim), k: (batch_size, seqlen, nheads_k, headdim), v: (batch_size, seqlen, nheads_k, headdim)
|
192 |
+
x_i = flash_attn_func(q[i], k[i], v[i], dropout_p=drop_rate, causal=causal)
|
193 |
+
q[i], k[i], v[i] = None, None, None
|
194 |
+
x.append(x_i)
|
195 |
+
del q, k, v
|
196 |
+
elif mode == "sageattn":
|
197 |
+
x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
198 |
+
del q, k, v
|
199 |
+
# x with shape [(bxs), a, d]
|
200 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
201 |
+
elif mode == "sageattn_fixlen":
|
202 |
+
x = []
|
203 |
+
for i in range(len(q)):
|
204 |
+
# HND seems to cause an error
|
205 |
+
x_i = sageattn(q[i], k[i], v[i]) # (batch_size, seq_len, head_num, head_dim)
|
206 |
+
q[i], k[i], v[i] = None, None, None
|
207 |
+
x.append(x_i)
|
208 |
+
del q, k, v
|
209 |
+
elif mode == "vanilla":
|
210 |
+
assert not split_attn, "Vanilla attention does not support trimming"
|
211 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
212 |
+
|
213 |
+
b, a, s, _ = q.shape
|
214 |
+
s1 = k.size(2)
|
215 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
216 |
+
if causal:
|
217 |
+
# Only applied to self attention
|
218 |
+
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
219 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
220 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
221 |
+
attn_bias.to(q.dtype)
|
222 |
+
|
223 |
+
if attn_mask is not None:
|
224 |
+
if attn_mask.dtype == torch.bool:
|
225 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
226 |
+
else:
|
227 |
+
attn_bias += attn_mask
|
228 |
+
|
229 |
+
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
230 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
231 |
+
attn += attn_bias
|
232 |
+
attn = attn.softmax(dim=-1)
|
233 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
234 |
+
x = attn @ v
|
235 |
+
else:
|
236 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
237 |
+
|
238 |
+
if split_attn:
|
239 |
+
x = [post_attn_layout(x_i) for x_i in x]
|
240 |
+
for i in range(len(x)):
|
241 |
+
x[i] = F.pad(x[i], (0, 0, 0, 0, 0, trimmed_len[i]))
|
242 |
+
x = torch.cat(x, dim=0)
|
243 |
+
else:
|
244 |
+
x = post_attn_layout(x)
|
245 |
+
|
246 |
+
b, s, a, d = x.shape
|
247 |
+
out = x.reshape(b, s, -1)
|
248 |
+
return out
|
249 |
+
|
250 |
+
|
251 |
+
def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
|
252 |
+
attn1 = hybrid_seq_parallel_attn(
|
253 |
+
None,
|
254 |
+
q[:, :img_q_len, :, :],
|
255 |
+
k[:, :img_kv_len, :, :],
|
256 |
+
v[:, :img_kv_len, :, :],
|
257 |
+
dropout_p=0.0,
|
258 |
+
causal=False,
|
259 |
+
joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
|
260 |
+
joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
|
261 |
+
joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
|
262 |
+
joint_strategy="rear",
|
263 |
+
)
|
264 |
+
if flash_attn.__version__ >= "2.7.0":
|
265 |
+
attn2, *_ = _flash_attn_forward(
|
266 |
+
q[:, cu_seqlens_q[1] :],
|
267 |
+
k[:, cu_seqlens_kv[1] :],
|
268 |
+
v[:, cu_seqlens_kv[1] :],
|
269 |
+
dropout_p=0.0,
|
270 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
271 |
+
causal=False,
|
272 |
+
window_size_left=-1,
|
273 |
+
window_size_right=-1,
|
274 |
+
softcap=0.0,
|
275 |
+
alibi_slopes=None,
|
276 |
+
return_softmax=False,
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
attn2, *_ = _flash_attn_forward(
|
280 |
+
q[:, cu_seqlens_q[1] :],
|
281 |
+
k[:, cu_seqlens_kv[1] :],
|
282 |
+
v[:, cu_seqlens_kv[1] :],
|
283 |
+
dropout_p=0.0,
|
284 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
285 |
+
causal=False,
|
286 |
+
window_size=(-1, -1),
|
287 |
+
softcap=0.0,
|
288 |
+
alibi_slopes=None,
|
289 |
+
return_softmax=False,
|
290 |
+
)
|
291 |
+
attn = torch.cat([attn1, attn2], dim=1)
|
292 |
+
b, s, a, d = attn.shape
|
293 |
+
attn = attn.reshape(b, s, -1)
|
294 |
+
|
295 |
+
return attn
|
hunyuan_model/autoencoder_kl_causal_3d.py
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
from typing import Dict, Optional, Tuple, Union
|
20 |
+
from dataclasses import dataclass
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
|
27 |
+
# try:
|
28 |
+
# # This diffusers is modified and packed in the mirror.
|
29 |
+
# from diffusers.loaders import FromOriginalVAEMixin
|
30 |
+
# except ImportError:
|
31 |
+
# # Use this to be compatible with the original diffusers.
|
32 |
+
# from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
|
33 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
34 |
+
from diffusers.models.attention_processor import (
|
35 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
36 |
+
CROSS_ATTENTION_PROCESSORS,
|
37 |
+
Attention,
|
38 |
+
AttentionProcessor,
|
39 |
+
AttnAddedKVProcessor,
|
40 |
+
AttnProcessor,
|
41 |
+
)
|
42 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
43 |
+
from diffusers.models.modeling_utils import ModelMixin
|
44 |
+
from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
|
45 |
+
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class DecoderOutput2(BaseOutput):
|
49 |
+
sample: torch.FloatTensor
|
50 |
+
posterior: Optional[DiagonalGaussianDistribution] = None
|
51 |
+
|
52 |
+
|
53 |
+
class AutoencoderKLCausal3D(ModelMixin, ConfigMixin):
|
54 |
+
r"""
|
55 |
+
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
|
56 |
+
|
57 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
58 |
+
for all models (such as downloading or saving).
|
59 |
+
"""
|
60 |
+
|
61 |
+
_supports_gradient_checkpointing = True
|
62 |
+
|
63 |
+
@register_to_config
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
in_channels: int = 3,
|
67 |
+
out_channels: int = 3,
|
68 |
+
down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
|
69 |
+
up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
|
70 |
+
block_out_channels: Tuple[int] = (64,),
|
71 |
+
layers_per_block: int = 1,
|
72 |
+
act_fn: str = "silu",
|
73 |
+
latent_channels: int = 4,
|
74 |
+
norm_num_groups: int = 32,
|
75 |
+
sample_size: int = 32,
|
76 |
+
sample_tsize: int = 64,
|
77 |
+
scaling_factor: float = 0.18215,
|
78 |
+
force_upcast: float = True,
|
79 |
+
spatial_compression_ratio: int = 8,
|
80 |
+
time_compression_ratio: int = 4,
|
81 |
+
mid_block_add_attention: bool = True,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
self.time_compression_ratio = time_compression_ratio
|
86 |
+
|
87 |
+
self.encoder = EncoderCausal3D(
|
88 |
+
in_channels=in_channels,
|
89 |
+
out_channels=latent_channels,
|
90 |
+
down_block_types=down_block_types,
|
91 |
+
block_out_channels=block_out_channels,
|
92 |
+
layers_per_block=layers_per_block,
|
93 |
+
act_fn=act_fn,
|
94 |
+
norm_num_groups=norm_num_groups,
|
95 |
+
double_z=True,
|
96 |
+
time_compression_ratio=time_compression_ratio,
|
97 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
98 |
+
mid_block_add_attention=mid_block_add_attention,
|
99 |
+
)
|
100 |
+
|
101 |
+
self.decoder = DecoderCausal3D(
|
102 |
+
in_channels=latent_channels,
|
103 |
+
out_channels=out_channels,
|
104 |
+
up_block_types=up_block_types,
|
105 |
+
block_out_channels=block_out_channels,
|
106 |
+
layers_per_block=layers_per_block,
|
107 |
+
norm_num_groups=norm_num_groups,
|
108 |
+
act_fn=act_fn,
|
109 |
+
time_compression_ratio=time_compression_ratio,
|
110 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
111 |
+
mid_block_add_attention=mid_block_add_attention,
|
112 |
+
)
|
113 |
+
|
114 |
+
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
115 |
+
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
116 |
+
|
117 |
+
self.use_slicing = False
|
118 |
+
self.use_spatial_tiling = False
|
119 |
+
self.use_temporal_tiling = False
|
120 |
+
|
121 |
+
# only relevant if vae tiling is enabled
|
122 |
+
self.tile_sample_min_tsize = sample_tsize
|
123 |
+
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
|
124 |
+
|
125 |
+
self.tile_sample_min_size = self.config.sample_size
|
126 |
+
sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
|
127 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
128 |
+
self.tile_overlap_factor = 0.25
|
129 |
+
|
130 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
131 |
+
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
|
132 |
+
module.gradient_checkpointing = value
|
133 |
+
|
134 |
+
def enable_temporal_tiling(self, use_tiling: bool = True):
|
135 |
+
self.use_temporal_tiling = use_tiling
|
136 |
+
|
137 |
+
def disable_temporal_tiling(self):
|
138 |
+
self.enable_temporal_tiling(False)
|
139 |
+
|
140 |
+
def enable_spatial_tiling(self, use_tiling: bool = True):
|
141 |
+
self.use_spatial_tiling = use_tiling
|
142 |
+
|
143 |
+
def disable_spatial_tiling(self):
|
144 |
+
self.enable_spatial_tiling(False)
|
145 |
+
|
146 |
+
def enable_tiling(self, use_tiling: bool = True):
|
147 |
+
r"""
|
148 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
149 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
150 |
+
processing larger videos.
|
151 |
+
"""
|
152 |
+
self.enable_spatial_tiling(use_tiling)
|
153 |
+
self.enable_temporal_tiling(use_tiling)
|
154 |
+
|
155 |
+
def disable_tiling(self):
|
156 |
+
r"""
|
157 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
158 |
+
decoding in one step.
|
159 |
+
"""
|
160 |
+
self.disable_spatial_tiling()
|
161 |
+
self.disable_temporal_tiling()
|
162 |
+
|
163 |
+
def enable_slicing(self):
|
164 |
+
r"""
|
165 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
166 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
167 |
+
"""
|
168 |
+
self.use_slicing = True
|
169 |
+
|
170 |
+
def disable_slicing(self):
|
171 |
+
r"""
|
172 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
173 |
+
decoding in one step.
|
174 |
+
"""
|
175 |
+
self.use_slicing = False
|
176 |
+
|
177 |
+
def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
|
178 |
+
# set chunk_size to CausalConv3d recursively
|
179 |
+
def set_chunk_size(module):
|
180 |
+
if hasattr(module, "chunk_size"):
|
181 |
+
module.chunk_size = chunk_size
|
182 |
+
|
183 |
+
self.apply(set_chunk_size)
|
184 |
+
|
185 |
+
@property
|
186 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
187 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
188 |
+
r"""
|
189 |
+
Returns:
|
190 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
191 |
+
indexed by its weight name.
|
192 |
+
"""
|
193 |
+
# set recursively
|
194 |
+
processors = {}
|
195 |
+
|
196 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
197 |
+
if hasattr(module, "get_processor"):
|
198 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
199 |
+
|
200 |
+
for sub_name, child in module.named_children():
|
201 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
202 |
+
|
203 |
+
return processors
|
204 |
+
|
205 |
+
for name, module in self.named_children():
|
206 |
+
fn_recursive_add_processors(name, module, processors)
|
207 |
+
|
208 |
+
return processors
|
209 |
+
|
210 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
211 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
|
212 |
+
r"""
|
213 |
+
Sets the attention processor to use to compute attention.
|
214 |
+
|
215 |
+
Parameters:
|
216 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
217 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
218 |
+
for **all** `Attention` layers.
|
219 |
+
|
220 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
221 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
222 |
+
|
223 |
+
"""
|
224 |
+
count = len(self.attn_processors.keys())
|
225 |
+
|
226 |
+
if isinstance(processor, dict) and len(processor) != count:
|
227 |
+
raise ValueError(
|
228 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
229 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
230 |
+
)
|
231 |
+
|
232 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
233 |
+
if hasattr(module, "set_processor"):
|
234 |
+
if not isinstance(processor, dict):
|
235 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
236 |
+
else:
|
237 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
238 |
+
|
239 |
+
for sub_name, child in module.named_children():
|
240 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
241 |
+
|
242 |
+
for name, module in self.named_children():
|
243 |
+
fn_recursive_attn_processor(name, module, processor)
|
244 |
+
|
245 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
246 |
+
def set_default_attn_processor(self):
|
247 |
+
"""
|
248 |
+
Disables custom attention processors and sets the default attention implementation.
|
249 |
+
"""
|
250 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
251 |
+
processor = AttnAddedKVProcessor()
|
252 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
253 |
+
processor = AttnProcessor()
|
254 |
+
else:
|
255 |
+
raise ValueError(
|
256 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
257 |
+
)
|
258 |
+
|
259 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
260 |
+
|
261 |
+
@apply_forward_hook
|
262 |
+
def encode(
|
263 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
264 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
265 |
+
"""
|
266 |
+
Encode a batch of images/videos into latents.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
270 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
271 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
272 |
+
|
273 |
+
Returns:
|
274 |
+
The latent representations of the encoded images/videos. If `return_dict` is True, a
|
275 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
276 |
+
"""
|
277 |
+
assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
|
278 |
+
|
279 |
+
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
|
280 |
+
return self.temporal_tiled_encode(x, return_dict=return_dict)
|
281 |
+
|
282 |
+
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
283 |
+
return self.spatial_tiled_encode(x, return_dict=return_dict)
|
284 |
+
|
285 |
+
if self.use_slicing and x.shape[0] > 1:
|
286 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
287 |
+
h = torch.cat(encoded_slices)
|
288 |
+
else:
|
289 |
+
h = self.encoder(x)
|
290 |
+
|
291 |
+
moments = self.quant_conv(h)
|
292 |
+
posterior = DiagonalGaussianDistribution(moments)
|
293 |
+
|
294 |
+
if not return_dict:
|
295 |
+
return (posterior,)
|
296 |
+
|
297 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
298 |
+
|
299 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
300 |
+
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
|
301 |
+
|
302 |
+
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
|
303 |
+
return self.temporal_tiled_decode(z, return_dict=return_dict)
|
304 |
+
|
305 |
+
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
306 |
+
return self.spatial_tiled_decode(z, return_dict=return_dict)
|
307 |
+
|
308 |
+
z = self.post_quant_conv(z)
|
309 |
+
dec = self.decoder(z)
|
310 |
+
|
311 |
+
if not return_dict:
|
312 |
+
return (dec,)
|
313 |
+
|
314 |
+
return DecoderOutput(sample=dec)
|
315 |
+
|
316 |
+
@apply_forward_hook
|
317 |
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
|
318 |
+
"""
|
319 |
+
Decode a batch of images/videos.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
323 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
324 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
328 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
329 |
+
returned.
|
330 |
+
|
331 |
+
"""
|
332 |
+
if self.use_slicing and z.shape[0] > 1:
|
333 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
334 |
+
decoded = torch.cat(decoded_slices)
|
335 |
+
else:
|
336 |
+
decoded = self._decode(z).sample
|
337 |
+
|
338 |
+
if not return_dict:
|
339 |
+
return (decoded,)
|
340 |
+
|
341 |
+
return DecoderOutput(sample=decoded)
|
342 |
+
|
343 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
344 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
345 |
+
for y in range(blend_extent):
|
346 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
347 |
+
return b
|
348 |
+
|
349 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
350 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
351 |
+
for x in range(blend_extent):
|
352 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
353 |
+
return b
|
354 |
+
|
355 |
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
356 |
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
357 |
+
for x in range(blend_extent):
|
358 |
+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
|
359 |
+
return b
|
360 |
+
|
361 |
+
def spatial_tiled_encode(
|
362 |
+
self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
|
363 |
+
) -> AutoencoderKLOutput:
|
364 |
+
r"""Encode a batch of images/videos using a tiled encoder.
|
365 |
+
|
366 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
367 |
+
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
|
368 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
369 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
370 |
+
output, but they should be much less noticeable.
|
371 |
+
|
372 |
+
Args:
|
373 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
374 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
375 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
379 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
380 |
+
`tuple` is returned.
|
381 |
+
"""
|
382 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
383 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
384 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
385 |
+
|
386 |
+
# Split video into tiles and encode them separately.
|
387 |
+
rows = []
|
388 |
+
for i in range(0, x.shape[-2], overlap_size):
|
389 |
+
row = []
|
390 |
+
for j in range(0, x.shape[-1], overlap_size):
|
391 |
+
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
392 |
+
tile = self.encoder(tile)
|
393 |
+
tile = self.quant_conv(tile)
|
394 |
+
row.append(tile)
|
395 |
+
rows.append(row)
|
396 |
+
result_rows = []
|
397 |
+
for i, row in enumerate(rows):
|
398 |
+
result_row = []
|
399 |
+
for j, tile in enumerate(row):
|
400 |
+
# blend the above tile and the left tile
|
401 |
+
# to the current tile and add the current tile to the result row
|
402 |
+
if i > 0:
|
403 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
404 |
+
if j > 0:
|
405 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
406 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
407 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
408 |
+
|
409 |
+
moments = torch.cat(result_rows, dim=-2)
|
410 |
+
if return_moments:
|
411 |
+
return moments
|
412 |
+
|
413 |
+
posterior = DiagonalGaussianDistribution(moments)
|
414 |
+
if not return_dict:
|
415 |
+
return (posterior,)
|
416 |
+
|
417 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
418 |
+
|
419 |
+
def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
420 |
+
r"""
|
421 |
+
Decode a batch of images/videos using a tiled decoder.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
425 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
426 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
430 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
431 |
+
returned.
|
432 |
+
"""
|
433 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
434 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
435 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
436 |
+
|
437 |
+
# Split z into overlapping tiles and decode them separately.
|
438 |
+
# The tiles have an overlap to avoid seams between tiles.
|
439 |
+
rows = []
|
440 |
+
for i in range(0, z.shape[-2], overlap_size):
|
441 |
+
row = []
|
442 |
+
for j in range(0, z.shape[-1], overlap_size):
|
443 |
+
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
444 |
+
tile = self.post_quant_conv(tile)
|
445 |
+
decoded = self.decoder(tile)
|
446 |
+
row.append(decoded)
|
447 |
+
rows.append(row)
|
448 |
+
result_rows = []
|
449 |
+
for i, row in enumerate(rows):
|
450 |
+
result_row = []
|
451 |
+
for j, tile in enumerate(row):
|
452 |
+
# blend the above tile and the left tile
|
453 |
+
# to the current tile and add the current tile to the result row
|
454 |
+
if i > 0:
|
455 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
456 |
+
if j > 0:
|
457 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
458 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
459 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
460 |
+
|
461 |
+
dec = torch.cat(result_rows, dim=-2)
|
462 |
+
if not return_dict:
|
463 |
+
return (dec,)
|
464 |
+
|
465 |
+
return DecoderOutput(sample=dec)
|
466 |
+
|
467 |
+
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
468 |
+
|
469 |
+
B, C, T, H, W = x.shape
|
470 |
+
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
|
471 |
+
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
|
472 |
+
t_limit = self.tile_latent_min_tsize - blend_extent
|
473 |
+
|
474 |
+
# Split the video into tiles and encode them separately.
|
475 |
+
row = []
|
476 |
+
for i in range(0, T, overlap_size):
|
477 |
+
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
|
478 |
+
if self.use_spatial_tiling and (
|
479 |
+
tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
|
480 |
+
):
|
481 |
+
tile = self.spatial_tiled_encode(tile, return_moments=True)
|
482 |
+
else:
|
483 |
+
tile = self.encoder(tile)
|
484 |
+
tile = self.quant_conv(tile)
|
485 |
+
if i > 0:
|
486 |
+
tile = tile[:, :, 1:, :, :]
|
487 |
+
row.append(tile)
|
488 |
+
result_row = []
|
489 |
+
for i, tile in enumerate(row):
|
490 |
+
if i > 0:
|
491 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
492 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
493 |
+
else:
|
494 |
+
result_row.append(tile[:, :, : t_limit + 1, :, :])
|
495 |
+
|
496 |
+
moments = torch.cat(result_row, dim=2)
|
497 |
+
posterior = DiagonalGaussianDistribution(moments)
|
498 |
+
|
499 |
+
if not return_dict:
|
500 |
+
return (posterior,)
|
501 |
+
|
502 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
503 |
+
|
504 |
+
def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
505 |
+
# Split z into overlapping tiles and decode them separately.
|
506 |
+
|
507 |
+
B, C, T, H, W = z.shape
|
508 |
+
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
|
509 |
+
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
|
510 |
+
t_limit = self.tile_sample_min_tsize - blend_extent
|
511 |
+
|
512 |
+
row = []
|
513 |
+
for i in range(0, T, overlap_size):
|
514 |
+
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
|
515 |
+
if self.use_spatial_tiling and (
|
516 |
+
tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
|
517 |
+
):
|
518 |
+
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
|
519 |
+
else:
|
520 |
+
tile = self.post_quant_conv(tile)
|
521 |
+
decoded = self.decoder(tile)
|
522 |
+
if i > 0:
|
523 |
+
decoded = decoded[:, :, 1:, :, :]
|
524 |
+
row.append(decoded)
|
525 |
+
result_row = []
|
526 |
+
for i, tile in enumerate(row):
|
527 |
+
if i > 0:
|
528 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
529 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
530 |
+
else:
|
531 |
+
result_row.append(tile[:, :, : t_limit + 1, :, :])
|
532 |
+
|
533 |
+
dec = torch.cat(result_row, dim=2)
|
534 |
+
if not return_dict:
|
535 |
+
return (dec,)
|
536 |
+
|
537 |
+
return DecoderOutput(sample=dec)
|
538 |
+
|
539 |
+
def forward(
|
540 |
+
self,
|
541 |
+
sample: torch.FloatTensor,
|
542 |
+
sample_posterior: bool = False,
|
543 |
+
return_dict: bool = True,
|
544 |
+
return_posterior: bool = False,
|
545 |
+
generator: Optional[torch.Generator] = None,
|
546 |
+
) -> Union[DecoderOutput2, torch.FloatTensor]:
|
547 |
+
r"""
|
548 |
+
Args:
|
549 |
+
sample (`torch.FloatTensor`): Input sample.
|
550 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
551 |
+
Whether to sample from the posterior.
|
552 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
553 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
554 |
+
"""
|
555 |
+
x = sample
|
556 |
+
posterior = self.encode(x).latent_dist
|
557 |
+
if sample_posterior:
|
558 |
+
z = posterior.sample(generator=generator)
|
559 |
+
else:
|
560 |
+
z = posterior.mode()
|
561 |
+
dec = self.decode(z).sample
|
562 |
+
|
563 |
+
if not return_dict:
|
564 |
+
if return_posterior:
|
565 |
+
return (dec, posterior)
|
566 |
+
else:
|
567 |
+
return (dec,)
|
568 |
+
if return_posterior:
|
569 |
+
return DecoderOutput2(sample=dec, posterior=posterior)
|
570 |
+
else:
|
571 |
+
return DecoderOutput2(sample=dec)
|
572 |
+
|
573 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
574 |
+
def fuse_qkv_projections(self):
|
575 |
+
"""
|
576 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
577 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
578 |
+
|
579 |
+
<Tip warning={true}>
|
580 |
+
|
581 |
+
This API is 🧪 experimental.
|
582 |
+
|
583 |
+
</Tip>
|
584 |
+
"""
|
585 |
+
self.original_attn_processors = None
|
586 |
+
|
587 |
+
for _, attn_processor in self.attn_processors.items():
|
588 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
589 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
590 |
+
|
591 |
+
self.original_attn_processors = self.attn_processors
|
592 |
+
|
593 |
+
for module in self.modules():
|
594 |
+
if isinstance(module, Attention):
|
595 |
+
module.fuse_projections(fuse=True)
|
596 |
+
|
597 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
598 |
+
def unfuse_qkv_projections(self):
|
599 |
+
"""Disables the fused QKV projection if enabled.
|
600 |
+
|
601 |
+
<Tip warning={true}>
|
602 |
+
|
603 |
+
This API is 🧪 experimental.
|
604 |
+
|
605 |
+
</Tip>
|
606 |
+
|
607 |
+
"""
|
608 |
+
if self.original_attn_processors is not None:
|
609 |
+
self.set_attn_processor(self.original_attn_processors)
|
hunyuan_model/embed_layers.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
from .helpers import to_2tuple
|
8 |
+
|
9 |
+
class PatchEmbed(nn.Module):
|
10 |
+
"""2D Image to Patch Embedding
|
11 |
+
|
12 |
+
Image to Patch Embedding using Conv2d
|
13 |
+
|
14 |
+
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
15 |
+
|
16 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
17 |
+
|
18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
19 |
+
|
20 |
+
Remove the _assert function in forward function to be compatible with multi-resolution images.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
patch_size=16,
|
26 |
+
in_chans=3,
|
27 |
+
embed_dim=768,
|
28 |
+
norm_layer=None,
|
29 |
+
flatten=True,
|
30 |
+
bias=True,
|
31 |
+
dtype=None,
|
32 |
+
device=None,
|
33 |
+
):
|
34 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
35 |
+
super().__init__()
|
36 |
+
patch_size = to_2tuple(patch_size)
|
37 |
+
self.patch_size = patch_size
|
38 |
+
self.flatten = flatten
|
39 |
+
|
40 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
|
41 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
42 |
+
if bias:
|
43 |
+
nn.init.zeros_(self.proj.bias)
|
44 |
+
|
45 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
x = self.proj(x)
|
49 |
+
if self.flatten:
|
50 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
51 |
+
x = self.norm(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class TextProjection(nn.Module):
|
56 |
+
"""
|
57 |
+
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
58 |
+
|
59 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
63 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
64 |
+
super().__init__()
|
65 |
+
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
|
66 |
+
self.act_1 = act_layer()
|
67 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
|
68 |
+
|
69 |
+
def forward(self, caption):
|
70 |
+
hidden_states = self.linear_1(caption)
|
71 |
+
hidden_states = self.act_1(hidden_states)
|
72 |
+
hidden_states = self.linear_2(hidden_states)
|
73 |
+
return hidden_states
|
74 |
+
|
75 |
+
|
76 |
+
def timestep_embedding(t, dim, max_period=10000):
|
77 |
+
"""
|
78 |
+
Create sinusoidal timestep embeddings.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
82 |
+
dim (int): the dimension of the output.
|
83 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
87 |
+
|
88 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
89 |
+
"""
|
90 |
+
half = dim // 2
|
91 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
92 |
+
args = t[:, None].float() * freqs[None]
|
93 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
94 |
+
if dim % 2:
|
95 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
96 |
+
return embedding
|
97 |
+
|
98 |
+
|
99 |
+
class TimestepEmbedder(nn.Module):
|
100 |
+
"""
|
101 |
+
Embeds scalar timesteps into vector representations.
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
hidden_size,
|
107 |
+
act_layer,
|
108 |
+
frequency_embedding_size=256,
|
109 |
+
max_period=10000,
|
110 |
+
out_size=None,
|
111 |
+
dtype=None,
|
112 |
+
device=None,
|
113 |
+
):
|
114 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
115 |
+
super().__init__()
|
116 |
+
self.frequency_embedding_size = frequency_embedding_size
|
117 |
+
self.max_period = max_period
|
118 |
+
if out_size is None:
|
119 |
+
out_size = hidden_size
|
120 |
+
|
121 |
+
self.mlp = nn.Sequential(
|
122 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
123 |
+
act_layer(),
|
124 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
125 |
+
)
|
126 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02)
|
127 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02)
|
128 |
+
|
129 |
+
def forward(self, t):
|
130 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
131 |
+
t_emb = self.mlp(t_freq)
|
132 |
+
return t_emb
|
hunyuan_model/helpers.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
|
3 |
+
from itertools import repeat
|
4 |
+
|
5 |
+
|
6 |
+
def _ntuple(n):
|
7 |
+
def parse(x):
|
8 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
9 |
+
x = tuple(x)
|
10 |
+
if len(x) == 1:
|
11 |
+
x = tuple(repeat(x[0], n))
|
12 |
+
return x
|
13 |
+
return tuple(repeat(x, n))
|
14 |
+
return parse
|
15 |
+
|
16 |
+
|
17 |
+
to_1tuple = _ntuple(1)
|
18 |
+
to_2tuple = _ntuple(2)
|
19 |
+
to_3tuple = _ntuple(3)
|
20 |
+
to_4tuple = _ntuple(4)
|
21 |
+
|
22 |
+
|
23 |
+
def as_tuple(x):
|
24 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
25 |
+
return tuple(x)
|
26 |
+
if x is None or isinstance(x, (int, float, str)):
|
27 |
+
return (x,)
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown type {type(x)}")
|
30 |
+
|
31 |
+
|
32 |
+
def as_list_of_2tuple(x):
|
33 |
+
x = as_tuple(x)
|
34 |
+
if len(x) == 1:
|
35 |
+
x = (x[0], x[0])
|
36 |
+
assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
|
37 |
+
lst = []
|
38 |
+
for i in range(0, len(x), 2):
|
39 |
+
lst.append((x[i], x[i + 1]))
|
40 |
+
return lst
|
hunyuan_model/mlp_layers.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from timm library:
|
2 |
+
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .modulate_layers import modulate
|
10 |
+
from .helpers import to_2tuple
|
11 |
+
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
in_channels,
|
19 |
+
hidden_channels=None,
|
20 |
+
out_features=None,
|
21 |
+
act_layer=nn.GELU,
|
22 |
+
norm_layer=None,
|
23 |
+
bias=True,
|
24 |
+
drop=0.0,
|
25 |
+
use_conv=False,
|
26 |
+
device=None,
|
27 |
+
dtype=None,
|
28 |
+
):
|
29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
30 |
+
super().__init__()
|
31 |
+
out_features = out_features or in_channels
|
32 |
+
hidden_channels = hidden_channels or in_channels
|
33 |
+
bias = to_2tuple(bias)
|
34 |
+
drop_probs = to_2tuple(drop)
|
35 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
36 |
+
|
37 |
+
self.fc1 = linear_layer(
|
38 |
+
in_channels, hidden_channels, bias=bias[0], **factory_kwargs
|
39 |
+
)
|
40 |
+
self.act = act_layer()
|
41 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
42 |
+
self.norm = (
|
43 |
+
norm_layer(hidden_channels, **factory_kwargs)
|
44 |
+
if norm_layer is not None
|
45 |
+
else nn.Identity()
|
46 |
+
)
|
47 |
+
self.fc2 = linear_layer(
|
48 |
+
hidden_channels, out_features, bias=bias[1], **factory_kwargs
|
49 |
+
)
|
50 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
x = self.fc1(x)
|
54 |
+
x = self.act(x)
|
55 |
+
x = self.drop1(x)
|
56 |
+
x = self.norm(x)
|
57 |
+
x = self.fc2(x)
|
58 |
+
x = self.drop2(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
#
|
63 |
+
class MLPEmbedder(nn.Module):
|
64 |
+
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
|
65 |
+
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
66 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
67 |
+
super().__init__()
|
68 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
69 |
+
self.silu = nn.SiLU()
|
70 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
73 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
74 |
+
|
75 |
+
|
76 |
+
class FinalLayer(nn.Module):
|
77 |
+
"""The final layer of DiT."""
|
78 |
+
|
79 |
+
def __init__(
|
80 |
+
self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
|
81 |
+
):
|
82 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
83 |
+
super().__init__()
|
84 |
+
|
85 |
+
# Just use LayerNorm for the final layer
|
86 |
+
self.norm_final = nn.LayerNorm(
|
87 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
88 |
+
)
|
89 |
+
if isinstance(patch_size, int):
|
90 |
+
self.linear = nn.Linear(
|
91 |
+
hidden_size,
|
92 |
+
patch_size * patch_size * out_channels,
|
93 |
+
bias=True,
|
94 |
+
**factory_kwargs
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
self.linear = nn.Linear(
|
98 |
+
hidden_size,
|
99 |
+
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
|
100 |
+
bias=True,
|
101 |
+
)
|
102 |
+
nn.init.zeros_(self.linear.weight)
|
103 |
+
nn.init.zeros_(self.linear.bias)
|
104 |
+
|
105 |
+
# Here we don't distinguish between the modulate types. Just use the simple one.
|
106 |
+
self.adaLN_modulation = nn.Sequential(
|
107 |
+
act_layer(),
|
108 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
109 |
+
)
|
110 |
+
# Zero-initialize the modulation
|
111 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
112 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
113 |
+
|
114 |
+
def forward(self, x, c):
|
115 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
116 |
+
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
117 |
+
x = self.linear(x)
|
118 |
+
return x
|
hunyuan_model/models.py
ADDED
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, List, Tuple, Optional, Union, Dict
|
3 |
+
import accelerate
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.utils.checkpoint import checkpoint
|
9 |
+
|
10 |
+
from .activation_layers import get_activation_layer
|
11 |
+
from .norm_layers import get_norm_layer
|
12 |
+
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
|
13 |
+
from .attention import attention, parallel_attention, get_cu_seqlens
|
14 |
+
from .posemb_layers import apply_rotary_emb
|
15 |
+
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
|
16 |
+
from .modulate_layers import ModulateDiT, modulate, apply_gate
|
17 |
+
from .token_refiner import SingleTokenRefiner
|
18 |
+
from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
|
19 |
+
from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
|
20 |
+
|
21 |
+
from utils.safetensors_utils import MemoryEfficientSafeOpen
|
22 |
+
|
23 |
+
|
24 |
+
class MMDoubleStreamBlock(nn.Module):
|
25 |
+
"""
|
26 |
+
A multimodal dit block with seperate modulation for
|
27 |
+
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
|
28 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
hidden_size: int,
|
34 |
+
heads_num: int,
|
35 |
+
mlp_width_ratio: float,
|
36 |
+
mlp_act_type: str = "gelu_tanh",
|
37 |
+
qk_norm: bool = True,
|
38 |
+
qk_norm_type: str = "rms",
|
39 |
+
qkv_bias: bool = False,
|
40 |
+
dtype: Optional[torch.dtype] = None,
|
41 |
+
device: Optional[torch.device] = None,
|
42 |
+
attn_mode: str = "flash",
|
43 |
+
split_attn: bool = False,
|
44 |
+
):
|
45 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
46 |
+
super().__init__()
|
47 |
+
self.attn_mode = attn_mode
|
48 |
+
self.split_attn = split_attn
|
49 |
+
|
50 |
+
self.deterministic = False
|
51 |
+
self.heads_num = heads_num
|
52 |
+
head_dim = hidden_size // heads_num
|
53 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
54 |
+
|
55 |
+
self.img_mod = ModulateDiT(
|
56 |
+
hidden_size,
|
57 |
+
factor=6,
|
58 |
+
act_layer=get_activation_layer("silu"),
|
59 |
+
**factory_kwargs,
|
60 |
+
)
|
61 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
62 |
+
|
63 |
+
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
64 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
65 |
+
self.img_attn_q_norm = (
|
66 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
67 |
+
)
|
68 |
+
self.img_attn_k_norm = (
|
69 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
70 |
+
)
|
71 |
+
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
72 |
+
|
73 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
74 |
+
self.img_mlp = MLP(
|
75 |
+
hidden_size,
|
76 |
+
mlp_hidden_dim,
|
77 |
+
act_layer=get_activation_layer(mlp_act_type),
|
78 |
+
bias=True,
|
79 |
+
**factory_kwargs,
|
80 |
+
)
|
81 |
+
|
82 |
+
self.txt_mod = ModulateDiT(
|
83 |
+
hidden_size,
|
84 |
+
factor=6,
|
85 |
+
act_layer=get_activation_layer("silu"),
|
86 |
+
**factory_kwargs,
|
87 |
+
)
|
88 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
89 |
+
|
90 |
+
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
91 |
+
self.txt_attn_q_norm = (
|
92 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
93 |
+
)
|
94 |
+
self.txt_attn_k_norm = (
|
95 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
96 |
+
)
|
97 |
+
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
98 |
+
|
99 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
100 |
+
self.txt_mlp = MLP(
|
101 |
+
hidden_size,
|
102 |
+
mlp_hidden_dim,
|
103 |
+
act_layer=get_activation_layer(mlp_act_type),
|
104 |
+
bias=True,
|
105 |
+
**factory_kwargs,
|
106 |
+
)
|
107 |
+
self.hybrid_seq_parallel_attn = None
|
108 |
+
|
109 |
+
self.gradient_checkpointing = False
|
110 |
+
|
111 |
+
def enable_deterministic(self):
|
112 |
+
self.deterministic = True
|
113 |
+
|
114 |
+
def disable_deterministic(self):
|
115 |
+
self.deterministic = False
|
116 |
+
|
117 |
+
def enable_gradient_checkpointing(self):
|
118 |
+
self.gradient_checkpointing = True
|
119 |
+
|
120 |
+
def disable_gradient_checkpointing(self):
|
121 |
+
self.gradient_checkpointing = False
|
122 |
+
|
123 |
+
def _forward(
|
124 |
+
self,
|
125 |
+
img: torch.Tensor,
|
126 |
+
txt: torch.Tensor,
|
127 |
+
vec: torch.Tensor,
|
128 |
+
attn_mask: Optional[torch.Tensor] = None,
|
129 |
+
total_len: Optional[torch.Tensor] = None,
|
130 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
131 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
132 |
+
max_seqlen_q: Optional[int] = None,
|
133 |
+
max_seqlen_kv: Optional[int] = None,
|
134 |
+
freqs_cis: tuple = None,
|
135 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
136 |
+
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
|
137 |
+
6, dim=-1
|
138 |
+
)
|
139 |
+
(txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
|
140 |
+
6, dim=-1
|
141 |
+
)
|
142 |
+
|
143 |
+
# Prepare image for attention.
|
144 |
+
img_modulated = self.img_norm1(img)
|
145 |
+
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
146 |
+
img_qkv = self.img_attn_qkv(img_modulated)
|
147 |
+
img_modulated = None
|
148 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
149 |
+
img_qkv = None
|
150 |
+
# Apply QK-Norm if needed
|
151 |
+
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
152 |
+
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
153 |
+
|
154 |
+
# Apply RoPE if needed.
|
155 |
+
if freqs_cis is not None:
|
156 |
+
img_q_shape = img_q.shape
|
157 |
+
img_k_shape = img_k.shape
|
158 |
+
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
159 |
+
assert (
|
160 |
+
img_q.shape == img_q_shape and img_k.shape == img_k_shape
|
161 |
+
), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
|
162 |
+
# img_q, img_k = img_qq, img_kk
|
163 |
+
|
164 |
+
# Prepare txt for attention.
|
165 |
+
txt_modulated = self.txt_norm1(txt)
|
166 |
+
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
167 |
+
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
168 |
+
txt_modulated = None
|
169 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
170 |
+
txt_qkv = None
|
171 |
+
# Apply QK-Norm if needed.
|
172 |
+
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
173 |
+
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
174 |
+
|
175 |
+
# Run actual attention.
|
176 |
+
img_q_len = img_q.shape[1]
|
177 |
+
img_kv_len = img_k.shape[1]
|
178 |
+
batch_size = img_k.shape[0]
|
179 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
180 |
+
img_q = txt_q = None
|
181 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
182 |
+
img_k = txt_k = None
|
183 |
+
v = torch.cat((img_v, txt_v), dim=1)
|
184 |
+
img_v = txt_v = None
|
185 |
+
|
186 |
+
assert (
|
187 |
+
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
|
188 |
+
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
|
189 |
+
|
190 |
+
# attention computation start
|
191 |
+
if not self.hybrid_seq_parallel_attn:
|
192 |
+
l = [q, k, v]
|
193 |
+
q = k = v = None
|
194 |
+
attn = attention(
|
195 |
+
l,
|
196 |
+
mode=self.attn_mode,
|
197 |
+
attn_mask=attn_mask,
|
198 |
+
total_len=total_len,
|
199 |
+
cu_seqlens_q=cu_seqlens_q,
|
200 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
201 |
+
max_seqlen_q=max_seqlen_q,
|
202 |
+
max_seqlen_kv=max_seqlen_kv,
|
203 |
+
batch_size=batch_size,
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
attn = parallel_attention(
|
207 |
+
self.hybrid_seq_parallel_attn,
|
208 |
+
q,
|
209 |
+
k,
|
210 |
+
v,
|
211 |
+
img_q_len=img_q_len,
|
212 |
+
img_kv_len=img_kv_len,
|
213 |
+
cu_seqlens_q=cu_seqlens_q,
|
214 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
215 |
+
)
|
216 |
+
|
217 |
+
# attention computation end
|
218 |
+
|
219 |
+
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
|
220 |
+
attn = None
|
221 |
+
|
222 |
+
# Calculate the img bloks.
|
223 |
+
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
224 |
+
img_attn = None
|
225 |
+
img = img + apply_gate(
|
226 |
+
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
|
227 |
+
gate=img_mod2_gate,
|
228 |
+
)
|
229 |
+
|
230 |
+
# Calculate the txt bloks.
|
231 |
+
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
232 |
+
txt_attn = None
|
233 |
+
txt = txt + apply_gate(
|
234 |
+
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
|
235 |
+
gate=txt_mod2_gate,
|
236 |
+
)
|
237 |
+
|
238 |
+
return img, txt
|
239 |
+
|
240 |
+
# def forward(
|
241 |
+
# self,
|
242 |
+
# img: torch.Tensor,
|
243 |
+
# txt: torch.Tensor,
|
244 |
+
# vec: torch.Tensor,
|
245 |
+
# attn_mask: Optional[torch.Tensor] = None,
|
246 |
+
# cu_seqlens_q: Optional[torch.Tensor] = None,
|
247 |
+
# cu_seqlens_kv: Optional[torch.Tensor] = None,
|
248 |
+
# max_seqlen_q: Optional[int] = None,
|
249 |
+
# max_seqlen_kv: Optional[int] = None,
|
250 |
+
# freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
251 |
+
# ) -> Tuple[torch.Tensor, torch.Tensor]:
|
252 |
+
def forward(self, *args, **kwargs):
|
253 |
+
if self.training and self.gradient_checkpointing:
|
254 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
255 |
+
else:
|
256 |
+
return self._forward(*args, **kwargs)
|
257 |
+
|
258 |
+
|
259 |
+
class MMSingleStreamBlock(nn.Module):
|
260 |
+
"""
|
261 |
+
A DiT block with parallel linear layers as described in
|
262 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
263 |
+
Also refer to (SD3): https://arxiv.org/abs/2403.03206
|
264 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
hidden_size: int,
|
270 |
+
heads_num: int,
|
271 |
+
mlp_width_ratio: float = 4.0,
|
272 |
+
mlp_act_type: str = "gelu_tanh",
|
273 |
+
qk_norm: bool = True,
|
274 |
+
qk_norm_type: str = "rms",
|
275 |
+
qk_scale: float = None,
|
276 |
+
dtype: Optional[torch.dtype] = None,
|
277 |
+
device: Optional[torch.device] = None,
|
278 |
+
attn_mode: str = "flash",
|
279 |
+
split_attn: bool = False,
|
280 |
+
):
|
281 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
282 |
+
super().__init__()
|
283 |
+
self.attn_mode = attn_mode
|
284 |
+
self.split_attn = split_attn
|
285 |
+
|
286 |
+
self.deterministic = False
|
287 |
+
self.hidden_size = hidden_size
|
288 |
+
self.heads_num = heads_num
|
289 |
+
head_dim = hidden_size // heads_num
|
290 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
291 |
+
self.mlp_hidden_dim = mlp_hidden_dim
|
292 |
+
self.scale = qk_scale or head_dim**-0.5
|
293 |
+
|
294 |
+
# qkv and mlp_in
|
295 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
|
296 |
+
# proj and mlp_out
|
297 |
+
self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
|
298 |
+
|
299 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
300 |
+
self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
301 |
+
self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
302 |
+
|
303 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
304 |
+
|
305 |
+
self.mlp_act = get_activation_layer(mlp_act_type)()
|
306 |
+
self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
307 |
+
self.hybrid_seq_parallel_attn = None
|
308 |
+
|
309 |
+
self.gradient_checkpointing = False
|
310 |
+
|
311 |
+
def enable_deterministic(self):
|
312 |
+
self.deterministic = True
|
313 |
+
|
314 |
+
def disable_deterministic(self):
|
315 |
+
self.deterministic = False
|
316 |
+
|
317 |
+
def enable_gradient_checkpointing(self):
|
318 |
+
self.gradient_checkpointing = True
|
319 |
+
|
320 |
+
def disable_gradient_checkpointing(self):
|
321 |
+
self.gradient_checkpointing = False
|
322 |
+
|
323 |
+
def _forward(
|
324 |
+
self,
|
325 |
+
x: torch.Tensor,
|
326 |
+
vec: torch.Tensor,
|
327 |
+
txt_len: int,
|
328 |
+
attn_mask: Optional[torch.Tensor] = None,
|
329 |
+
total_len: Optional[torch.Tensor] = None,
|
330 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
331 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
332 |
+
max_seqlen_q: Optional[int] = None,
|
333 |
+
max_seqlen_kv: Optional[int] = None,
|
334 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
335 |
+
) -> torch.Tensor:
|
336 |
+
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
337 |
+
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
338 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
339 |
+
x_mod = None
|
340 |
+
# mlp = mlp.to("cpu", non_blocking=True)
|
341 |
+
# clean_memory_on_device(x.device)
|
342 |
+
|
343 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
344 |
+
qkv = None
|
345 |
+
|
346 |
+
# Apply QK-Norm if needed.
|
347 |
+
q = self.q_norm(q).to(v)
|
348 |
+
k = self.k_norm(k).to(v)
|
349 |
+
|
350 |
+
# Apply RoPE if needed.
|
351 |
+
if freqs_cis is not None:
|
352 |
+
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
353 |
+
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
354 |
+
q = k = None
|
355 |
+
img_q_shape = img_q.shape
|
356 |
+
img_k_shape = img_k.shape
|
357 |
+
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
358 |
+
assert (
|
359 |
+
img_q.shape == img_q_shape and img_k_shape == img_k.shape
|
360 |
+
), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
|
361 |
+
# img_q, img_k = img_qq, img_kk
|
362 |
+
# del img_qq, img_kk
|
363 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
364 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
365 |
+
del img_q, txt_q, img_k, txt_k
|
366 |
+
|
367 |
+
# Compute attention.
|
368 |
+
assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
|
369 |
+
|
370 |
+
# attention computation start
|
371 |
+
if not self.hybrid_seq_parallel_attn:
|
372 |
+
l = [q, k, v]
|
373 |
+
q = k = v = None
|
374 |
+
attn = attention(
|
375 |
+
l,
|
376 |
+
mode=self.attn_mode,
|
377 |
+
attn_mask=attn_mask,
|
378 |
+
total_len=total_len,
|
379 |
+
cu_seqlens_q=cu_seqlens_q,
|
380 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
381 |
+
max_seqlen_q=max_seqlen_q,
|
382 |
+
max_seqlen_kv=max_seqlen_kv,
|
383 |
+
batch_size=x.shape[0],
|
384 |
+
)
|
385 |
+
else:
|
386 |
+
attn = parallel_attention(
|
387 |
+
self.hybrid_seq_parallel_attn,
|
388 |
+
q,
|
389 |
+
k,
|
390 |
+
v,
|
391 |
+
img_q_len=img_q.shape[1],
|
392 |
+
img_kv_len=img_k.shape[1],
|
393 |
+
cu_seqlens_q=cu_seqlens_q,
|
394 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
395 |
+
)
|
396 |
+
# attention computation end
|
397 |
+
|
398 |
+
# Compute activation in mlp stream, cat again and run second linear layer.
|
399 |
+
# mlp = mlp.to(x.device)
|
400 |
+
mlp = self.mlp_act(mlp)
|
401 |
+
attn_mlp = torch.cat((attn, mlp), 2)
|
402 |
+
attn = None
|
403 |
+
mlp = None
|
404 |
+
output = self.linear2(attn_mlp)
|
405 |
+
attn_mlp = None
|
406 |
+
return x + apply_gate(output, gate=mod_gate)
|
407 |
+
|
408 |
+
# def forward(
|
409 |
+
# self,
|
410 |
+
# x: torch.Tensor,
|
411 |
+
# vec: torch.Tensor,
|
412 |
+
# txt_len: int,
|
413 |
+
# attn_mask: Optional[torch.Tensor] = None,
|
414 |
+
# cu_seqlens_q: Optional[torch.Tensor] = None,
|
415 |
+
# cu_seqlens_kv: Optional[torch.Tensor] = None,
|
416 |
+
# max_seqlen_q: Optional[int] = None,
|
417 |
+
# max_seqlen_kv: Optional[int] = None,
|
418 |
+
# freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
419 |
+
# ) -> torch.Tensor:
|
420 |
+
def forward(self, *args, **kwargs):
|
421 |
+
if self.training and self.gradient_checkpointing:
|
422 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
423 |
+
else:
|
424 |
+
return self._forward(*args, **kwargs)
|
425 |
+
|
426 |
+
|
427 |
+
class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
|
428 |
+
"""
|
429 |
+
HunyuanVideo Transformer backbone
|
430 |
+
|
431 |
+
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
|
432 |
+
|
433 |
+
Reference:
|
434 |
+
[1] Flux.1: https://github.com/black-forest-labs/flux
|
435 |
+
[2] MMDiT: http://arxiv.org/abs/2403.03206
|
436 |
+
|
437 |
+
Parameters
|
438 |
+
----------
|
439 |
+
args: argparse.Namespace
|
440 |
+
The arguments parsed by argparse.
|
441 |
+
patch_size: list
|
442 |
+
The size of the patch.
|
443 |
+
in_channels: int
|
444 |
+
The number of input channels.
|
445 |
+
out_channels: int
|
446 |
+
The number of output channels.
|
447 |
+
hidden_size: int
|
448 |
+
The hidden size of the transformer backbone.
|
449 |
+
heads_num: int
|
450 |
+
The number of attention heads.
|
451 |
+
mlp_width_ratio: float
|
452 |
+
The ratio of the hidden size of the MLP in the transformer block.
|
453 |
+
mlp_act_type: str
|
454 |
+
The activation function of the MLP in the transformer block.
|
455 |
+
depth_double_blocks: int
|
456 |
+
The number of transformer blocks in the double blocks.
|
457 |
+
depth_single_blocks: int
|
458 |
+
The number of transformer blocks in the single blocks.
|
459 |
+
rope_dim_list: list
|
460 |
+
The dimension of the rotary embedding for t, h, w.
|
461 |
+
qkv_bias: bool
|
462 |
+
Whether to use bias in the qkv linear layer.
|
463 |
+
qk_norm: bool
|
464 |
+
Whether to use qk norm.
|
465 |
+
qk_norm_type: str
|
466 |
+
The type of qk norm.
|
467 |
+
guidance_embed: bool
|
468 |
+
Whether to use guidance embedding for distillation.
|
469 |
+
text_projection: str
|
470 |
+
The type of the text projection, default is single_refiner.
|
471 |
+
use_attention_mask: bool
|
472 |
+
Whether to use attention mask for text encoder.
|
473 |
+
dtype: torch.dtype
|
474 |
+
The dtype of the model.
|
475 |
+
device: torch.device
|
476 |
+
The device of the model.
|
477 |
+
attn_mode: str
|
478 |
+
The mode of the attention, default is flash.
|
479 |
+
split_attn: bool
|
480 |
+
Whether to use split attention (make attention as batch size 1).
|
481 |
+
"""
|
482 |
+
|
483 |
+
# @register_to_config
|
484 |
+
def __init__(
|
485 |
+
self,
|
486 |
+
text_states_dim: int,
|
487 |
+
text_states_dim_2: int,
|
488 |
+
patch_size: list = [1, 2, 2],
|
489 |
+
in_channels: int = 4, # Should be VAE.config.latent_channels.
|
490 |
+
out_channels: int = None,
|
491 |
+
hidden_size: int = 3072,
|
492 |
+
heads_num: int = 24,
|
493 |
+
mlp_width_ratio: float = 4.0,
|
494 |
+
mlp_act_type: str = "gelu_tanh",
|
495 |
+
mm_double_blocks_depth: int = 20,
|
496 |
+
mm_single_blocks_depth: int = 40,
|
497 |
+
rope_dim_list: List[int] = [16, 56, 56],
|
498 |
+
qkv_bias: bool = True,
|
499 |
+
qk_norm: bool = True,
|
500 |
+
qk_norm_type: str = "rms",
|
501 |
+
guidance_embed: bool = False, # For modulation.
|
502 |
+
text_projection: str = "single_refiner",
|
503 |
+
use_attention_mask: bool = True,
|
504 |
+
dtype: Optional[torch.dtype] = None,
|
505 |
+
device: Optional[torch.device] = None,
|
506 |
+
attn_mode: str = "flash",
|
507 |
+
split_attn: bool = False,
|
508 |
+
):
|
509 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
510 |
+
super().__init__()
|
511 |
+
|
512 |
+
self.patch_size = patch_size
|
513 |
+
self.in_channels = in_channels
|
514 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
515 |
+
self.unpatchify_channels = self.out_channels
|
516 |
+
self.guidance_embed = guidance_embed
|
517 |
+
self.rope_dim_list = rope_dim_list
|
518 |
+
|
519 |
+
# Text projection. Default to linear projection.
|
520 |
+
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
|
521 |
+
self.use_attention_mask = use_attention_mask
|
522 |
+
self.text_projection = text_projection
|
523 |
+
|
524 |
+
self.text_states_dim = text_states_dim
|
525 |
+
self.text_states_dim_2 = text_states_dim_2
|
526 |
+
|
527 |
+
if hidden_size % heads_num != 0:
|
528 |
+
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
|
529 |
+
pe_dim = hidden_size // heads_num
|
530 |
+
if sum(rope_dim_list) != pe_dim:
|
531 |
+
raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
|
532 |
+
self.hidden_size = hidden_size
|
533 |
+
self.heads_num = heads_num
|
534 |
+
|
535 |
+
self.attn_mode = attn_mode
|
536 |
+
self.split_attn = split_attn
|
537 |
+
print(f"Using {self.attn_mode} attention mode, split_attn: {self.split_attn}")
|
538 |
+
|
539 |
+
# image projection
|
540 |
+
self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
|
541 |
+
|
542 |
+
# text projection
|
543 |
+
if self.text_projection == "linear":
|
544 |
+
self.txt_in = TextProjection(
|
545 |
+
self.text_states_dim,
|
546 |
+
self.hidden_size,
|
547 |
+
get_activation_layer("silu"),
|
548 |
+
**factory_kwargs,
|
549 |
+
)
|
550 |
+
elif self.text_projection == "single_refiner":
|
551 |
+
self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
|
552 |
+
else:
|
553 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
554 |
+
|
555 |
+
# time modulation
|
556 |
+
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
|
557 |
+
|
558 |
+
# text modulation
|
559 |
+
self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
|
560 |
+
|
561 |
+
# guidance modulation
|
562 |
+
self.guidance_in = (
|
563 |
+
TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
|
564 |
+
)
|
565 |
+
|
566 |
+
# double blocks
|
567 |
+
self.double_blocks = nn.ModuleList(
|
568 |
+
[
|
569 |
+
MMDoubleStreamBlock(
|
570 |
+
self.hidden_size,
|
571 |
+
self.heads_num,
|
572 |
+
mlp_width_ratio=mlp_width_ratio,
|
573 |
+
mlp_act_type=mlp_act_type,
|
574 |
+
qk_norm=qk_norm,
|
575 |
+
qk_norm_type=qk_norm_type,
|
576 |
+
qkv_bias=qkv_bias,
|
577 |
+
attn_mode=attn_mode,
|
578 |
+
split_attn=split_attn,
|
579 |
+
**factory_kwargs,
|
580 |
+
)
|
581 |
+
for _ in range(mm_double_blocks_depth)
|
582 |
+
]
|
583 |
+
)
|
584 |
+
|
585 |
+
# single blocks
|
586 |
+
self.single_blocks = nn.ModuleList(
|
587 |
+
[
|
588 |
+
MMSingleStreamBlock(
|
589 |
+
self.hidden_size,
|
590 |
+
self.heads_num,
|
591 |
+
mlp_width_ratio=mlp_width_ratio,
|
592 |
+
mlp_act_type=mlp_act_type,
|
593 |
+
qk_norm=qk_norm,
|
594 |
+
qk_norm_type=qk_norm_type,
|
595 |
+
attn_mode=attn_mode,
|
596 |
+
split_attn=split_attn,
|
597 |
+
**factory_kwargs,
|
598 |
+
)
|
599 |
+
for _ in range(mm_single_blocks_depth)
|
600 |
+
]
|
601 |
+
)
|
602 |
+
|
603 |
+
self.final_layer = FinalLayer(
|
604 |
+
self.hidden_size,
|
605 |
+
self.patch_size,
|
606 |
+
self.out_channels,
|
607 |
+
get_activation_layer("silu"),
|
608 |
+
**factory_kwargs,
|
609 |
+
)
|
610 |
+
|
611 |
+
self.gradient_checkpointing = False
|
612 |
+
self.blocks_to_swap = None
|
613 |
+
self.offloader_double = None
|
614 |
+
self.offloader_single = None
|
615 |
+
self._enable_img_in_txt_in_offloading = False
|
616 |
+
|
617 |
+
@property
|
618 |
+
def device(self):
|
619 |
+
return next(self.parameters()).device
|
620 |
+
|
621 |
+
@property
|
622 |
+
def dtype(self):
|
623 |
+
return next(self.parameters()).dtype
|
624 |
+
|
625 |
+
def enable_gradient_checkpointing(self):
|
626 |
+
self.gradient_checkpointing = True
|
627 |
+
|
628 |
+
self.txt_in.enable_gradient_checkpointing()
|
629 |
+
|
630 |
+
for block in self.double_blocks + self.single_blocks:
|
631 |
+
block.enable_gradient_checkpointing()
|
632 |
+
|
633 |
+
print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
|
634 |
+
|
635 |
+
def disable_gradient_checkpointing(self):
|
636 |
+
self.gradient_checkpointing = False
|
637 |
+
|
638 |
+
self.txt_in.disable_gradient_checkpointing()
|
639 |
+
|
640 |
+
for block in self.double_blocks + self.single_blocks:
|
641 |
+
block.disable_gradient_checkpointing()
|
642 |
+
|
643 |
+
print(f"HYVideoDiffusionTransformer: Gradient checkpointing disabled.")
|
644 |
+
|
645 |
+
def enable_img_in_txt_in_offloading(self):
|
646 |
+
self._enable_img_in_txt_in_offloading = True
|
647 |
+
|
648 |
+
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
|
649 |
+
self.blocks_to_swap = num_blocks
|
650 |
+
self.num_double_blocks = len(self.double_blocks)
|
651 |
+
self.num_single_blocks = len(self.single_blocks)
|
652 |
+
double_blocks_to_swap = num_blocks // 2
|
653 |
+
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
|
654 |
+
|
655 |
+
assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
|
656 |
+
f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
|
657 |
+
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
|
658 |
+
)
|
659 |
+
|
660 |
+
self.offloader_double = ModelOffloader(
|
661 |
+
"double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
|
662 |
+
)
|
663 |
+
self.offloader_single = ModelOffloader(
|
664 |
+
"single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
|
665 |
+
)
|
666 |
+
print(
|
667 |
+
f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
668 |
+
)
|
669 |
+
|
670 |
+
def switch_block_swap_for_inference(self):
|
671 |
+
if self.blocks_to_swap:
|
672 |
+
self.offloader_double.set_forward_only(True)
|
673 |
+
self.offloader_single.set_forward_only(True)
|
674 |
+
self.prepare_block_swap_before_forward()
|
675 |
+
print(f"HYVideoDiffusionTransformer: Block swap set to forward only.")
|
676 |
+
|
677 |
+
def switch_block_swap_for_training(self):
|
678 |
+
if self.blocks_to_swap:
|
679 |
+
self.offloader_double.set_forward_only(False)
|
680 |
+
self.offloader_single.set_forward_only(False)
|
681 |
+
self.prepare_block_swap_before_forward()
|
682 |
+
print(f"HYVideoDiffusionTransformer: Block swap set to forward and backward.")
|
683 |
+
|
684 |
+
def move_to_device_except_swap_blocks(self, device: torch.device):
|
685 |
+
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
686 |
+
if self.blocks_to_swap:
|
687 |
+
save_double_blocks = self.double_blocks
|
688 |
+
save_single_blocks = self.single_blocks
|
689 |
+
self.double_blocks = None
|
690 |
+
self.single_blocks = None
|
691 |
+
|
692 |
+
self.to(device)
|
693 |
+
|
694 |
+
if self.blocks_to_swap:
|
695 |
+
self.double_blocks = save_double_blocks
|
696 |
+
self.single_blocks = save_single_blocks
|
697 |
+
|
698 |
+
def prepare_block_swap_before_forward(self):
|
699 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
700 |
+
return
|
701 |
+
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
702 |
+
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
703 |
+
|
704 |
+
def enable_deterministic(self):
|
705 |
+
for block in self.double_blocks:
|
706 |
+
block.enable_deterministic()
|
707 |
+
for block in self.single_blocks:
|
708 |
+
block.enable_deterministic()
|
709 |
+
|
710 |
+
def disable_deterministic(self):
|
711 |
+
for block in self.double_blocks:
|
712 |
+
block.disable_deterministic()
|
713 |
+
for block in self.single_blocks:
|
714 |
+
block.disable_deterministic()
|
715 |
+
|
716 |
+
def forward(
|
717 |
+
self,
|
718 |
+
x: torch.Tensor,
|
719 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
720 |
+
text_states: torch.Tensor = None,
|
721 |
+
text_mask: torch.Tensor = None, # Now we don't use it.
|
722 |
+
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
|
723 |
+
freqs_cos: Optional[torch.Tensor] = None,
|
724 |
+
freqs_sin: Optional[torch.Tensor] = None,
|
725 |
+
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
|
726 |
+
return_dict: bool = True,
|
727 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
728 |
+
out = {}
|
729 |
+
img = x
|
730 |
+
txt = text_states
|
731 |
+
_, _, ot, oh, ow = x.shape
|
732 |
+
tt, th, tw = (
|
733 |
+
ot // self.patch_size[0],
|
734 |
+
oh // self.patch_size[1],
|
735 |
+
ow // self.patch_size[2],
|
736 |
+
)
|
737 |
+
|
738 |
+
# Prepare modulation vectors.
|
739 |
+
vec = self.time_in(t)
|
740 |
+
|
741 |
+
# text modulation
|
742 |
+
vec = vec + self.vector_in(text_states_2)
|
743 |
+
|
744 |
+
# guidance modulation
|
745 |
+
if self.guidance_embed:
|
746 |
+
if guidance is None:
|
747 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
748 |
+
|
749 |
+
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
|
750 |
+
vec = vec + self.guidance_in(guidance)
|
751 |
+
|
752 |
+
# Embed image and text.
|
753 |
+
if self._enable_img_in_txt_in_offloading:
|
754 |
+
self.img_in.to(x.device, non_blocking=True)
|
755 |
+
self.txt_in.to(x.device, non_blocking=True)
|
756 |
+
synchronize_device(x.device)
|
757 |
+
|
758 |
+
img = self.img_in(img)
|
759 |
+
if self.text_projection == "linear":
|
760 |
+
txt = self.txt_in(txt)
|
761 |
+
elif self.text_projection == "single_refiner":
|
762 |
+
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
|
763 |
+
else:
|
764 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
765 |
+
|
766 |
+
if self._enable_img_in_txt_in_offloading:
|
767 |
+
self.img_in.to(torch.device("cpu"), non_blocking=True)
|
768 |
+
self.txt_in.to(torch.device("cpu"), non_blocking=True)
|
769 |
+
synchronize_device(x.device)
|
770 |
+
clean_memory_on_device(x.device)
|
771 |
+
|
772 |
+
txt_seq_len = txt.shape[1]
|
773 |
+
img_seq_len = img.shape[1]
|
774 |
+
|
775 |
+
# Compute cu_squlens and max_seqlen for flash attention
|
776 |
+
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
|
777 |
+
cu_seqlens_kv = cu_seqlens_q
|
778 |
+
max_seqlen_q = img_seq_len + txt_seq_len
|
779 |
+
max_seqlen_kv = max_seqlen_q
|
780 |
+
|
781 |
+
attn_mask = total_len = None
|
782 |
+
if self.split_attn or self.attn_mode == "torch":
|
783 |
+
# calculate text length and total length
|
784 |
+
text_len = text_mask.sum(dim=1) # (bs, )
|
785 |
+
total_len = img_seq_len + text_len # (bs, )
|
786 |
+
if self.attn_mode == "torch" and not self.split_attn:
|
787 |
+
# initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
|
788 |
+
bs = img.shape[0]
|
789 |
+
attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
|
790 |
+
|
791 |
+
# set attention mask with total_len
|
792 |
+
for i in range(bs):
|
793 |
+
attn_mask[i, :, : total_len[i], : total_len[i]] = True
|
794 |
+
total_len = None # means we don't use split_attn
|
795 |
+
|
796 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
797 |
+
# --------------------- Pass through DiT blocks ------------------------
|
798 |
+
for block_idx, block in enumerate(self.double_blocks):
|
799 |
+
double_block_args = [
|
800 |
+
img,
|
801 |
+
txt,
|
802 |
+
vec,
|
803 |
+
attn_mask,
|
804 |
+
total_len,
|
805 |
+
cu_seqlens_q,
|
806 |
+
cu_seqlens_kv,
|
807 |
+
max_seqlen_q,
|
808 |
+
max_seqlen_kv,
|
809 |
+
freqs_cis,
|
810 |
+
]
|
811 |
+
|
812 |
+
if self.blocks_to_swap:
|
813 |
+
self.offloader_double.wait_for_block(block_idx)
|
814 |
+
|
815 |
+
img, txt = block(*double_block_args)
|
816 |
+
|
817 |
+
if self.blocks_to_swap:
|
818 |
+
self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
|
819 |
+
|
820 |
+
# Merge txt and img to pass through single stream blocks.
|
821 |
+
x = torch.cat((img, txt), 1)
|
822 |
+
if self.blocks_to_swap:
|
823 |
+
# delete img, txt to reduce memory usage
|
824 |
+
del img, txt
|
825 |
+
clean_memory_on_device(x.device)
|
826 |
+
|
827 |
+
if len(self.single_blocks) > 0:
|
828 |
+
for block_idx, block in enumerate(self.single_blocks):
|
829 |
+
single_block_args = [
|
830 |
+
x,
|
831 |
+
vec,
|
832 |
+
txt_seq_len,
|
833 |
+
attn_mask,
|
834 |
+
total_len,
|
835 |
+
cu_seqlens_q,
|
836 |
+
cu_seqlens_kv,
|
837 |
+
max_seqlen_q,
|
838 |
+
max_seqlen_kv,
|
839 |
+
freqs_cis,
|
840 |
+
]
|
841 |
+
if self.blocks_to_swap:
|
842 |
+
self.offloader_single.wait_for_block(block_idx)
|
843 |
+
|
844 |
+
x = block(*single_block_args)
|
845 |
+
|
846 |
+
if self.blocks_to_swap:
|
847 |
+
self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
|
848 |
+
|
849 |
+
img = x[:, :img_seq_len, ...]
|
850 |
+
x = None
|
851 |
+
|
852 |
+
# ---------------------------- Final layer ------------------------------
|
853 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
854 |
+
|
855 |
+
img = self.unpatchify(img, tt, th, tw)
|
856 |
+
if return_dict:
|
857 |
+
out["x"] = img
|
858 |
+
return out
|
859 |
+
return img
|
860 |
+
|
861 |
+
def unpatchify(self, x, t, h, w):
|
862 |
+
"""
|
863 |
+
x: (N, T, patch_size**2 * C)
|
864 |
+
imgs: (N, H, W, C)
|
865 |
+
"""
|
866 |
+
c = self.unpatchify_channels
|
867 |
+
pt, ph, pw = self.patch_size
|
868 |
+
assert t * h * w == x.shape[1]
|
869 |
+
|
870 |
+
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
|
871 |
+
x = torch.einsum("nthwcopq->nctohpwq", x)
|
872 |
+
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
873 |
+
|
874 |
+
return imgs
|
875 |
+
|
876 |
+
def params_count(self):
|
877 |
+
counts = {
|
878 |
+
"double": sum(
|
879 |
+
[
|
880 |
+
sum(p.numel() for p in block.img_attn_qkv.parameters())
|
881 |
+
+ sum(p.numel() for p in block.img_attn_proj.parameters())
|
882 |
+
+ sum(p.numel() for p in block.img_mlp.parameters())
|
883 |
+
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
|
884 |
+
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
|
885 |
+
+ sum(p.numel() for p in block.txt_mlp.parameters())
|
886 |
+
for block in self.double_blocks
|
887 |
+
]
|
888 |
+
),
|
889 |
+
"single": sum(
|
890 |
+
[
|
891 |
+
sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
|
892 |
+
for block in self.single_blocks
|
893 |
+
]
|
894 |
+
),
|
895 |
+
"total": sum(p.numel() for p in self.parameters()),
|
896 |
+
}
|
897 |
+
counts["attn+mlp"] = counts["double"] + counts["single"]
|
898 |
+
return counts
|
899 |
+
|
900 |
+
|
901 |
+
#################################################################################
|
902 |
+
# HunyuanVideo Configs #
|
903 |
+
#################################################################################
|
904 |
+
|
905 |
+
HUNYUAN_VIDEO_CONFIG = {
|
906 |
+
"HYVideo-T/2": {
|
907 |
+
"mm_double_blocks_depth": 20,
|
908 |
+
"mm_single_blocks_depth": 40,
|
909 |
+
"rope_dim_list": [16, 56, 56],
|
910 |
+
"hidden_size": 3072,
|
911 |
+
"heads_num": 24,
|
912 |
+
"mlp_width_ratio": 4,
|
913 |
+
},
|
914 |
+
"HYVideo-T/2-cfgdistill": {
|
915 |
+
"mm_double_blocks_depth": 20,
|
916 |
+
"mm_single_blocks_depth": 40,
|
917 |
+
"rope_dim_list": [16, 56, 56],
|
918 |
+
"hidden_size": 3072,
|
919 |
+
"heads_num": 24,
|
920 |
+
"mlp_width_ratio": 4,
|
921 |
+
"guidance_embed": True,
|
922 |
+
},
|
923 |
+
}
|
924 |
+
|
925 |
+
|
926 |
+
def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
|
927 |
+
"""load hunyuan video model
|
928 |
+
|
929 |
+
NOTE: Only support HYVideo-T/2-cfgdistill now.
|
930 |
+
|
931 |
+
Args:
|
932 |
+
text_state_dim (int): text state dimension
|
933 |
+
text_state_dim_2 (int): text state dimension 2
|
934 |
+
in_channels (int): input channels number
|
935 |
+
out_channels (int): output channels number
|
936 |
+
factor_kwargs (dict): factor kwargs
|
937 |
+
|
938 |
+
Returns:
|
939 |
+
model (nn.Module): The hunyuan video model
|
940 |
+
"""
|
941 |
+
# if args.model in HUNYUAN_VIDEO_CONFIG.keys():
|
942 |
+
model = HYVideoDiffusionTransformer(
|
943 |
+
text_states_dim=text_states_dim,
|
944 |
+
text_states_dim_2=text_states_dim_2,
|
945 |
+
in_channels=in_channels,
|
946 |
+
out_channels=out_channels,
|
947 |
+
**HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
|
948 |
+
**factor_kwargs,
|
949 |
+
)
|
950 |
+
return model
|
951 |
+
# else:
|
952 |
+
# raise NotImplementedError()
|
953 |
+
|
954 |
+
|
955 |
+
def load_state_dict(model, model_path):
|
956 |
+
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
|
957 |
+
|
958 |
+
load_key = "module"
|
959 |
+
if load_key in state_dict:
|
960 |
+
state_dict = state_dict[load_key]
|
961 |
+
else:
|
962 |
+
raise KeyError(
|
963 |
+
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
|
964 |
+
f"are: {list(state_dict.keys())}."
|
965 |
+
)
|
966 |
+
model.load_state_dict(state_dict, strict=True, assign=True)
|
967 |
+
return model
|
968 |
+
|
969 |
+
|
970 |
+
def load_transformer(dit_path, attn_mode, split_attn, device, dtype, in_channels=16) -> HYVideoDiffusionTransformer:
|
971 |
+
# =========================== Build main model ===========================
|
972 |
+
factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode, "split_attn": split_attn}
|
973 |
+
latent_channels = 16
|
974 |
+
out_channels = latent_channels
|
975 |
+
|
976 |
+
with accelerate.init_empty_weights():
|
977 |
+
transformer = load_dit_model(
|
978 |
+
text_states_dim=4096,
|
979 |
+
text_states_dim_2=768,
|
980 |
+
in_channels=in_channels,
|
981 |
+
out_channels=out_channels,
|
982 |
+
factor_kwargs=factor_kwargs,
|
983 |
+
)
|
984 |
+
|
985 |
+
if os.path.splitext(dit_path)[-1] == ".safetensors":
|
986 |
+
# loading safetensors: may be already fp8
|
987 |
+
with MemoryEfficientSafeOpen(dit_path) as f:
|
988 |
+
state_dict = {}
|
989 |
+
for k in f.keys():
|
990 |
+
tensor = f.get_tensor(k)
|
991 |
+
tensor = tensor.to(device=device, dtype=dtype)
|
992 |
+
# TODO support comfy model
|
993 |
+
# if k.startswith("model.model."):
|
994 |
+
# k = convert_comfy_model_key(k)
|
995 |
+
state_dict[k] = tensor
|
996 |
+
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
997 |
+
else:
|
998 |
+
transformer = load_state_dict(transformer, dit_path)
|
999 |
+
|
1000 |
+
return transformer
|
1001 |
+
|
1002 |
+
|
1003 |
+
def get_rotary_pos_embed_by_shape(model, latents_size):
|
1004 |
+
target_ndim = 3
|
1005 |
+
ndim = 5 - 2
|
1006 |
+
|
1007 |
+
if isinstance(model.patch_size, int):
|
1008 |
+
assert all(s % model.patch_size == 0 for s in latents_size), (
|
1009 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
|
1010 |
+
f"but got {latents_size}."
|
1011 |
+
)
|
1012 |
+
rope_sizes = [s // model.patch_size for s in latents_size]
|
1013 |
+
elif isinstance(model.patch_size, list):
|
1014 |
+
assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
|
1015 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
|
1016 |
+
f"but got {latents_size}."
|
1017 |
+
)
|
1018 |
+
rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
|
1019 |
+
|
1020 |
+
if len(rope_sizes) != target_ndim:
|
1021 |
+
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
1022 |
+
head_dim = model.hidden_size // model.heads_num
|
1023 |
+
rope_dim_list = model.rope_dim_list
|
1024 |
+
if rope_dim_list is None:
|
1025 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
1026 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
1027 |
+
|
1028 |
+
rope_theta = 256
|
1029 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
1030 |
+
rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
|
1031 |
+
)
|
1032 |
+
return freqs_cos, freqs_sin
|
1033 |
+
|
1034 |
+
|
1035 |
+
def get_rotary_pos_embed(vae_name, model, video_length, height, width):
|
1036 |
+
# 884
|
1037 |
+
if "884" in vae_name:
|
1038 |
+
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
|
1039 |
+
elif "888" in vae_name:
|
1040 |
+
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
|
1041 |
+
else:
|
1042 |
+
latents_size = [video_length, height // 8, width // 8]
|
1043 |
+
|
1044 |
+
return get_rotary_pos_embed_by_shape(model, latents_size)
|
hunyuan_model/modulate_layers.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class ModulateDiT(nn.Module):
|
8 |
+
"""Modulation layer for DiT."""
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
hidden_size: int,
|
12 |
+
factor: int,
|
13 |
+
act_layer: Callable,
|
14 |
+
dtype=None,
|
15 |
+
device=None,
|
16 |
+
):
|
17 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
18 |
+
super().__init__()
|
19 |
+
self.act = act_layer()
|
20 |
+
self.linear = nn.Linear(
|
21 |
+
hidden_size, factor * hidden_size, bias=True, **factory_kwargs
|
22 |
+
)
|
23 |
+
# Zero-initialize the modulation
|
24 |
+
nn.init.zeros_(self.linear.weight)
|
25 |
+
nn.init.zeros_(self.linear.bias)
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
return self.linear(self.act(x))
|
29 |
+
|
30 |
+
|
31 |
+
def modulate(x, shift=None, scale=None):
|
32 |
+
"""modulate by shift and scale
|
33 |
+
|
34 |
+
Args:
|
35 |
+
x (torch.Tensor): input tensor.
|
36 |
+
shift (torch.Tensor, optional): shift tensor. Defaults to None.
|
37 |
+
scale (torch.Tensor, optional): scale tensor. Defaults to None.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: the output tensor after modulate.
|
41 |
+
"""
|
42 |
+
if scale is None and shift is None:
|
43 |
+
return x
|
44 |
+
elif shift is None:
|
45 |
+
return x * (1 + scale.unsqueeze(1))
|
46 |
+
elif scale is None:
|
47 |
+
return x + shift.unsqueeze(1)
|
48 |
+
else:
|
49 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
50 |
+
|
51 |
+
|
52 |
+
def apply_gate(x, gate=None, tanh=False):
|
53 |
+
"""AI is creating summary for apply_gate
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x (torch.Tensor): input tensor.
|
57 |
+
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
58 |
+
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
torch.Tensor: the output tensor after apply gate.
|
62 |
+
"""
|
63 |
+
if gate is None:
|
64 |
+
return x
|
65 |
+
if tanh:
|
66 |
+
return x * gate.unsqueeze(1).tanh()
|
67 |
+
else:
|
68 |
+
return x * gate.unsqueeze(1)
|
69 |
+
|
70 |
+
|
71 |
+
def ckpt_wrapper(module):
|
72 |
+
def ckpt_forward(*inputs):
|
73 |
+
outputs = module(*inputs)
|
74 |
+
return outputs
|
75 |
+
|
76 |
+
return ckpt_forward
|