Fabrice-TIERCELIN commited on
Commit
e2bfe5b
·
verified ·
1 Parent(s): 481c850

Delete diffusers_helper

Browse files
diffusers_helper/bucket_tools.py DELETED
@@ -1,30 +0,0 @@
1
- bucket_options = {
2
- 640: [
3
- (416, 960),
4
- (448, 864),
5
- (480, 832),
6
- (512, 768),
7
- (544, 704),
8
- (576, 672),
9
- (608, 640),
10
- (640, 608),
11
- (672, 576),
12
- (704, 544),
13
- (768, 512),
14
- (832, 480),
15
- (864, 448),
16
- (960, 416),
17
- ],
18
- }
19
-
20
-
21
- def find_nearest_bucket(h, w, resolution=640):
22
- min_metric = float('inf')
23
- best_bucket = None
24
- for (bucket_h, bucket_w) in bucket_options[resolution]:
25
- metric = abs(h * bucket_w - w * bucket_h)
26
- if metric <= min_metric:
27
- min_metric = metric
28
- best_bucket = (bucket_h, bucket_w)
29
- return best_bucket
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/clip_vision.py DELETED
@@ -1,12 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- def hf_clip_vision_encode(image, feature_extractor, image_encoder):
5
- assert isinstance(image, np.ndarray)
6
- assert image.ndim == 3 and image.shape[2] == 3
7
- assert image.dtype == np.uint8
8
-
9
- preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
10
- image_encoder_output = image_encoder(**preprocessed)
11
-
12
- return image_encoder_output
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/dit_common.py DELETED
@@ -1,53 +0,0 @@
1
- import torch
2
- import accelerate.accelerator
3
-
4
- from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
5
-
6
-
7
- accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
8
-
9
-
10
- def LayerNorm_forward(self, x):
11
- return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
12
-
13
-
14
- LayerNorm.forward = LayerNorm_forward
15
- torch.nn.LayerNorm.forward = LayerNorm_forward
16
-
17
-
18
- def FP32LayerNorm_forward(self, x):
19
- origin_dtype = x.dtype
20
- return torch.nn.functional.layer_norm(
21
- x.float(),
22
- self.normalized_shape,
23
- self.weight.float() if self.weight is not None else None,
24
- self.bias.float() if self.bias is not None else None,
25
- self.eps,
26
- ).to(origin_dtype)
27
-
28
-
29
- FP32LayerNorm.forward = FP32LayerNorm_forward
30
-
31
-
32
- def RMSNorm_forward(self, hidden_states):
33
- input_dtype = hidden_states.dtype
34
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
35
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
36
-
37
- if self.weight is None:
38
- return hidden_states.to(input_dtype)
39
-
40
- return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
41
-
42
-
43
- RMSNorm.forward = RMSNorm_forward
44
-
45
-
46
- def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
47
- emb = self.linear(self.silu(conditioning_embedding))
48
- scale, shift = emb.chunk(2, dim=1)
49
- x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
50
- return x
51
-
52
-
53
- AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/gradio/progress_bar.py DELETED
@@ -1,86 +0,0 @@
1
- progress_html = '''
2
- <div class="loader-container">
3
- <div class="loader"></div>
4
- <div class="progress-container">
5
- <progress value="*number*" max="100"></progress>
6
- </div>
7
- <span>*text*</span>
8
- </div>
9
- '''
10
-
11
- css = '''
12
- .loader-container {
13
- display: flex; /* Use flex to align items horizontally */
14
- align-items: center; /* Center items vertically within the container */
15
- white-space: nowrap; /* Prevent line breaks within the container */
16
- }
17
-
18
- .loader {
19
- border: 8px solid #f3f3f3; /* Light grey */
20
- border-top: 8px solid #3498db; /* Blue */
21
- border-radius: 50%;
22
- width: 30px;
23
- height: 30px;
24
- animation: spin 2s linear infinite;
25
- }
26
-
27
- @keyframes spin {
28
- 0% { transform: rotate(0deg); }
29
- 100% { transform: rotate(360deg); }
30
- }
31
-
32
- /* Style the progress bar */
33
- progress {
34
- appearance: none; /* Remove default styling */
35
- height: 20px; /* Set the height of the progress bar */
36
- border-radius: 5px; /* Round the corners of the progress bar */
37
- background-color: #f3f3f3; /* Light grey background */
38
- width: 100%;
39
- vertical-align: middle !important;
40
- }
41
-
42
- /* Style the progress bar container */
43
- .progress-container {
44
- margin-left: 20px;
45
- margin-right: 20px;
46
- flex-grow: 1; /* Allow the progress container to take up remaining space */
47
- }
48
-
49
- /* Set the color of the progress bar fill */
50
- progress::-webkit-progress-value {
51
- background-color: #3498db; /* Blue color for the fill */
52
- }
53
-
54
- progress::-moz-progress-bar {
55
- background-color: #3498db; /* Blue color for the fill in Firefox */
56
- }
57
-
58
- /* Style the text on the progress bar */
59
- progress::after {
60
- content: attr(value '%'); /* Display the progress value followed by '%' */
61
- position: absolute;
62
- top: 50%;
63
- left: 50%;
64
- transform: translate(-50%, -50%);
65
- color: white; /* Set text color */
66
- font-size: 14px; /* Set font size */
67
- }
68
-
69
- /* Style other texts */
70
- .loader-container > span {
71
- margin-left: 5px; /* Add spacing between the progress bar and the text */
72
- }
73
-
74
- .no-generating-animation > .generating {
75
- display: none !important;
76
- }
77
-
78
- '''
79
-
80
-
81
- def make_progress_bar_html(number, text):
82
- return progress_html.replace('*number*', str(number)).replace('*text*', text)
83
-
84
-
85
- def make_progress_bar_css():
86
- return css
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/hf_login.py DELETED
@@ -1,21 +0,0 @@
1
- import os
2
-
3
-
4
- def login(token):
5
- from huggingface_hub import login
6
- import time
7
-
8
- while True:
9
- try:
10
- login(token)
11
- print('HF login ok.')
12
- break
13
- except Exception as e:
14
- print(f'HF login failed: {e}. Retrying')
15
- time.sleep(0.5)
16
-
17
-
18
- hf_token = os.environ.get('HF_TOKEN', None)
19
-
20
- if hf_token is not None:
21
- login(hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/hunyuan.py DELETED
@@ -1,111 +0,0 @@
1
- import torch
2
-
3
- from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
4
- from diffusers_helper.utils import crop_or_pad_yield_mask
5
-
6
-
7
- @torch.no_grad()
8
- def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
9
- assert isinstance(prompt, str)
10
-
11
- prompt = [prompt]
12
-
13
- # LLAMA
14
-
15
- prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt]
16
- crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"]
17
-
18
- llama_inputs = tokenizer(
19
- prompt_llama,
20
- padding="max_length",
21
- max_length=max_length + crop_start,
22
- truncation=True,
23
- return_tensors="pt",
24
- return_length=False,
25
- return_overflowing_tokens=False,
26
- return_attention_mask=True,
27
- )
28
-
29
- llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
30
- llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
31
- llama_attention_length = int(llama_attention_mask.sum())
32
-
33
- llama_outputs = text_encoder(
34
- input_ids=llama_input_ids,
35
- attention_mask=llama_attention_mask,
36
- output_hidden_states=True,
37
- )
38
-
39
- llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
40
- # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
41
- llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
42
-
43
- assert torch.all(llama_attention_mask.bool())
44
-
45
- # CLIP
46
-
47
- clip_l_input_ids = tokenizer_2(
48
- prompt,
49
- padding="max_length",
50
- max_length=77,
51
- truncation=True,
52
- return_overflowing_tokens=False,
53
- return_length=False,
54
- return_tensors="pt",
55
- ).input_ids
56
- clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
57
-
58
- return llama_vec, clip_l_pooler
59
-
60
-
61
- @torch.no_grad()
62
- def vae_decode_fake(latents):
63
- latent_rgb_factors = [
64
- [-0.0395, -0.0331, 0.0445],
65
- [0.0696, 0.0795, 0.0518],
66
- [0.0135, -0.0945, -0.0282],
67
- [0.0108, -0.0250, -0.0765],
68
- [-0.0209, 0.0032, 0.0224],
69
- [-0.0804, -0.0254, -0.0639],
70
- [-0.0991, 0.0271, -0.0669],
71
- [-0.0646, -0.0422, -0.0400],
72
- [-0.0696, -0.0595, -0.0894],
73
- [-0.0799, -0.0208, -0.0375],
74
- [0.1166, 0.1627, 0.0962],
75
- [0.1165, 0.0432, 0.0407],
76
- [-0.2315, -0.1920, -0.1355],
77
- [-0.0270, 0.0401, -0.0821],
78
- [-0.0616, -0.0997, -0.0727],
79
- [0.0249, -0.0469, -0.1703]
80
- ] # From comfyui
81
-
82
- latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
83
-
84
- weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
85
- bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
86
-
87
- images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
88
- images = images.clamp(0.0, 1.0)
89
-
90
- return images
91
-
92
-
93
- @torch.no_grad()
94
- def vae_decode(latents, vae, image_mode=False):
95
- latents = latents / vae.config.scaling_factor
96
-
97
- if not image_mode:
98
- image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
99
- else:
100
- latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
101
- image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
102
- image = torch.cat(image, dim=2)
103
-
104
- return image
105
-
106
-
107
- @torch.no_grad()
108
- def vae_encode(image, vae):
109
- latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
110
- latents = latents * vae.config.scaling_factor
111
- return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/k_diffusion/uni_pc_fm.py DELETED
@@ -1,141 +0,0 @@
1
- # Better Flow Matching UniPC by Lvmin Zhang
2
- # (c) 2025
3
- # CC BY-SA 4.0
4
- # Attribution-ShareAlike 4.0 International Licence
5
-
6
-
7
- import torch
8
-
9
- from tqdm.auto import trange
10
-
11
-
12
- def expand_dims(v, dims):
13
- return v[(...,) + (None,) * (dims - 1)]
14
-
15
-
16
- class FlowMatchUniPC:
17
- def __init__(self, model, extra_args, variant='bh1'):
18
- self.model = model
19
- self.variant = variant
20
- self.extra_args = extra_args
21
-
22
- def model_fn(self, x, t):
23
- return self.model(x, t, **self.extra_args)
24
-
25
- def update_fn(self, x, model_prev_list, t_prev_list, t, order):
26
- assert order <= len(model_prev_list)
27
- dims = x.dim()
28
-
29
- t_prev_0 = t_prev_list[-1]
30
- lambda_prev_0 = - torch.log(t_prev_0)
31
- lambda_t = - torch.log(t)
32
- model_prev_0 = model_prev_list[-1]
33
-
34
- h = lambda_t - lambda_prev_0
35
-
36
- rks = []
37
- D1s = []
38
- for i in range(1, order):
39
- t_prev_i = t_prev_list[-(i + 1)]
40
- model_prev_i = model_prev_list[-(i + 1)]
41
- lambda_prev_i = - torch.log(t_prev_i)
42
- rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
43
- rks.append(rk)
44
- D1s.append((model_prev_i - model_prev_0) / rk)
45
-
46
- rks.append(1.)
47
- rks = torch.tensor(rks, device=x.device)
48
-
49
- R = []
50
- b = []
51
-
52
- hh = -h[0]
53
- h_phi_1 = torch.expm1(hh)
54
- h_phi_k = h_phi_1 / hh - 1
55
-
56
- factorial_i = 1
57
-
58
- if self.variant == 'bh1':
59
- B_h = hh
60
- elif self.variant == 'bh2':
61
- B_h = torch.expm1(hh)
62
- else:
63
- raise NotImplementedError('Bad variant!')
64
-
65
- for i in range(1, order + 1):
66
- R.append(torch.pow(rks, i - 1))
67
- b.append(h_phi_k * factorial_i / B_h)
68
- factorial_i *= (i + 1)
69
- h_phi_k = h_phi_k / hh - 1 / factorial_i
70
-
71
- R = torch.stack(R)
72
- b = torch.tensor(b, device=x.device)
73
-
74
- use_predictor = len(D1s) > 0
75
-
76
- if use_predictor:
77
- D1s = torch.stack(D1s, dim=1)
78
- if order == 2:
79
- rhos_p = torch.tensor([0.5], device=b.device)
80
- else:
81
- rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
82
- else:
83
- D1s = None
84
- rhos_p = None
85
-
86
- if order == 1:
87
- rhos_c = torch.tensor([0.5], device=b.device)
88
- else:
89
- rhos_c = torch.linalg.solve(R, b)
90
-
91
- x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
92
-
93
- if use_predictor:
94
- pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
95
- else:
96
- pred_res = 0
97
-
98
- x_t = x_t_ - expand_dims(B_h, dims) * pred_res
99
- model_t = self.model_fn(x_t, t)
100
-
101
- if D1s is not None:
102
- corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
103
- else:
104
- corr_res = 0
105
-
106
- D1_t = (model_t - model_prev_0)
107
- x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
108
-
109
- return x_t, model_t
110
-
111
- def sample(self, x, sigmas, callback=None, disable_pbar=False):
112
- order = min(3, len(sigmas) - 2)
113
- model_prev_list, t_prev_list = [], []
114
- for i in trange(len(sigmas) - 1, disable=disable_pbar):
115
- vec_t = sigmas[i].expand(x.shape[0])
116
-
117
- if i == 0:
118
- model_prev_list = [self.model_fn(x, vec_t)]
119
- t_prev_list = [vec_t]
120
- elif i < order:
121
- init_order = i
122
- x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
123
- model_prev_list.append(model_x)
124
- t_prev_list.append(vec_t)
125
- else:
126
- x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
127
- model_prev_list.append(model_x)
128
- t_prev_list.append(vec_t)
129
-
130
- model_prev_list = model_prev_list[-order:]
131
- t_prev_list = t_prev_list[-order:]
132
-
133
- if callback is not None:
134
- callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
135
-
136
- return model_prev_list[-1]
137
-
138
-
139
- def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
140
- assert variant in ['bh1', 'bh2']
141
- return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/k_diffusion/wrapper.py DELETED
@@ -1,51 +0,0 @@
1
- import torch
2
-
3
-
4
- def append_dims(x, target_dims):
5
- return x[(...,) + (None,) * (target_dims - x.ndim)]
6
-
7
-
8
- def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
9
- if guidance_rescale == 0:
10
- return noise_cfg
11
-
12
- std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
13
- std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
14
- noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
15
- noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
16
- return noise_cfg
17
-
18
-
19
- def fm_wrapper(transformer, t_scale=1000.0):
20
- def k_model(x, sigma, **extra_args):
21
- dtype = extra_args['dtype']
22
- cfg_scale = extra_args['cfg_scale']
23
- cfg_rescale = extra_args['cfg_rescale']
24
- concat_latent = extra_args['concat_latent']
25
-
26
- original_dtype = x.dtype
27
- sigma = sigma.float()
28
-
29
- x = x.to(dtype)
30
- timestep = (sigma * t_scale).to(dtype)
31
-
32
- if concat_latent is None:
33
- hidden_states = x
34
- else:
35
- hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
36
-
37
- pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
38
-
39
- if cfg_scale == 1.0:
40
- pred_negative = torch.zeros_like(pred_positive)
41
- else:
42
- pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
43
-
44
- pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
45
- pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
46
-
47
- x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
48
-
49
- return x0.to(dtype=original_dtype)
50
-
51
- return k_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/memory.py DELETED
@@ -1,134 +0,0 @@
1
- # By lllyasviel
2
-
3
-
4
- import torch
5
-
6
-
7
- cpu = torch.device('cpu')
8
- gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
9
- gpu_complete_modules = []
10
-
11
-
12
- class DynamicSwapInstaller:
13
- @staticmethod
14
- def _install_module(module: torch.nn.Module, **kwargs):
15
- original_class = module.__class__
16
- module.__dict__['forge_backup_original_class'] = original_class
17
-
18
- def hacked_get_attr(self, name: str):
19
- if '_parameters' in self.__dict__:
20
- _parameters = self.__dict__['_parameters']
21
- if name in _parameters:
22
- p = _parameters[name]
23
- if p is None:
24
- return None
25
- if p.__class__ == torch.nn.Parameter:
26
- return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
27
- else:
28
- return p.to(**kwargs)
29
- if '_buffers' in self.__dict__:
30
- _buffers = self.__dict__['_buffers']
31
- if name in _buffers:
32
- return _buffers[name].to(**kwargs)
33
- return super(original_class, self).__getattr__(name)
34
-
35
- module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
36
- '__getattr__': hacked_get_attr,
37
- })
38
-
39
- return
40
-
41
- @staticmethod
42
- def _uninstall_module(module: torch.nn.Module):
43
- if 'forge_backup_original_class' in module.__dict__:
44
- module.__class__ = module.__dict__.pop('forge_backup_original_class')
45
- return
46
-
47
- @staticmethod
48
- def install_model(model: torch.nn.Module, **kwargs):
49
- for m in model.modules():
50
- DynamicSwapInstaller._install_module(m, **kwargs)
51
- return
52
-
53
- @staticmethod
54
- def uninstall_model(model: torch.nn.Module):
55
- for m in model.modules():
56
- DynamicSwapInstaller._uninstall_module(m)
57
- return
58
-
59
-
60
- def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
61
- if hasattr(model, 'scale_shift_table'):
62
- model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
63
- return
64
-
65
- for k, p in model.named_modules():
66
- if hasattr(p, 'weight'):
67
- p.to(target_device)
68
- return
69
-
70
-
71
- def get_cuda_free_memory_gb(device=None):
72
- if device is None:
73
- device = gpu
74
-
75
- memory_stats = torch.cuda.memory_stats(device)
76
- bytes_active = memory_stats['active_bytes.all.current']
77
- bytes_reserved = memory_stats['reserved_bytes.all.current']
78
- bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
79
- bytes_inactive_reserved = bytes_reserved - bytes_active
80
- bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
81
- return bytes_total_available / (1024 ** 3)
82
-
83
-
84
- def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
85
- print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
86
-
87
- for m in model.modules():
88
- if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
89
- torch.cuda.empty_cache()
90
- return
91
-
92
- if hasattr(m, 'weight'):
93
- m.to(device=target_device)
94
-
95
- model.to(device=target_device)
96
- torch.cuda.empty_cache()
97
- return
98
-
99
-
100
- def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
101
- print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
102
-
103
- for m in model.modules():
104
- if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
105
- torch.cuda.empty_cache()
106
- return
107
-
108
- if hasattr(m, 'weight'):
109
- m.to(device=cpu)
110
-
111
- model.to(device=cpu)
112
- torch.cuda.empty_cache()
113
- return
114
-
115
-
116
- def unload_complete_models(*args):
117
- for m in gpu_complete_modules + list(args):
118
- m.to(device=cpu)
119
- print(f'Unloaded {m.__class__.__name__} as complete.')
120
-
121
- gpu_complete_modules.clear()
122
- torch.cuda.empty_cache()
123
- return
124
-
125
-
126
- def load_model_as_complete(model, target_device, unload=True):
127
- if unload:
128
- unload_complete_models()
129
-
130
- model.to(device=target_device)
131
- print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
132
-
133
- gpu_complete_modules.append(model)
134
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/models/hunyuan_video_packed.py DELETED
@@ -1,1035 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Tuple, Union
2
-
3
- import torch
4
- import einops
5
- import torch.nn as nn
6
- import numpy as np
7
-
8
- from diffusers.loaders import FromOriginalModelMixin
9
- from diffusers.configuration_utils import ConfigMixin, register_to_config
10
- from diffusers.loaders import PeftAdapterMixin
11
- from diffusers.utils import logging
12
- from diffusers.models.attention import FeedForward
13
- from diffusers.models.attention_processor import Attention
14
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
- from diffusers.models.modeling_utils import ModelMixin
17
- from diffusers_helper.dit_common import LayerNorm
18
- from diffusers_helper.utils import zero_module
19
-
20
-
21
- enabled_backends = []
22
-
23
- if torch.backends.cuda.flash_sdp_enabled():
24
- enabled_backends.append("flash")
25
- if torch.backends.cuda.math_sdp_enabled():
26
- enabled_backends.append("math")
27
- if torch.backends.cuda.mem_efficient_sdp_enabled():
28
- enabled_backends.append("mem_efficient")
29
- if torch.backends.cuda.cudnn_sdp_enabled():
30
- enabled_backends.append("cudnn")
31
-
32
- print("Currently enabled native sdp backends:", enabled_backends)
33
-
34
- try:
35
- # raise NotImplementedError
36
- from xformers.ops import memory_efficient_attention as xformers_attn_func
37
- print('Xformers is installed!')
38
- except:
39
- print('Xformers is not installed!')
40
- xformers_attn_func = None
41
-
42
- try:
43
- # raise NotImplementedError
44
- from flash_attn import flash_attn_varlen_func, flash_attn_func
45
- print('Flash Attn is installed!')
46
- except:
47
- print('Flash Attn is not installed!')
48
- flash_attn_varlen_func = None
49
- flash_attn_func = None
50
-
51
- try:
52
- # raise NotImplementedError
53
- from sageattention import sageattn_varlen, sageattn
54
- print('Sage Attn is installed!')
55
- except:
56
- print('Sage Attn is not installed!')
57
- sageattn_varlen = None
58
- sageattn = None
59
-
60
-
61
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
-
63
-
64
- def pad_for_3d_conv(x, kernel_size):
65
- b, c, t, h, w = x.shape
66
- pt, ph, pw = kernel_size
67
- pad_t = (pt - (t % pt)) % pt
68
- pad_h = (ph - (h % ph)) % ph
69
- pad_w = (pw - (w % pw)) % pw
70
- return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
71
-
72
-
73
- def center_down_sample_3d(x, kernel_size):
74
- # pt, ph, pw = kernel_size
75
- # cp = (pt * ph * pw) // 2
76
- # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
77
- # xc = xp[cp]
78
- # return xc
79
- return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
80
-
81
-
82
- def get_cu_seqlens(text_mask, img_len):
83
- batch_size = text_mask.shape[0]
84
- text_len = text_mask.sum(dim=1)
85
- max_len = text_mask.shape[1] + img_len
86
-
87
- cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
88
-
89
- for i in range(batch_size):
90
- s = text_len[i] + img_len
91
- s1 = i * max_len + s
92
- s2 = (i + 1) * max_len
93
- cu_seqlens[2 * i + 1] = s1
94
- cu_seqlens[2 * i + 2] = s2
95
-
96
- return cu_seqlens
97
-
98
-
99
- def apply_rotary_emb_transposed(x, freqs_cis):
100
- cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
101
- x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
102
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
103
- out = x.float() * cos + x_rotated.float() * sin
104
- out = out.to(x)
105
- return out
106
-
107
-
108
- def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
109
- if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
110
- if sageattn is not None:
111
- x = sageattn(q, k, v, tensor_layout='NHD')
112
- return x
113
-
114
- if flash_attn_func is not None:
115
- x = flash_attn_func(q, k, v)
116
- return x
117
-
118
- if xformers_attn_func is not None:
119
- x = xformers_attn_func(q, k, v)
120
- return x
121
-
122
- x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
- return x
124
-
125
- B, L, H, C = q.shape
126
-
127
- q = q.flatten(0, 1)
128
- k = k.flatten(0, 1)
129
- v = v.flatten(0, 1)
130
-
131
- if sageattn_varlen is not None:
132
- x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
- elif flash_attn_varlen_func is not None:
134
- x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
135
- else:
136
- raise NotImplementedError('No Attn Installed!')
137
-
138
- x = x.unflatten(0, (B, L))
139
-
140
- return x
141
-
142
-
143
- class HunyuanAttnProcessorFlashAttnDouble:
144
- def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
145
- cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
146
-
147
- query = attn.to_q(hidden_states)
148
- key = attn.to_k(hidden_states)
149
- value = attn.to_v(hidden_states)
150
-
151
- query = query.unflatten(2, (attn.heads, -1))
152
- key = key.unflatten(2, (attn.heads, -1))
153
- value = value.unflatten(2, (attn.heads, -1))
154
-
155
- query = attn.norm_q(query)
156
- key = attn.norm_k(key)
157
-
158
- query = apply_rotary_emb_transposed(query, image_rotary_emb)
159
- key = apply_rotary_emb_transposed(key, image_rotary_emb)
160
-
161
- encoder_query = attn.add_q_proj(encoder_hidden_states)
162
- encoder_key = attn.add_k_proj(encoder_hidden_states)
163
- encoder_value = attn.add_v_proj(encoder_hidden_states)
164
-
165
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
166
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
167
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
168
-
169
- encoder_query = attn.norm_added_q(encoder_query)
170
- encoder_key = attn.norm_added_k(encoder_key)
171
-
172
- query = torch.cat([query, encoder_query], dim=1)
173
- key = torch.cat([key, encoder_key], dim=1)
174
- value = torch.cat([value, encoder_value], dim=1)
175
-
176
- hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
177
- hidden_states = hidden_states.flatten(-2)
178
-
179
- txt_length = encoder_hidden_states.shape[1]
180
- hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
181
-
182
- hidden_states = attn.to_out[0](hidden_states)
183
- hidden_states = attn.to_out[1](hidden_states)
184
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
185
-
186
- return hidden_states, encoder_hidden_states
187
-
188
-
189
- class HunyuanAttnProcessorFlashAttnSingle:
190
- def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
191
- cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
192
-
193
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
194
-
195
- query = attn.to_q(hidden_states)
196
- key = attn.to_k(hidden_states)
197
- value = attn.to_v(hidden_states)
198
-
199
- query = query.unflatten(2, (attn.heads, -1))
200
- key = key.unflatten(2, (attn.heads, -1))
201
- value = value.unflatten(2, (attn.heads, -1))
202
-
203
- query = attn.norm_q(query)
204
- key = attn.norm_k(key)
205
-
206
- txt_length = encoder_hidden_states.shape[1]
207
-
208
- query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
209
- key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
210
-
211
- hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
212
- hidden_states = hidden_states.flatten(-2)
213
-
214
- hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
215
-
216
- return hidden_states, encoder_hidden_states
217
-
218
-
219
- class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
220
- def __init__(self, embedding_dim, pooled_projection_dim):
221
- super().__init__()
222
-
223
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
224
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
225
- self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
226
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
227
-
228
- def forward(self, timestep, guidance, pooled_projection):
229
- timesteps_proj = self.time_proj(timestep)
230
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
231
-
232
- guidance_proj = self.time_proj(guidance)
233
- guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
234
-
235
- time_guidance_emb = timesteps_emb + guidance_emb
236
-
237
- pooled_projections = self.text_embedder(pooled_projection)
238
- conditioning = time_guidance_emb + pooled_projections
239
-
240
- return conditioning
241
-
242
-
243
- class CombinedTimestepTextProjEmbeddings(nn.Module):
244
- def __init__(self, embedding_dim, pooled_projection_dim):
245
- super().__init__()
246
-
247
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
248
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
249
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
250
-
251
- def forward(self, timestep, pooled_projection):
252
- timesteps_proj = self.time_proj(timestep)
253
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
254
-
255
- pooled_projections = self.text_embedder(pooled_projection)
256
-
257
- conditioning = timesteps_emb + pooled_projections
258
-
259
- return conditioning
260
-
261
-
262
- class HunyuanVideoAdaNorm(nn.Module):
263
- def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
264
- super().__init__()
265
-
266
- out_features = out_features or 2 * in_features
267
- self.linear = nn.Linear(in_features, out_features)
268
- self.nonlinearity = nn.SiLU()
269
-
270
- def forward(
271
- self, temb: torch.Tensor
272
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
273
- temb = self.linear(self.nonlinearity(temb))
274
- gate_msa, gate_mlp = temb.chunk(2, dim=-1)
275
- gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
276
- return gate_msa, gate_mlp
277
-
278
-
279
- class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
280
- def __init__(
281
- self,
282
- num_attention_heads: int,
283
- attention_head_dim: int,
284
- mlp_width_ratio: str = 4.0,
285
- mlp_drop_rate: float = 0.0,
286
- attention_bias: bool = True,
287
- ) -> None:
288
- super().__init__()
289
-
290
- hidden_size = num_attention_heads * attention_head_dim
291
-
292
- self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
293
- self.attn = Attention(
294
- query_dim=hidden_size,
295
- cross_attention_dim=None,
296
- heads=num_attention_heads,
297
- dim_head=attention_head_dim,
298
- bias=attention_bias,
299
- )
300
-
301
- self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
302
- self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
303
-
304
- self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
305
-
306
- def forward(
307
- self,
308
- hidden_states: torch.Tensor,
309
- temb: torch.Tensor,
310
- attention_mask: Optional[torch.Tensor] = None,
311
- ) -> torch.Tensor:
312
- norm_hidden_states = self.norm1(hidden_states)
313
-
314
- attn_output = self.attn(
315
- hidden_states=norm_hidden_states,
316
- encoder_hidden_states=None,
317
- attention_mask=attention_mask,
318
- )
319
-
320
- gate_msa, gate_mlp = self.norm_out(temb)
321
- hidden_states = hidden_states + attn_output * gate_msa
322
-
323
- ff_output = self.ff(self.norm2(hidden_states))
324
- hidden_states = hidden_states + ff_output * gate_mlp
325
-
326
- return hidden_states
327
-
328
-
329
- class HunyuanVideoIndividualTokenRefiner(nn.Module):
330
- def __init__(
331
- self,
332
- num_attention_heads: int,
333
- attention_head_dim: int,
334
- num_layers: int,
335
- mlp_width_ratio: float = 4.0,
336
- mlp_drop_rate: float = 0.0,
337
- attention_bias: bool = True,
338
- ) -> None:
339
- super().__init__()
340
-
341
- self.refiner_blocks = nn.ModuleList(
342
- [
343
- HunyuanVideoIndividualTokenRefinerBlock(
344
- num_attention_heads=num_attention_heads,
345
- attention_head_dim=attention_head_dim,
346
- mlp_width_ratio=mlp_width_ratio,
347
- mlp_drop_rate=mlp_drop_rate,
348
- attention_bias=attention_bias,
349
- )
350
- for _ in range(num_layers)
351
- ]
352
- )
353
-
354
- def forward(
355
- self,
356
- hidden_states: torch.Tensor,
357
- temb: torch.Tensor,
358
- attention_mask: Optional[torch.Tensor] = None,
359
- ) -> None:
360
- self_attn_mask = None
361
- if attention_mask is not None:
362
- batch_size = attention_mask.shape[0]
363
- seq_len = attention_mask.shape[1]
364
- attention_mask = attention_mask.to(hidden_states.device).bool()
365
- self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
366
- self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
367
- self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
368
- self_attn_mask[:, :, :, 0] = True
369
-
370
- for block in self.refiner_blocks:
371
- hidden_states = block(hidden_states, temb, self_attn_mask)
372
-
373
- return hidden_states
374
-
375
-
376
- class HunyuanVideoTokenRefiner(nn.Module):
377
- def __init__(
378
- self,
379
- in_channels: int,
380
- num_attention_heads: int,
381
- attention_head_dim: int,
382
- num_layers: int,
383
- mlp_ratio: float = 4.0,
384
- mlp_drop_rate: float = 0.0,
385
- attention_bias: bool = True,
386
- ) -> None:
387
- super().__init__()
388
-
389
- hidden_size = num_attention_heads * attention_head_dim
390
-
391
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
392
- embedding_dim=hidden_size, pooled_projection_dim=in_channels
393
- )
394
- self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
395
- self.token_refiner = HunyuanVideoIndividualTokenRefiner(
396
- num_attention_heads=num_attention_heads,
397
- attention_head_dim=attention_head_dim,
398
- num_layers=num_layers,
399
- mlp_width_ratio=mlp_ratio,
400
- mlp_drop_rate=mlp_drop_rate,
401
- attention_bias=attention_bias,
402
- )
403
-
404
- def forward(
405
- self,
406
- hidden_states: torch.Tensor,
407
- timestep: torch.LongTensor,
408
- attention_mask: Optional[torch.LongTensor] = None,
409
- ) -> torch.Tensor:
410
- if attention_mask is None:
411
- pooled_projections = hidden_states.mean(dim=1)
412
- else:
413
- original_dtype = hidden_states.dtype
414
- mask_float = attention_mask.float().unsqueeze(-1)
415
- pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
416
- pooled_projections = pooled_projections.to(original_dtype)
417
-
418
- temb = self.time_text_embed(timestep, pooled_projections)
419
- hidden_states = self.proj_in(hidden_states)
420
- hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
421
-
422
- return hidden_states
423
-
424
-
425
- class HunyuanVideoRotaryPosEmbed(nn.Module):
426
- def __init__(self, rope_dim, theta):
427
- super().__init__()
428
- self.DT, self.DY, self.DX = rope_dim
429
- self.theta = theta
430
-
431
- @torch.no_grad()
432
- def get_frequency(self, dim, pos):
433
- T, H, W = pos.shape
434
- freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
435
- freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
436
- return freqs.cos(), freqs.sin()
437
-
438
- @torch.no_grad()
439
- def forward_inner(self, frame_indices, height, width, device):
440
- GT, GY, GX = torch.meshgrid(
441
- frame_indices.to(device=device, dtype=torch.float32),
442
- torch.arange(0, height, device=device, dtype=torch.float32),
443
- torch.arange(0, width, device=device, dtype=torch.float32),
444
- indexing="ij"
445
- )
446
-
447
- FCT, FST = self.get_frequency(self.DT, GT)
448
- FCY, FSY = self.get_frequency(self.DY, GY)
449
- FCX, FSX = self.get_frequency(self.DX, GX)
450
-
451
- result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
452
-
453
- return result.to(device)
454
-
455
- @torch.no_grad()
456
- def forward(self, frame_indices, height, width, device):
457
- frame_indices = frame_indices.unbind(0)
458
- results = [self.forward_inner(f, height, width, device) for f in frame_indices]
459
- results = torch.stack(results, dim=0)
460
- return results
461
-
462
-
463
- class AdaLayerNormZero(nn.Module):
464
- def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
465
- super().__init__()
466
- self.silu = nn.SiLU()
467
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
468
- if norm_type == "layer_norm":
469
- self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
470
- else:
471
- raise ValueError(f"unknown norm_type {norm_type}")
472
-
473
- def forward(
474
- self,
475
- x: torch.Tensor,
476
- emb: Optional[torch.Tensor] = None,
477
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
478
- emb = emb.unsqueeze(-2)
479
- emb = self.linear(self.silu(emb))
480
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
481
- x = self.norm(x) * (1 + scale_msa) + shift_msa
482
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
483
-
484
-
485
- class AdaLayerNormZeroSingle(nn.Module):
486
- def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
487
- super().__init__()
488
-
489
- self.silu = nn.SiLU()
490
- self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
491
- if norm_type == "layer_norm":
492
- self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
493
- else:
494
- raise ValueError(f"unknown norm_type {norm_type}")
495
-
496
- def forward(
497
- self,
498
- x: torch.Tensor,
499
- emb: Optional[torch.Tensor] = None,
500
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
501
- emb = emb.unsqueeze(-2)
502
- emb = self.linear(self.silu(emb))
503
- shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
504
- x = self.norm(x) * (1 + scale_msa) + shift_msa
505
- return x, gate_msa
506
-
507
-
508
- class AdaLayerNormContinuous(nn.Module):
509
- def __init__(
510
- self,
511
- embedding_dim: int,
512
- conditioning_embedding_dim: int,
513
- elementwise_affine=True,
514
- eps=1e-5,
515
- bias=True,
516
- norm_type="layer_norm",
517
- ):
518
- super().__init__()
519
- self.silu = nn.SiLU()
520
- self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
521
- if norm_type == "layer_norm":
522
- self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
523
- else:
524
- raise ValueError(f"unknown norm_type {norm_type}")
525
-
526
- def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
527
- emb = emb.unsqueeze(-2)
528
- emb = self.linear(self.silu(emb))
529
- scale, shift = emb.chunk(2, dim=-1)
530
- x = self.norm(x) * (1 + scale) + shift
531
- return x
532
-
533
-
534
- class HunyuanVideoSingleTransformerBlock(nn.Module):
535
- def __init__(
536
- self,
537
- num_attention_heads: int,
538
- attention_head_dim: int,
539
- mlp_ratio: float = 4.0,
540
- qk_norm: str = "rms_norm",
541
- ) -> None:
542
- super().__init__()
543
-
544
- hidden_size = num_attention_heads * attention_head_dim
545
- mlp_dim = int(hidden_size * mlp_ratio)
546
-
547
- self.attn = Attention(
548
- query_dim=hidden_size,
549
- cross_attention_dim=None,
550
- dim_head=attention_head_dim,
551
- heads=num_attention_heads,
552
- out_dim=hidden_size,
553
- bias=True,
554
- processor=HunyuanAttnProcessorFlashAttnSingle(),
555
- qk_norm=qk_norm,
556
- eps=1e-6,
557
- pre_only=True,
558
- )
559
-
560
- self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
561
- self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
562
- self.act_mlp = nn.GELU(approximate="tanh")
563
- self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
564
-
565
- def forward(
566
- self,
567
- hidden_states: torch.Tensor,
568
- encoder_hidden_states: torch.Tensor,
569
- temb: torch.Tensor,
570
- attention_mask: Optional[torch.Tensor] = None,
571
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
572
- ) -> torch.Tensor:
573
- text_seq_length = encoder_hidden_states.shape[1]
574
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
575
-
576
- residual = hidden_states
577
-
578
- # 1. Input normalization
579
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
580
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
581
-
582
- norm_hidden_states, norm_encoder_hidden_states = (
583
- norm_hidden_states[:, :-text_seq_length, :],
584
- norm_hidden_states[:, -text_seq_length:, :],
585
- )
586
-
587
- # 2. Attention
588
- attn_output, context_attn_output = self.attn(
589
- hidden_states=norm_hidden_states,
590
- encoder_hidden_states=norm_encoder_hidden_states,
591
- attention_mask=attention_mask,
592
- image_rotary_emb=image_rotary_emb,
593
- )
594
- attn_output = torch.cat([attn_output, context_attn_output], dim=1)
595
-
596
- # 3. Modulation and residual connection
597
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
598
- hidden_states = gate * self.proj_out(hidden_states)
599
- hidden_states = hidden_states + residual
600
-
601
- hidden_states, encoder_hidden_states = (
602
- hidden_states[:, :-text_seq_length, :],
603
- hidden_states[:, -text_seq_length:, :],
604
- )
605
- return hidden_states, encoder_hidden_states
606
-
607
-
608
- class HunyuanVideoTransformerBlock(nn.Module):
609
- def __init__(
610
- self,
611
- num_attention_heads: int,
612
- attention_head_dim: int,
613
- mlp_ratio: float,
614
- qk_norm: str = "rms_norm",
615
- ) -> None:
616
- super().__init__()
617
-
618
- hidden_size = num_attention_heads * attention_head_dim
619
-
620
- self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
621
- self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
622
-
623
- self.attn = Attention(
624
- query_dim=hidden_size,
625
- cross_attention_dim=None,
626
- added_kv_proj_dim=hidden_size,
627
- dim_head=attention_head_dim,
628
- heads=num_attention_heads,
629
- out_dim=hidden_size,
630
- context_pre_only=False,
631
- bias=True,
632
- processor=HunyuanAttnProcessorFlashAttnDouble(),
633
- qk_norm=qk_norm,
634
- eps=1e-6,
635
- )
636
-
637
- self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
638
- self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
639
-
640
- self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
641
- self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
642
-
643
- def forward(
644
- self,
645
- hidden_states: torch.Tensor,
646
- encoder_hidden_states: torch.Tensor,
647
- temb: torch.Tensor,
648
- attention_mask: Optional[torch.Tensor] = None,
649
- freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
650
- ) -> Tuple[torch.Tensor, torch.Tensor]:
651
- # 1. Input normalization
652
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
653
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
654
-
655
- # 2. Joint attention
656
- attn_output, context_attn_output = self.attn(
657
- hidden_states=norm_hidden_states,
658
- encoder_hidden_states=norm_encoder_hidden_states,
659
- attention_mask=attention_mask,
660
- image_rotary_emb=freqs_cis,
661
- )
662
-
663
- # 3. Modulation and residual connection
664
- hidden_states = hidden_states + attn_output * gate_msa
665
- encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
666
-
667
- norm_hidden_states = self.norm2(hidden_states)
668
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
669
-
670
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
671
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
672
-
673
- # 4. Feed-forward
674
- ff_output = self.ff(norm_hidden_states)
675
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
676
-
677
- hidden_states = hidden_states + gate_mlp * ff_output
678
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
679
-
680
- return hidden_states, encoder_hidden_states
681
-
682
-
683
- class ClipVisionProjection(nn.Module):
684
- def __init__(self, in_channels, out_channels):
685
- super().__init__()
686
- self.up = nn.Linear(in_channels, out_channels * 3)
687
- self.down = nn.Linear(out_channels * 3, out_channels)
688
-
689
- def forward(self, x):
690
- projected_x = self.down(nn.functional.silu(self.up(x)))
691
- return projected_x
692
-
693
-
694
- class HunyuanVideoPatchEmbed(nn.Module):
695
- def __init__(self, patch_size, in_chans, embed_dim):
696
- super().__init__()
697
- self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
698
-
699
-
700
- class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
701
- def __init__(self, inner_dim):
702
- super().__init__()
703
- self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
704
- self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
705
- self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
706
-
707
- @torch.no_grad()
708
- def initialize_weight_from_another_conv3d(self, another_layer):
709
- weight = another_layer.weight.detach().clone()
710
- bias = another_layer.bias.detach().clone()
711
-
712
- sd = {
713
- 'proj.weight': weight.clone(),
714
- 'proj.bias': bias.clone(),
715
- 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
716
- 'proj_2x.bias': bias.clone(),
717
- 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
718
- 'proj_4x.bias': bias.clone(),
719
- }
720
-
721
- sd = {k: v.clone() for k, v in sd.items()}
722
-
723
- self.load_state_dict(sd)
724
- return
725
-
726
-
727
- class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
728
- @register_to_config
729
- def __init__(
730
- self,
731
- in_channels: int = 16,
732
- out_channels: int = 16,
733
- num_attention_heads: int = 24,
734
- attention_head_dim: int = 128,
735
- num_layers: int = 20,
736
- num_single_layers: int = 40,
737
- num_refiner_layers: int = 2,
738
- mlp_ratio: float = 4.0,
739
- patch_size: int = 2,
740
- patch_size_t: int = 1,
741
- qk_norm: str = "rms_norm",
742
- guidance_embeds: bool = True,
743
- text_embed_dim: int = 4096,
744
- pooled_projection_dim: int = 768,
745
- rope_theta: float = 256.0,
746
- rope_axes_dim: Tuple[int] = (16, 56, 56),
747
- has_image_proj=False,
748
- image_proj_dim=1152,
749
- has_clean_x_embedder=False,
750
- ) -> None:
751
- super().__init__()
752
-
753
- inner_dim = num_attention_heads * attention_head_dim
754
- out_channels = out_channels or in_channels
755
-
756
- # 1. Latent and condition embedders
757
- self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
758
- self.context_embedder = HunyuanVideoTokenRefiner(
759
- text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
760
- )
761
- self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
762
-
763
- self.clean_x_embedder = None
764
- self.image_projection = None
765
-
766
- # 2. RoPE
767
- self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
768
-
769
- # 3. Dual stream transformer blocks
770
- self.transformer_blocks = nn.ModuleList(
771
- [
772
- HunyuanVideoTransformerBlock(
773
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
774
- )
775
- for _ in range(num_layers)
776
- ]
777
- )
778
-
779
- # 4. Single stream transformer blocks
780
- self.single_transformer_blocks = nn.ModuleList(
781
- [
782
- HunyuanVideoSingleTransformerBlock(
783
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
784
- )
785
- for _ in range(num_single_layers)
786
- ]
787
- )
788
-
789
- # 5. Output projection
790
- self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
791
- self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
792
-
793
- self.inner_dim = inner_dim
794
- self.use_gradient_checkpointing = False
795
- self.enable_teacache = False
796
-
797
- if has_image_proj:
798
- self.install_image_projection(image_proj_dim)
799
-
800
- if has_clean_x_embedder:
801
- self.install_clean_x_embedder()
802
-
803
- self.high_quality_fp32_output_for_inference = False
804
-
805
- def install_image_projection(self, in_channels):
806
- self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
807
- self.config['has_image_proj'] = True
808
- self.config['image_proj_dim'] = in_channels
809
-
810
- def install_clean_x_embedder(self):
811
- self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
812
- self.config['has_clean_x_embedder'] = True
813
-
814
- def enable_gradient_checkpointing(self):
815
- self.use_gradient_checkpointing = True
816
- print('self.use_gradient_checkpointing = True')
817
-
818
- def disable_gradient_checkpointing(self):
819
- self.use_gradient_checkpointing = False
820
- print('self.use_gradient_checkpointing = False')
821
-
822
- def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
823
- self.enable_teacache = enable_teacache
824
- self.cnt = 0
825
- self.num_steps = num_steps
826
- self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
827
- self.accumulated_rel_l1_distance = 0
828
- self.previous_modulated_input = None
829
- self.previous_residual = None
830
- self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
831
-
832
- def gradient_checkpointing_method(self, block, *args):
833
- if self.use_gradient_checkpointing:
834
- result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
835
- else:
836
- result = block(*args)
837
- return result
838
-
839
- def process_input_hidden_states(
840
- self,
841
- latents, latent_indices=None,
842
- clean_latents=None, clean_latent_indices=None,
843
- clean_latents_2x=None, clean_latent_2x_indices=None,
844
- clean_latents_4x=None, clean_latent_4x_indices=None
845
- ):
846
- hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
847
- B, C, T, H, W = hidden_states.shape
848
-
849
- if latent_indices is None:
850
- latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
851
-
852
- hidden_states = hidden_states.flatten(2).transpose(1, 2)
853
-
854
- rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
855
- rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
856
-
857
- if clean_latents is not None and clean_latent_indices is not None:
858
- clean_latents = clean_latents.to(hidden_states)
859
- clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
860
- clean_latents = clean_latents.flatten(2).transpose(1, 2)
861
-
862
- clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
863
- clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
864
-
865
- hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
866
- rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
867
-
868
- if clean_latents_2x is not None and clean_latent_2x_indices is not None:
869
- clean_latents_2x = clean_latents_2x.to(hidden_states)
870
- clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
871
- clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
872
- clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
873
-
874
- clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
875
- clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
876
- clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
877
- clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
878
-
879
- hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
880
- rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
881
-
882
- if clean_latents_4x is not None and clean_latent_4x_indices is not None:
883
- clean_latents_4x = clean_latents_4x.to(hidden_states)
884
- clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
885
- clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
886
- clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
887
-
888
- clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
889
- clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
890
- clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
891
- clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
892
-
893
- hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
894
- rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
895
-
896
- return hidden_states, rope_freqs
897
-
898
- def forward(
899
- self,
900
- hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
901
- latent_indices=None,
902
- clean_latents=None, clean_latent_indices=None,
903
- clean_latents_2x=None, clean_latent_2x_indices=None,
904
- clean_latents_4x=None, clean_latent_4x_indices=None,
905
- image_embeddings=None,
906
- attention_kwargs=None, return_dict=True
907
- ):
908
-
909
- if attention_kwargs is None:
910
- attention_kwargs = {}
911
-
912
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
913
- p, p_t = self.config['patch_size'], self.config['patch_size_t']
914
- post_patch_num_frames = num_frames // p_t
915
- post_patch_height = height // p
916
- post_patch_width = width // p
917
- original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
918
-
919
- hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
920
-
921
- temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
922
- encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
923
-
924
- if self.image_projection is not None:
925
- assert image_embeddings is not None, 'You must use image embeddings!'
926
- extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
927
- extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
928
-
929
- # must cat before (not after) encoder_hidden_states, due to attn masking
930
- encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
931
- encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
932
-
933
- if batch_size == 1:
934
- # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
935
- # If they are not same, then their impls are wrong. Ours are always the correct one.
936
- text_len = encoder_attention_mask.sum().item()
937
- encoder_hidden_states = encoder_hidden_states[:, :text_len]
938
- attention_mask = None, None, None, None
939
- else:
940
- img_seq_len = hidden_states.shape[1]
941
- txt_seq_len = encoder_hidden_states.shape[1]
942
-
943
- cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
944
- cu_seqlens_kv = cu_seqlens_q
945
- max_seqlen_q = img_seq_len + txt_seq_len
946
- max_seqlen_kv = max_seqlen_q
947
-
948
- attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
949
-
950
- if self.enable_teacache:
951
- modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
952
-
953
- if self.cnt == 0 or self.cnt == self.num_steps-1:
954
- should_calc = True
955
- self.accumulated_rel_l1_distance = 0
956
- else:
957
- curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
958
- self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
959
- should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
960
-
961
- if should_calc:
962
- self.accumulated_rel_l1_distance = 0
963
-
964
- self.previous_modulated_input = modulated_inp
965
- self.cnt += 1
966
-
967
- if self.cnt == self.num_steps:
968
- self.cnt = 0
969
-
970
- if not should_calc:
971
- hidden_states = hidden_states + self.previous_residual
972
- else:
973
- ori_hidden_states = hidden_states.clone()
974
-
975
- for block_id, block in enumerate(self.transformer_blocks):
976
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
977
- block,
978
- hidden_states,
979
- encoder_hidden_states,
980
- temb,
981
- attention_mask,
982
- rope_freqs
983
- )
984
-
985
- for block_id, block in enumerate(self.single_transformer_blocks):
986
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
987
- block,
988
- hidden_states,
989
- encoder_hidden_states,
990
- temb,
991
- attention_mask,
992
- rope_freqs
993
- )
994
-
995
- self.previous_residual = hidden_states - ori_hidden_states
996
- else:
997
- for block_id, block in enumerate(self.transformer_blocks):
998
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
999
- block,
1000
- hidden_states,
1001
- encoder_hidden_states,
1002
- temb,
1003
- attention_mask,
1004
- rope_freqs
1005
- )
1006
-
1007
- for block_id, block in enumerate(self.single_transformer_blocks):
1008
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1009
- block,
1010
- hidden_states,
1011
- encoder_hidden_states,
1012
- temb,
1013
- attention_mask,
1014
- rope_freqs
1015
- )
1016
-
1017
- hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1018
-
1019
- hidden_states = hidden_states[:, -original_context_length:, :]
1020
-
1021
- if self.high_quality_fp32_output_for_inference:
1022
- hidden_states = hidden_states.to(dtype=torch.float32)
1023
- if self.proj_out.weight.dtype != torch.float32:
1024
- self.proj_out.to(dtype=torch.float32)
1025
-
1026
- hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1027
-
1028
- hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1029
- t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1030
- pt=p_t, ph=p, pw=p)
1031
-
1032
- if return_dict:
1033
- return Transformer2DModelOutput(sample=hidden_states)
1034
-
1035
- return hidden_states,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/pipelines/k_diffusion_hunyuan.py DELETED
@@ -1,120 +0,0 @@
1
- import torch
2
- import math
3
-
4
- from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
5
- from diffusers_helper.k_diffusion.wrapper import fm_wrapper
6
- from diffusers_helper.utils import repeat_to_batch_size
7
-
8
-
9
- def flux_time_shift(t, mu=1.15, sigma=1.0):
10
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
11
-
12
-
13
- def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
14
- k = (y2 - y1) / (x2 - x1)
15
- b = y1 - k * x1
16
- mu = k * context_length + b
17
- mu = min(mu, math.log(exp_max))
18
- return mu
19
-
20
-
21
- def get_flux_sigmas_from_mu(n, mu):
22
- sigmas = torch.linspace(1, 0, steps=n + 1)
23
- sigmas = flux_time_shift(sigmas, mu=mu)
24
- return sigmas
25
-
26
-
27
- @torch.inference_mode()
28
- def sample_hunyuan(
29
- transformer,
30
- sampler='unipc',
31
- initial_latent=None,
32
- concat_latent=None,
33
- strength=1.0,
34
- width=512,
35
- height=512,
36
- frames=16,
37
- real_guidance_scale=1.0,
38
- distilled_guidance_scale=6.0,
39
- guidance_rescale=0.0,
40
- shift=None,
41
- num_inference_steps=25,
42
- batch_size=None,
43
- generator=None,
44
- prompt_embeds=None,
45
- prompt_embeds_mask=None,
46
- prompt_poolers=None,
47
- negative_prompt_embeds=None,
48
- negative_prompt_embeds_mask=None,
49
- negative_prompt_poolers=None,
50
- dtype=torch.bfloat16,
51
- device=None,
52
- negative_kwargs=None,
53
- callback=None,
54
- **kwargs,
55
- ):
56
- device = device or transformer.device
57
-
58
- if batch_size is None:
59
- batch_size = int(prompt_embeds.shape[0])
60
-
61
- latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32)
62
-
63
- B, C, T, H, W = latents.shape
64
- seq_length = T * H * W // 4
65
-
66
- if shift is None:
67
- mu = calculate_flux_mu(seq_length, exp_max=7.0)
68
- else:
69
- mu = math.log(shift)
70
-
71
- sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
72
-
73
- k_model = fm_wrapper(transformer)
74
-
75
- if initial_latent is not None:
76
- sigmas = sigmas * strength
77
- first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
78
- initial_latent = initial_latent.to(device=device, dtype=torch.float32)
79
- latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
80
-
81
- if concat_latent is not None:
82
- concat_latent = concat_latent.to(latents)
83
-
84
- distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
85
-
86
- prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
87
- prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
88
- prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
89
- negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
90
- negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
91
- negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
92
- concat_latent = repeat_to_batch_size(concat_latent, batch_size)
93
-
94
- sampler_kwargs = dict(
95
- dtype=dtype,
96
- cfg_scale=real_guidance_scale,
97
- cfg_rescale=guidance_rescale,
98
- concat_latent=concat_latent,
99
- positive=dict(
100
- pooled_projections=prompt_poolers,
101
- encoder_hidden_states=prompt_embeds,
102
- encoder_attention_mask=prompt_embeds_mask,
103
- guidance=distilled_guidance,
104
- **kwargs,
105
- ),
106
- negative=dict(
107
- pooled_projections=negative_prompt_poolers,
108
- encoder_hidden_states=negative_prompt_embeds,
109
- encoder_attention_mask=negative_prompt_embeds_mask,
110
- guidance=distilled_guidance,
111
- **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
112
- )
113
- )
114
-
115
- if sampler == 'unipc':
116
- results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
117
- else:
118
- raise NotImplementedError(f'Sampler {sampler} is not supported.')
119
-
120
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/thread_utils.py DELETED
@@ -1,76 +0,0 @@
1
- import time
2
-
3
- from threading import Thread, Lock
4
-
5
-
6
- class Listener:
7
- task_queue = []
8
- lock = Lock()
9
- thread = None
10
-
11
- @classmethod
12
- def _process_tasks(cls):
13
- while True:
14
- task = None
15
- with cls.lock:
16
- if cls.task_queue:
17
- task = cls.task_queue.pop(0)
18
-
19
- if task is None:
20
- time.sleep(0.001)
21
- continue
22
-
23
- func, args, kwargs = task
24
- try:
25
- func(*args, **kwargs)
26
- except Exception as e:
27
- print(f"Error in listener thread: {e}")
28
-
29
- @classmethod
30
- def add_task(cls, func, *args, **kwargs):
31
- with cls.lock:
32
- cls.task_queue.append((func, args, kwargs))
33
-
34
- if cls.thread is None:
35
- cls.thread = Thread(target=cls._process_tasks, daemon=True)
36
- cls.thread.start()
37
-
38
-
39
- def async_run(func, *args, **kwargs):
40
- Listener.add_task(func, *args, **kwargs)
41
-
42
-
43
- class FIFOQueue:
44
- def __init__(self):
45
- self.queue = []
46
- self.lock = Lock()
47
-
48
- def push(self, item):
49
- with self.lock:
50
- self.queue.append(item)
51
-
52
- def pop(self):
53
- with self.lock:
54
- if self.queue:
55
- return self.queue.pop(0)
56
- return None
57
-
58
- def top(self):
59
- with self.lock:
60
- if self.queue:
61
- return self.queue[0]
62
- return None
63
-
64
- def next(self):
65
- while True:
66
- with self.lock:
67
- if self.queue:
68
- return self.queue.pop(0)
69
-
70
- time.sleep(0.001)
71
-
72
-
73
- class AsyncStream:
74
- def __init__(self):
75
- self.input_queue = FIFOQueue()
76
- self.output_queue = FIFOQueue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers_helper/utils.py DELETED
@@ -1,613 +0,0 @@
1
- import os
2
- import cv2
3
- import json
4
- import random
5
- import glob
6
- import torch
7
- import einops
8
- import numpy as np
9
- import datetime
10
- import torchvision
11
-
12
- import safetensors.torch as sf
13
- from PIL import Image
14
-
15
-
16
- def min_resize(x, m):
17
- if x.shape[0] < x.shape[1]:
18
- s0 = m
19
- s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
20
- else:
21
- s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
22
- s1 = m
23
- new_max = max(s1, s0)
24
- raw_max = max(x.shape[0], x.shape[1])
25
- if new_max < raw_max:
26
- interpolation = cv2.INTER_AREA
27
- else:
28
- interpolation = cv2.INTER_LANCZOS4
29
- y = cv2.resize(x, (s1, s0), interpolation=interpolation)
30
- return y
31
-
32
-
33
- def d_resize(x, y):
34
- H, W, C = y.shape
35
- new_min = min(H, W)
36
- raw_min = min(x.shape[0], x.shape[1])
37
- if new_min < raw_min:
38
- interpolation = cv2.INTER_AREA
39
- else:
40
- interpolation = cv2.INTER_LANCZOS4
41
- y = cv2.resize(x, (W, H), interpolation=interpolation)
42
- return y
43
-
44
-
45
- def resize_and_center_crop(image, target_width, target_height):
46
- if target_height == image.shape[0] and target_width == image.shape[1]:
47
- return image
48
-
49
- pil_image = Image.fromarray(image)
50
- original_width, original_height = pil_image.size
51
- scale_factor = max(target_width / original_width, target_height / original_height)
52
- resized_width = int(round(original_width * scale_factor))
53
- resized_height = int(round(original_height * scale_factor))
54
- resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
55
- left = (resized_width - target_width) / 2
56
- top = (resized_height - target_height) / 2
57
- right = (resized_width + target_width) / 2
58
- bottom = (resized_height + target_height) / 2
59
- cropped_image = resized_image.crop((left, top, right, bottom))
60
- return np.array(cropped_image)
61
-
62
-
63
- def resize_and_center_crop_pytorch(image, target_width, target_height):
64
- B, C, H, W = image.shape
65
-
66
- if H == target_height and W == target_width:
67
- return image
68
-
69
- scale_factor = max(target_width / W, target_height / H)
70
- resized_width = int(round(W * scale_factor))
71
- resized_height = int(round(H * scale_factor))
72
-
73
- resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
74
-
75
- top = (resized_height - target_height) // 2
76
- left = (resized_width - target_width) // 2
77
- cropped = resized[:, :, top:top + target_height, left:left + target_width]
78
-
79
- return cropped
80
-
81
-
82
- def resize_without_crop(image, target_width, target_height):
83
- if target_height == image.shape[0] and target_width == image.shape[1]:
84
- return image
85
-
86
- pil_image = Image.fromarray(image)
87
- resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
88
- return np.array(resized_image)
89
-
90
-
91
- def just_crop(image, w, h):
92
- if h == image.shape[0] and w == image.shape[1]:
93
- return image
94
-
95
- original_height, original_width = image.shape[:2]
96
- k = min(original_height / h, original_width / w)
97
- new_width = int(round(w * k))
98
- new_height = int(round(h * k))
99
- x_start = (original_width - new_width) // 2
100
- y_start = (original_height - new_height) // 2
101
- cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
102
- return cropped_image
103
-
104
-
105
- def write_to_json(data, file_path):
106
- temp_file_path = file_path + ".tmp"
107
- with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
108
- json.dump(data, temp_file, indent=4)
109
- os.replace(temp_file_path, file_path)
110
- return
111
-
112
-
113
- def read_from_json(file_path):
114
- with open(file_path, 'rt', encoding='utf-8') as file:
115
- data = json.load(file)
116
- return data
117
-
118
-
119
- def get_active_parameters(m):
120
- return {k: v for k, v in m.named_parameters() if v.requires_grad}
121
-
122
-
123
- def cast_training_params(m, dtype=torch.float32):
124
- result = {}
125
- for n, param in m.named_parameters():
126
- if param.requires_grad:
127
- param.data = param.to(dtype)
128
- result[n] = param
129
- return result
130
-
131
-
132
- def separate_lora_AB(parameters, B_patterns=None):
133
- parameters_normal = {}
134
- parameters_B = {}
135
-
136
- if B_patterns is None:
137
- B_patterns = ['.lora_B.', '__zero__']
138
-
139
- for k, v in parameters.items():
140
- if any(B_pattern in k for B_pattern in B_patterns):
141
- parameters_B[k] = v
142
- else:
143
- parameters_normal[k] = v
144
-
145
- return parameters_normal, parameters_B
146
-
147
-
148
- def set_attr_recursive(obj, attr, value):
149
- attrs = attr.split(".")
150
- for name in attrs[:-1]:
151
- obj = getattr(obj, name)
152
- setattr(obj, attrs[-1], value)
153
- return
154
-
155
-
156
- def print_tensor_list_size(tensors):
157
- total_size = 0
158
- total_elements = 0
159
-
160
- if isinstance(tensors, dict):
161
- tensors = tensors.values()
162
-
163
- for tensor in tensors:
164
- total_size += tensor.nelement() * tensor.element_size()
165
- total_elements += tensor.nelement()
166
-
167
- total_size_MB = total_size / (1024 ** 2)
168
- total_elements_B = total_elements / 1e9
169
-
170
- print(f"Total number of tensors: {len(tensors)}")
171
- print(f"Total size of tensors: {total_size_MB:.2f} MB")
172
- print(f"Total number of parameters: {total_elements_B:.3f} billion")
173
- return
174
-
175
-
176
- @torch.no_grad()
177
- def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
178
- batch_size = a.size(0)
179
-
180
- if b is None:
181
- b = torch.zeros_like(a)
182
-
183
- if mask_a is None:
184
- mask_a = torch.rand(batch_size) < probability_a
185
-
186
- mask_a = mask_a.to(a.device)
187
- mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
188
- result = torch.where(mask_a, a, b)
189
- return result
190
-
191
-
192
- @torch.no_grad()
193
- def zero_module(module):
194
- for p in module.parameters():
195
- p.detach().zero_()
196
- return module
197
-
198
-
199
- @torch.no_grad()
200
- def supress_lower_channels(m, k, alpha=0.01):
201
- data = m.weight.data.clone()
202
-
203
- assert int(data.shape[1]) >= k
204
-
205
- data[:, :k] = data[:, :k] * alpha
206
- m.weight.data = data.contiguous().clone()
207
- return m
208
-
209
-
210
- def freeze_module(m):
211
- if not hasattr(m, '_forward_inside_frozen_module'):
212
- m._forward_inside_frozen_module = m.forward
213
- m.requires_grad_(False)
214
- m.forward = torch.no_grad()(m.forward)
215
- return m
216
-
217
-
218
- def get_latest_safetensors(folder_path):
219
- safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
220
-
221
- if not safetensors_files:
222
- raise ValueError('No file to resume!')
223
-
224
- latest_file = max(safetensors_files, key=os.path.getmtime)
225
- latest_file = os.path.abspath(os.path.realpath(latest_file))
226
- return latest_file
227
-
228
-
229
- def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
230
- tags = tags_str.split(', ')
231
- tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
232
- prompt = ', '.join(tags)
233
- return prompt
234
-
235
-
236
- def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
237
- numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
238
- if round_to_int:
239
- numbers = np.round(numbers).astype(int)
240
- return numbers.tolist()
241
-
242
-
243
- def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
244
- edges = np.linspace(0, 1, n + 1)
245
- points = np.random.uniform(edges[:-1], edges[1:])
246
- numbers = inclusive + (exclusive - inclusive) * points
247
- if round_to_int:
248
- numbers = np.round(numbers).astype(int)
249
- return numbers.tolist()
250
-
251
-
252
- def soft_append_bcthw(history, current, overlap=0):
253
- if overlap <= 0:
254
- return torch.cat([history, current], dim=2)
255
-
256
- assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
257
- assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
258
-
259
- weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
260
- blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
261
- output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
262
-
263
- return output.to(history)
264
-
265
-
266
- def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
267
- b, c, t, h, w = x.shape
268
-
269
- per_row = b
270
- for p in [6, 5, 4, 3, 2]:
271
- if b % p == 0:
272
- per_row = p
273
- break
274
-
275
- os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
276
- x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
277
- x = x.detach().cpu().to(torch.uint8)
278
- x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
279
- torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
280
- return x
281
-
282
-
283
- def save_bcthw_as_png(x, output_filename):
284
- os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
285
- x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
286
- x = x.detach().cpu().to(torch.uint8)
287
- x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
288
- torchvision.io.write_png(x, output_filename)
289
- return output_filename
290
-
291
-
292
- def save_bchw_as_png(x, output_filename):
293
- os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
294
- x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
295
- x = x.detach().cpu().to(torch.uint8)
296
- x = einops.rearrange(x, 'b c h w -> c h (b w)')
297
- torchvision.io.write_png(x, output_filename)
298
- return output_filename
299
-
300
-
301
- def add_tensors_with_padding(tensor1, tensor2):
302
- if tensor1.shape == tensor2.shape:
303
- return tensor1 + tensor2
304
-
305
- shape1 = tensor1.shape
306
- shape2 = tensor2.shape
307
-
308
- new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
309
-
310
- padded_tensor1 = torch.zeros(new_shape)
311
- padded_tensor2 = torch.zeros(new_shape)
312
-
313
- padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
314
- padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
315
-
316
- result = padded_tensor1 + padded_tensor2
317
- return result
318
-
319
-
320
- def print_free_mem():
321
- torch.cuda.empty_cache()
322
- free_mem, total_mem = torch.cuda.mem_get_info(0)
323
- free_mem_mb = free_mem / (1024 ** 2)
324
- total_mem_mb = total_mem / (1024 ** 2)
325
- print(f"Free memory: {free_mem_mb:.2f} MB")
326
- print(f"Total memory: {total_mem_mb:.2f} MB")
327
- return
328
-
329
-
330
- def print_gpu_parameters(device, state_dict, log_count=1):
331
- summary = {"device": device, "keys_count": len(state_dict)}
332
-
333
- logged_params = {}
334
- for i, (key, tensor) in enumerate(state_dict.items()):
335
- if i >= log_count:
336
- break
337
- logged_params[key] = tensor.flatten()[:3].tolist()
338
-
339
- summary["params"] = logged_params
340
-
341
- print(str(summary))
342
- return
343
-
344
-
345
- def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
346
- from PIL import Image, ImageDraw, ImageFont
347
-
348
- txt = Image.new("RGB", (width, height), color="white")
349
- draw = ImageDraw.Draw(txt)
350
- font = ImageFont.truetype(font_path, size=size)
351
-
352
- if text == '':
353
- return np.array(txt)
354
-
355
- # Split text into lines that fit within the image width
356
- lines = []
357
- words = text.split()
358
- current_line = words[0]
359
-
360
- for word in words[1:]:
361
- line_with_word = f"{current_line} {word}"
362
- if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
363
- current_line = line_with_word
364
- else:
365
- lines.append(current_line)
366
- current_line = word
367
-
368
- lines.append(current_line)
369
-
370
- # Draw the text line by line
371
- y = 0
372
- line_height = draw.textbbox((0, 0), "A", font=font)[3]
373
-
374
- for line in lines:
375
- if y + line_height > height:
376
- break # stop drawing if the next line will be outside the image
377
- draw.text((0, y), line, fill="black", font=font)
378
- y += line_height
379
-
380
- return np.array(txt)
381
-
382
-
383
- def blue_mark(x):
384
- x = x.copy()
385
- c = x[:, :, 2]
386
- b = cv2.blur(c, (9, 9))
387
- x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
388
- return x
389
-
390
-
391
- def green_mark(x):
392
- x = x.copy()
393
- x[:, :, 2] = -1
394
- x[:, :, 0] = -1
395
- return x
396
-
397
-
398
- def frame_mark(x):
399
- x = x.copy()
400
- x[:64] = -1
401
- x[-64:] = -1
402
- x[:, :8] = 1
403
- x[:, -8:] = 1
404
- return x
405
-
406
-
407
- @torch.inference_mode()
408
- def pytorch2numpy(imgs):
409
- results = []
410
- for x in imgs:
411
- y = x.movedim(0, -1)
412
- y = y * 127.5 + 127.5
413
- y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
414
- results.append(y)
415
- return results
416
-
417
-
418
- @torch.inference_mode()
419
- def numpy2pytorch(imgs):
420
- h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
421
- h = h.movedim(-1, 1)
422
- return h
423
-
424
-
425
- @torch.no_grad()
426
- def duplicate_prefix_to_suffix(x, count, zero_out=False):
427
- if zero_out:
428
- return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
429
- else:
430
- return torch.cat([x, x[:count]], dim=0)
431
-
432
-
433
- def weighted_mse(a, b, weight):
434
- return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
435
-
436
-
437
- def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
438
- x = (x - x_min) / (x_max - x_min)
439
- x = max(0.0, min(x, 1.0))
440
- x = x ** sigma
441
- return y_min + x * (y_max - y_min)
442
-
443
-
444
- def expand_to_dims(x, target_dims):
445
- return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
446
-
447
-
448
- def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
449
- if tensor is None:
450
- return None
451
-
452
- first_dim = tensor.shape[0]
453
-
454
- if first_dim == batch_size:
455
- return tensor
456
-
457
- if batch_size % first_dim != 0:
458
- raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
459
-
460
- repeat_times = batch_size // first_dim
461
-
462
- return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
463
-
464
-
465
- def dim5(x):
466
- return expand_to_dims(x, 5)
467
-
468
-
469
- def dim4(x):
470
- return expand_to_dims(x, 4)
471
-
472
-
473
- def dim3(x):
474
- return expand_to_dims(x, 3)
475
-
476
-
477
- def crop_or_pad_yield_mask(x, length):
478
- B, F, C = x.shape
479
- device = x.device
480
- dtype = x.dtype
481
-
482
- if F < length:
483
- y = torch.zeros((B, length, C), dtype=dtype, device=device)
484
- mask = torch.zeros((B, length), dtype=torch.bool, device=device)
485
- y[:, :F, :] = x
486
- mask[:, :F] = True
487
- return y, mask
488
-
489
- return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
490
-
491
-
492
- def extend_dim(x, dim, minimal_length, zero_pad=False):
493
- original_length = int(x.shape[dim])
494
-
495
- if original_length >= minimal_length:
496
- return x
497
-
498
- if zero_pad:
499
- padding_shape = list(x.shape)
500
- padding_shape[dim] = minimal_length - original_length
501
- padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
502
- else:
503
- idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
504
- last_element = x[idx]
505
- padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
506
-
507
- return torch.cat([x, padding], dim=dim)
508
-
509
-
510
- def lazy_positional_encoding(t, repeats=None):
511
- if not isinstance(t, list):
512
- t = [t]
513
-
514
- from diffusers.models.embeddings import get_timestep_embedding
515
-
516
- te = torch.tensor(t)
517
- te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
518
-
519
- if repeats is None:
520
- return te
521
-
522
- te = te[:, None, :].expand(-1, repeats, -1)
523
-
524
- return te
525
-
526
-
527
- def state_dict_offset_merge(A, B, C=None):
528
- result = {}
529
- keys = A.keys()
530
-
531
- for key in keys:
532
- A_value = A[key]
533
- B_value = B[key].to(A_value)
534
-
535
- if C is None:
536
- result[key] = A_value + B_value
537
- else:
538
- C_value = C[key].to(A_value)
539
- result[key] = A_value + B_value - C_value
540
-
541
- return result
542
-
543
-
544
- def state_dict_weighted_merge(state_dicts, weights):
545
- if len(state_dicts) != len(weights):
546
- raise ValueError("Number of state dictionaries must match number of weights")
547
-
548
- if not state_dicts:
549
- return {}
550
-
551
- total_weight = sum(weights)
552
-
553
- if total_weight == 0:
554
- raise ValueError("Sum of weights cannot be zero")
555
-
556
- normalized_weights = [w / total_weight for w in weights]
557
-
558
- keys = state_dicts[0].keys()
559
- result = {}
560
-
561
- for key in keys:
562
- result[key] = state_dicts[0][key] * normalized_weights[0]
563
-
564
- for i in range(1, len(state_dicts)):
565
- state_dict_value = state_dicts[i][key].to(result[key])
566
- result[key] += state_dict_value * normalized_weights[i]
567
-
568
- return result
569
-
570
-
571
- def group_files_by_folder(all_files):
572
- grouped_files = {}
573
-
574
- for file in all_files:
575
- folder_name = os.path.basename(os.path.dirname(file))
576
- if folder_name not in grouped_files:
577
- grouped_files[folder_name] = []
578
- grouped_files[folder_name].append(file)
579
-
580
- list_of_lists = list(grouped_files.values())
581
- return list_of_lists
582
-
583
-
584
- def generate_timestamp():
585
- now = datetime.datetime.now()
586
- timestamp = now.strftime('%y%m%d_%H%M%S')
587
- milliseconds = f"{int(now.microsecond / 1000):03d}"
588
- random_number = random.randint(0, 9999)
589
- return f"{timestamp}_{milliseconds}_{random_number}"
590
-
591
-
592
- def write_PIL_image_with_png_info(image, metadata, path):
593
- from PIL.PngImagePlugin import PngInfo
594
-
595
- png_info = PngInfo()
596
- for key, value in metadata.items():
597
- png_info.add_text(key, value)
598
-
599
- image.save(path, "PNG", pnginfo=png_info)
600
- return image
601
-
602
-
603
- def torch_safe_save(content, path):
604
- torch.save(content, path + '_tmp')
605
- os.replace(path + '_tmp', path)
606
- return path
607
-
608
-
609
- def move_optimizer_to_device(optimizer, device):
610
- for state in optimizer.state.values():
611
- for k, v in state.items():
612
- if isinstance(v, torch.Tensor):
613
- state[k] = v.to(device)