linoyts HF Staff commited on
Commit
35563ae
·
verified ·
1 Parent(s): a545bcc

Upload 28 files

Browse files
app.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers_helper.hf_login import login
2
+
3
+ import os
4
+
5
+ os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')))
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import traceback
10
+ import einops
11
+ import safetensors.torch as sf
12
+ import numpy as np
13
+ import math
14
+ import spaces
15
+
16
+ from PIL import Image
17
+ from diffusers import AutoencoderKLHunyuanVideo
18
+ from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
19
+ from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake
20
+ from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp
21
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
22
+ from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
23
+ from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
24
+ from diffusers_helper.thread_utils import AsyncStream, async_run
25
+ from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
26
+ from transformers import SiglipImageProcessor, SiglipVisionModel
27
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
28
+ from diffusers_helper.bucket_tools import find_nearest_bucket
29
+
30
+
31
+ free_mem_gb = get_cuda_free_memory_gb(gpu)
32
+ high_vram = free_mem_gb > 60
33
+
34
+ print(f'Free VRAM {free_mem_gb} GB')
35
+ print(f'High-VRAM Mode: {high_vram}')
36
+
37
+ text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
38
+ text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
39
+ tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
40
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
41
+ vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()
42
+
43
+ feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
44
+ image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
45
+
46
+ transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained('lllyasviel/FramePack_F1_I2V_HY_20250503', torch_dtype=torch.bfloat16).cpu()
47
+
48
+ vae.eval()
49
+ text_encoder.eval()
50
+ text_encoder_2.eval()
51
+ image_encoder.eval()
52
+ transformer.eval()
53
+
54
+ if not high_vram:
55
+ vae.enable_slicing()
56
+ vae.enable_tiling()
57
+
58
+ transformer.high_quality_fp32_output_for_inference = True
59
+ print('transformer.high_quality_fp32_output_for_inference = True')
60
+
61
+ transformer.to(dtype=torch.bfloat16)
62
+ vae.to(dtype=torch.float16)
63
+ image_encoder.to(dtype=torch.float16)
64
+ text_encoder.to(dtype=torch.float16)
65
+ text_encoder_2.to(dtype=torch.float16)
66
+
67
+ vae.requires_grad_(False)
68
+ text_encoder.requires_grad_(False)
69
+ text_encoder_2.requires_grad_(False)
70
+ image_encoder.requires_grad_(False)
71
+ transformer.requires_grad_(False)
72
+
73
+ if not high_vram:
74
+ # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
75
+ DynamicSwapInstaller.install_model(transformer, device=gpu)
76
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
77
+ else:
78
+ text_encoder.to(gpu)
79
+ text_encoder_2.to(gpu)
80
+ image_encoder.to(gpu)
81
+ vae.to(gpu)
82
+ transformer.to(gpu)
83
+
84
+ stream = AsyncStream()
85
+
86
+ outputs_folder = './outputs/'
87
+ os.makedirs(outputs_folder, exist_ok=True)
88
+
89
+
90
+ @torch.no_grad()
91
+ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
92
+ total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
93
+ total_latent_sections = int(max(round(total_latent_sections), 1))
94
+
95
+ job_id = generate_timestamp()
96
+
97
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
98
+
99
+ try:
100
+ # Clean GPU
101
+ if not high_vram:
102
+ unload_complete_models(
103
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
104
+ )
105
+
106
+ # Text encoding
107
+
108
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
109
+
110
+ if not high_vram:
111
+ fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
112
+ load_model_as_complete(text_encoder_2, target_device=gpu)
113
+
114
+ llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
115
+
116
+ if cfg == 1:
117
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
118
+ else:
119
+ llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
120
+
121
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
122
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
123
+
124
+ # Processing input image
125
+
126
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
127
+
128
+ H, W, C = input_image.shape
129
+ height, width = find_nearest_bucket(H, W, resolution=640)
130
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
131
+
132
+ Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
133
+
134
+ input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
135
+ input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
136
+
137
+ # VAE encoding
138
+
139
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
140
+
141
+ if not high_vram:
142
+ load_model_as_complete(vae, target_device=gpu)
143
+
144
+ start_latent = vae_encode(input_image_pt, vae)
145
+
146
+ # CLIP Vision
147
+
148
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
149
+
150
+ if not high_vram:
151
+ load_model_as_complete(image_encoder, target_device=gpu)
152
+
153
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
154
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
155
+
156
+ # Dtype
157
+
158
+ llama_vec = llama_vec.to(transformer.dtype)
159
+ llama_vec_n = llama_vec_n.to(transformer.dtype)
160
+ clip_l_pooler = clip_l_pooler.to(transformer.dtype)
161
+ clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
162
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
163
+
164
+ # Sampling
165
+
166
+ stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
167
+
168
+ rnd = torch.Generator("cpu").manual_seed(seed)
169
+
170
+ history_latents = torch.zeros(size=(1, 16, 16 + 2 + 1, height // 8, width // 8), dtype=torch.float32).cpu()
171
+ history_pixels = None
172
+
173
+ history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
174
+ total_generated_latent_frames = 1
175
+
176
+ for section_index in range(total_latent_sections):
177
+ if stream.input_queue.top() == 'end':
178
+ stream.output_queue.push(('end', None))
179
+ return
180
+
181
+ print(f'section_index = {section_index}, total_latent_sections = {total_latent_sections}')
182
+
183
+ if not high_vram:
184
+ unload_complete_models()
185
+ move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
186
+
187
+ if use_teacache:
188
+ transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
189
+ else:
190
+ transformer.initialize_teacache(enable_teacache=False)
191
+
192
+ def callback(d):
193
+ preview = d['denoised']
194
+ preview = vae_decode_fake(preview)
195
+
196
+ preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
197
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
198
+
199
+ if stream.input_queue.top() == 'end':
200
+ stream.output_queue.push(('end', None))
201
+ raise KeyboardInterrupt('User ends the task.')
202
+
203
+ current_step = d['i'] + 1
204
+ percentage = int(100.0 * current_step / steps)
205
+ hint = f'Sampling {current_step}/{steps}'
206
+ desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). The video is being extended now ...'
207
+ stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
208
+ return
209
+
210
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
211
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
212
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
213
+
214
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
215
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
216
+
217
+ generated_latents = sample_hunyuan(
218
+ transformer=transformer,
219
+ sampler='unipc',
220
+ width=width,
221
+ height=height,
222
+ frames=latent_window_size * 4 - 3,
223
+ real_guidance_scale=cfg,
224
+ distilled_guidance_scale=gs,
225
+ guidance_rescale=rs,
226
+ # shift=3.0,
227
+ num_inference_steps=steps,
228
+ generator=rnd,
229
+ prompt_embeds=llama_vec,
230
+ prompt_embeds_mask=llama_attention_mask,
231
+ prompt_poolers=clip_l_pooler,
232
+ negative_prompt_embeds=llama_vec_n,
233
+ negative_prompt_embeds_mask=llama_attention_mask_n,
234
+ negative_prompt_poolers=clip_l_pooler_n,
235
+ device=gpu,
236
+ dtype=torch.bfloat16,
237
+ image_embeddings=image_encoder_last_hidden_state,
238
+ latent_indices=latent_indices,
239
+ clean_latents=clean_latents,
240
+ clean_latent_indices=clean_latent_indices,
241
+ clean_latents_2x=clean_latents_2x,
242
+ clean_latent_2x_indices=clean_latent_2x_indices,
243
+ clean_latents_4x=clean_latents_4x,
244
+ clean_latent_4x_indices=clean_latent_4x_indices,
245
+ callback=callback,
246
+ )
247
+
248
+ total_generated_latent_frames += int(generated_latents.shape[2])
249
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
250
+
251
+ if not high_vram:
252
+ offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
253
+ load_model_as_complete(vae, target_device=gpu)
254
+
255
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
256
+
257
+ if history_pixels is None:
258
+ history_pixels = vae_decode(real_history_latents, vae).cpu()
259
+ else:
260
+ section_latent_frames = latent_window_size * 2
261
+ overlapped_frames = latent_window_size * 4 - 3
262
+
263
+ current_pixels = vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
264
+ history_pixels = soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
265
+
266
+ if not high_vram:
267
+ unload_complete_models()
268
+
269
+ output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
270
+
271
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=mp4_crf)
272
+
273
+ print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
274
+
275
+ stream.output_queue.push(('file', output_filename))
276
+ except:
277
+ traceback.print_exc()
278
+
279
+ if not high_vram:
280
+ unload_complete_models(
281
+ text_encoder, text_encoder_2, image_encoder, vae, transformer
282
+ )
283
+
284
+ stream.output_queue.push(('end', None))
285
+ return
286
+
287
+ @spaces.GPU
288
+ def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
289
+ global stream
290
+ assert input_image is not None, 'No input image!'
291
+
292
+ yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
293
+
294
+ stream = AsyncStream()
295
+
296
+ async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf)
297
+
298
+ output_filename = None
299
+
300
+ while True:
301
+ flag, data = stream.output_queue.next()
302
+
303
+ if flag == 'file':
304
+ output_filename = data
305
+ yield output_filename, gr.update(), gr.update(), gr.update(), gr.update(interactive=False), gr.update(interactive=True)
306
+
307
+ if flag == 'progress':
308
+ preview, desc, html = data
309
+ yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
310
+
311
+ if flag == 'end':
312
+ yield output_filename, gr.update(visible=False), gr.update(), '', gr.update(interactive=True), gr.update(interactive=False)
313
+ break
314
+
315
+
316
+ def end_process():
317
+ stream.input_queue.push('end')
318
+
319
+
320
+ quick_prompts = [
321
+ 'The girl dances gracefully, with clear movements, full of charm.',
322
+ 'A character doing some simple body movements.',
323
+ ]
324
+ quick_prompts = [[x] for x in quick_prompts]
325
+
326
+
327
+ css = make_progress_bar_css()
328
+ block = gr.Blocks(css=css).queue()
329
+ with block:
330
+ gr.Markdown('# FramePack-F1')
331
+ with gr.Row():
332
+ with gr.Column():
333
+ input_image = gr.Image(sources='upload', type="numpy", label="Image", height=320)
334
+ prompt = gr.Textbox(label="Prompt", value='')
335
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Quick List', samples_per_page=1000, components=[prompt])
336
+ example_quick_prompts.click(lambda x: x[0], inputs=[example_quick_prompts], outputs=prompt, show_progress=False, queue=False)
337
+
338
+ with gr.Row():
339
+ start_button = gr.Button(value="Start Generation")
340
+ end_button = gr.Button(value="End Generation", interactive=False)
341
+
342
+ with gr.Group():
343
+ use_teacache = gr.Checkbox(label='Use TeaCache', value=True, info='Faster speed, but often makes hands and fingers slightly worse.')
344
+
345
+ n_prompt = gr.Textbox(label="Negative Prompt", value="", visible=False) # Not used
346
+ seed = gr.Number(label="Seed", value=31337, precision=0)
347
+
348
+ total_second_length = gr.Slider(label="Total Video Length (Seconds)", minimum=1, maximum=5, value=2, step=0.1)
349
+ latent_window_size = gr.Slider(label="Latent Window Size", minimum=1, maximum=33, value=9, step=1, visible=False) # Should not change
350
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1, info='Changing this value is not recommended.')
351
+
352
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01, visible=False) # Should not change
353
+ gs = gr.Slider(label="Distilled CFG Scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01, info='Changing this value is not recommended.')
354
+ rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01, visible=False) # Should not change
355
+
356
+ gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=6, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.")
357
+
358
+ mp4_crf = gr.Slider(label="MP4 Compression", minimum=0, maximum=100, value=16, step=1, info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs. ")
359
+
360
+ with gr.Column():
361
+ preview_image = gr.Image(label="Next Latents", height=200, visible=False)
362
+ result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=512, loop=True)
363
+ progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
364
+ progress_bar = gr.HTML('', elem_classes='no-generating-animation')
365
+
366
+ gr.HTML('<div style="text-align:center; margin-top:20px;">Share your results and find ideas at the <a href="https://x.com/search?q=framepack&f=live" target="_blank">FramePack Twitter (X) thread</a></div>')
367
+
368
+ ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf]
369
+ start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button])
370
+ end_button.click(fn=end_process)
371
+
372
+
373
+ block.launch(share=True)
diffusers_helper/__pycache__/bucket_tools.cpython-310.pyc ADDED
Binary file (646 Bytes). View file
 
diffusers_helper/__pycache__/clip_vision.cpython-310.pyc ADDED
Binary file (601 Bytes). View file
 
diffusers_helper/__pycache__/dit_common.cpython-310.pyc ADDED
Binary file (1.7 kB). View file
 
diffusers_helper/__pycache__/hf_login.cpython-310.pyc ADDED
Binary file (598 Bytes). View file
 
diffusers_helper/__pycache__/hunyuan.cpython-310.pyc ADDED
Binary file (3.35 kB). View file
 
diffusers_helper/__pycache__/memory.cpython-310.pyc ADDED
Binary file (4.11 kB). View file
 
diffusers_helper/__pycache__/thread_utils.cpython-310.pyc ADDED
Binary file (2.76 kB). View file
 
diffusers_helper/__pycache__/utils.cpython-310.pyc ADDED
Binary file (18 kB). View file
 
diffusers_helper/bucket_tools.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bucket_options = {
2
+ 640: [
3
+ (416, 960),
4
+ (448, 864),
5
+ (480, 832),
6
+ (512, 768),
7
+ (544, 704),
8
+ (576, 672),
9
+ (608, 640),
10
+ (640, 608),
11
+ (672, 576),
12
+ (704, 544),
13
+ (768, 512),
14
+ (832, 480),
15
+ (864, 448),
16
+ (960, 416),
17
+ ],
18
+ }
19
+
20
+
21
+ def find_nearest_bucket(h, w, resolution=640):
22
+ min_metric = float('inf')
23
+ best_bucket = None
24
+ for (bucket_h, bucket_w) in bucket_options[resolution]:
25
+ metric = abs(h * bucket_w - w * bucket_h)
26
+ if metric <= min_metric:
27
+ min_metric = metric
28
+ best_bucket = (bucket_h, bucket_w)
29
+ return best_bucket
30
+
diffusers_helper/clip_vision.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def hf_clip_vision_encode(image, feature_extractor, image_encoder):
5
+ assert isinstance(image, np.ndarray)
6
+ assert image.ndim == 3 and image.shape[2] == 3
7
+ assert image.dtype == np.uint8
8
+
9
+ preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
10
+ image_encoder_output = image_encoder(**preprocessed)
11
+
12
+ return image_encoder_output
diffusers_helper/dit_common.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import accelerate.accelerator
3
+
4
+ from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
5
+
6
+
7
+ accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
8
+
9
+
10
+ def LayerNorm_forward(self, x):
11
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
12
+
13
+
14
+ LayerNorm.forward = LayerNorm_forward
15
+ torch.nn.LayerNorm.forward = LayerNorm_forward
16
+
17
+
18
+ def FP32LayerNorm_forward(self, x):
19
+ origin_dtype = x.dtype
20
+ return torch.nn.functional.layer_norm(
21
+ x.float(),
22
+ self.normalized_shape,
23
+ self.weight.float() if self.weight is not None else None,
24
+ self.bias.float() if self.bias is not None else None,
25
+ self.eps,
26
+ ).to(origin_dtype)
27
+
28
+
29
+ FP32LayerNorm.forward = FP32LayerNorm_forward
30
+
31
+
32
+ def RMSNorm_forward(self, hidden_states):
33
+ input_dtype = hidden_states.dtype
34
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
35
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
36
+
37
+ if self.weight is None:
38
+ return hidden_states.to(input_dtype)
39
+
40
+ return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
41
+
42
+
43
+ RMSNorm.forward = RMSNorm_forward
44
+
45
+
46
+ def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
47
+ emb = self.linear(self.silu(conditioning_embedding))
48
+ scale, shift = emb.chunk(2, dim=1)
49
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
50
+ return x
51
+
52
+
53
+ AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
diffusers_helper/gradio/__pycache__/progress_bar.cpython-310.pyc ADDED
Binary file (2.45 kB). View file
 
diffusers_helper/gradio/progress_bar.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ progress_html = '''
2
+ <div class="loader-container">
3
+ <div class="loader"></div>
4
+ <div class="progress-container">
5
+ <progress value="*number*" max="100"></progress>
6
+ </div>
7
+ <span>*text*</span>
8
+ </div>
9
+ '''
10
+
11
+ css = '''
12
+ .loader-container {
13
+ display: flex; /* Use flex to align items horizontally */
14
+ align-items: center; /* Center items vertically within the container */
15
+ white-space: nowrap; /* Prevent line breaks within the container */
16
+ }
17
+
18
+ .loader {
19
+ border: 8px solid #f3f3f3; /* Light grey */
20
+ border-top: 8px solid #3498db; /* Blue */
21
+ border-radius: 50%;
22
+ width: 30px;
23
+ height: 30px;
24
+ animation: spin 2s linear infinite;
25
+ }
26
+
27
+ @keyframes spin {
28
+ 0% { transform: rotate(0deg); }
29
+ 100% { transform: rotate(360deg); }
30
+ }
31
+
32
+ /* Style the progress bar */
33
+ progress {
34
+ appearance: none; /* Remove default styling */
35
+ height: 20px; /* Set the height of the progress bar */
36
+ border-radius: 5px; /* Round the corners of the progress bar */
37
+ background-color: #f3f3f3; /* Light grey background */
38
+ width: 100%;
39
+ vertical-align: middle !important;
40
+ }
41
+
42
+ /* Style the progress bar container */
43
+ .progress-container {
44
+ margin-left: 20px;
45
+ margin-right: 20px;
46
+ flex-grow: 1; /* Allow the progress container to take up remaining space */
47
+ }
48
+
49
+ /* Set the color of the progress bar fill */
50
+ progress::-webkit-progress-value {
51
+ background-color: #3498db; /* Blue color for the fill */
52
+ }
53
+
54
+ progress::-moz-progress-bar {
55
+ background-color: #3498db; /* Blue color for the fill in Firefox */
56
+ }
57
+
58
+ /* Style the text on the progress bar */
59
+ progress::after {
60
+ content: attr(value '%'); /* Display the progress value followed by '%' */
61
+ position: absolute;
62
+ top: 50%;
63
+ left: 50%;
64
+ transform: translate(-50%, -50%);
65
+ color: white; /* Set text color */
66
+ font-size: 14px; /* Set font size */
67
+ }
68
+
69
+ /* Style other texts */
70
+ .loader-container > span {
71
+ margin-left: 5px; /* Add spacing between the progress bar and the text */
72
+ }
73
+
74
+ .no-generating-animation > .generating {
75
+ display: none !important;
76
+ }
77
+
78
+ '''
79
+
80
+
81
+ def make_progress_bar_html(number, text):
82
+ return progress_html.replace('*number*', str(number)).replace('*text*', text)
83
+
84
+
85
+ def make_progress_bar_css():
86
+ return css
diffusers_helper/hf_login.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def login(token):
5
+ from huggingface_hub import login
6
+ import time
7
+
8
+ while True:
9
+ try:
10
+ login(token)
11
+ print('HF login ok.')
12
+ break
13
+ except Exception as e:
14
+ print(f'HF login failed: {e}. Retrying')
15
+ time.sleep(0.5)
16
+
17
+
18
+ hf_token = os.environ.get('HF_TOKEN', None)
19
+
20
+ if hf_token is not None:
21
+ login(hf_token)
diffusers_helper/hunyuan.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
4
+ from diffusers_helper.utils import crop_or_pad_yield_mask
5
+
6
+
7
+ @torch.no_grad()
8
+ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
9
+ assert isinstance(prompt, str)
10
+
11
+ prompt = [prompt]
12
+
13
+ # LLAMA
14
+
15
+ prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt]
16
+ crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"]
17
+
18
+ llama_inputs = tokenizer(
19
+ prompt_llama,
20
+ padding="max_length",
21
+ max_length=max_length + crop_start,
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ return_length=False,
25
+ return_overflowing_tokens=False,
26
+ return_attention_mask=True,
27
+ )
28
+
29
+ llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
30
+ llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
31
+ llama_attention_length = int(llama_attention_mask.sum())
32
+
33
+ llama_outputs = text_encoder(
34
+ input_ids=llama_input_ids,
35
+ attention_mask=llama_attention_mask,
36
+ output_hidden_states=True,
37
+ )
38
+
39
+ llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
40
+ # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
41
+ llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
42
+
43
+ assert torch.all(llama_attention_mask.bool())
44
+
45
+ # CLIP
46
+
47
+ clip_l_input_ids = tokenizer_2(
48
+ prompt,
49
+ padding="max_length",
50
+ max_length=77,
51
+ truncation=True,
52
+ return_overflowing_tokens=False,
53
+ return_length=False,
54
+ return_tensors="pt",
55
+ ).input_ids
56
+ clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
57
+
58
+ return llama_vec, clip_l_pooler
59
+
60
+
61
+ @torch.no_grad()
62
+ def vae_decode_fake(latents):
63
+ latent_rgb_factors = [
64
+ [-0.0395, -0.0331, 0.0445],
65
+ [0.0696, 0.0795, 0.0518],
66
+ [0.0135, -0.0945, -0.0282],
67
+ [0.0108, -0.0250, -0.0765],
68
+ [-0.0209, 0.0032, 0.0224],
69
+ [-0.0804, -0.0254, -0.0639],
70
+ [-0.0991, 0.0271, -0.0669],
71
+ [-0.0646, -0.0422, -0.0400],
72
+ [-0.0696, -0.0595, -0.0894],
73
+ [-0.0799, -0.0208, -0.0375],
74
+ [0.1166, 0.1627, 0.0962],
75
+ [0.1165, 0.0432, 0.0407],
76
+ [-0.2315, -0.1920, -0.1355],
77
+ [-0.0270, 0.0401, -0.0821],
78
+ [-0.0616, -0.0997, -0.0727],
79
+ [0.0249, -0.0469, -0.1703]
80
+ ] # From comfyui
81
+
82
+ latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
83
+
84
+ weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
85
+ bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
86
+
87
+ images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
88
+ images = images.clamp(0.0, 1.0)
89
+
90
+ return images
91
+
92
+
93
+ @torch.no_grad()
94
+ def vae_decode(latents, vae, image_mode=False):
95
+ latents = latents / vae.config.scaling_factor
96
+
97
+ if not image_mode:
98
+ image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
99
+ else:
100
+ latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
101
+ image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
102
+ image = torch.cat(image, dim=2)
103
+
104
+ return image
105
+
106
+
107
+ @torch.no_grad()
108
+ def vae_encode(image, vae):
109
+ latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
110
+ latents = latents * vae.config.scaling_factor
111
+ return latents
diffusers_helper/k_diffusion/__pycache__/uni_pc_fm.cpython-310.pyc ADDED
Binary file (3.22 kB). View file
 
diffusers_helper/k_diffusion/__pycache__/wrapper.cpython-310.pyc ADDED
Binary file (1.59 kB). View file
 
diffusers_helper/k_diffusion/uni_pc_fm.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Better Flow Matching UniPC by Lvmin Zhang
2
+ # (c) 2025
3
+ # CC BY-SA 4.0
4
+ # Attribution-ShareAlike 4.0 International Licence
5
+
6
+
7
+ import torch
8
+
9
+ from tqdm.auto import trange
10
+
11
+
12
+ def expand_dims(v, dims):
13
+ return v[(...,) + (None,) * (dims - 1)]
14
+
15
+
16
+ class FlowMatchUniPC:
17
+ def __init__(self, model, extra_args, variant='bh1'):
18
+ self.model = model
19
+ self.variant = variant
20
+ self.extra_args = extra_args
21
+
22
+ def model_fn(self, x, t):
23
+ return self.model(x, t, **self.extra_args)
24
+
25
+ def update_fn(self, x, model_prev_list, t_prev_list, t, order):
26
+ assert order <= len(model_prev_list)
27
+ dims = x.dim()
28
+
29
+ t_prev_0 = t_prev_list[-1]
30
+ lambda_prev_0 = - torch.log(t_prev_0)
31
+ lambda_t = - torch.log(t)
32
+ model_prev_0 = model_prev_list[-1]
33
+
34
+ h = lambda_t - lambda_prev_0
35
+
36
+ rks = []
37
+ D1s = []
38
+ for i in range(1, order):
39
+ t_prev_i = t_prev_list[-(i + 1)]
40
+ model_prev_i = model_prev_list[-(i + 1)]
41
+ lambda_prev_i = - torch.log(t_prev_i)
42
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
43
+ rks.append(rk)
44
+ D1s.append((model_prev_i - model_prev_0) / rk)
45
+
46
+ rks.append(1.)
47
+ rks = torch.tensor(rks, device=x.device)
48
+
49
+ R = []
50
+ b = []
51
+
52
+ hh = -h[0]
53
+ h_phi_1 = torch.expm1(hh)
54
+ h_phi_k = h_phi_1 / hh - 1
55
+
56
+ factorial_i = 1
57
+
58
+ if self.variant == 'bh1':
59
+ B_h = hh
60
+ elif self.variant == 'bh2':
61
+ B_h = torch.expm1(hh)
62
+ else:
63
+ raise NotImplementedError('Bad variant!')
64
+
65
+ for i in range(1, order + 1):
66
+ R.append(torch.pow(rks, i - 1))
67
+ b.append(h_phi_k * factorial_i / B_h)
68
+ factorial_i *= (i + 1)
69
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
70
+
71
+ R = torch.stack(R)
72
+ b = torch.tensor(b, device=x.device)
73
+
74
+ use_predictor = len(D1s) > 0
75
+
76
+ if use_predictor:
77
+ D1s = torch.stack(D1s, dim=1)
78
+ if order == 2:
79
+ rhos_p = torch.tensor([0.5], device=b.device)
80
+ else:
81
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
82
+ else:
83
+ D1s = None
84
+ rhos_p = None
85
+
86
+ if order == 1:
87
+ rhos_c = torch.tensor([0.5], device=b.device)
88
+ else:
89
+ rhos_c = torch.linalg.solve(R, b)
90
+
91
+ x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
92
+
93
+ if use_predictor:
94
+ pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
95
+ else:
96
+ pred_res = 0
97
+
98
+ x_t = x_t_ - expand_dims(B_h, dims) * pred_res
99
+ model_t = self.model_fn(x_t, t)
100
+
101
+ if D1s is not None:
102
+ corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
103
+ else:
104
+ corr_res = 0
105
+
106
+ D1_t = (model_t - model_prev_0)
107
+ x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
108
+
109
+ return x_t, model_t
110
+
111
+ def sample(self, x, sigmas, callback=None, disable_pbar=False):
112
+ order = min(3, len(sigmas) - 2)
113
+ model_prev_list, t_prev_list = [], []
114
+ for i in trange(len(sigmas) - 1, disable=disable_pbar):
115
+ vec_t = sigmas[i].expand(x.shape[0])
116
+
117
+ if i == 0:
118
+ model_prev_list = [self.model_fn(x, vec_t)]
119
+ t_prev_list = [vec_t]
120
+ elif i < order:
121
+ init_order = i
122
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
123
+ model_prev_list.append(model_x)
124
+ t_prev_list.append(vec_t)
125
+ else:
126
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
127
+ model_prev_list.append(model_x)
128
+ t_prev_list.append(vec_t)
129
+
130
+ model_prev_list = model_prev_list[-order:]
131
+ t_prev_list = t_prev_list[-order:]
132
+
133
+ if callback is not None:
134
+ callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
135
+
136
+ return model_prev_list[-1]
137
+
138
+
139
+ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
140
+ assert variant in ['bh1', 'bh2']
141
+ return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
diffusers_helper/k_diffusion/wrapper.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def append_dims(x, target_dims):
5
+ return x[(...,) + (None,) * (target_dims - x.ndim)]
6
+
7
+
8
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
9
+ if guidance_rescale == 0:
10
+ return noise_cfg
11
+
12
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
13
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
14
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
15
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
16
+ return noise_cfg
17
+
18
+
19
+ def fm_wrapper(transformer, t_scale=1000.0):
20
+ def k_model(x, sigma, **extra_args):
21
+ dtype = extra_args['dtype']
22
+ cfg_scale = extra_args['cfg_scale']
23
+ cfg_rescale = extra_args['cfg_rescale']
24
+ concat_latent = extra_args['concat_latent']
25
+
26
+ original_dtype = x.dtype
27
+ sigma = sigma.float()
28
+
29
+ x = x.to(dtype)
30
+ timestep = (sigma * t_scale).to(dtype)
31
+
32
+ if concat_latent is None:
33
+ hidden_states = x
34
+ else:
35
+ hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
36
+
37
+ pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
38
+
39
+ if cfg_scale == 1.0:
40
+ pred_negative = torch.zeros_like(pred_positive)
41
+ else:
42
+ pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
43
+
44
+ pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
45
+ pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
46
+
47
+ x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
48
+
49
+ return x0.to(dtype=original_dtype)
50
+
51
+ return k_model
diffusers_helper/memory.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # By lllyasviel
2
+
3
+
4
+ import torch
5
+
6
+
7
+ cpu = torch.device('cpu')
8
+ gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
9
+ gpu_complete_modules = []
10
+
11
+
12
+ class DynamicSwapInstaller:
13
+ @staticmethod
14
+ def _install_module(module: torch.nn.Module, **kwargs):
15
+ original_class = module.__class__
16
+ module.__dict__['forge_backup_original_class'] = original_class
17
+
18
+ def hacked_get_attr(self, name: str):
19
+ if '_parameters' in self.__dict__:
20
+ _parameters = self.__dict__['_parameters']
21
+ if name in _parameters:
22
+ p = _parameters[name]
23
+ if p is None:
24
+ return None
25
+ if p.__class__ == torch.nn.Parameter:
26
+ return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
27
+ else:
28
+ return p.to(**kwargs)
29
+ if '_buffers' in self.__dict__:
30
+ _buffers = self.__dict__['_buffers']
31
+ if name in _buffers:
32
+ return _buffers[name].to(**kwargs)
33
+ return super(original_class, self).__getattr__(name)
34
+
35
+ module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
36
+ '__getattr__': hacked_get_attr,
37
+ })
38
+
39
+ return
40
+
41
+ @staticmethod
42
+ def _uninstall_module(module: torch.nn.Module):
43
+ if 'forge_backup_original_class' in module.__dict__:
44
+ module.__class__ = module.__dict__.pop('forge_backup_original_class')
45
+ return
46
+
47
+ @staticmethod
48
+ def install_model(model: torch.nn.Module, **kwargs):
49
+ for m in model.modules():
50
+ DynamicSwapInstaller._install_module(m, **kwargs)
51
+ return
52
+
53
+ @staticmethod
54
+ def uninstall_model(model: torch.nn.Module):
55
+ for m in model.modules():
56
+ DynamicSwapInstaller._uninstall_module(m)
57
+ return
58
+
59
+
60
+ def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
61
+ if hasattr(model, 'scale_shift_table'):
62
+ model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
63
+ return
64
+
65
+ for k, p in model.named_modules():
66
+ if hasattr(p, 'weight'):
67
+ p.to(target_device)
68
+ return
69
+
70
+
71
+ def get_cuda_free_memory_gb(device=None):
72
+ if device is None:
73
+ device = gpu
74
+
75
+ memory_stats = torch.cuda.memory_stats(device)
76
+ bytes_active = memory_stats['active_bytes.all.current']
77
+ bytes_reserved = memory_stats['reserved_bytes.all.current']
78
+ bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
79
+ bytes_inactive_reserved = bytes_reserved - bytes_active
80
+ bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
81
+ return bytes_total_available / (1024 ** 3)
82
+
83
+
84
+ def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
85
+ print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
86
+
87
+ for m in model.modules():
88
+ if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
89
+ torch.cuda.empty_cache()
90
+ return
91
+
92
+ if hasattr(m, 'weight'):
93
+ m.to(device=target_device)
94
+
95
+ model.to(device=target_device)
96
+ torch.cuda.empty_cache()
97
+ return
98
+
99
+
100
+ def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
101
+ print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
102
+
103
+ for m in model.modules():
104
+ if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
105
+ torch.cuda.empty_cache()
106
+ return
107
+
108
+ if hasattr(m, 'weight'):
109
+ m.to(device=cpu)
110
+
111
+ model.to(device=cpu)
112
+ torch.cuda.empty_cache()
113
+ return
114
+
115
+
116
+ def unload_complete_models(*args):
117
+ for m in gpu_complete_modules + list(args):
118
+ m.to(device=cpu)
119
+ print(f'Unloaded {m.__class__.__name__} as complete.')
120
+
121
+ gpu_complete_modules.clear()
122
+ torch.cuda.empty_cache()
123
+ return
124
+
125
+
126
+ def load_model_as_complete(model, target_device, unload=True):
127
+ if unload:
128
+ unload_complete_models()
129
+
130
+ model.to(device=target_device)
131
+ print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
132
+
133
+ gpu_complete_modules.append(model)
134
+ return
diffusers_helper/models/__pycache__/hunyuan_video_packed.cpython-310.pyc ADDED
Binary file (29.2 kB). View file
 
diffusers_helper/models/hunyuan_video_packed.py ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import einops
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ from diffusers.loaders import FromOriginalModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.loaders import PeftAdapterMixin
11
+ from diffusers.utils import logging
12
+ from diffusers.models.attention import FeedForward
13
+ from diffusers.models.attention_processor import Attention
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers_helper.dit_common import LayerNorm
18
+ from diffusers_helper.utils import zero_module
19
+
20
+
21
+ enabled_backends = []
22
+
23
+ if torch.backends.cuda.flash_sdp_enabled():
24
+ enabled_backends.append("flash")
25
+ if torch.backends.cuda.math_sdp_enabled():
26
+ enabled_backends.append("math")
27
+ if torch.backends.cuda.mem_efficient_sdp_enabled():
28
+ enabled_backends.append("mem_efficient")
29
+ if torch.backends.cuda.cudnn_sdp_enabled():
30
+ enabled_backends.append("cudnn")
31
+
32
+ print("Currently enabled native sdp backends:", enabled_backends)
33
+
34
+ try:
35
+ # raise NotImplementedError
36
+ from xformers.ops import memory_efficient_attention as xformers_attn_func
37
+ print('Xformers is installed!')
38
+ except:
39
+ print('Xformers is not installed!')
40
+ xformers_attn_func = None
41
+
42
+ try:
43
+ # raise NotImplementedError
44
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
45
+ print('Flash Attn is installed!')
46
+ except:
47
+ print('Flash Attn is not installed!')
48
+ flash_attn_varlen_func = None
49
+ flash_attn_func = None
50
+
51
+ try:
52
+ # raise NotImplementedError
53
+ from sageattention import sageattn_varlen, sageattn
54
+ print('Sage Attn is installed!')
55
+ except:
56
+ print('Sage Attn is not installed!')
57
+ sageattn_varlen = None
58
+ sageattn = None
59
+
60
+
61
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
+
63
+
64
+ def pad_for_3d_conv(x, kernel_size):
65
+ b, c, t, h, w = x.shape
66
+ pt, ph, pw = kernel_size
67
+ pad_t = (pt - (t % pt)) % pt
68
+ pad_h = (ph - (h % ph)) % ph
69
+ pad_w = (pw - (w % pw)) % pw
70
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
71
+
72
+
73
+ def center_down_sample_3d(x, kernel_size):
74
+ # pt, ph, pw = kernel_size
75
+ # cp = (pt * ph * pw) // 2
76
+ # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
77
+ # xc = xp[cp]
78
+ # return xc
79
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
80
+
81
+
82
+ def get_cu_seqlens(text_mask, img_len):
83
+ batch_size = text_mask.shape[0]
84
+ text_len = text_mask.sum(dim=1)
85
+ max_len = text_mask.shape[1] + img_len
86
+
87
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
88
+
89
+ for i in range(batch_size):
90
+ s = text_len[i] + img_len
91
+ s1 = i * max_len + s
92
+ s2 = (i + 1) * max_len
93
+ cu_seqlens[2 * i + 1] = s1
94
+ cu_seqlens[2 * i + 2] = s2
95
+
96
+ return cu_seqlens
97
+
98
+
99
+ def apply_rotary_emb_transposed(x, freqs_cis):
100
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
101
+ x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
102
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
103
+ out = x.float() * cos + x_rotated.float() * sin
104
+ out = out.to(x)
105
+ return out
106
+
107
+
108
+ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
109
+ if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
110
+ if sageattn is not None:
111
+ x = sageattn(q, k, v, tensor_layout='NHD')
112
+ return x
113
+
114
+ if flash_attn_func is not None:
115
+ x = flash_attn_func(q, k, v)
116
+ return x
117
+
118
+ if xformers_attn_func is not None:
119
+ x = xformers_attn_func(q, k, v)
120
+ return x
121
+
122
+ x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
+ return x
124
+
125
+ B, L, H, C = q.shape
126
+
127
+ q = q.flatten(0, 1)
128
+ k = k.flatten(0, 1)
129
+ v = v.flatten(0, 1)
130
+
131
+ if sageattn_varlen is not None:
132
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
+ elif flash_attn_varlen_func is not None:
134
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
135
+ else:
136
+ raise NotImplementedError('No Attn Installed!')
137
+
138
+ x = x.unflatten(0, (B, L))
139
+
140
+ return x
141
+
142
+
143
+ class HunyuanAttnProcessorFlashAttnDouble:
144
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
145
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
146
+
147
+ query = attn.to_q(hidden_states)
148
+ key = attn.to_k(hidden_states)
149
+ value = attn.to_v(hidden_states)
150
+
151
+ query = query.unflatten(2, (attn.heads, -1))
152
+ key = key.unflatten(2, (attn.heads, -1))
153
+ value = value.unflatten(2, (attn.heads, -1))
154
+
155
+ query = attn.norm_q(query)
156
+ key = attn.norm_k(key)
157
+
158
+ query = apply_rotary_emb_transposed(query, image_rotary_emb)
159
+ key = apply_rotary_emb_transposed(key, image_rotary_emb)
160
+
161
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
162
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
163
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
164
+
165
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
166
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
167
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
168
+
169
+ encoder_query = attn.norm_added_q(encoder_query)
170
+ encoder_key = attn.norm_added_k(encoder_key)
171
+
172
+ query = torch.cat([query, encoder_query], dim=1)
173
+ key = torch.cat([key, encoder_key], dim=1)
174
+ value = torch.cat([value, encoder_value], dim=1)
175
+
176
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
177
+ hidden_states = hidden_states.flatten(-2)
178
+
179
+ txt_length = encoder_hidden_states.shape[1]
180
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
181
+
182
+ hidden_states = attn.to_out[0](hidden_states)
183
+ hidden_states = attn.to_out[1](hidden_states)
184
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
185
+
186
+ return hidden_states, encoder_hidden_states
187
+
188
+
189
+ class HunyuanAttnProcessorFlashAttnSingle:
190
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
191
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
192
+
193
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
194
+
195
+ query = attn.to_q(hidden_states)
196
+ key = attn.to_k(hidden_states)
197
+ value = attn.to_v(hidden_states)
198
+
199
+ query = query.unflatten(2, (attn.heads, -1))
200
+ key = key.unflatten(2, (attn.heads, -1))
201
+ value = value.unflatten(2, (attn.heads, -1))
202
+
203
+ query = attn.norm_q(query)
204
+ key = attn.norm_k(key)
205
+
206
+ txt_length = encoder_hidden_states.shape[1]
207
+
208
+ query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
209
+ key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
210
+
211
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
212
+ hidden_states = hidden_states.flatten(-2)
213
+
214
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
215
+
216
+ return hidden_states, encoder_hidden_states
217
+
218
+
219
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
220
+ def __init__(self, embedding_dim, pooled_projection_dim):
221
+ super().__init__()
222
+
223
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
224
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
225
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
226
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
227
+
228
+ def forward(self, timestep, guidance, pooled_projection):
229
+ timesteps_proj = self.time_proj(timestep)
230
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
231
+
232
+ guidance_proj = self.time_proj(guidance)
233
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
234
+
235
+ time_guidance_emb = timesteps_emb + guidance_emb
236
+
237
+ pooled_projections = self.text_embedder(pooled_projection)
238
+ conditioning = time_guidance_emb + pooled_projections
239
+
240
+ return conditioning
241
+
242
+
243
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
244
+ def __init__(self, embedding_dim, pooled_projection_dim):
245
+ super().__init__()
246
+
247
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
248
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
249
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
250
+
251
+ def forward(self, timestep, pooled_projection):
252
+ timesteps_proj = self.time_proj(timestep)
253
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
254
+
255
+ pooled_projections = self.text_embedder(pooled_projection)
256
+
257
+ conditioning = timesteps_emb + pooled_projections
258
+
259
+ return conditioning
260
+
261
+
262
+ class HunyuanVideoAdaNorm(nn.Module):
263
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
264
+ super().__init__()
265
+
266
+ out_features = out_features or 2 * in_features
267
+ self.linear = nn.Linear(in_features, out_features)
268
+ self.nonlinearity = nn.SiLU()
269
+
270
+ def forward(
271
+ self, temb: torch.Tensor
272
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
273
+ temb = self.linear(self.nonlinearity(temb))
274
+ gate_msa, gate_mlp = temb.chunk(2, dim=-1)
275
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
276
+ return gate_msa, gate_mlp
277
+
278
+
279
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
280
+ def __init__(
281
+ self,
282
+ num_attention_heads: int,
283
+ attention_head_dim: int,
284
+ mlp_width_ratio: str = 4.0,
285
+ mlp_drop_rate: float = 0.0,
286
+ attention_bias: bool = True,
287
+ ) -> None:
288
+ super().__init__()
289
+
290
+ hidden_size = num_attention_heads * attention_head_dim
291
+
292
+ self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
293
+ self.attn = Attention(
294
+ query_dim=hidden_size,
295
+ cross_attention_dim=None,
296
+ heads=num_attention_heads,
297
+ dim_head=attention_head_dim,
298
+ bias=attention_bias,
299
+ )
300
+
301
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
302
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
303
+
304
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ temb: torch.Tensor,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ ) -> torch.Tensor:
312
+ norm_hidden_states = self.norm1(hidden_states)
313
+
314
+ attn_output = self.attn(
315
+ hidden_states=norm_hidden_states,
316
+ encoder_hidden_states=None,
317
+ attention_mask=attention_mask,
318
+ )
319
+
320
+ gate_msa, gate_mlp = self.norm_out(temb)
321
+ hidden_states = hidden_states + attn_output * gate_msa
322
+
323
+ ff_output = self.ff(self.norm2(hidden_states))
324
+ hidden_states = hidden_states + ff_output * gate_mlp
325
+
326
+ return hidden_states
327
+
328
+
329
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
330
+ def __init__(
331
+ self,
332
+ num_attention_heads: int,
333
+ attention_head_dim: int,
334
+ num_layers: int,
335
+ mlp_width_ratio: float = 4.0,
336
+ mlp_drop_rate: float = 0.0,
337
+ attention_bias: bool = True,
338
+ ) -> None:
339
+ super().__init__()
340
+
341
+ self.refiner_blocks = nn.ModuleList(
342
+ [
343
+ HunyuanVideoIndividualTokenRefinerBlock(
344
+ num_attention_heads=num_attention_heads,
345
+ attention_head_dim=attention_head_dim,
346
+ mlp_width_ratio=mlp_width_ratio,
347
+ mlp_drop_rate=mlp_drop_rate,
348
+ attention_bias=attention_bias,
349
+ )
350
+ for _ in range(num_layers)
351
+ ]
352
+ )
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states: torch.Tensor,
357
+ temb: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ ) -> None:
360
+ self_attn_mask = None
361
+ if attention_mask is not None:
362
+ batch_size = attention_mask.shape[0]
363
+ seq_len = attention_mask.shape[1]
364
+ attention_mask = attention_mask.to(hidden_states.device).bool()
365
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
366
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
367
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
368
+ self_attn_mask[:, :, :, 0] = True
369
+
370
+ for block in self.refiner_blocks:
371
+ hidden_states = block(hidden_states, temb, self_attn_mask)
372
+
373
+ return hidden_states
374
+
375
+
376
+ class HunyuanVideoTokenRefiner(nn.Module):
377
+ def __init__(
378
+ self,
379
+ in_channels: int,
380
+ num_attention_heads: int,
381
+ attention_head_dim: int,
382
+ num_layers: int,
383
+ mlp_ratio: float = 4.0,
384
+ mlp_drop_rate: float = 0.0,
385
+ attention_bias: bool = True,
386
+ ) -> None:
387
+ super().__init__()
388
+
389
+ hidden_size = num_attention_heads * attention_head_dim
390
+
391
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
392
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
393
+ )
394
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
395
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
396
+ num_attention_heads=num_attention_heads,
397
+ attention_head_dim=attention_head_dim,
398
+ num_layers=num_layers,
399
+ mlp_width_ratio=mlp_ratio,
400
+ mlp_drop_rate=mlp_drop_rate,
401
+ attention_bias=attention_bias,
402
+ )
403
+
404
+ def forward(
405
+ self,
406
+ hidden_states: torch.Tensor,
407
+ timestep: torch.LongTensor,
408
+ attention_mask: Optional[torch.LongTensor] = None,
409
+ ) -> torch.Tensor:
410
+ if attention_mask is None:
411
+ pooled_projections = hidden_states.mean(dim=1)
412
+ else:
413
+ original_dtype = hidden_states.dtype
414
+ mask_float = attention_mask.float().unsqueeze(-1)
415
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
416
+ pooled_projections = pooled_projections.to(original_dtype)
417
+
418
+ temb = self.time_text_embed(timestep, pooled_projections)
419
+ hidden_states = self.proj_in(hidden_states)
420
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
421
+
422
+ return hidden_states
423
+
424
+
425
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
426
+ def __init__(self, rope_dim, theta):
427
+ super().__init__()
428
+ self.DT, self.DY, self.DX = rope_dim
429
+ self.theta = theta
430
+
431
+ @torch.no_grad()
432
+ def get_frequency(self, dim, pos):
433
+ T, H, W = pos.shape
434
+ freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
435
+ freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
436
+ return freqs.cos(), freqs.sin()
437
+
438
+ @torch.no_grad()
439
+ def forward_inner(self, frame_indices, height, width, device):
440
+ GT, GY, GX = torch.meshgrid(
441
+ frame_indices.to(device=device, dtype=torch.float32),
442
+ torch.arange(0, height, device=device, dtype=torch.float32),
443
+ torch.arange(0, width, device=device, dtype=torch.float32),
444
+ indexing="ij"
445
+ )
446
+
447
+ FCT, FST = self.get_frequency(self.DT, GT)
448
+ FCY, FSY = self.get_frequency(self.DY, GY)
449
+ FCX, FSX = self.get_frequency(self.DX, GX)
450
+
451
+ result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
452
+
453
+ return result.to(device)
454
+
455
+ @torch.no_grad()
456
+ def forward(self, frame_indices, height, width, device):
457
+ frame_indices = frame_indices.unbind(0)
458
+ results = [self.forward_inner(f, height, width, device) for f in frame_indices]
459
+ results = torch.stack(results, dim=0)
460
+ return results
461
+
462
+
463
+ class AdaLayerNormZero(nn.Module):
464
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
465
+ super().__init__()
466
+ self.silu = nn.SiLU()
467
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
468
+ if norm_type == "layer_norm":
469
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
470
+ else:
471
+ raise ValueError(f"unknown norm_type {norm_type}")
472
+
473
+ def forward(
474
+ self,
475
+ x: torch.Tensor,
476
+ emb: Optional[torch.Tensor] = None,
477
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
478
+ emb = emb.unsqueeze(-2)
479
+ emb = self.linear(self.silu(emb))
480
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
481
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
482
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
483
+
484
+
485
+ class AdaLayerNormZeroSingle(nn.Module):
486
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
487
+ super().__init__()
488
+
489
+ self.silu = nn.SiLU()
490
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
491
+ if norm_type == "layer_norm":
492
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
493
+ else:
494
+ raise ValueError(f"unknown norm_type {norm_type}")
495
+
496
+ def forward(
497
+ self,
498
+ x: torch.Tensor,
499
+ emb: Optional[torch.Tensor] = None,
500
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
501
+ emb = emb.unsqueeze(-2)
502
+ emb = self.linear(self.silu(emb))
503
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
504
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
505
+ return x, gate_msa
506
+
507
+
508
+ class AdaLayerNormContinuous(nn.Module):
509
+ def __init__(
510
+ self,
511
+ embedding_dim: int,
512
+ conditioning_embedding_dim: int,
513
+ elementwise_affine=True,
514
+ eps=1e-5,
515
+ bias=True,
516
+ norm_type="layer_norm",
517
+ ):
518
+ super().__init__()
519
+ self.silu = nn.SiLU()
520
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
521
+ if norm_type == "layer_norm":
522
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
523
+ else:
524
+ raise ValueError(f"unknown norm_type {norm_type}")
525
+
526
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
527
+ emb = emb.unsqueeze(-2)
528
+ emb = self.linear(self.silu(emb))
529
+ scale, shift = emb.chunk(2, dim=-1)
530
+ x = self.norm(x) * (1 + scale) + shift
531
+ return x
532
+
533
+
534
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
535
+ def __init__(
536
+ self,
537
+ num_attention_heads: int,
538
+ attention_head_dim: int,
539
+ mlp_ratio: float = 4.0,
540
+ qk_norm: str = "rms_norm",
541
+ ) -> None:
542
+ super().__init__()
543
+
544
+ hidden_size = num_attention_heads * attention_head_dim
545
+ mlp_dim = int(hidden_size * mlp_ratio)
546
+
547
+ self.attn = Attention(
548
+ query_dim=hidden_size,
549
+ cross_attention_dim=None,
550
+ dim_head=attention_head_dim,
551
+ heads=num_attention_heads,
552
+ out_dim=hidden_size,
553
+ bias=True,
554
+ processor=HunyuanAttnProcessorFlashAttnSingle(),
555
+ qk_norm=qk_norm,
556
+ eps=1e-6,
557
+ pre_only=True,
558
+ )
559
+
560
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
561
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
562
+ self.act_mlp = nn.GELU(approximate="tanh")
563
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
564
+
565
+ def forward(
566
+ self,
567
+ hidden_states: torch.Tensor,
568
+ encoder_hidden_states: torch.Tensor,
569
+ temb: torch.Tensor,
570
+ attention_mask: Optional[torch.Tensor] = None,
571
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
572
+ ) -> torch.Tensor:
573
+ text_seq_length = encoder_hidden_states.shape[1]
574
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
575
+
576
+ residual = hidden_states
577
+
578
+ # 1. Input normalization
579
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
580
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
581
+
582
+ norm_hidden_states, norm_encoder_hidden_states = (
583
+ norm_hidden_states[:, :-text_seq_length, :],
584
+ norm_hidden_states[:, -text_seq_length:, :],
585
+ )
586
+
587
+ # 2. Attention
588
+ attn_output, context_attn_output = self.attn(
589
+ hidden_states=norm_hidden_states,
590
+ encoder_hidden_states=norm_encoder_hidden_states,
591
+ attention_mask=attention_mask,
592
+ image_rotary_emb=image_rotary_emb,
593
+ )
594
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
595
+
596
+ # 3. Modulation and residual connection
597
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
598
+ hidden_states = gate * self.proj_out(hidden_states)
599
+ hidden_states = hidden_states + residual
600
+
601
+ hidden_states, encoder_hidden_states = (
602
+ hidden_states[:, :-text_seq_length, :],
603
+ hidden_states[:, -text_seq_length:, :],
604
+ )
605
+ return hidden_states, encoder_hidden_states
606
+
607
+
608
+ class HunyuanVideoTransformerBlock(nn.Module):
609
+ def __init__(
610
+ self,
611
+ num_attention_heads: int,
612
+ attention_head_dim: int,
613
+ mlp_ratio: float,
614
+ qk_norm: str = "rms_norm",
615
+ ) -> None:
616
+ super().__init__()
617
+
618
+ hidden_size = num_attention_heads * attention_head_dim
619
+
620
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
621
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
622
+
623
+ self.attn = Attention(
624
+ query_dim=hidden_size,
625
+ cross_attention_dim=None,
626
+ added_kv_proj_dim=hidden_size,
627
+ dim_head=attention_head_dim,
628
+ heads=num_attention_heads,
629
+ out_dim=hidden_size,
630
+ context_pre_only=False,
631
+ bias=True,
632
+ processor=HunyuanAttnProcessorFlashAttnDouble(),
633
+ qk_norm=qk_norm,
634
+ eps=1e-6,
635
+ )
636
+
637
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
638
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
639
+
640
+ self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
641
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
642
+
643
+ def forward(
644
+ self,
645
+ hidden_states: torch.Tensor,
646
+ encoder_hidden_states: torch.Tensor,
647
+ temb: torch.Tensor,
648
+ attention_mask: Optional[torch.Tensor] = None,
649
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
650
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
651
+ # 1. Input normalization
652
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
653
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
654
+
655
+ # 2. Joint attention
656
+ attn_output, context_attn_output = self.attn(
657
+ hidden_states=norm_hidden_states,
658
+ encoder_hidden_states=norm_encoder_hidden_states,
659
+ attention_mask=attention_mask,
660
+ image_rotary_emb=freqs_cis,
661
+ )
662
+
663
+ # 3. Modulation and residual connection
664
+ hidden_states = hidden_states + attn_output * gate_msa
665
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
666
+
667
+ norm_hidden_states = self.norm2(hidden_states)
668
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
669
+
670
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
671
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
672
+
673
+ # 4. Feed-forward
674
+ ff_output = self.ff(norm_hidden_states)
675
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
676
+
677
+ hidden_states = hidden_states + gate_mlp * ff_output
678
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
679
+
680
+ return hidden_states, encoder_hidden_states
681
+
682
+
683
+ class ClipVisionProjection(nn.Module):
684
+ def __init__(self, in_channels, out_channels):
685
+ super().__init__()
686
+ self.up = nn.Linear(in_channels, out_channels * 3)
687
+ self.down = nn.Linear(out_channels * 3, out_channels)
688
+
689
+ def forward(self, x):
690
+ projected_x = self.down(nn.functional.silu(self.up(x)))
691
+ return projected_x
692
+
693
+
694
+ class HunyuanVideoPatchEmbed(nn.Module):
695
+ def __init__(self, patch_size, in_chans, embed_dim):
696
+ super().__init__()
697
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
698
+
699
+
700
+ class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
701
+ def __init__(self, inner_dim):
702
+ super().__init__()
703
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
704
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
705
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
706
+
707
+ @torch.no_grad()
708
+ def initialize_weight_from_another_conv3d(self, another_layer):
709
+ weight = another_layer.weight.detach().clone()
710
+ bias = another_layer.bias.detach().clone()
711
+
712
+ sd = {
713
+ 'proj.weight': weight.clone(),
714
+ 'proj.bias': bias.clone(),
715
+ 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
716
+ 'proj_2x.bias': bias.clone(),
717
+ 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
718
+ 'proj_4x.bias': bias.clone(),
719
+ }
720
+
721
+ sd = {k: v.clone() for k, v in sd.items()}
722
+
723
+ self.load_state_dict(sd)
724
+ return
725
+
726
+
727
+ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
728
+ @register_to_config
729
+ def __init__(
730
+ self,
731
+ in_channels: int = 16,
732
+ out_channels: int = 16,
733
+ num_attention_heads: int = 24,
734
+ attention_head_dim: int = 128,
735
+ num_layers: int = 20,
736
+ num_single_layers: int = 40,
737
+ num_refiner_layers: int = 2,
738
+ mlp_ratio: float = 4.0,
739
+ patch_size: int = 2,
740
+ patch_size_t: int = 1,
741
+ qk_norm: str = "rms_norm",
742
+ guidance_embeds: bool = True,
743
+ text_embed_dim: int = 4096,
744
+ pooled_projection_dim: int = 768,
745
+ rope_theta: float = 256.0,
746
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
747
+ has_image_proj=False,
748
+ image_proj_dim=1152,
749
+ has_clean_x_embedder=False,
750
+ ) -> None:
751
+ super().__init__()
752
+
753
+ inner_dim = num_attention_heads * attention_head_dim
754
+ out_channels = out_channels or in_channels
755
+
756
+ # 1. Latent and condition embedders
757
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
758
+ self.context_embedder = HunyuanVideoTokenRefiner(
759
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
760
+ )
761
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
762
+
763
+ self.clean_x_embedder = None
764
+ self.image_projection = None
765
+
766
+ # 2. RoPE
767
+ self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
768
+
769
+ # 3. Dual stream transformer blocks
770
+ self.transformer_blocks = nn.ModuleList(
771
+ [
772
+ HunyuanVideoTransformerBlock(
773
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
774
+ )
775
+ for _ in range(num_layers)
776
+ ]
777
+ )
778
+
779
+ # 4. Single stream transformer blocks
780
+ self.single_transformer_blocks = nn.ModuleList(
781
+ [
782
+ HunyuanVideoSingleTransformerBlock(
783
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
784
+ )
785
+ for _ in range(num_single_layers)
786
+ ]
787
+ )
788
+
789
+ # 5. Output projection
790
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
791
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
792
+
793
+ self.inner_dim = inner_dim
794
+ self.use_gradient_checkpointing = False
795
+ self.enable_teacache = False
796
+
797
+ if has_image_proj:
798
+ self.install_image_projection(image_proj_dim)
799
+
800
+ if has_clean_x_embedder:
801
+ self.install_clean_x_embedder()
802
+
803
+ self.high_quality_fp32_output_for_inference = False
804
+
805
+ def install_image_projection(self, in_channels):
806
+ self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
807
+ self.config['has_image_proj'] = True
808
+ self.config['image_proj_dim'] = in_channels
809
+
810
+ def install_clean_x_embedder(self):
811
+ self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
812
+ self.config['has_clean_x_embedder'] = True
813
+
814
+ def enable_gradient_checkpointing(self):
815
+ self.use_gradient_checkpointing = True
816
+ print('self.use_gradient_checkpointing = True')
817
+
818
+ def disable_gradient_checkpointing(self):
819
+ self.use_gradient_checkpointing = False
820
+ print('self.use_gradient_checkpointing = False')
821
+
822
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
823
+ self.enable_teacache = enable_teacache
824
+ self.cnt = 0
825
+ self.num_steps = num_steps
826
+ self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
827
+ self.accumulated_rel_l1_distance = 0
828
+ self.previous_modulated_input = None
829
+ self.previous_residual = None
830
+ self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
831
+
832
+ def gradient_checkpointing_method(self, block, *args):
833
+ if self.use_gradient_checkpointing:
834
+ result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
835
+ else:
836
+ result = block(*args)
837
+ return result
838
+
839
+ def process_input_hidden_states(
840
+ self,
841
+ latents, latent_indices=None,
842
+ clean_latents=None, clean_latent_indices=None,
843
+ clean_latents_2x=None, clean_latent_2x_indices=None,
844
+ clean_latents_4x=None, clean_latent_4x_indices=None
845
+ ):
846
+ hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
847
+ B, C, T, H, W = hidden_states.shape
848
+
849
+ if latent_indices is None:
850
+ latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
851
+
852
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
853
+
854
+ rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
855
+ rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
856
+
857
+ if clean_latents is not None and clean_latent_indices is not None:
858
+ clean_latents = clean_latents.to(hidden_states)
859
+ clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
860
+ clean_latents = clean_latents.flatten(2).transpose(1, 2)
861
+
862
+ clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
863
+ clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
864
+
865
+ hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
866
+ rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
867
+
868
+ if clean_latents_2x is not None and clean_latent_2x_indices is not None:
869
+ clean_latents_2x = clean_latents_2x.to(hidden_states)
870
+ clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
871
+ clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
872
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
873
+
874
+ clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
875
+ clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
876
+ clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
877
+ clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
878
+
879
+ hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
880
+ rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
881
+
882
+ if clean_latents_4x is not None and clean_latent_4x_indices is not None:
883
+ clean_latents_4x = clean_latents_4x.to(hidden_states)
884
+ clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
885
+ clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
886
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
887
+
888
+ clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
889
+ clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
890
+ clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
891
+ clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
892
+
893
+ hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
894
+ rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
895
+
896
+ return hidden_states, rope_freqs
897
+
898
+ def forward(
899
+ self,
900
+ hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
901
+ latent_indices=None,
902
+ clean_latents=None, clean_latent_indices=None,
903
+ clean_latents_2x=None, clean_latent_2x_indices=None,
904
+ clean_latents_4x=None, clean_latent_4x_indices=None,
905
+ image_embeddings=None,
906
+ attention_kwargs=None, return_dict=True
907
+ ):
908
+
909
+ if attention_kwargs is None:
910
+ attention_kwargs = {}
911
+
912
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
913
+ p, p_t = self.config['patch_size'], self.config['patch_size_t']
914
+ post_patch_num_frames = num_frames // p_t
915
+ post_patch_height = height // p
916
+ post_patch_width = width // p
917
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
918
+
919
+ hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
920
+
921
+ temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
922
+ encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
923
+
924
+ if self.image_projection is not None:
925
+ assert image_embeddings is not None, 'You must use image embeddings!'
926
+ extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
927
+ extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
928
+
929
+ # must cat before (not after) encoder_hidden_states, due to attn masking
930
+ encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
931
+ encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
932
+
933
+ if batch_size == 1:
934
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
935
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
936
+ text_len = encoder_attention_mask.sum().item()
937
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
938
+ attention_mask = None, None, None, None
939
+ else:
940
+ img_seq_len = hidden_states.shape[1]
941
+ txt_seq_len = encoder_hidden_states.shape[1]
942
+
943
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
944
+ cu_seqlens_kv = cu_seqlens_q
945
+ max_seqlen_q = img_seq_len + txt_seq_len
946
+ max_seqlen_kv = max_seqlen_q
947
+
948
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
949
+
950
+ if self.enable_teacache:
951
+ modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
952
+
953
+ if self.cnt == 0 or self.cnt == self.num_steps-1:
954
+ should_calc = True
955
+ self.accumulated_rel_l1_distance = 0
956
+ else:
957
+ curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
958
+ self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
959
+ should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
960
+
961
+ if should_calc:
962
+ self.accumulated_rel_l1_distance = 0
963
+
964
+ self.previous_modulated_input = modulated_inp
965
+ self.cnt += 1
966
+
967
+ if self.cnt == self.num_steps:
968
+ self.cnt = 0
969
+
970
+ if not should_calc:
971
+ hidden_states = hidden_states + self.previous_residual
972
+ else:
973
+ ori_hidden_states = hidden_states.clone()
974
+
975
+ for block_id, block in enumerate(self.transformer_blocks):
976
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
977
+ block,
978
+ hidden_states,
979
+ encoder_hidden_states,
980
+ temb,
981
+ attention_mask,
982
+ rope_freqs
983
+ )
984
+
985
+ for block_id, block in enumerate(self.single_transformer_blocks):
986
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
987
+ block,
988
+ hidden_states,
989
+ encoder_hidden_states,
990
+ temb,
991
+ attention_mask,
992
+ rope_freqs
993
+ )
994
+
995
+ self.previous_residual = hidden_states - ori_hidden_states
996
+ else:
997
+ for block_id, block in enumerate(self.transformer_blocks):
998
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
999
+ block,
1000
+ hidden_states,
1001
+ encoder_hidden_states,
1002
+ temb,
1003
+ attention_mask,
1004
+ rope_freqs
1005
+ )
1006
+
1007
+ for block_id, block in enumerate(self.single_transformer_blocks):
1008
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1009
+ block,
1010
+ hidden_states,
1011
+ encoder_hidden_states,
1012
+ temb,
1013
+ attention_mask,
1014
+ rope_freqs
1015
+ )
1016
+
1017
+ hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1018
+
1019
+ hidden_states = hidden_states[:, -original_context_length:, :]
1020
+
1021
+ if self.high_quality_fp32_output_for_inference:
1022
+ hidden_states = hidden_states.to(dtype=torch.float32)
1023
+ if self.proj_out.weight.dtype != torch.float32:
1024
+ self.proj_out.to(dtype=torch.float32)
1025
+
1026
+ hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1027
+
1028
+ hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1029
+ t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1030
+ pt=p_t, ph=p, pw=p)
1031
+
1032
+ if return_dict:
1033
+ return Transformer2DModelOutput(sample=hidden_states)
1034
+
1035
+ return hidden_states,
diffusers_helper/pipelines/__pycache__/k_diffusion_hunyuan.cpython-310.pyc ADDED
Binary file (2.88 kB). View file
 
diffusers_helper/pipelines/k_diffusion_hunyuan.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
5
+ from diffusers_helper.k_diffusion.wrapper import fm_wrapper
6
+ from diffusers_helper.utils import repeat_to_batch_size
7
+
8
+
9
+ def flux_time_shift(t, mu=1.15, sigma=1.0):
10
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
11
+
12
+
13
+ def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
14
+ k = (y2 - y1) / (x2 - x1)
15
+ b = y1 - k * x1
16
+ mu = k * context_length + b
17
+ mu = min(mu, math.log(exp_max))
18
+ return mu
19
+
20
+
21
+ def get_flux_sigmas_from_mu(n, mu):
22
+ sigmas = torch.linspace(1, 0, steps=n + 1)
23
+ sigmas = flux_time_shift(sigmas, mu=mu)
24
+ return sigmas
25
+
26
+
27
+ @torch.inference_mode()
28
+ def sample_hunyuan(
29
+ transformer,
30
+ sampler='unipc',
31
+ initial_latent=None,
32
+ concat_latent=None,
33
+ strength=1.0,
34
+ width=512,
35
+ height=512,
36
+ frames=16,
37
+ real_guidance_scale=1.0,
38
+ distilled_guidance_scale=6.0,
39
+ guidance_rescale=0.0,
40
+ shift=None,
41
+ num_inference_steps=25,
42
+ batch_size=None,
43
+ generator=None,
44
+ prompt_embeds=None,
45
+ prompt_embeds_mask=None,
46
+ prompt_poolers=None,
47
+ negative_prompt_embeds=None,
48
+ negative_prompt_embeds_mask=None,
49
+ negative_prompt_poolers=None,
50
+ dtype=torch.bfloat16,
51
+ device=None,
52
+ negative_kwargs=None,
53
+ callback=None,
54
+ **kwargs,
55
+ ):
56
+ device = device or transformer.device
57
+
58
+ if batch_size is None:
59
+ batch_size = int(prompt_embeds.shape[0])
60
+
61
+ latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32)
62
+
63
+ B, C, T, H, W = latents.shape
64
+ seq_length = T * H * W // 4
65
+
66
+ if shift is None:
67
+ mu = calculate_flux_mu(seq_length, exp_max=7.0)
68
+ else:
69
+ mu = math.log(shift)
70
+
71
+ sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
72
+
73
+ k_model = fm_wrapper(transformer)
74
+
75
+ if initial_latent is not None:
76
+ sigmas = sigmas * strength
77
+ first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
78
+ initial_latent = initial_latent.to(device=device, dtype=torch.float32)
79
+ latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
80
+
81
+ if concat_latent is not None:
82
+ concat_latent = concat_latent.to(latents)
83
+
84
+ distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
85
+
86
+ prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
87
+ prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
88
+ prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
89
+ negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
90
+ negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
91
+ negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
92
+ concat_latent = repeat_to_batch_size(concat_latent, batch_size)
93
+
94
+ sampler_kwargs = dict(
95
+ dtype=dtype,
96
+ cfg_scale=real_guidance_scale,
97
+ cfg_rescale=guidance_rescale,
98
+ concat_latent=concat_latent,
99
+ positive=dict(
100
+ pooled_projections=prompt_poolers,
101
+ encoder_hidden_states=prompt_embeds,
102
+ encoder_attention_mask=prompt_embeds_mask,
103
+ guidance=distilled_guidance,
104
+ **kwargs,
105
+ ),
106
+ negative=dict(
107
+ pooled_projections=negative_prompt_poolers,
108
+ encoder_hidden_states=negative_prompt_embeds,
109
+ encoder_attention_mask=negative_prompt_embeds_mask,
110
+ guidance=distilled_guidance,
111
+ **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
112
+ )
113
+ )
114
+
115
+ if sampler == 'unipc':
116
+ results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
117
+ else:
118
+ raise NotImplementedError(f'Sampler {sampler} is not supported.')
119
+
120
+ return results
diffusers_helper/thread_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from threading import Thread, Lock
4
+
5
+
6
+ class Listener:
7
+ task_queue = []
8
+ lock = Lock()
9
+ thread = None
10
+
11
+ @classmethod
12
+ def _process_tasks(cls):
13
+ while True:
14
+ task = None
15
+ with cls.lock:
16
+ if cls.task_queue:
17
+ task = cls.task_queue.pop(0)
18
+
19
+ if task is None:
20
+ time.sleep(0.001)
21
+ continue
22
+
23
+ func, args, kwargs = task
24
+ try:
25
+ func(*args, **kwargs)
26
+ except Exception as e:
27
+ print(f"Error in listener thread: {e}")
28
+
29
+ @classmethod
30
+ def add_task(cls, func, *args, **kwargs):
31
+ with cls.lock:
32
+ cls.task_queue.append((func, args, kwargs))
33
+
34
+ if cls.thread is None:
35
+ cls.thread = Thread(target=cls._process_tasks, daemon=True)
36
+ cls.thread.start()
37
+
38
+
39
+ def async_run(func, *args, **kwargs):
40
+ Listener.add_task(func, *args, **kwargs)
41
+
42
+
43
+ class FIFOQueue:
44
+ def __init__(self):
45
+ self.queue = []
46
+ self.lock = Lock()
47
+
48
+ def push(self, item):
49
+ with self.lock:
50
+ self.queue.append(item)
51
+
52
+ def pop(self):
53
+ with self.lock:
54
+ if self.queue:
55
+ return self.queue.pop(0)
56
+ return None
57
+
58
+ def top(self):
59
+ with self.lock:
60
+ if self.queue:
61
+ return self.queue[0]
62
+ return None
63
+
64
+ def next(self):
65
+ while True:
66
+ with self.lock:
67
+ if self.queue:
68
+ return self.queue.pop(0)
69
+
70
+ time.sleep(0.001)
71
+
72
+
73
+ class AsyncStream:
74
+ def __init__(self):
75
+ self.input_queue = FIFOQueue()
76
+ self.output_queue = FIFOQueue()
diffusers_helper/utils.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import random
5
+ import glob
6
+ import torch
7
+ import einops
8
+ import numpy as np
9
+ import datetime
10
+ import torchvision
11
+
12
+ import safetensors.torch as sf
13
+ from PIL import Image
14
+
15
+
16
+ def min_resize(x, m):
17
+ if x.shape[0] < x.shape[1]:
18
+ s0 = m
19
+ s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
20
+ else:
21
+ s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
22
+ s1 = m
23
+ new_max = max(s1, s0)
24
+ raw_max = max(x.shape[0], x.shape[1])
25
+ if new_max < raw_max:
26
+ interpolation = cv2.INTER_AREA
27
+ else:
28
+ interpolation = cv2.INTER_LANCZOS4
29
+ y = cv2.resize(x, (s1, s0), interpolation=interpolation)
30
+ return y
31
+
32
+
33
+ def d_resize(x, y):
34
+ H, W, C = y.shape
35
+ new_min = min(H, W)
36
+ raw_min = min(x.shape[0], x.shape[1])
37
+ if new_min < raw_min:
38
+ interpolation = cv2.INTER_AREA
39
+ else:
40
+ interpolation = cv2.INTER_LANCZOS4
41
+ y = cv2.resize(x, (W, H), interpolation=interpolation)
42
+ return y
43
+
44
+
45
+ def resize_and_center_crop(image, target_width, target_height):
46
+ if target_height == image.shape[0] and target_width == image.shape[1]:
47
+ return image
48
+
49
+ pil_image = Image.fromarray(image)
50
+ original_width, original_height = pil_image.size
51
+ scale_factor = max(target_width / original_width, target_height / original_height)
52
+ resized_width = int(round(original_width * scale_factor))
53
+ resized_height = int(round(original_height * scale_factor))
54
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
55
+ left = (resized_width - target_width) / 2
56
+ top = (resized_height - target_height) / 2
57
+ right = (resized_width + target_width) / 2
58
+ bottom = (resized_height + target_height) / 2
59
+ cropped_image = resized_image.crop((left, top, right, bottom))
60
+ return np.array(cropped_image)
61
+
62
+
63
+ def resize_and_center_crop_pytorch(image, target_width, target_height):
64
+ B, C, H, W = image.shape
65
+
66
+ if H == target_height and W == target_width:
67
+ return image
68
+
69
+ scale_factor = max(target_width / W, target_height / H)
70
+ resized_width = int(round(W * scale_factor))
71
+ resized_height = int(round(H * scale_factor))
72
+
73
+ resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
74
+
75
+ top = (resized_height - target_height) // 2
76
+ left = (resized_width - target_width) // 2
77
+ cropped = resized[:, :, top:top + target_height, left:left + target_width]
78
+
79
+ return cropped
80
+
81
+
82
+ def resize_without_crop(image, target_width, target_height):
83
+ if target_height == image.shape[0] and target_width == image.shape[1]:
84
+ return image
85
+
86
+ pil_image = Image.fromarray(image)
87
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
88
+ return np.array(resized_image)
89
+
90
+
91
+ def just_crop(image, w, h):
92
+ if h == image.shape[0] and w == image.shape[1]:
93
+ return image
94
+
95
+ original_height, original_width = image.shape[:2]
96
+ k = min(original_height / h, original_width / w)
97
+ new_width = int(round(w * k))
98
+ new_height = int(round(h * k))
99
+ x_start = (original_width - new_width) // 2
100
+ y_start = (original_height - new_height) // 2
101
+ cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
102
+ return cropped_image
103
+
104
+
105
+ def write_to_json(data, file_path):
106
+ temp_file_path = file_path + ".tmp"
107
+ with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
108
+ json.dump(data, temp_file, indent=4)
109
+ os.replace(temp_file_path, file_path)
110
+ return
111
+
112
+
113
+ def read_from_json(file_path):
114
+ with open(file_path, 'rt', encoding='utf-8') as file:
115
+ data = json.load(file)
116
+ return data
117
+
118
+
119
+ def get_active_parameters(m):
120
+ return {k: v for k, v in m.named_parameters() if v.requires_grad}
121
+
122
+
123
+ def cast_training_params(m, dtype=torch.float32):
124
+ result = {}
125
+ for n, param in m.named_parameters():
126
+ if param.requires_grad:
127
+ param.data = param.to(dtype)
128
+ result[n] = param
129
+ return result
130
+
131
+
132
+ def separate_lora_AB(parameters, B_patterns=None):
133
+ parameters_normal = {}
134
+ parameters_B = {}
135
+
136
+ if B_patterns is None:
137
+ B_patterns = ['.lora_B.', '__zero__']
138
+
139
+ for k, v in parameters.items():
140
+ if any(B_pattern in k for B_pattern in B_patterns):
141
+ parameters_B[k] = v
142
+ else:
143
+ parameters_normal[k] = v
144
+
145
+ return parameters_normal, parameters_B
146
+
147
+
148
+ def set_attr_recursive(obj, attr, value):
149
+ attrs = attr.split(".")
150
+ for name in attrs[:-1]:
151
+ obj = getattr(obj, name)
152
+ setattr(obj, attrs[-1], value)
153
+ return
154
+
155
+
156
+ def print_tensor_list_size(tensors):
157
+ total_size = 0
158
+ total_elements = 0
159
+
160
+ if isinstance(tensors, dict):
161
+ tensors = tensors.values()
162
+
163
+ for tensor in tensors:
164
+ total_size += tensor.nelement() * tensor.element_size()
165
+ total_elements += tensor.nelement()
166
+
167
+ total_size_MB = total_size / (1024 ** 2)
168
+ total_elements_B = total_elements / 1e9
169
+
170
+ print(f"Total number of tensors: {len(tensors)}")
171
+ print(f"Total size of tensors: {total_size_MB:.2f} MB")
172
+ print(f"Total number of parameters: {total_elements_B:.3f} billion")
173
+ return
174
+
175
+
176
+ @torch.no_grad()
177
+ def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
178
+ batch_size = a.size(0)
179
+
180
+ if b is None:
181
+ b = torch.zeros_like(a)
182
+
183
+ if mask_a is None:
184
+ mask_a = torch.rand(batch_size) < probability_a
185
+
186
+ mask_a = mask_a.to(a.device)
187
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
188
+ result = torch.where(mask_a, a, b)
189
+ return result
190
+
191
+
192
+ @torch.no_grad()
193
+ def zero_module(module):
194
+ for p in module.parameters():
195
+ p.detach().zero_()
196
+ return module
197
+
198
+
199
+ @torch.no_grad()
200
+ def supress_lower_channels(m, k, alpha=0.01):
201
+ data = m.weight.data.clone()
202
+
203
+ assert int(data.shape[1]) >= k
204
+
205
+ data[:, :k] = data[:, :k] * alpha
206
+ m.weight.data = data.contiguous().clone()
207
+ return m
208
+
209
+
210
+ def freeze_module(m):
211
+ if not hasattr(m, '_forward_inside_frozen_module'):
212
+ m._forward_inside_frozen_module = m.forward
213
+ m.requires_grad_(False)
214
+ m.forward = torch.no_grad()(m.forward)
215
+ return m
216
+
217
+
218
+ def get_latest_safetensors(folder_path):
219
+ safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
220
+
221
+ if not safetensors_files:
222
+ raise ValueError('No file to resume!')
223
+
224
+ latest_file = max(safetensors_files, key=os.path.getmtime)
225
+ latest_file = os.path.abspath(os.path.realpath(latest_file))
226
+ return latest_file
227
+
228
+
229
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
230
+ tags = tags_str.split(', ')
231
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
232
+ prompt = ', '.join(tags)
233
+ return prompt
234
+
235
+
236
+ def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
237
+ numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
238
+ if round_to_int:
239
+ numbers = np.round(numbers).astype(int)
240
+ return numbers.tolist()
241
+
242
+
243
+ def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
244
+ edges = np.linspace(0, 1, n + 1)
245
+ points = np.random.uniform(edges[:-1], edges[1:])
246
+ numbers = inclusive + (exclusive - inclusive) * points
247
+ if round_to_int:
248
+ numbers = np.round(numbers).astype(int)
249
+ return numbers.tolist()
250
+
251
+
252
+ def soft_append_bcthw(history, current, overlap=0):
253
+ if overlap <= 0:
254
+ return torch.cat([history, current], dim=2)
255
+
256
+ assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
257
+ assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
258
+
259
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
260
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
261
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
262
+
263
+ return output.to(history)
264
+
265
+
266
+ def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
267
+ b, c, t, h, w = x.shape
268
+
269
+ per_row = b
270
+ for p in [6, 5, 4, 3, 2]:
271
+ if b % p == 0:
272
+ per_row = p
273
+ break
274
+
275
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
276
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
277
+ x = x.detach().cpu().to(torch.uint8)
278
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
279
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
280
+ return x
281
+
282
+
283
+ def save_bcthw_as_png(x, output_filename):
284
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
285
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
286
+ x = x.detach().cpu().to(torch.uint8)
287
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
288
+ torchvision.io.write_png(x, output_filename)
289
+ return output_filename
290
+
291
+
292
+ def save_bchw_as_png(x, output_filename):
293
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
294
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
295
+ x = x.detach().cpu().to(torch.uint8)
296
+ x = einops.rearrange(x, 'b c h w -> c h (b w)')
297
+ torchvision.io.write_png(x, output_filename)
298
+ return output_filename
299
+
300
+
301
+ def add_tensors_with_padding(tensor1, tensor2):
302
+ if tensor1.shape == tensor2.shape:
303
+ return tensor1 + tensor2
304
+
305
+ shape1 = tensor1.shape
306
+ shape2 = tensor2.shape
307
+
308
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
309
+
310
+ padded_tensor1 = torch.zeros(new_shape)
311
+ padded_tensor2 = torch.zeros(new_shape)
312
+
313
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
314
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
315
+
316
+ result = padded_tensor1 + padded_tensor2
317
+ return result
318
+
319
+
320
+ def print_free_mem():
321
+ torch.cuda.empty_cache()
322
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
323
+ free_mem_mb = free_mem / (1024 ** 2)
324
+ total_mem_mb = total_mem / (1024 ** 2)
325
+ print(f"Free memory: {free_mem_mb:.2f} MB")
326
+ print(f"Total memory: {total_mem_mb:.2f} MB")
327
+ return
328
+
329
+
330
+ def print_gpu_parameters(device, state_dict, log_count=1):
331
+ summary = {"device": device, "keys_count": len(state_dict)}
332
+
333
+ logged_params = {}
334
+ for i, (key, tensor) in enumerate(state_dict.items()):
335
+ if i >= log_count:
336
+ break
337
+ logged_params[key] = tensor.flatten()[:3].tolist()
338
+
339
+ summary["params"] = logged_params
340
+
341
+ print(str(summary))
342
+ return
343
+
344
+
345
+ def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
346
+ from PIL import Image, ImageDraw, ImageFont
347
+
348
+ txt = Image.new("RGB", (width, height), color="white")
349
+ draw = ImageDraw.Draw(txt)
350
+ font = ImageFont.truetype(font_path, size=size)
351
+
352
+ if text == '':
353
+ return np.array(txt)
354
+
355
+ # Split text into lines that fit within the image width
356
+ lines = []
357
+ words = text.split()
358
+ current_line = words[0]
359
+
360
+ for word in words[1:]:
361
+ line_with_word = f"{current_line} {word}"
362
+ if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
363
+ current_line = line_with_word
364
+ else:
365
+ lines.append(current_line)
366
+ current_line = word
367
+
368
+ lines.append(current_line)
369
+
370
+ # Draw the text line by line
371
+ y = 0
372
+ line_height = draw.textbbox((0, 0), "A", font=font)[3]
373
+
374
+ for line in lines:
375
+ if y + line_height > height:
376
+ break # stop drawing if the next line will be outside the image
377
+ draw.text((0, y), line, fill="black", font=font)
378
+ y += line_height
379
+
380
+ return np.array(txt)
381
+
382
+
383
+ def blue_mark(x):
384
+ x = x.copy()
385
+ c = x[:, :, 2]
386
+ b = cv2.blur(c, (9, 9))
387
+ x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
388
+ return x
389
+
390
+
391
+ def green_mark(x):
392
+ x = x.copy()
393
+ x[:, :, 2] = -1
394
+ x[:, :, 0] = -1
395
+ return x
396
+
397
+
398
+ def frame_mark(x):
399
+ x = x.copy()
400
+ x[:64] = -1
401
+ x[-64:] = -1
402
+ x[:, :8] = 1
403
+ x[:, -8:] = 1
404
+ return x
405
+
406
+
407
+ @torch.inference_mode()
408
+ def pytorch2numpy(imgs):
409
+ results = []
410
+ for x in imgs:
411
+ y = x.movedim(0, -1)
412
+ y = y * 127.5 + 127.5
413
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
414
+ results.append(y)
415
+ return results
416
+
417
+
418
+ @torch.inference_mode()
419
+ def numpy2pytorch(imgs):
420
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
421
+ h = h.movedim(-1, 1)
422
+ return h
423
+
424
+
425
+ @torch.no_grad()
426
+ def duplicate_prefix_to_suffix(x, count, zero_out=False):
427
+ if zero_out:
428
+ return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
429
+ else:
430
+ return torch.cat([x, x[:count]], dim=0)
431
+
432
+
433
+ def weighted_mse(a, b, weight):
434
+ return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
435
+
436
+
437
+ def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
438
+ x = (x - x_min) / (x_max - x_min)
439
+ x = max(0.0, min(x, 1.0))
440
+ x = x ** sigma
441
+ return y_min + x * (y_max - y_min)
442
+
443
+
444
+ def expand_to_dims(x, target_dims):
445
+ return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
446
+
447
+
448
+ def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
449
+ if tensor is None:
450
+ return None
451
+
452
+ first_dim = tensor.shape[0]
453
+
454
+ if first_dim == batch_size:
455
+ return tensor
456
+
457
+ if batch_size % first_dim != 0:
458
+ raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
459
+
460
+ repeat_times = batch_size // first_dim
461
+
462
+ return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
463
+
464
+
465
+ def dim5(x):
466
+ return expand_to_dims(x, 5)
467
+
468
+
469
+ def dim4(x):
470
+ return expand_to_dims(x, 4)
471
+
472
+
473
+ def dim3(x):
474
+ return expand_to_dims(x, 3)
475
+
476
+
477
+ def crop_or_pad_yield_mask(x, length):
478
+ B, F, C = x.shape
479
+ device = x.device
480
+ dtype = x.dtype
481
+
482
+ if F < length:
483
+ y = torch.zeros((B, length, C), dtype=dtype, device=device)
484
+ mask = torch.zeros((B, length), dtype=torch.bool, device=device)
485
+ y[:, :F, :] = x
486
+ mask[:, :F] = True
487
+ return y, mask
488
+
489
+ return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
490
+
491
+
492
+ def extend_dim(x, dim, minimal_length, zero_pad=False):
493
+ original_length = int(x.shape[dim])
494
+
495
+ if original_length >= minimal_length:
496
+ return x
497
+
498
+ if zero_pad:
499
+ padding_shape = list(x.shape)
500
+ padding_shape[dim] = minimal_length - original_length
501
+ padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
502
+ else:
503
+ idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
504
+ last_element = x[idx]
505
+ padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
506
+
507
+ return torch.cat([x, padding], dim=dim)
508
+
509
+
510
+ def lazy_positional_encoding(t, repeats=None):
511
+ if not isinstance(t, list):
512
+ t = [t]
513
+
514
+ from diffusers.models.embeddings import get_timestep_embedding
515
+
516
+ te = torch.tensor(t)
517
+ te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
518
+
519
+ if repeats is None:
520
+ return te
521
+
522
+ te = te[:, None, :].expand(-1, repeats, -1)
523
+
524
+ return te
525
+
526
+
527
+ def state_dict_offset_merge(A, B, C=None):
528
+ result = {}
529
+ keys = A.keys()
530
+
531
+ for key in keys:
532
+ A_value = A[key]
533
+ B_value = B[key].to(A_value)
534
+
535
+ if C is None:
536
+ result[key] = A_value + B_value
537
+ else:
538
+ C_value = C[key].to(A_value)
539
+ result[key] = A_value + B_value - C_value
540
+
541
+ return result
542
+
543
+
544
+ def state_dict_weighted_merge(state_dicts, weights):
545
+ if len(state_dicts) != len(weights):
546
+ raise ValueError("Number of state dictionaries must match number of weights")
547
+
548
+ if not state_dicts:
549
+ return {}
550
+
551
+ total_weight = sum(weights)
552
+
553
+ if total_weight == 0:
554
+ raise ValueError("Sum of weights cannot be zero")
555
+
556
+ normalized_weights = [w / total_weight for w in weights]
557
+
558
+ keys = state_dicts[0].keys()
559
+ result = {}
560
+
561
+ for key in keys:
562
+ result[key] = state_dicts[0][key] * normalized_weights[0]
563
+
564
+ for i in range(1, len(state_dicts)):
565
+ state_dict_value = state_dicts[i][key].to(result[key])
566
+ result[key] += state_dict_value * normalized_weights[i]
567
+
568
+ return result
569
+
570
+
571
+ def group_files_by_folder(all_files):
572
+ grouped_files = {}
573
+
574
+ for file in all_files:
575
+ folder_name = os.path.basename(os.path.dirname(file))
576
+ if folder_name not in grouped_files:
577
+ grouped_files[folder_name] = []
578
+ grouped_files[folder_name].append(file)
579
+
580
+ list_of_lists = list(grouped_files.values())
581
+ return list_of_lists
582
+
583
+
584
+ def generate_timestamp():
585
+ now = datetime.datetime.now()
586
+ timestamp = now.strftime('%y%m%d_%H%M%S')
587
+ milliseconds = f"{int(now.microsecond / 1000):03d}"
588
+ random_number = random.randint(0, 9999)
589
+ return f"{timestamp}_{milliseconds}_{random_number}"
590
+
591
+
592
+ def write_PIL_image_with_png_info(image, metadata, path):
593
+ from PIL.PngImagePlugin import PngInfo
594
+
595
+ png_info = PngInfo()
596
+ for key, value in metadata.items():
597
+ png_info.add_text(key, value)
598
+
599
+ image.save(path, "PNG", pnginfo=png_info)
600
+ return image
601
+
602
+
603
+ def torch_safe_save(content, path):
604
+ torch.save(content, path + '_tmp')
605
+ os.replace(path + '_tmp', path)
606
+ return path
607
+
608
+
609
+ def move_optimizer_to_device(optimizer, device):
610
+ for state in optimizer.state.values():
611
+ for k, v in state.items():
612
+ if isinstance(v, torch.Tensor):
613
+ state[k] = v.to(device)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.6.0
2
+ diffusers==0.33.1
3
+ transformers==4.46.2
4
+ gradio==5.23.0
5
+ sentencepiece==0.2.0
6
+ pillow==11.1.0
7
+ av==12.1.0
8
+ numpy==1.26.2
9
+ scipy==1.12.0
10
+ requests==2.31.0
11
+ torchsde==0.2.6
12
+
13
+ einops
14
+ opencv-contrib-python
15
+ safetensors