This Pull Request also extends a video & uses end frame

#1
.gitattributes CHANGED
@@ -2,4 +2,9 @@
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  *.mp4 filter=lfs diff=lfs merge=lfs -text
4
  *.pt filter=lfs diff=lfs merge=lfs -text
5
- *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  *.mp4 filter=lfs diff=lfs merge=lfs -text
4
  *.pt filter=lfs diff=lfs merge=lfs -text
5
+ *.pth filter=lfs diff=lfs merge=lfs -text img_examples/Example1.png filter=lfs diff=lfs merge=lfs -text
6
+ img_examples/Example2.webp filter=lfs diff=lfs merge=lfs -text
7
+ img_examples/Example3.jpg filter=lfs diff=lfs merge=lfs -text
8
+ img_examples/Example4.webp filter=lfs diff=lfs merge=lfs -text
9
+ img_examples/Example5.png filter=lfs diff=lfs merge=lfs -text
10
+ img_examples/Example6.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -4,7 +4,18 @@ emoji: 🎬
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.29.0
8
  app_file: app.py
9
- pinned: false
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.29.1
8
  app_file: app.py
9
+ license: apache-2.0
10
+ short_description: Text-to-Video/Image-to-Video/Video extender (timed prompt)
11
+ tags:
12
+ - Image-to-Video
13
+ - Image-2-Video
14
+ - Img-to-Vid
15
+ - Img-2-Vid
16
+ - language models
17
+ - LLMs
18
+ suggested_hardware: zero-a10g
19
+ ---
20
+
21
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
app_lora.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusers_helper/bucket_tools.py CHANGED
@@ -1,30 +1,103 @@
1
- bucket_options = {
2
- 640: [
3
- (416, 960),
4
- (448, 864),
5
- (480, 832),
6
- (512, 768),
7
- (544, 704),
8
- (576, 672),
9
- (608, 640),
10
- (640, 608),
11
- (672, 576),
12
- (704, 544),
13
- (768, 512),
14
- (832, 480),
15
- (864, 448),
16
- (960, 416),
17
- ],
18
- }
19
-
20
-
21
- def find_nearest_bucket(h, w, resolution=640):
22
- min_metric = float('inf')
23
- best_bucket = None
24
- for (bucket_h, bucket_w) in bucket_options[resolution]:
25
- metric = abs(h * bucket_w - w * bucket_h)
26
- if metric <= min_metric:
27
- min_metric = metric
28
- best_bucket = (bucket_h, bucket_w)
29
- return best_bucket
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bucket_options = {
2
+ 640: [
3
+ (416, 960),
4
+ (448, 864),
5
+ (480, 832),
6
+ (512, 768),
7
+ (544, 704),
8
+ (576, 672),
9
+ (608, 640),
10
+ (640, 608),
11
+ (672, 576),
12
+ (704, 544),
13
+ (768, 512),
14
+ (832, 480),
15
+ (864, 448),
16
+ (960, 416),
17
+ ],
18
+ 672: [
19
+ (480, 864),
20
+ (512, 832),
21
+ (544, 768),
22
+ (576, 704),
23
+ (608, 672),
24
+ (640, 640),
25
+ (672, 608),
26
+ (704, 576),
27
+ (768, 544),
28
+ (832, 512),
29
+ (864, 480),
30
+ ],
31
+ 704: [
32
+ (480, 960),
33
+ (512, 864),
34
+ (544, 832),
35
+ (576, 768),
36
+ (608, 704),
37
+ (640, 672),
38
+ (672, 640),
39
+ (704, 608),
40
+ (768, 576),
41
+ (832, 544),
42
+ (864, 512),
43
+ (960, 480),
44
+ ],
45
+ 768: [
46
+ (512, 960),
47
+ (544, 864),
48
+ (576, 832),
49
+ (608, 768),
50
+ (640, 704),
51
+ (672, 672),
52
+ (704, 640),
53
+ (768, 608),
54
+ (832, 576),
55
+ (864, 544),
56
+ (960, 512),
57
+ ],
58
+ 832: [
59
+ (544, 960),
60
+ (576, 864),
61
+ (608, 832),
62
+ (640, 768),
63
+ (672, 704),
64
+ (704, 672),
65
+ (768, 640),
66
+ (832, 608),
67
+ (864, 576),
68
+ (960, 544),
69
+ ],
70
+ 864: [
71
+ (576, 960),
72
+ (608, 864),
73
+ (640, 832),
74
+ (672, 768),
75
+ (704, 704),
76
+ (768, 672),
77
+ (832, 640),
78
+ (864, 608),
79
+ (960, 576),
80
+ ],
81
+ 960: [
82
+ (608, 960),
83
+ (640, 864),
84
+ (672, 832),
85
+ (704, 768),
86
+ (768, 704),
87
+ (832, 672),
88
+ (864, 640),
89
+ (960, 608),
90
+ ],
91
+ }
92
+
93
+
94
+ def find_nearest_bucket(h, w, resolution=640):
95
+ min_metric = float('inf')
96
+ best_bucket = None
97
+ for (bucket_h, bucket_w) in bucket_options[resolution]:
98
+ metric = abs(h * bucket_w - w * bucket_h)
99
+ if metric <= min_metric:
100
+ min_metric = metric
101
+ best_bucket = (bucket_h, bucket_w)
102
+ print("The resolution of the generated video will be " + str(best_bucket))
103
+ return best_bucket
diffusers_helper/models/hunyuan_video_packed.py CHANGED
@@ -1,1032 +1,1032 @@
1
- from typing import Any, Dict, List, Optional, Tuple, Union
2
-
3
- import torch
4
- import einops
5
- import torch.nn as nn
6
- import numpy as np
7
-
8
- from diffusers.loaders import FromOriginalModelMixin
9
- from diffusers.configuration_utils import ConfigMixin, register_to_config
10
- from diffusers.loaders import PeftAdapterMixin
11
- from diffusers.utils import logging
12
- from diffusers.models.attention import FeedForward
13
- from diffusers.models.attention_processor import Attention
14
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
- from diffusers.models.modeling_utils import ModelMixin
17
- from diffusers_helper.dit_common import LayerNorm
18
- from diffusers_helper.utils import zero_module
19
-
20
-
21
- enabled_backends = []
22
-
23
- if torch.backends.cuda.flash_sdp_enabled():
24
- enabled_backends.append("flash")
25
- if torch.backends.cuda.math_sdp_enabled():
26
- enabled_backends.append("math")
27
- if torch.backends.cuda.mem_efficient_sdp_enabled():
28
- enabled_backends.append("mem_efficient")
29
- if torch.backends.cuda.cudnn_sdp_enabled():
30
- enabled_backends.append("cudnn")
31
-
32
- print("Currently enabled native sdp backends:", enabled_backends)
33
-
34
- try:
35
- # raise NotImplementedError
36
- from xformers.ops import memory_efficient_attention as xformers_attn_func
37
- print('Xformers is installed!')
38
- except:
39
- print('Xformers is not installed!')
40
- xformers_attn_func = None
41
-
42
- try:
43
- # raise NotImplementedError
44
- from flash_attn import flash_attn_varlen_func, flash_attn_func
45
- print('Flash Attn is installed!')
46
- except:
47
- print('Flash Attn is not installed!')
48
- flash_attn_varlen_func = None
49
- flash_attn_func = None
50
-
51
- try:
52
- # raise NotImplementedError
53
- from sageattention import sageattn_varlen, sageattn
54
- print('Sage Attn is installed!')
55
- except:
56
- print('Sage Attn is not installed!')
57
- sageattn_varlen = None
58
- sageattn = None
59
-
60
-
61
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
-
63
-
64
- def pad_for_3d_conv(x, kernel_size):
65
- b, c, t, h, w = x.shape
66
- pt, ph, pw = kernel_size
67
- pad_t = (pt - (t % pt)) % pt
68
- pad_h = (ph - (h % ph)) % ph
69
- pad_w = (pw - (w % pw)) % pw
70
- return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
71
-
72
-
73
- def center_down_sample_3d(x, kernel_size):
74
- # pt, ph, pw = kernel_size
75
- # cp = (pt * ph * pw) // 2
76
- # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
77
- # xc = xp[cp]
78
- # return xc
79
- return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
80
-
81
-
82
- def get_cu_seqlens(text_mask, img_len):
83
- batch_size = text_mask.shape[0]
84
- text_len = text_mask.sum(dim=1)
85
- max_len = text_mask.shape[1] + img_len
86
-
87
- cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
88
-
89
- for i in range(batch_size):
90
- s = text_len[i] + img_len
91
- s1 = i * max_len + s
92
- s2 = (i + 1) * max_len
93
- cu_seqlens[2 * i + 1] = s1
94
- cu_seqlens[2 * i + 2] = s2
95
-
96
- return cu_seqlens
97
-
98
-
99
- def apply_rotary_emb_transposed(x, freqs_cis):
100
- cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
101
- x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
102
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
103
- out = x.float() * cos + x_rotated.float() * sin
104
- out = out.to(x)
105
- return out
106
-
107
-
108
- def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
109
- if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
110
- if sageattn is not None:
111
- x = sageattn(q, k, v, tensor_layout='NHD')
112
- return x
113
-
114
- if flash_attn_func is not None:
115
- x = flash_attn_func(q, k, v)
116
- return x
117
-
118
- if xformers_attn_func is not None:
119
- x = xformers_attn_func(q, k, v)
120
- return x
121
-
122
- x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
- return x
124
-
125
- batch_size = q.shape[0]
126
- q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
127
- k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
128
- v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
129
- if sageattn_varlen is not None:
130
- x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
131
- elif flash_attn_varlen_func is not None:
132
- x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
- else:
134
- raise NotImplementedError('No Attn Installed!')
135
- x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
136
- return x
137
-
138
-
139
- class HunyuanAttnProcessorFlashAttnDouble:
140
- def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
141
- cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
142
-
143
- query = attn.to_q(hidden_states)
144
- key = attn.to_k(hidden_states)
145
- value = attn.to_v(hidden_states)
146
-
147
- query = query.unflatten(2, (attn.heads, -1))
148
- key = key.unflatten(2, (attn.heads, -1))
149
- value = value.unflatten(2, (attn.heads, -1))
150
-
151
- query = attn.norm_q(query)
152
- key = attn.norm_k(key)
153
-
154
- query = apply_rotary_emb_transposed(query, image_rotary_emb)
155
- key = apply_rotary_emb_transposed(key, image_rotary_emb)
156
-
157
- encoder_query = attn.add_q_proj(encoder_hidden_states)
158
- encoder_key = attn.add_k_proj(encoder_hidden_states)
159
- encoder_value = attn.add_v_proj(encoder_hidden_states)
160
-
161
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
162
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
163
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
164
-
165
- encoder_query = attn.norm_added_q(encoder_query)
166
- encoder_key = attn.norm_added_k(encoder_key)
167
-
168
- query = torch.cat([query, encoder_query], dim=1)
169
- key = torch.cat([key, encoder_key], dim=1)
170
- value = torch.cat([value, encoder_value], dim=1)
171
-
172
- hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
173
- hidden_states = hidden_states.flatten(-2)
174
-
175
- txt_length = encoder_hidden_states.shape[1]
176
- hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
177
-
178
- hidden_states = attn.to_out[0](hidden_states)
179
- hidden_states = attn.to_out[1](hidden_states)
180
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
181
-
182
- return hidden_states, encoder_hidden_states
183
-
184
-
185
- class HunyuanAttnProcessorFlashAttnSingle:
186
- def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
187
- cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
188
-
189
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
190
-
191
- query = attn.to_q(hidden_states)
192
- key = attn.to_k(hidden_states)
193
- value = attn.to_v(hidden_states)
194
-
195
- query = query.unflatten(2, (attn.heads, -1))
196
- key = key.unflatten(2, (attn.heads, -1))
197
- value = value.unflatten(2, (attn.heads, -1))
198
-
199
- query = attn.norm_q(query)
200
- key = attn.norm_k(key)
201
-
202
- txt_length = encoder_hidden_states.shape[1]
203
-
204
- query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
205
- key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
206
-
207
- hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
208
- hidden_states = hidden_states.flatten(-2)
209
-
210
- hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
211
-
212
- return hidden_states, encoder_hidden_states
213
-
214
-
215
- class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
216
- def __init__(self, embedding_dim, pooled_projection_dim):
217
- super().__init__()
218
-
219
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
220
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
221
- self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
222
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
223
-
224
- def forward(self, timestep, guidance, pooled_projection):
225
- timesteps_proj = self.time_proj(timestep)
226
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
227
-
228
- guidance_proj = self.time_proj(guidance)
229
- guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
230
-
231
- time_guidance_emb = timesteps_emb + guidance_emb
232
-
233
- pooled_projections = self.text_embedder(pooled_projection)
234
- conditioning = time_guidance_emb + pooled_projections
235
-
236
- return conditioning
237
-
238
-
239
- class CombinedTimestepTextProjEmbeddings(nn.Module):
240
- def __init__(self, embedding_dim, pooled_projection_dim):
241
- super().__init__()
242
-
243
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
244
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
245
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
246
-
247
- def forward(self, timestep, pooled_projection):
248
- timesteps_proj = self.time_proj(timestep)
249
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
250
-
251
- pooled_projections = self.text_embedder(pooled_projection)
252
-
253
- conditioning = timesteps_emb + pooled_projections
254
-
255
- return conditioning
256
-
257
-
258
- class HunyuanVideoAdaNorm(nn.Module):
259
- def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
260
- super().__init__()
261
-
262
- out_features = out_features or 2 * in_features
263
- self.linear = nn.Linear(in_features, out_features)
264
- self.nonlinearity = nn.SiLU()
265
-
266
- def forward(
267
- self, temb: torch.Tensor
268
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
269
- temb = self.linear(self.nonlinearity(temb))
270
- gate_msa, gate_mlp = temb.chunk(2, dim=-1)
271
- gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
272
- return gate_msa, gate_mlp
273
-
274
-
275
- class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
276
- def __init__(
277
- self,
278
- num_attention_heads: int,
279
- attention_head_dim: int,
280
- mlp_width_ratio: str = 4.0,
281
- mlp_drop_rate: float = 0.0,
282
- attention_bias: bool = True,
283
- ) -> None:
284
- super().__init__()
285
-
286
- hidden_size = num_attention_heads * attention_head_dim
287
-
288
- self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
289
- self.attn = Attention(
290
- query_dim=hidden_size,
291
- cross_attention_dim=None,
292
- heads=num_attention_heads,
293
- dim_head=attention_head_dim,
294
- bias=attention_bias,
295
- )
296
-
297
- self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
298
- self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
299
-
300
- self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
301
-
302
- def forward(
303
- self,
304
- hidden_states: torch.Tensor,
305
- temb: torch.Tensor,
306
- attention_mask: Optional[torch.Tensor] = None,
307
- ) -> torch.Tensor:
308
- norm_hidden_states = self.norm1(hidden_states)
309
-
310
- attn_output = self.attn(
311
- hidden_states=norm_hidden_states,
312
- encoder_hidden_states=None,
313
- attention_mask=attention_mask,
314
- )
315
-
316
- gate_msa, gate_mlp = self.norm_out(temb)
317
- hidden_states = hidden_states + attn_output * gate_msa
318
-
319
- ff_output = self.ff(self.norm2(hidden_states))
320
- hidden_states = hidden_states + ff_output * gate_mlp
321
-
322
- return hidden_states
323
-
324
-
325
- class HunyuanVideoIndividualTokenRefiner(nn.Module):
326
- def __init__(
327
- self,
328
- num_attention_heads: int,
329
- attention_head_dim: int,
330
- num_layers: int,
331
- mlp_width_ratio: float = 4.0,
332
- mlp_drop_rate: float = 0.0,
333
- attention_bias: bool = True,
334
- ) -> None:
335
- super().__init__()
336
-
337
- self.refiner_blocks = nn.ModuleList(
338
- [
339
- HunyuanVideoIndividualTokenRefinerBlock(
340
- num_attention_heads=num_attention_heads,
341
- attention_head_dim=attention_head_dim,
342
- mlp_width_ratio=mlp_width_ratio,
343
- mlp_drop_rate=mlp_drop_rate,
344
- attention_bias=attention_bias,
345
- )
346
- for _ in range(num_layers)
347
- ]
348
- )
349
-
350
- def forward(
351
- self,
352
- hidden_states: torch.Tensor,
353
- temb: torch.Tensor,
354
- attention_mask: Optional[torch.Tensor] = None,
355
- ) -> None:
356
- self_attn_mask = None
357
- if attention_mask is not None:
358
- batch_size = attention_mask.shape[0]
359
- seq_len = attention_mask.shape[1]
360
- attention_mask = attention_mask.to(hidden_states.device).bool()
361
- self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
362
- self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
363
- self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
364
- self_attn_mask[:, :, :, 0] = True
365
-
366
- for block in self.refiner_blocks:
367
- hidden_states = block(hidden_states, temb, self_attn_mask)
368
-
369
- return hidden_states
370
-
371
-
372
- class HunyuanVideoTokenRefiner(nn.Module):
373
- def __init__(
374
- self,
375
- in_channels: int,
376
- num_attention_heads: int,
377
- attention_head_dim: int,
378
- num_layers: int,
379
- mlp_ratio: float = 4.0,
380
- mlp_drop_rate: float = 0.0,
381
- attention_bias: bool = True,
382
- ) -> None:
383
- super().__init__()
384
-
385
- hidden_size = num_attention_heads * attention_head_dim
386
-
387
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
388
- embedding_dim=hidden_size, pooled_projection_dim=in_channels
389
- )
390
- self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
391
- self.token_refiner = HunyuanVideoIndividualTokenRefiner(
392
- num_attention_heads=num_attention_heads,
393
- attention_head_dim=attention_head_dim,
394
- num_layers=num_layers,
395
- mlp_width_ratio=mlp_ratio,
396
- mlp_drop_rate=mlp_drop_rate,
397
- attention_bias=attention_bias,
398
- )
399
-
400
- def forward(
401
- self,
402
- hidden_states: torch.Tensor,
403
- timestep: torch.LongTensor,
404
- attention_mask: Optional[torch.LongTensor] = None,
405
- ) -> torch.Tensor:
406
- if attention_mask is None:
407
- pooled_projections = hidden_states.mean(dim=1)
408
- else:
409
- original_dtype = hidden_states.dtype
410
- mask_float = attention_mask.float().unsqueeze(-1)
411
- pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
412
- pooled_projections = pooled_projections.to(original_dtype)
413
-
414
- temb = self.time_text_embed(timestep, pooled_projections)
415
- hidden_states = self.proj_in(hidden_states)
416
- hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
417
-
418
- return hidden_states
419
-
420
-
421
- class HunyuanVideoRotaryPosEmbed(nn.Module):
422
- def __init__(self, rope_dim, theta):
423
- super().__init__()
424
- self.DT, self.DY, self.DX = rope_dim
425
- self.theta = theta
426
-
427
- @torch.no_grad()
428
- def get_frequency(self, dim, pos):
429
- T, H, W = pos.shape
430
- freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
431
- freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
432
- return freqs.cos(), freqs.sin()
433
-
434
- @torch.no_grad()
435
- def forward_inner(self, frame_indices, height, width, device):
436
- GT, GY, GX = torch.meshgrid(
437
- frame_indices.to(device=device, dtype=torch.float32),
438
- torch.arange(0, height, device=device, dtype=torch.float32),
439
- torch.arange(0, width, device=device, dtype=torch.float32),
440
- indexing="ij"
441
- )
442
-
443
- FCT, FST = self.get_frequency(self.DT, GT)
444
- FCY, FSY = self.get_frequency(self.DY, GY)
445
- FCX, FSX = self.get_frequency(self.DX, GX)
446
-
447
- result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
448
-
449
- return result.to(device)
450
-
451
- @torch.no_grad()
452
- def forward(self, frame_indices, height, width, device):
453
- frame_indices = frame_indices.unbind(0)
454
- results = [self.forward_inner(f, height, width, device) for f in frame_indices]
455
- results = torch.stack(results, dim=0)
456
- return results
457
-
458
-
459
- class AdaLayerNormZero(nn.Module):
460
- def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
461
- super().__init__()
462
- self.silu = nn.SiLU()
463
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
464
- if norm_type == "layer_norm":
465
- self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
466
- else:
467
- raise ValueError(f"unknown norm_type {norm_type}")
468
-
469
- def forward(
470
- self,
471
- x: torch.Tensor,
472
- emb: Optional[torch.Tensor] = None,
473
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
474
- emb = emb.unsqueeze(-2)
475
- emb = self.linear(self.silu(emb))
476
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
477
- x = self.norm(x) * (1 + scale_msa) + shift_msa
478
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
479
-
480
-
481
- class AdaLayerNormZeroSingle(nn.Module):
482
- def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
483
- super().__init__()
484
-
485
- self.silu = nn.SiLU()
486
- self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
487
- if norm_type == "layer_norm":
488
- self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
489
- else:
490
- raise ValueError(f"unknown norm_type {norm_type}")
491
-
492
- def forward(
493
- self,
494
- x: torch.Tensor,
495
- emb: Optional[torch.Tensor] = None,
496
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
497
- emb = emb.unsqueeze(-2)
498
- emb = self.linear(self.silu(emb))
499
- shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
500
- x = self.norm(x) * (1 + scale_msa) + shift_msa
501
- return x, gate_msa
502
-
503
-
504
- class AdaLayerNormContinuous(nn.Module):
505
- def __init__(
506
- self,
507
- embedding_dim: int,
508
- conditioning_embedding_dim: int,
509
- elementwise_affine=True,
510
- eps=1e-5,
511
- bias=True,
512
- norm_type="layer_norm",
513
- ):
514
- super().__init__()
515
- self.silu = nn.SiLU()
516
- self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
517
- if norm_type == "layer_norm":
518
- self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
519
- else:
520
- raise ValueError(f"unknown norm_type {norm_type}")
521
-
522
- def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
523
- emb = emb.unsqueeze(-2)
524
- emb = self.linear(self.silu(emb))
525
- scale, shift = emb.chunk(2, dim=-1)
526
- x = self.norm(x) * (1 + scale) + shift
527
- return x
528
-
529
-
530
- class HunyuanVideoSingleTransformerBlock(nn.Module):
531
- def __init__(
532
- self,
533
- num_attention_heads: int,
534
- attention_head_dim: int,
535
- mlp_ratio: float = 4.0,
536
- qk_norm: str = "rms_norm",
537
- ) -> None:
538
- super().__init__()
539
-
540
- hidden_size = num_attention_heads * attention_head_dim
541
- mlp_dim = int(hidden_size * mlp_ratio)
542
-
543
- self.attn = Attention(
544
- query_dim=hidden_size,
545
- cross_attention_dim=None,
546
- dim_head=attention_head_dim,
547
- heads=num_attention_heads,
548
- out_dim=hidden_size,
549
- bias=True,
550
- processor=HunyuanAttnProcessorFlashAttnSingle(),
551
- qk_norm=qk_norm,
552
- eps=1e-6,
553
- pre_only=True,
554
- )
555
-
556
- self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
557
- self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
558
- self.act_mlp = nn.GELU(approximate="tanh")
559
- self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
560
-
561
- def forward(
562
- self,
563
- hidden_states: torch.Tensor,
564
- encoder_hidden_states: torch.Tensor,
565
- temb: torch.Tensor,
566
- attention_mask: Optional[torch.Tensor] = None,
567
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
568
- ) -> torch.Tensor:
569
- text_seq_length = encoder_hidden_states.shape[1]
570
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
571
-
572
- residual = hidden_states
573
-
574
- # 1. Input normalization
575
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
576
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
577
-
578
- norm_hidden_states, norm_encoder_hidden_states = (
579
- norm_hidden_states[:, :-text_seq_length, :],
580
- norm_hidden_states[:, -text_seq_length:, :],
581
- )
582
-
583
- # 2. Attention
584
- attn_output, context_attn_output = self.attn(
585
- hidden_states=norm_hidden_states,
586
- encoder_hidden_states=norm_encoder_hidden_states,
587
- attention_mask=attention_mask,
588
- image_rotary_emb=image_rotary_emb,
589
- )
590
- attn_output = torch.cat([attn_output, context_attn_output], dim=1)
591
-
592
- # 3. Modulation and residual connection
593
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
594
- hidden_states = gate * self.proj_out(hidden_states)
595
- hidden_states = hidden_states + residual
596
-
597
- hidden_states, encoder_hidden_states = (
598
- hidden_states[:, :-text_seq_length, :],
599
- hidden_states[:, -text_seq_length:, :],
600
- )
601
- return hidden_states, encoder_hidden_states
602
-
603
-
604
- class HunyuanVideoTransformerBlock(nn.Module):
605
- def __init__(
606
- self,
607
- num_attention_heads: int,
608
- attention_head_dim: int,
609
- mlp_ratio: float,
610
- qk_norm: str = "rms_norm",
611
- ) -> None:
612
- super().__init__()
613
-
614
- hidden_size = num_attention_heads * attention_head_dim
615
-
616
- self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
617
- self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
618
-
619
- self.attn = Attention(
620
- query_dim=hidden_size,
621
- cross_attention_dim=None,
622
- added_kv_proj_dim=hidden_size,
623
- dim_head=attention_head_dim,
624
- heads=num_attention_heads,
625
- out_dim=hidden_size,
626
- context_pre_only=False,
627
- bias=True,
628
- processor=HunyuanAttnProcessorFlashAttnDouble(),
629
- qk_norm=qk_norm,
630
- eps=1e-6,
631
- )
632
-
633
- self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
634
- self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
635
-
636
- self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
637
- self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
638
-
639
- def forward(
640
- self,
641
- hidden_states: torch.Tensor,
642
- encoder_hidden_states: torch.Tensor,
643
- temb: torch.Tensor,
644
- attention_mask: Optional[torch.Tensor] = None,
645
- freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
646
- ) -> Tuple[torch.Tensor, torch.Tensor]:
647
- # 1. Input normalization
648
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
649
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
650
-
651
- # 2. Joint attention
652
- attn_output, context_attn_output = self.attn(
653
- hidden_states=norm_hidden_states,
654
- encoder_hidden_states=norm_encoder_hidden_states,
655
- attention_mask=attention_mask,
656
- image_rotary_emb=freqs_cis,
657
- )
658
-
659
- # 3. Modulation and residual connection
660
- hidden_states = hidden_states + attn_output * gate_msa
661
- encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
662
-
663
- norm_hidden_states = self.norm2(hidden_states)
664
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
665
-
666
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
667
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
668
-
669
- # 4. Feed-forward
670
- ff_output = self.ff(norm_hidden_states)
671
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
672
-
673
- hidden_states = hidden_states + gate_mlp * ff_output
674
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
675
-
676
- return hidden_states, encoder_hidden_states
677
-
678
-
679
- class ClipVisionProjection(nn.Module):
680
- def __init__(self, in_channels, out_channels):
681
- super().__init__()
682
- self.up = nn.Linear(in_channels, out_channels * 3)
683
- self.down = nn.Linear(out_channels * 3, out_channels)
684
-
685
- def forward(self, x):
686
- projected_x = self.down(nn.functional.silu(self.up(x)))
687
- return projected_x
688
-
689
-
690
- class HunyuanVideoPatchEmbed(nn.Module):
691
- def __init__(self, patch_size, in_chans, embed_dim):
692
- super().__init__()
693
- self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
694
-
695
-
696
- class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
697
- def __init__(self, inner_dim):
698
- super().__init__()
699
- self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
700
- self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
701
- self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
702
-
703
- @torch.no_grad()
704
- def initialize_weight_from_another_conv3d(self, another_layer):
705
- weight = another_layer.weight.detach().clone()
706
- bias = another_layer.bias.detach().clone()
707
-
708
- sd = {
709
- 'proj.weight': weight.clone(),
710
- 'proj.bias': bias.clone(),
711
- 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
712
- 'proj_2x.bias': bias.clone(),
713
- 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
714
- 'proj_4x.bias': bias.clone(),
715
- }
716
-
717
- sd = {k: v.clone() for k, v in sd.items()}
718
-
719
- self.load_state_dict(sd)
720
- return
721
-
722
-
723
- class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
724
- @register_to_config
725
- def __init__(
726
- self,
727
- in_channels: int = 16,
728
- out_channels: int = 16,
729
- num_attention_heads: int = 24,
730
- attention_head_dim: int = 128,
731
- num_layers: int = 20,
732
- num_single_layers: int = 40,
733
- num_refiner_layers: int = 2,
734
- mlp_ratio: float = 4.0,
735
- patch_size: int = 2,
736
- patch_size_t: int = 1,
737
- qk_norm: str = "rms_norm",
738
- guidance_embeds: bool = True,
739
- text_embed_dim: int = 4096,
740
- pooled_projection_dim: int = 768,
741
- rope_theta: float = 256.0,
742
- rope_axes_dim: Tuple[int] = (16, 56, 56),
743
- has_image_proj=False,
744
- image_proj_dim=1152,
745
- has_clean_x_embedder=False,
746
- ) -> None:
747
- super().__init__()
748
-
749
- inner_dim = num_attention_heads * attention_head_dim
750
- out_channels = out_channels or in_channels
751
-
752
- # 1. Latent and condition embedders
753
- self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
754
- self.context_embedder = HunyuanVideoTokenRefiner(
755
- text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
756
- )
757
- self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
758
-
759
- self.clean_x_embedder = None
760
- self.image_projection = None
761
-
762
- # 2. RoPE
763
- self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
764
-
765
- # 3. Dual stream transformer blocks
766
- self.transformer_blocks = nn.ModuleList(
767
- [
768
- HunyuanVideoTransformerBlock(
769
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
770
- )
771
- for _ in range(num_layers)
772
- ]
773
- )
774
-
775
- # 4. Single stream transformer blocks
776
- self.single_transformer_blocks = nn.ModuleList(
777
- [
778
- HunyuanVideoSingleTransformerBlock(
779
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
780
- )
781
- for _ in range(num_single_layers)
782
- ]
783
- )
784
-
785
- # 5. Output projection
786
- self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
787
- self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
788
-
789
- self.inner_dim = inner_dim
790
- self.use_gradient_checkpointing = False
791
- self.enable_teacache = False
792
-
793
- if has_image_proj:
794
- self.install_image_projection(image_proj_dim)
795
-
796
- if has_clean_x_embedder:
797
- self.install_clean_x_embedder()
798
-
799
- self.high_quality_fp32_output_for_inference = False
800
-
801
- def install_image_projection(self, in_channels):
802
- self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
803
- self.config['has_image_proj'] = True
804
- self.config['image_proj_dim'] = in_channels
805
-
806
- def install_clean_x_embedder(self):
807
- self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
808
- self.config['has_clean_x_embedder'] = True
809
-
810
- def enable_gradient_checkpointing(self):
811
- self.use_gradient_checkpointing = True
812
- print('self.use_gradient_checkpointing = True')
813
-
814
- def disable_gradient_checkpointing(self):
815
- self.use_gradient_checkpointing = False
816
- print('self.use_gradient_checkpointing = False')
817
-
818
- def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
819
- self.enable_teacache = enable_teacache
820
- self.cnt = 0
821
- self.num_steps = num_steps
822
- self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
823
- self.accumulated_rel_l1_distance = 0
824
- self.previous_modulated_input = None
825
- self.previous_residual = None
826
- self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
827
-
828
- def gradient_checkpointing_method(self, block, *args):
829
- if self.use_gradient_checkpointing:
830
- result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
831
- else:
832
- result = block(*args)
833
- return result
834
-
835
- def process_input_hidden_states(
836
- self,
837
- latents, latent_indices=None,
838
- clean_latents=None, clean_latent_indices=None,
839
- clean_latents_2x=None, clean_latent_2x_indices=None,
840
- clean_latents_4x=None, clean_latent_4x_indices=None
841
- ):
842
- hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
843
- B, C, T, H, W = hidden_states.shape
844
-
845
- if latent_indices is None:
846
- latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
847
-
848
- hidden_states = hidden_states.flatten(2).transpose(1, 2)
849
-
850
- rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
851
- rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
852
-
853
- if clean_latents is not None and clean_latent_indices is not None:
854
- clean_latents = clean_latents.to(hidden_states)
855
- clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
856
- clean_latents = clean_latents.flatten(2).transpose(1, 2)
857
-
858
- clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
859
- clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
860
-
861
- hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
862
- rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
863
-
864
- if clean_latents_2x is not None and clean_latent_2x_indices is not None:
865
- clean_latents_2x = clean_latents_2x.to(hidden_states)
866
- clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
867
- clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
868
- clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
869
-
870
- clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
871
- clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
872
- clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
873
- clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
874
-
875
- hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
876
- rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
877
-
878
- if clean_latents_4x is not None and clean_latent_4x_indices is not None:
879
- clean_latents_4x = clean_latents_4x.to(hidden_states)
880
- clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
881
- clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
882
- clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
883
-
884
- clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
885
- clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
886
- clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
887
- clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
888
-
889
- hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
890
- rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
891
-
892
- return hidden_states, rope_freqs
893
-
894
- def forward(
895
- self,
896
- hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
897
- latent_indices=None,
898
- clean_latents=None, clean_latent_indices=None,
899
- clean_latents_2x=None, clean_latent_2x_indices=None,
900
- clean_latents_4x=None, clean_latent_4x_indices=None,
901
- image_embeddings=None,
902
- attention_kwargs=None, return_dict=True
903
- ):
904
-
905
- if attention_kwargs is None:
906
- attention_kwargs = {}
907
-
908
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
909
- p, p_t = self.config['patch_size'], self.config['patch_size_t']
910
- post_patch_num_frames = num_frames // p_t
911
- post_patch_height = height // p
912
- post_patch_width = width // p
913
- original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
914
-
915
- hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
916
-
917
- temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
918
- encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
919
-
920
- if self.image_projection is not None:
921
- assert image_embeddings is not None, 'You must use image embeddings!'
922
- extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
923
- extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
924
-
925
- # must cat before (not after) encoder_hidden_states, due to attn masking
926
- encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
927
- encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
928
-
929
- with torch.no_grad():
930
- if batch_size == 1:
931
- # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
932
- # If they are not same, then their impls are wrong. Ours are always the correct one.
933
- text_len = encoder_attention_mask.sum().item()
934
- encoder_hidden_states = encoder_hidden_states[:, :text_len]
935
- attention_mask = None, None, None, None
936
- else:
937
- img_seq_len = hidden_states.shape[1]
938
- txt_seq_len = encoder_hidden_states.shape[1]
939
-
940
- cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
941
- cu_seqlens_kv = cu_seqlens_q
942
- max_seqlen_q = img_seq_len + txt_seq_len
943
- max_seqlen_kv = max_seqlen_q
944
-
945
- attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
946
-
947
- if self.enable_teacache:
948
- modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
949
-
950
- if self.cnt == 0 or self.cnt == self.num_steps-1:
951
- should_calc = True
952
- self.accumulated_rel_l1_distance = 0
953
- else:
954
- curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
955
- self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
956
- should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
957
-
958
- if should_calc:
959
- self.accumulated_rel_l1_distance = 0
960
-
961
- self.previous_modulated_input = modulated_inp
962
- self.cnt += 1
963
-
964
- if self.cnt == self.num_steps:
965
- self.cnt = 0
966
-
967
- if not should_calc:
968
- hidden_states = hidden_states + self.previous_residual
969
- else:
970
- ori_hidden_states = hidden_states.clone()
971
-
972
- for block_id, block in enumerate(self.transformer_blocks):
973
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
974
- block,
975
- hidden_states,
976
- encoder_hidden_states,
977
- temb,
978
- attention_mask,
979
- rope_freqs
980
- )
981
-
982
- for block_id, block in enumerate(self.single_transformer_blocks):
983
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
984
- block,
985
- hidden_states,
986
- encoder_hidden_states,
987
- temb,
988
- attention_mask,
989
- rope_freqs
990
- )
991
-
992
- self.previous_residual = hidden_states - ori_hidden_states
993
- else:
994
- for block_id, block in enumerate(self.transformer_blocks):
995
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
996
- block,
997
- hidden_states,
998
- encoder_hidden_states,
999
- temb,
1000
- attention_mask,
1001
- rope_freqs
1002
- )
1003
-
1004
- for block_id, block in enumerate(self.single_transformer_blocks):
1005
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1006
- block,
1007
- hidden_states,
1008
- encoder_hidden_states,
1009
- temb,
1010
- attention_mask,
1011
- rope_freqs
1012
- )
1013
-
1014
- hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1015
-
1016
- hidden_states = hidden_states[:, -original_context_length:, :]
1017
-
1018
- if self.high_quality_fp32_output_for_inference:
1019
- hidden_states = hidden_states.to(dtype=torch.float32)
1020
- if self.proj_out.weight.dtype != torch.float32:
1021
- self.proj_out.to(dtype=torch.float32)
1022
-
1023
- hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1024
-
1025
- hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1026
- t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1027
- pt=p_t, ph=p, pw=p)
1028
-
1029
- if return_dict:
1030
- return Transformer2DModelOutput(sample=hidden_states)
1031
-
1032
- return hidden_states,
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import einops
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ from diffusers.loaders import FromOriginalModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.loaders import PeftAdapterMixin
11
+ from diffusers.utils import logging
12
+ from diffusers.models.attention import FeedForward
13
+ from diffusers.models.attention_processor import Attention
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers_helper.dit_common import LayerNorm
18
+ from diffusers_helper.utils import zero_module
19
+
20
+
21
+ enabled_backends = []
22
+
23
+ if torch.backends.cuda.flash_sdp_enabled():
24
+ enabled_backends.append("flash")
25
+ if torch.backends.cuda.math_sdp_enabled():
26
+ enabled_backends.append("math")
27
+ if torch.backends.cuda.mem_efficient_sdp_enabled():
28
+ enabled_backends.append("mem_efficient")
29
+ if torch.backends.cuda.cudnn_sdp_enabled():
30
+ enabled_backends.append("cudnn")
31
+
32
+ print("Currently enabled native sdp backends:", enabled_backends)
33
+
34
+ try:
35
+ # raise NotImplementedError
36
+ from xformers.ops import memory_efficient_attention as xformers_attn_func
37
+ print('Xformers is installed!')
38
+ except:
39
+ print('Xformers is not installed!')
40
+ xformers_attn_func = None
41
+
42
+ try:
43
+ # raise NotImplementedError
44
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
45
+ print('Flash Attn is installed!')
46
+ except:
47
+ print('Flash Attn is not installed!')
48
+ flash_attn_varlen_func = None
49
+ flash_attn_func = None
50
+
51
+ try:
52
+ # raise NotImplementedError
53
+ from sageattention import sageattn_varlen, sageattn
54
+ print('Sage Attn is installed!')
55
+ except:
56
+ print('Sage Attn is not installed!')
57
+ sageattn_varlen = None
58
+ sageattn = None
59
+
60
+
61
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
+
63
+
64
+ def pad_for_3d_conv(x, kernel_size):
65
+ b, c, t, h, w = x.shape
66
+ pt, ph, pw = kernel_size
67
+ pad_t = (pt - (t % pt)) % pt
68
+ pad_h = (ph - (h % ph)) % ph
69
+ pad_w = (pw - (w % pw)) % pw
70
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
71
+
72
+
73
+ def center_down_sample_3d(x, kernel_size):
74
+ # pt, ph, pw = kernel_size
75
+ # cp = (pt * ph * pw) // 2
76
+ # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
77
+ # xc = xp[cp]
78
+ # return xc
79
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
80
+
81
+
82
+ def get_cu_seqlens(text_mask, img_len):
83
+ batch_size = text_mask.shape[0]
84
+ text_len = text_mask.sum(dim=1)
85
+ max_len = text_mask.shape[1] + img_len
86
+
87
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
88
+
89
+ for i in range(batch_size):
90
+ s = text_len[i] + img_len
91
+ s1 = i * max_len + s
92
+ s2 = (i + 1) * max_len
93
+ cu_seqlens[2 * i + 1] = s1
94
+ cu_seqlens[2 * i + 2] = s2
95
+
96
+ return cu_seqlens
97
+
98
+
99
+ def apply_rotary_emb_transposed(x, freqs_cis):
100
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
101
+ x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
102
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
103
+ out = x.float() * cos + x_rotated.float() * sin
104
+ out = out.to(x)
105
+ return out
106
+
107
+
108
+ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
109
+ if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
110
+ if sageattn is not None:
111
+ x = sageattn(q, k, v, tensor_layout='NHD')
112
+ return x
113
+
114
+ if flash_attn_func is not None:
115
+ x = flash_attn_func(q, k, v)
116
+ return x
117
+
118
+ if xformers_attn_func is not None:
119
+ x = xformers_attn_func(q, k, v)
120
+ return x
121
+
122
+ x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
+ return x
124
+
125
+ batch_size = q.shape[0]
126
+ q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
127
+ k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
128
+ v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
129
+ if sageattn_varlen is not None:
130
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
131
+ elif flash_attn_varlen_func is not None:
132
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
+ else:
134
+ raise NotImplementedError('No Attn Installed!')
135
+ x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
136
+ return x
137
+
138
+
139
+ class HunyuanAttnProcessorFlashAttnDouble:
140
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
141
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
142
+
143
+ query = attn.to_q(hidden_states)
144
+ key = attn.to_k(hidden_states)
145
+ value = attn.to_v(hidden_states)
146
+
147
+ query = query.unflatten(2, (attn.heads, -1))
148
+ key = key.unflatten(2, (attn.heads, -1))
149
+ value = value.unflatten(2, (attn.heads, -1))
150
+
151
+ query = attn.norm_q(query)
152
+ key = attn.norm_k(key)
153
+
154
+ query = apply_rotary_emb_transposed(query, image_rotary_emb)
155
+ key = apply_rotary_emb_transposed(key, image_rotary_emb)
156
+
157
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
158
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
159
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
160
+
161
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
162
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
163
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
164
+
165
+ encoder_query = attn.norm_added_q(encoder_query)
166
+ encoder_key = attn.norm_added_k(encoder_key)
167
+
168
+ query = torch.cat([query, encoder_query], dim=1)
169
+ key = torch.cat([key, encoder_key], dim=1)
170
+ value = torch.cat([value, encoder_value], dim=1)
171
+
172
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
173
+ hidden_states = hidden_states.flatten(-2)
174
+
175
+ txt_length = encoder_hidden_states.shape[1]
176
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
177
+
178
+ hidden_states = attn.to_out[0](hidden_states)
179
+ hidden_states = attn.to_out[1](hidden_states)
180
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
181
+
182
+ return hidden_states, encoder_hidden_states
183
+
184
+
185
+ class HunyuanAttnProcessorFlashAttnSingle:
186
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
187
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
188
+
189
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
190
+
191
+ query = attn.to_q(hidden_states)
192
+ key = attn.to_k(hidden_states)
193
+ value = attn.to_v(hidden_states)
194
+
195
+ query = query.unflatten(2, (attn.heads, -1))
196
+ key = key.unflatten(2, (attn.heads, -1))
197
+ value = value.unflatten(2, (attn.heads, -1))
198
+
199
+ query = attn.norm_q(query)
200
+ key = attn.norm_k(key)
201
+
202
+ txt_length = encoder_hidden_states.shape[1]
203
+
204
+ query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
205
+ key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
206
+
207
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
208
+ hidden_states = hidden_states.flatten(-2)
209
+
210
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
211
+
212
+ return hidden_states, encoder_hidden_states
213
+
214
+
215
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
216
+ def __init__(self, embedding_dim, pooled_projection_dim):
217
+ super().__init__()
218
+
219
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
220
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
221
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
222
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
223
+
224
+ def forward(self, timestep, guidance, pooled_projection):
225
+ timesteps_proj = self.time_proj(timestep)
226
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
227
+
228
+ guidance_proj = self.time_proj(guidance)
229
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
230
+
231
+ time_guidance_emb = timesteps_emb + guidance_emb
232
+
233
+ pooled_projections = self.text_embedder(pooled_projection)
234
+ conditioning = time_guidance_emb + pooled_projections
235
+
236
+ return conditioning
237
+
238
+
239
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
240
+ def __init__(self, embedding_dim, pooled_projection_dim):
241
+ super().__init__()
242
+
243
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
244
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
245
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
246
+
247
+ def forward(self, timestep, pooled_projection):
248
+ timesteps_proj = self.time_proj(timestep)
249
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
250
+
251
+ pooled_projections = self.text_embedder(pooled_projection)
252
+
253
+ conditioning = timesteps_emb + pooled_projections
254
+
255
+ return conditioning
256
+
257
+
258
+ class HunyuanVideoAdaNorm(nn.Module):
259
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
260
+ super().__init__()
261
+
262
+ out_features = out_features or 2 * in_features
263
+ self.linear = nn.Linear(in_features, out_features)
264
+ self.nonlinearity = nn.SiLU()
265
+
266
+ def forward(
267
+ self, temb: torch.Tensor
268
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
269
+ temb = self.linear(self.nonlinearity(temb))
270
+ gate_msa, gate_mlp = temb.chunk(2, dim=-1)
271
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
272
+ return gate_msa, gate_mlp
273
+
274
+
275
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
276
+ def __init__(
277
+ self,
278
+ num_attention_heads: int,
279
+ attention_head_dim: int,
280
+ mlp_width_ratio: str = 4.0,
281
+ mlp_drop_rate: float = 0.0,
282
+ attention_bias: bool = True,
283
+ ) -> None:
284
+ super().__init__()
285
+
286
+ hidden_size = num_attention_heads * attention_head_dim
287
+
288
+ self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
289
+ self.attn = Attention(
290
+ query_dim=hidden_size,
291
+ cross_attention_dim=None,
292
+ heads=num_attention_heads,
293
+ dim_head=attention_head_dim,
294
+ bias=attention_bias,
295
+ )
296
+
297
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
298
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
299
+
300
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ temb: torch.Tensor,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ ) -> torch.Tensor:
308
+ norm_hidden_states = self.norm1(hidden_states)
309
+
310
+ attn_output = self.attn(
311
+ hidden_states=norm_hidden_states,
312
+ encoder_hidden_states=None,
313
+ attention_mask=attention_mask,
314
+ )
315
+
316
+ gate_msa, gate_mlp = self.norm_out(temb)
317
+ hidden_states = hidden_states + attn_output * gate_msa
318
+
319
+ ff_output = self.ff(self.norm2(hidden_states))
320
+ hidden_states = hidden_states + ff_output * gate_mlp
321
+
322
+ return hidden_states
323
+
324
+
325
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
326
+ def __init__(
327
+ self,
328
+ num_attention_heads: int,
329
+ attention_head_dim: int,
330
+ num_layers: int,
331
+ mlp_width_ratio: float = 4.0,
332
+ mlp_drop_rate: float = 0.0,
333
+ attention_bias: bool = True,
334
+ ) -> None:
335
+ super().__init__()
336
+
337
+ self.refiner_blocks = nn.ModuleList(
338
+ [
339
+ HunyuanVideoIndividualTokenRefinerBlock(
340
+ num_attention_heads=num_attention_heads,
341
+ attention_head_dim=attention_head_dim,
342
+ mlp_width_ratio=mlp_width_ratio,
343
+ mlp_drop_rate=mlp_drop_rate,
344
+ attention_bias=attention_bias,
345
+ )
346
+ for _ in range(num_layers)
347
+ ]
348
+ )
349
+
350
+ def forward(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ temb: torch.Tensor,
354
+ attention_mask: Optional[torch.Tensor] = None,
355
+ ) -> None:
356
+ self_attn_mask = None
357
+ if attention_mask is not None:
358
+ batch_size = attention_mask.shape[0]
359
+ seq_len = attention_mask.shape[1]
360
+ attention_mask = attention_mask.to(hidden_states.device).bool()
361
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).expand(-1, -1, seq_len, -1)
362
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
363
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
364
+ self_attn_mask[:, :, :, 0] = True
365
+
366
+ for block in self.refiner_blocks:
367
+ hidden_states = block(hidden_states, temb, self_attn_mask)
368
+
369
+ return hidden_states
370
+
371
+
372
+ class HunyuanVideoTokenRefiner(nn.Module):
373
+ def __init__(
374
+ self,
375
+ in_channels: int,
376
+ num_attention_heads: int,
377
+ attention_head_dim: int,
378
+ num_layers: int,
379
+ mlp_ratio: float = 4.0,
380
+ mlp_drop_rate: float = 0.0,
381
+ attention_bias: bool = True,
382
+ ) -> None:
383
+ super().__init__()
384
+
385
+ hidden_size = num_attention_heads * attention_head_dim
386
+
387
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
388
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
389
+ )
390
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
391
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
392
+ num_attention_heads=num_attention_heads,
393
+ attention_head_dim=attention_head_dim,
394
+ num_layers=num_layers,
395
+ mlp_width_ratio=mlp_ratio,
396
+ mlp_drop_rate=mlp_drop_rate,
397
+ attention_bias=attention_bias,
398
+ )
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states: torch.Tensor,
403
+ timestep: torch.LongTensor,
404
+ attention_mask: Optional[torch.LongTensor] = None,
405
+ ) -> torch.Tensor:
406
+ if attention_mask is None:
407
+ pooled_projections = hidden_states.mean(dim=1)
408
+ else:
409
+ original_dtype = hidden_states.dtype
410
+ mask_float = attention_mask.float().unsqueeze(-1)
411
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
412
+ pooled_projections = pooled_projections.to(original_dtype)
413
+
414
+ temb = self.time_text_embed(timestep, pooled_projections)
415
+ hidden_states = self.proj_in(hidden_states)
416
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
417
+
418
+ return hidden_states
419
+
420
+
421
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
422
+ def __init__(self, rope_dim, theta):
423
+ super().__init__()
424
+ self.DT, self.DY, self.DX = rope_dim
425
+ self.theta = theta
426
+
427
+ @torch.no_grad()
428
+ def get_frequency(self, dim, pos):
429
+ T, H, W = pos.shape
430
+ freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
431
+ freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
432
+ return freqs.cos(), freqs.sin()
433
+
434
+ @torch.no_grad()
435
+ def forward_inner(self, frame_indices, height, width, device):
436
+ GT, GY, GX = torch.meshgrid(
437
+ frame_indices.to(device=device, dtype=torch.float32),
438
+ torch.arange(0, height, device=device, dtype=torch.float32),
439
+ torch.arange(0, width, device=device, dtype=torch.float32),
440
+ indexing="ij"
441
+ )
442
+
443
+ FCT, FST = self.get_frequency(self.DT, GT)
444
+ FCY, FSY = self.get_frequency(self.DY, GY)
445
+ FCX, FSX = self.get_frequency(self.DX, GX)
446
+
447
+ result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
448
+
449
+ return result.to(device)
450
+
451
+ @torch.no_grad()
452
+ def forward(self, frame_indices, height, width, device):
453
+ frame_indices = frame_indices.unbind(0)
454
+ results = [self.forward_inner(f, height, width, device) for f in frame_indices]
455
+ results = torch.stack(results, dim=0)
456
+ return results
457
+
458
+
459
+ class AdaLayerNormZero(nn.Module):
460
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
461
+ super().__init__()
462
+ self.silu = nn.SiLU()
463
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
464
+ if norm_type == "layer_norm":
465
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
466
+ else:
467
+ raise ValueError(f"unknown norm_type {norm_type}")
468
+
469
+ def forward(
470
+ self,
471
+ x: torch.Tensor,
472
+ emb: Optional[torch.Tensor] = None,
473
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
474
+ emb = emb.unsqueeze(-2)
475
+ emb = self.linear(self.silu(emb))
476
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
477
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
478
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
479
+
480
+
481
+ class AdaLayerNormZeroSingle(nn.Module):
482
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
483
+ super().__init__()
484
+
485
+ self.silu = nn.SiLU()
486
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
487
+ if norm_type == "layer_norm":
488
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
489
+ else:
490
+ raise ValueError(f"unknown norm_type {norm_type}")
491
+
492
+ def forward(
493
+ self,
494
+ x: torch.Tensor,
495
+ emb: Optional[torch.Tensor] = None,
496
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
497
+ emb = emb.unsqueeze(-2)
498
+ emb = self.linear(self.silu(emb))
499
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
500
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
501
+ return x, gate_msa
502
+
503
+
504
+ class AdaLayerNormContinuous(nn.Module):
505
+ def __init__(
506
+ self,
507
+ embedding_dim: int,
508
+ conditioning_embedding_dim: int,
509
+ elementwise_affine=True,
510
+ eps=1e-5,
511
+ bias=True,
512
+ norm_type="layer_norm",
513
+ ):
514
+ super().__init__()
515
+ self.silu = nn.SiLU()
516
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
517
+ if norm_type == "layer_norm":
518
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
519
+ else:
520
+ raise ValueError(f"unknown norm_type {norm_type}")
521
+
522
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
523
+ emb = emb.unsqueeze(-2)
524
+ emb = self.linear(self.silu(emb))
525
+ scale, shift = emb.chunk(2, dim=-1)
526
+ x = self.norm(x) * (1 + scale) + shift
527
+ return x
528
+
529
+
530
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
531
+ def __init__(
532
+ self,
533
+ num_attention_heads: int,
534
+ attention_head_dim: int,
535
+ mlp_ratio: float = 4.0,
536
+ qk_norm: str = "rms_norm",
537
+ ) -> None:
538
+ super().__init__()
539
+
540
+ hidden_size = num_attention_heads * attention_head_dim
541
+ mlp_dim = int(hidden_size * mlp_ratio)
542
+
543
+ self.attn = Attention(
544
+ query_dim=hidden_size,
545
+ cross_attention_dim=None,
546
+ dim_head=attention_head_dim,
547
+ heads=num_attention_heads,
548
+ out_dim=hidden_size,
549
+ bias=True,
550
+ processor=HunyuanAttnProcessorFlashAttnSingle(),
551
+ qk_norm=qk_norm,
552
+ eps=1e-6,
553
+ pre_only=True,
554
+ )
555
+
556
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
557
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
558
+ self.act_mlp = nn.GELU(approximate="tanh")
559
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
560
+
561
+ def forward(
562
+ self,
563
+ hidden_states: torch.Tensor,
564
+ encoder_hidden_states: torch.Tensor,
565
+ temb: torch.Tensor,
566
+ attention_mask: Optional[torch.Tensor] = None,
567
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
568
+ ) -> torch.Tensor:
569
+ text_seq_length = encoder_hidden_states.shape[1]
570
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
571
+
572
+ residual = hidden_states
573
+
574
+ # 1. Input normalization
575
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
576
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
577
+
578
+ norm_hidden_states, norm_encoder_hidden_states = (
579
+ norm_hidden_states[:, :-text_seq_length, :],
580
+ norm_hidden_states[:, -text_seq_length:, :],
581
+ )
582
+
583
+ # 2. Attention
584
+ attn_output, context_attn_output = self.attn(
585
+ hidden_states=norm_hidden_states,
586
+ encoder_hidden_states=norm_encoder_hidden_states,
587
+ attention_mask=attention_mask,
588
+ image_rotary_emb=image_rotary_emb,
589
+ )
590
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
591
+
592
+ # 3. Modulation and residual connection
593
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
594
+ hidden_states = gate * self.proj_out(hidden_states)
595
+ hidden_states = hidden_states + residual
596
+
597
+ hidden_states, encoder_hidden_states = (
598
+ hidden_states[:, :-text_seq_length, :],
599
+ hidden_states[:, -text_seq_length:, :],
600
+ )
601
+ return hidden_states, encoder_hidden_states
602
+
603
+
604
+ class HunyuanVideoTransformerBlock(nn.Module):
605
+ def __init__(
606
+ self,
607
+ num_attention_heads: int,
608
+ attention_head_dim: int,
609
+ mlp_ratio: float,
610
+ qk_norm: str = "rms_norm",
611
+ ) -> None:
612
+ super().__init__()
613
+
614
+ hidden_size = num_attention_heads * attention_head_dim
615
+
616
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
617
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
618
+
619
+ self.attn = Attention(
620
+ query_dim=hidden_size,
621
+ cross_attention_dim=None,
622
+ added_kv_proj_dim=hidden_size,
623
+ dim_head=attention_head_dim,
624
+ heads=num_attention_heads,
625
+ out_dim=hidden_size,
626
+ context_pre_only=False,
627
+ bias=True,
628
+ processor=HunyuanAttnProcessorFlashAttnDouble(),
629
+ qk_norm=qk_norm,
630
+ eps=1e-6,
631
+ )
632
+
633
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
634
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
635
+
636
+ self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
637
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
638
+
639
+ def forward(
640
+ self,
641
+ hidden_states: torch.Tensor,
642
+ encoder_hidden_states: torch.Tensor,
643
+ temb: torch.Tensor,
644
+ attention_mask: Optional[torch.Tensor] = None,
645
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
646
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
647
+ # 1. Input normalization
648
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
649
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
650
+
651
+ # 2. Joint attention
652
+ attn_output, context_attn_output = self.attn(
653
+ hidden_states=norm_hidden_states,
654
+ encoder_hidden_states=norm_encoder_hidden_states,
655
+ attention_mask=attention_mask,
656
+ image_rotary_emb=freqs_cis,
657
+ )
658
+
659
+ # 3. Modulation and residual connection
660
+ hidden_states = hidden_states + attn_output * gate_msa
661
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
662
+
663
+ norm_hidden_states = self.norm2(hidden_states)
664
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
665
+
666
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
667
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
668
+
669
+ # 4. Feed-forward
670
+ ff_output = self.ff(norm_hidden_states)
671
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
672
+
673
+ hidden_states = hidden_states + gate_mlp * ff_output
674
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
675
+
676
+ return hidden_states, encoder_hidden_states
677
+
678
+
679
+ class ClipVisionProjection(nn.Module):
680
+ def __init__(self, in_channels, out_channels):
681
+ super().__init__()
682
+ self.up = nn.Linear(in_channels, out_channels * 3)
683
+ self.down = nn.Linear(out_channels * 3, out_channels)
684
+
685
+ def forward(self, x):
686
+ projected_x = self.down(nn.functional.silu(self.up(x)))
687
+ return projected_x
688
+
689
+
690
+ class HunyuanVideoPatchEmbed(nn.Module):
691
+ def __init__(self, patch_size, in_chans, embed_dim):
692
+ super().__init__()
693
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
694
+
695
+
696
+ class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
697
+ def __init__(self, inner_dim):
698
+ super().__init__()
699
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
700
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
701
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
702
+
703
+ @torch.no_grad()
704
+ def initialize_weight_from_another_conv3d(self, another_layer):
705
+ weight = another_layer.weight.detach().clone()
706
+ bias = another_layer.bias.detach().clone()
707
+
708
+ sd = {
709
+ 'proj.weight': weight.clone(),
710
+ 'proj.bias': bias.clone(),
711
+ 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
712
+ 'proj_2x.bias': bias.clone(),
713
+ 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
714
+ 'proj_4x.bias': bias.clone(),
715
+ }
716
+
717
+ sd = {k: v.clone() for k, v in sd.items()}
718
+
719
+ self.load_state_dict(sd)
720
+ return
721
+
722
+
723
+ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
724
+ @register_to_config
725
+ def __init__(
726
+ self,
727
+ in_channels: int = 16,
728
+ out_channels: int = 16,
729
+ num_attention_heads: int = 24,
730
+ attention_head_dim: int = 128,
731
+ num_layers: int = 20,
732
+ num_single_layers: int = 40,
733
+ num_refiner_layers: int = 2,
734
+ mlp_ratio: float = 4.0,
735
+ patch_size: int = 2,
736
+ patch_size_t: int = 1,
737
+ qk_norm: str = "rms_norm",
738
+ guidance_embeds: bool = True,
739
+ text_embed_dim: int = 4096,
740
+ pooled_projection_dim: int = 768,
741
+ rope_theta: float = 256.0,
742
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
743
+ has_image_proj=False,
744
+ image_proj_dim=1152,
745
+ has_clean_x_embedder=False,
746
+ ) -> None:
747
+ super().__init__()
748
+
749
+ inner_dim = num_attention_heads * attention_head_dim
750
+ out_channels = out_channels or in_channels
751
+
752
+ # 1. Latent and condition embedders
753
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
754
+ self.context_embedder = HunyuanVideoTokenRefiner(
755
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
756
+ )
757
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
758
+
759
+ self.clean_x_embedder = None
760
+ self.image_projection = None
761
+
762
+ # 2. RoPE
763
+ self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
764
+
765
+ # 3. Dual stream transformer blocks
766
+ self.transformer_blocks = nn.ModuleList(
767
+ [
768
+ HunyuanVideoTransformerBlock(
769
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
770
+ )
771
+ for _ in range(num_layers)
772
+ ]
773
+ )
774
+
775
+ # 4. Single stream transformer blocks
776
+ self.single_transformer_blocks = nn.ModuleList(
777
+ [
778
+ HunyuanVideoSingleTransformerBlock(
779
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
780
+ )
781
+ for _ in range(num_single_layers)
782
+ ]
783
+ )
784
+
785
+ # 5. Output projection
786
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
787
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
788
+
789
+ self.inner_dim = inner_dim
790
+ self.use_gradient_checkpointing = False
791
+ self.enable_teacache = False
792
+
793
+ if has_image_proj:
794
+ self.install_image_projection(image_proj_dim)
795
+
796
+ if has_clean_x_embedder:
797
+ self.install_clean_x_embedder()
798
+
799
+ self.high_quality_fp32_output_for_inference = False
800
+
801
+ def install_image_projection(self, in_channels):
802
+ self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
803
+ self.config['has_image_proj'] = True
804
+ self.config['image_proj_dim'] = in_channels
805
+
806
+ def install_clean_x_embedder(self):
807
+ self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
808
+ self.config['has_clean_x_embedder'] = True
809
+
810
+ def enable_gradient_checkpointing(self):
811
+ self.use_gradient_checkpointing = True
812
+ print('self.use_gradient_checkpointing = True')
813
+
814
+ def disable_gradient_checkpointing(self):
815
+ self.use_gradient_checkpointing = False
816
+ print('self.use_gradient_checkpointing = False')
817
+
818
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
819
+ self.enable_teacache = enable_teacache
820
+ self.cnt = 0
821
+ self.num_steps = num_steps
822
+ self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
823
+ self.accumulated_rel_l1_distance = 0
824
+ self.previous_modulated_input = None
825
+ self.previous_residual = None
826
+ self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
827
+
828
+ def gradient_checkpointing_method(self, block, *args):
829
+ if self.use_gradient_checkpointing:
830
+ result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
831
+ else:
832
+ result = block(*args)
833
+ return result
834
+
835
+ def process_input_hidden_states(
836
+ self,
837
+ latents, latent_indices=None,
838
+ clean_latents=None, clean_latent_indices=None,
839
+ clean_latents_2x=None, clean_latent_2x_indices=None,
840
+ clean_latents_4x=None, clean_latent_4x_indices=None
841
+ ):
842
+ hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
843
+ B, C, T, H, W = hidden_states.shape
844
+
845
+ if latent_indices is None:
846
+ latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
847
+
848
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
849
+
850
+ rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
851
+ rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
852
+
853
+ if clean_latents is not None and clean_latent_indices is not None:
854
+ clean_latents = clean_latents.to(hidden_states)
855
+ clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
856
+ clean_latents = clean_latents.flatten(2).transpose(1, 2)
857
+
858
+ clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
859
+ clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
860
+
861
+ hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
862
+ rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
863
+
864
+ if clean_latents_2x is not None and clean_latent_2x_indices is not None:
865
+ clean_latents_2x = clean_latents_2x.to(hidden_states)
866
+ clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
867
+ clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
868
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
869
+
870
+ clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
871
+ clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
872
+ clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
873
+ clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
874
+
875
+ hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
876
+ rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
877
+
878
+ if clean_latents_4x is not None and clean_latent_4x_indices is not None:
879
+ clean_latents_4x = clean_latents_4x.to(hidden_states)
880
+ clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
881
+ clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
882
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
883
+
884
+ clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
885
+ clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
886
+ clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
887
+ clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
888
+
889
+ hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
890
+ rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
891
+
892
+ return hidden_states, rope_freqs
893
+
894
+ def forward(
895
+ self,
896
+ hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
897
+ latent_indices=None,
898
+ clean_latents=None, clean_latent_indices=None,
899
+ clean_latents_2x=None, clean_latent_2x_indices=None,
900
+ clean_latents_4x=None, clean_latent_4x_indices=None,
901
+ image_embeddings=None,
902
+ attention_kwargs=None, return_dict=True
903
+ ):
904
+
905
+ if attention_kwargs is None:
906
+ attention_kwargs = {}
907
+
908
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
909
+ p, p_t = self.config['patch_size'], self.config['patch_size_t']
910
+ post_patch_num_frames = num_frames // p_t
911
+ post_patch_height = height // p
912
+ post_patch_width = width // p
913
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
914
+
915
+ hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
916
+
917
+ temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
918
+ encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
919
+
920
+ if self.image_projection is not None:
921
+ assert image_embeddings is not None, 'You must use image embeddings!'
922
+ extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
923
+ extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
924
+
925
+ # must cat before (not after) encoder_hidden_states, due to attn masking
926
+ encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
927
+ encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
928
+
929
+ with torch.no_grad():
930
+ if batch_size == 1:
931
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
932
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
933
+ text_len = encoder_attention_mask.sum().item()
934
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
935
+ attention_mask = None, None, None, None
936
+ else:
937
+ img_seq_len = hidden_states.shape[1]
938
+ txt_seq_len = encoder_hidden_states.shape[1]
939
+
940
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
941
+ cu_seqlens_kv = cu_seqlens_q
942
+ max_seqlen_q = img_seq_len + txt_seq_len
943
+ max_seqlen_kv = max_seqlen_q
944
+
945
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
946
+
947
+ if self.enable_teacache:
948
+ modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
949
+
950
+ if self.cnt == 0 or self.cnt == self.num_steps-1:
951
+ should_calc = True
952
+ self.accumulated_rel_l1_distance = 0
953
+ else:
954
+ curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
955
+ self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
956
+ should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
957
+
958
+ if should_calc:
959
+ self.accumulated_rel_l1_distance = 0
960
+
961
+ self.previous_modulated_input = modulated_inp
962
+ self.cnt += 1
963
+
964
+ if self.cnt == self.num_steps:
965
+ self.cnt = 0
966
+
967
+ if not should_calc:
968
+ hidden_states = hidden_states + self.previous_residual
969
+ else:
970
+ ori_hidden_states = hidden_states.clone()
971
+
972
+ for block_id, block in enumerate(self.transformer_blocks):
973
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
974
+ block,
975
+ hidden_states,
976
+ encoder_hidden_states,
977
+ temb,
978
+ attention_mask,
979
+ rope_freqs
980
+ )
981
+
982
+ for block_id, block in enumerate(self.single_transformer_blocks):
983
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
984
+ block,
985
+ hidden_states,
986
+ encoder_hidden_states,
987
+ temb,
988
+ attention_mask,
989
+ rope_freqs
990
+ )
991
+
992
+ self.previous_residual = hidden_states - ori_hidden_states
993
+ else:
994
+ for block_id, block in enumerate(self.transformer_blocks):
995
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
996
+ block,
997
+ hidden_states,
998
+ encoder_hidden_states,
999
+ temb,
1000
+ attention_mask,
1001
+ rope_freqs
1002
+ )
1003
+
1004
+ for block_id, block in enumerate(self.single_transformer_blocks):
1005
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1006
+ block,
1007
+ hidden_states,
1008
+ encoder_hidden_states,
1009
+ temb,
1010
+ attention_mask,
1011
+ rope_freqs
1012
+ )
1013
+
1014
+ hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1015
+
1016
+ hidden_states = hidden_states[:, -original_context_length:, :]
1017
+
1018
+ if self.high_quality_fp32_output_for_inference:
1019
+ hidden_states = hidden_states.to(dtype=torch.float32)
1020
+ if self.proj_out.weight.dtype != torch.float32:
1021
+ self.proj_out.to(dtype=torch.float32)
1022
+
1023
+ hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1024
+
1025
+ hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1026
+ t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1027
+ pt=p_t, ph=p, pw=p)
1028
+
1029
+ if return_dict:
1030
+ return Transformer2DModelOutput(sample=hidden_states)
1031
+
1032
+ return hidden_states,
img_examples/Example1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a906a1d14d1699f67ca54865c7aa5857e55246f4ec63bbaf3edcf359e73bebd1
3
+ size 240647
img_examples/Example1.png ADDED

Git LFS Details

  • SHA256: a057c160bcf3ecfa41d150ec9550423f87efefb9a9793420fea382760daff98b
  • Pointer size: 131 Bytes
  • Size of remote file: 513 kB
img_examples/Example2.webp ADDED

Git LFS Details

  • SHA256: 736480a5f8d043eacad5758f0e80b427aabfa4d98839769615ee61f3fda9f77e
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
img_examples/Example3.jpg ADDED

Git LFS Details

  • SHA256: b1a9be93d2f117d687e08c91c043e67598bdb7c44f5c932f18a3026790fb82fa
  • Pointer size: 131 Bytes
  • Size of remote file: 208 kB
img_examples/Example4.webp ADDED

Git LFS Details

  • SHA256: dd4e7ef35f4cfc8d44ff97f38b68ba7cc248ad5b54c89f8525f5046508f7c4a3
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
img_examples/Example5.png ADDED

Git LFS Details

  • SHA256: b6a7b7521a2ffe77f60a78bb52013c1ef73bfcefbd809f45cfdeef804aee8906
  • Pointer size: 131 Bytes
  • Size of remote file: 431 kB
img_examples/Example6.png ADDED

Git LFS Details

  • SHA256: 59e76d165d9bece1775302a7e4032f31b28545937726d42f41b0c67aae9d4143
  • Pointer size: 131 Bytes
  • Size of remote file: 721 kB
requirements.txt CHANGED
@@ -1,12 +1,12 @@
1
- accelerate==1.6.0
2
  diffusers==0.33.1
3
- transformers==4.46.2
4
  sentencepiece==0.2.0
5
- pillow==11.1.0
6
  av==12.1.0
7
  numpy==1.26.2
8
  scipy==1.12.0
9
- requests==2.31.0
10
  torchsde==0.2.6
11
  torch>=2.0.0
12
  torchvision
@@ -15,4 +15,9 @@ einops
15
  opencv-contrib-python
16
  safetensors
17
  huggingface_hub
18
- spaces
 
 
 
 
 
 
1
+ accelerate==1.7.0
2
  diffusers==0.33.1
3
+ transformers==4.52.4
4
  sentencepiece==0.2.0
5
+ pillow==11.2.1
6
  av==12.1.0
7
  numpy==1.26.2
8
  scipy==1.12.0
9
+ requests==2.32.4
10
  torchsde==0.2.6
11
  torch>=2.0.0
12
  torchvision
 
15
  opencv-contrib-python
16
  safetensors
17
  huggingface_hub
18
+ decord
19
+ imageio_ffmpeg==0.6.0
20
+ sageattention==1.0.6
21
+ xformers==0.0.29.post3
22
+ bitsandbytes==0.46.0
23
+ pillow-heif==0.22.0