Fabrice-TIERCELIN commited on
Commit
7542c74
·
verified ·
1 Parent(s): 0684df1

self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).expand(1, 1, seq_len, 1)

Browse files
diffusers_helper/models/hunyuan_video_packed.py CHANGED
@@ -1,1035 +1,1036 @@
1
- from typing import Any, Dict, List, Optional, Tuple, Union
2
-
3
- import torch
4
- import einops
5
- import torch.nn as nn
6
- import numpy as np
7
-
8
- from diffusers.loaders import FromOriginalModelMixin
9
- from diffusers.configuration_utils import ConfigMixin, register_to_config
10
- from diffusers.loaders import PeftAdapterMixin
11
- from diffusers.utils import logging
12
- from diffusers.models.attention import FeedForward
13
- from diffusers.models.attention_processor import Attention
14
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
- from diffusers.models.modeling_utils import ModelMixin
17
- from diffusers_helper.dit_common import LayerNorm
18
- from diffusers_helper.utils import zero_module
19
-
20
-
21
- enabled_backends = []
22
-
23
- if torch.backends.cuda.flash_sdp_enabled():
24
- enabled_backends.append("flash")
25
- if torch.backends.cuda.math_sdp_enabled():
26
- enabled_backends.append("math")
27
- if torch.backends.cuda.mem_efficient_sdp_enabled():
28
- enabled_backends.append("mem_efficient")
29
- if torch.backends.cuda.cudnn_sdp_enabled():
30
- enabled_backends.append("cudnn")
31
-
32
- print("Currently enabled native sdp backends:", enabled_backends)
33
-
34
- try:
35
- # raise NotImplementedError
36
- from xformers.ops import memory_efficient_attention as xformers_attn_func
37
- print('Xformers is installed!')
38
- except:
39
- print('Xformers is not installed!')
40
- xformers_attn_func = None
41
-
42
- try:
43
- # raise NotImplementedError
44
- from flash_attn import flash_attn_varlen_func, flash_attn_func
45
- print('Flash Attn is installed!')
46
- except:
47
- print('Flash Attn is not installed!')
48
- flash_attn_varlen_func = None
49
- flash_attn_func = None
50
-
51
- try:
52
- # raise NotImplementedError
53
- from sageattention import sageattn_varlen, sageattn
54
- print('Sage Attn is installed!')
55
- except:
56
- print('Sage Attn is not installed!')
57
- sageattn_varlen = None
58
- sageattn = None
59
-
60
-
61
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
-
63
-
64
- def pad_for_3d_conv(x, kernel_size):
65
- b, c, t, h, w = x.shape
66
- pt, ph, pw = kernel_size
67
- pad_t = (pt - (t % pt)) % pt
68
- pad_h = (ph - (h % ph)) % ph
69
- pad_w = (pw - (w % pw)) % pw
70
- return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
71
-
72
-
73
- def center_down_sample_3d(x, kernel_size):
74
- # pt, ph, pw = kernel_size
75
- # cp = (pt * ph * pw) // 2
76
- # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
77
- # xc = xp[cp]
78
- # return xc
79
- return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
80
-
81
-
82
- def get_cu_seqlens(text_mask, img_len):
83
- batch_size = text_mask.shape[0]
84
- text_len = text_mask.sum(dim=1)
85
- max_len = text_mask.shape[1] + img_len
86
-
87
- cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
88
-
89
- for i in range(batch_size):
90
- s = text_len[i] + img_len
91
- s1 = i * max_len + s
92
- s2 = (i + 1) * max_len
93
- cu_seqlens[2 * i + 1] = s1
94
- cu_seqlens[2 * i + 2] = s2
95
-
96
- return cu_seqlens
97
-
98
-
99
- def apply_rotary_emb_transposed(x, freqs_cis):
100
- cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
101
- x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
102
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
103
- out = x.float() * cos + x_rotated.float() * sin
104
- out = out.to(x)
105
- return out
106
-
107
-
108
- def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
109
- if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
110
- if sageattn is not None:
111
- x = sageattn(q, k, v, tensor_layout='NHD')
112
- return x
113
-
114
- if flash_attn_func is not None:
115
- x = flash_attn_func(q, k, v)
116
- return x
117
-
118
- if xformers_attn_func is not None:
119
- x = xformers_attn_func(q, k, v)
120
- return x
121
-
122
- x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
- return x
124
-
125
- B, L, H, C = q.shape
126
-
127
- q = q.flatten(0, 1)
128
- k = k.flatten(0, 1)
129
- v = v.flatten(0, 1)
130
-
131
- if sageattn_varlen is not None:
132
- x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
- elif flash_attn_varlen_func is not None:
134
- x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
135
- else:
136
- raise NotImplementedError('No Attn Installed!')
137
-
138
- x = x.unflatten(0, (B, L))
139
-
140
- return x
141
-
142
-
143
- class HunyuanAttnProcessorFlashAttnDouble:
144
- def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
145
- cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
146
-
147
- query = attn.to_q(hidden_states)
148
- key = attn.to_k(hidden_states)
149
- value = attn.to_v(hidden_states)
150
-
151
- query = query.unflatten(2, (attn.heads, -1))
152
- key = key.unflatten(2, (attn.heads, -1))
153
- value = value.unflatten(2, (attn.heads, -1))
154
-
155
- query = attn.norm_q(query)
156
- key = attn.norm_k(key)
157
-
158
- query = apply_rotary_emb_transposed(query, image_rotary_emb)
159
- key = apply_rotary_emb_transposed(key, image_rotary_emb)
160
-
161
- encoder_query = attn.add_q_proj(encoder_hidden_states)
162
- encoder_key = attn.add_k_proj(encoder_hidden_states)
163
- encoder_value = attn.add_v_proj(encoder_hidden_states)
164
-
165
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
166
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
167
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
168
-
169
- encoder_query = attn.norm_added_q(encoder_query)
170
- encoder_key = attn.norm_added_k(encoder_key)
171
-
172
- query = torch.cat([query, encoder_query], dim=1)
173
- key = torch.cat([key, encoder_key], dim=1)
174
- value = torch.cat([value, encoder_value], dim=1)
175
-
176
- hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
177
- hidden_states = hidden_states.flatten(-2)
178
-
179
- txt_length = encoder_hidden_states.shape[1]
180
- hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
181
-
182
- hidden_states = attn.to_out[0](hidden_states)
183
- hidden_states = attn.to_out[1](hidden_states)
184
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
185
-
186
- return hidden_states, encoder_hidden_states
187
-
188
-
189
- class HunyuanAttnProcessorFlashAttnSingle:
190
- def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
191
- cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
192
-
193
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
194
-
195
- query = attn.to_q(hidden_states)
196
- key = attn.to_k(hidden_states)
197
- value = attn.to_v(hidden_states)
198
-
199
- query = query.unflatten(2, (attn.heads, -1))
200
- key = key.unflatten(2, (attn.heads, -1))
201
- value = value.unflatten(2, (attn.heads, -1))
202
-
203
- query = attn.norm_q(query)
204
- key = attn.norm_k(key)
205
-
206
- txt_length = encoder_hidden_states.shape[1]
207
-
208
- query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
209
- key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
210
-
211
- hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
212
- hidden_states = hidden_states.flatten(-2)
213
-
214
- hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
215
-
216
- return hidden_states, encoder_hidden_states
217
-
218
-
219
- class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
220
- def __init__(self, embedding_dim, pooled_projection_dim):
221
- super().__init__()
222
-
223
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
224
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
225
- self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
226
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
227
-
228
- def forward(self, timestep, guidance, pooled_projection):
229
- timesteps_proj = self.time_proj(timestep)
230
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
231
-
232
- guidance_proj = self.time_proj(guidance)
233
- guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
234
-
235
- time_guidance_emb = timesteps_emb + guidance_emb
236
-
237
- pooled_projections = self.text_embedder(pooled_projection)
238
- conditioning = time_guidance_emb + pooled_projections
239
-
240
- return conditioning
241
-
242
-
243
- class CombinedTimestepTextProjEmbeddings(nn.Module):
244
- def __init__(self, embedding_dim, pooled_projection_dim):
245
- super().__init__()
246
-
247
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
248
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
249
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
250
-
251
- def forward(self, timestep, pooled_projection):
252
- timesteps_proj = self.time_proj(timestep)
253
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
254
-
255
- pooled_projections = self.text_embedder(pooled_projection)
256
-
257
- conditioning = timesteps_emb + pooled_projections
258
-
259
- return conditioning
260
-
261
-
262
- class HunyuanVideoAdaNorm(nn.Module):
263
- def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
264
- super().__init__()
265
-
266
- out_features = out_features or 2 * in_features
267
- self.linear = nn.Linear(in_features, out_features)
268
- self.nonlinearity = nn.SiLU()
269
-
270
- def forward(
271
- self, temb: torch.Tensor
272
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
273
- temb = self.linear(self.nonlinearity(temb))
274
- gate_msa, gate_mlp = temb.chunk(2, dim=-1)
275
- gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
276
- return gate_msa, gate_mlp
277
-
278
-
279
- class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
280
- def __init__(
281
- self,
282
- num_attention_heads: int,
283
- attention_head_dim: int,
284
- mlp_width_ratio: str = 4.0,
285
- mlp_drop_rate: float = 0.0,
286
- attention_bias: bool = True,
287
- ) -> None:
288
- super().__init__()
289
-
290
- hidden_size = num_attention_heads * attention_head_dim
291
-
292
- self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
293
- self.attn = Attention(
294
- query_dim=hidden_size,
295
- cross_attention_dim=None,
296
- heads=num_attention_heads,
297
- dim_head=attention_head_dim,
298
- bias=attention_bias,
299
- )
300
-
301
- self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
302
- self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
303
-
304
- self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
305
-
306
- def forward(
307
- self,
308
- hidden_states: torch.Tensor,
309
- temb: torch.Tensor,
310
- attention_mask: Optional[torch.Tensor] = None,
311
- ) -> torch.Tensor:
312
- norm_hidden_states = self.norm1(hidden_states)
313
-
314
- attn_output = self.attn(
315
- hidden_states=norm_hidden_states,
316
- encoder_hidden_states=None,
317
- attention_mask=attention_mask,
318
- )
319
-
320
- gate_msa, gate_mlp = self.norm_out(temb)
321
- hidden_states = hidden_states + attn_output * gate_msa
322
-
323
- ff_output = self.ff(self.norm2(hidden_states))
324
- hidden_states = hidden_states + ff_output * gate_mlp
325
-
326
- return hidden_states
327
-
328
-
329
- class HunyuanVideoIndividualTokenRefiner(nn.Module):
330
- def __init__(
331
- self,
332
- num_attention_heads: int,
333
- attention_head_dim: int,
334
- num_layers: int,
335
- mlp_width_ratio: float = 4.0,
336
- mlp_drop_rate: float = 0.0,
337
- attention_bias: bool = True,
338
- ) -> None:
339
- super().__init__()
340
-
341
- self.refiner_blocks = nn.ModuleList(
342
- [
343
- HunyuanVideoIndividualTokenRefinerBlock(
344
- num_attention_heads=num_attention_heads,
345
- attention_head_dim=attention_head_dim,
346
- mlp_width_ratio=mlp_width_ratio,
347
- mlp_drop_rate=mlp_drop_rate,
348
- attention_bias=attention_bias,
349
- )
350
- for _ in range(num_layers)
351
- ]
352
- )
353
-
354
- def forward(
355
- self,
356
- hidden_states: torch.Tensor,
357
- temb: torch.Tensor,
358
- attention_mask: Optional[torch.Tensor] = None,
359
- ) -> None:
360
- self_attn_mask = None
361
- if attention_mask is not None:
362
- batch_size = attention_mask.shape[0]
363
- seq_len = attention_mask.shape[1]
364
- attention_mask = attention_mask.to(hidden_states.device).bool()
365
- self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
366
- self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
367
- self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
368
- self_attn_mask[:, :, :, 0] = True
369
-
370
- for block in self.refiner_blocks:
371
- hidden_states = block(hidden_states, temb, self_attn_mask)
372
-
373
- return hidden_states
374
-
375
-
376
- class HunyuanVideoTokenRefiner(nn.Module):
377
- def __init__(
378
- self,
379
- in_channels: int,
380
- num_attention_heads: int,
381
- attention_head_dim: int,
382
- num_layers: int,
383
- mlp_ratio: float = 4.0,
384
- mlp_drop_rate: float = 0.0,
385
- attention_bias: bool = True,
386
- ) -> None:
387
- super().__init__()
388
-
389
- hidden_size = num_attention_heads * attention_head_dim
390
-
391
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
392
- embedding_dim=hidden_size, pooled_projection_dim=in_channels
393
- )
394
- self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
395
- self.token_refiner = HunyuanVideoIndividualTokenRefiner(
396
- num_attention_heads=num_attention_heads,
397
- attention_head_dim=attention_head_dim,
398
- num_layers=num_layers,
399
- mlp_width_ratio=mlp_ratio,
400
- mlp_drop_rate=mlp_drop_rate,
401
- attention_bias=attention_bias,
402
- )
403
-
404
- def forward(
405
- self,
406
- hidden_states: torch.Tensor,
407
- timestep: torch.LongTensor,
408
- attention_mask: Optional[torch.LongTensor] = None,
409
- ) -> torch.Tensor:
410
- if attention_mask is None:
411
- pooled_projections = hidden_states.mean(dim=1)
412
- else:
413
- original_dtype = hidden_states.dtype
414
- mask_float = attention_mask.float().unsqueeze(-1)
415
- pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
416
- pooled_projections = pooled_projections.to(original_dtype)
417
-
418
- temb = self.time_text_embed(timestep, pooled_projections)
419
- hidden_states = self.proj_in(hidden_states)
420
- hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
421
-
422
- return hidden_states
423
-
424
-
425
- class HunyuanVideoRotaryPosEmbed(nn.Module):
426
- def __init__(self, rope_dim, theta):
427
- super().__init__()
428
- self.DT, self.DY, self.DX = rope_dim
429
- self.theta = theta
430
-
431
- @torch.no_grad()
432
- def get_frequency(self, dim, pos):
433
- T, H, W = pos.shape
434
- freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
435
- freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
436
- return freqs.cos(), freqs.sin()
437
-
438
- @torch.no_grad()
439
- def forward_inner(self, frame_indices, height, width, device):
440
- GT, GY, GX = torch.meshgrid(
441
- frame_indices.to(device=device, dtype=torch.float32),
442
- torch.arange(0, height, device=device, dtype=torch.float32),
443
- torch.arange(0, width, device=device, dtype=torch.float32),
444
- indexing="ij"
445
- )
446
-
447
- FCT, FST = self.get_frequency(self.DT, GT)
448
- FCY, FSY = self.get_frequency(self.DY, GY)
449
- FCX, FSX = self.get_frequency(self.DX, GX)
450
-
451
- result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
452
-
453
- return result.to(device)
454
-
455
- @torch.no_grad()
456
- def forward(self, frame_indices, height, width, device):
457
- frame_indices = frame_indices.unbind(0)
458
- results = [self.forward_inner(f, height, width, device) for f in frame_indices]
459
- results = torch.stack(results, dim=0)
460
- return results
461
-
462
-
463
- class AdaLayerNormZero(nn.Module):
464
- def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
465
- super().__init__()
466
- self.silu = nn.SiLU()
467
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
468
- if norm_type == "layer_norm":
469
- self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
470
- else:
471
- raise ValueError(f"unknown norm_type {norm_type}")
472
-
473
- def forward(
474
- self,
475
- x: torch.Tensor,
476
- emb: Optional[torch.Tensor] = None,
477
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
478
- emb = emb.unsqueeze(-2)
479
- emb = self.linear(self.silu(emb))
480
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
481
- x = self.norm(x) * (1 + scale_msa) + shift_msa
482
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
483
-
484
-
485
- class AdaLayerNormZeroSingle(nn.Module):
486
- def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
487
- super().__init__()
488
-
489
- self.silu = nn.SiLU()
490
- self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
491
- if norm_type == "layer_norm":
492
- self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
493
- else:
494
- raise ValueError(f"unknown norm_type {norm_type}")
495
-
496
- def forward(
497
- self,
498
- x: torch.Tensor,
499
- emb: Optional[torch.Tensor] = None,
500
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
501
- emb = emb.unsqueeze(-2)
502
- emb = self.linear(self.silu(emb))
503
- shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
504
- x = self.norm(x) * (1 + scale_msa) + shift_msa
505
- return x, gate_msa
506
-
507
-
508
- class AdaLayerNormContinuous(nn.Module):
509
- def __init__(
510
- self,
511
- embedding_dim: int,
512
- conditioning_embedding_dim: int,
513
- elementwise_affine=True,
514
- eps=1e-5,
515
- bias=True,
516
- norm_type="layer_norm",
517
- ):
518
- super().__init__()
519
- self.silu = nn.SiLU()
520
- self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
521
- if norm_type == "layer_norm":
522
- self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
523
- else:
524
- raise ValueError(f"unknown norm_type {norm_type}")
525
-
526
- def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
527
- emb = emb.unsqueeze(-2)
528
- emb = self.linear(self.silu(emb))
529
- scale, shift = emb.chunk(2, dim=-1)
530
- x = self.norm(x) * (1 + scale) + shift
531
- return x
532
-
533
-
534
- class HunyuanVideoSingleTransformerBlock(nn.Module):
535
- def __init__(
536
- self,
537
- num_attention_heads: int,
538
- attention_head_dim: int,
539
- mlp_ratio: float = 4.0,
540
- qk_norm: str = "rms_norm",
541
- ) -> None:
542
- super().__init__()
543
-
544
- hidden_size = num_attention_heads * attention_head_dim
545
- mlp_dim = int(hidden_size * mlp_ratio)
546
-
547
- self.attn = Attention(
548
- query_dim=hidden_size,
549
- cross_attention_dim=None,
550
- dim_head=attention_head_dim,
551
- heads=num_attention_heads,
552
- out_dim=hidden_size,
553
- bias=True,
554
- processor=HunyuanAttnProcessorFlashAttnSingle(),
555
- qk_norm=qk_norm,
556
- eps=1e-6,
557
- pre_only=True,
558
- )
559
-
560
- self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
561
- self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
562
- self.act_mlp = nn.GELU(approximate="tanh")
563
- self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
564
-
565
- def forward(
566
- self,
567
- hidden_states: torch.Tensor,
568
- encoder_hidden_states: torch.Tensor,
569
- temb: torch.Tensor,
570
- attention_mask: Optional[torch.Tensor] = None,
571
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
572
- ) -> torch.Tensor:
573
- text_seq_length = encoder_hidden_states.shape[1]
574
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
575
-
576
- residual = hidden_states
577
-
578
- # 1. Input normalization
579
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
580
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
581
-
582
- norm_hidden_states, norm_encoder_hidden_states = (
583
- norm_hidden_states[:, :-text_seq_length, :],
584
- norm_hidden_states[:, -text_seq_length:, :],
585
- )
586
-
587
- # 2. Attention
588
- attn_output, context_attn_output = self.attn(
589
- hidden_states=norm_hidden_states,
590
- encoder_hidden_states=norm_encoder_hidden_states,
591
- attention_mask=attention_mask,
592
- image_rotary_emb=image_rotary_emb,
593
- )
594
- attn_output = torch.cat([attn_output, context_attn_output], dim=1)
595
-
596
- # 3. Modulation and residual connection
597
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
598
- hidden_states = gate * self.proj_out(hidden_states)
599
- hidden_states = hidden_states + residual
600
-
601
- hidden_states, encoder_hidden_states = (
602
- hidden_states[:, :-text_seq_length, :],
603
- hidden_states[:, -text_seq_length:, :],
604
- )
605
- return hidden_states, encoder_hidden_states
606
-
607
-
608
- class HunyuanVideoTransformerBlock(nn.Module):
609
- def __init__(
610
- self,
611
- num_attention_heads: int,
612
- attention_head_dim: int,
613
- mlp_ratio: float,
614
- qk_norm: str = "rms_norm",
615
- ) -> None:
616
- super().__init__()
617
-
618
- hidden_size = num_attention_heads * attention_head_dim
619
-
620
- self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
621
- self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
622
-
623
- self.attn = Attention(
624
- query_dim=hidden_size,
625
- cross_attention_dim=None,
626
- added_kv_proj_dim=hidden_size,
627
- dim_head=attention_head_dim,
628
- heads=num_attention_heads,
629
- out_dim=hidden_size,
630
- context_pre_only=False,
631
- bias=True,
632
- processor=HunyuanAttnProcessorFlashAttnDouble(),
633
- qk_norm=qk_norm,
634
- eps=1e-6,
635
- )
636
-
637
- self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
638
- self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
639
-
640
- self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
641
- self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
642
-
643
- def forward(
644
- self,
645
- hidden_states: torch.Tensor,
646
- encoder_hidden_states: torch.Tensor,
647
- temb: torch.Tensor,
648
- attention_mask: Optional[torch.Tensor] = None,
649
- freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
650
- ) -> Tuple[torch.Tensor, torch.Tensor]:
651
- # 1. Input normalization
652
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
653
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
654
-
655
- # 2. Joint attention
656
- attn_output, context_attn_output = self.attn(
657
- hidden_states=norm_hidden_states,
658
- encoder_hidden_states=norm_encoder_hidden_states,
659
- attention_mask=attention_mask,
660
- image_rotary_emb=freqs_cis,
661
- )
662
-
663
- # 3. Modulation and residual connection
664
- hidden_states = hidden_states + attn_output * gate_msa
665
- encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
666
-
667
- norm_hidden_states = self.norm2(hidden_states)
668
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
669
-
670
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
671
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
672
-
673
- # 4. Feed-forward
674
- ff_output = self.ff(norm_hidden_states)
675
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
676
-
677
- hidden_states = hidden_states + gate_mlp * ff_output
678
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
679
-
680
- return hidden_states, encoder_hidden_states
681
-
682
-
683
- class ClipVisionProjection(nn.Module):
684
- def __init__(self, in_channels, out_channels):
685
- super().__init__()
686
- self.up = nn.Linear(in_channels, out_channels * 3)
687
- self.down = nn.Linear(out_channels * 3, out_channels)
688
-
689
- def forward(self, x):
690
- projected_x = self.down(nn.functional.silu(self.up(x)))
691
- return projected_x
692
-
693
-
694
- class HunyuanVideoPatchEmbed(nn.Module):
695
- def __init__(self, patch_size, in_chans, embed_dim):
696
- super().__init__()
697
- self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
698
-
699
-
700
- class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
701
- def __init__(self, inner_dim):
702
- super().__init__()
703
- self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
704
- self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
705
- self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
706
-
707
- @torch.no_grad()
708
- def initialize_weight_from_another_conv3d(self, another_layer):
709
- weight = another_layer.weight.detach().clone()
710
- bias = another_layer.bias.detach().clone()
711
-
712
- sd = {
713
- 'proj.weight': weight.clone(),
714
- 'proj.bias': bias.clone(),
715
- 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
716
- 'proj_2x.bias': bias.clone(),
717
- 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
718
- 'proj_4x.bias': bias.clone(),
719
- }
720
-
721
- sd = {k: v.clone() for k, v in sd.items()}
722
-
723
- self.load_state_dict(sd)
724
- return
725
-
726
-
727
- class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
728
- @register_to_config
729
- def __init__(
730
- self,
731
- in_channels: int = 16,
732
- out_channels: int = 16,
733
- num_attention_heads: int = 24,
734
- attention_head_dim: int = 128,
735
- num_layers: int = 20,
736
- num_single_layers: int = 40,
737
- num_refiner_layers: int = 2,
738
- mlp_ratio: float = 4.0,
739
- patch_size: int = 2,
740
- patch_size_t: int = 1,
741
- qk_norm: str = "rms_norm",
742
- guidance_embeds: bool = True,
743
- text_embed_dim: int = 4096,
744
- pooled_projection_dim: int = 768,
745
- rope_theta: float = 256.0,
746
- rope_axes_dim: Tuple[int] = (16, 56, 56),
747
- has_image_proj=False,
748
- image_proj_dim=1152,
749
- has_clean_x_embedder=False,
750
- ) -> None:
751
- super().__init__()
752
-
753
- inner_dim = num_attention_heads * attention_head_dim
754
- out_channels = out_channels or in_channels
755
-
756
- # 1. Latent and condition embedders
757
- self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
758
- self.context_embedder = HunyuanVideoTokenRefiner(
759
- text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
760
- )
761
- self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
762
-
763
- self.clean_x_embedder = None
764
- self.image_projection = None
765
-
766
- # 2. RoPE
767
- self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
768
-
769
- # 3. Dual stream transformer blocks
770
- self.transformer_blocks = nn.ModuleList(
771
- [
772
- HunyuanVideoTransformerBlock(
773
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
774
- )
775
- for _ in range(num_layers)
776
- ]
777
- )
778
-
779
- # 4. Single stream transformer blocks
780
- self.single_transformer_blocks = nn.ModuleList(
781
- [
782
- HunyuanVideoSingleTransformerBlock(
783
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
784
- )
785
- for _ in range(num_single_layers)
786
- ]
787
- )
788
-
789
- # 5. Output projection
790
- self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
791
- self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
792
-
793
- self.inner_dim = inner_dim
794
- self.use_gradient_checkpointing = False
795
- self.enable_teacache = False
796
-
797
- if has_image_proj:
798
- self.install_image_projection(image_proj_dim)
799
-
800
- if has_clean_x_embedder:
801
- self.install_clean_x_embedder()
802
-
803
- self.high_quality_fp32_output_for_inference = False
804
-
805
- def install_image_projection(self, in_channels):
806
- self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
807
- self.config['has_image_proj'] = True
808
- self.config['image_proj_dim'] = in_channels
809
-
810
- def install_clean_x_embedder(self):
811
- self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
812
- self.config['has_clean_x_embedder'] = True
813
-
814
- def enable_gradient_checkpointing(self):
815
- self.use_gradient_checkpointing = True
816
- print('self.use_gradient_checkpointing = True')
817
-
818
- def disable_gradient_checkpointing(self):
819
- self.use_gradient_checkpointing = False
820
- print('self.use_gradient_checkpointing = False')
821
-
822
- def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
823
- self.enable_teacache = enable_teacache
824
- self.cnt = 0
825
- self.num_steps = num_steps
826
- self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
827
- self.accumulated_rel_l1_distance = 0
828
- self.previous_modulated_input = None
829
- self.previous_residual = None
830
- self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
831
-
832
- def gradient_checkpointing_method(self, block, *args):
833
- if self.use_gradient_checkpointing:
834
- result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
835
- else:
836
- result = block(*args)
837
- return result
838
-
839
- def process_input_hidden_states(
840
- self,
841
- latents, latent_indices=None,
842
- clean_latents=None, clean_latent_indices=None,
843
- clean_latents_2x=None, clean_latent_2x_indices=None,
844
- clean_latents_4x=None, clean_latent_4x_indices=None
845
- ):
846
- hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
847
- B, C, T, H, W = hidden_states.shape
848
-
849
- if latent_indices is None:
850
- latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
851
-
852
- hidden_states = hidden_states.flatten(2).transpose(1, 2)
853
-
854
- rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
855
- rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
856
-
857
- if clean_latents is not None and clean_latent_indices is not None:
858
- clean_latents = clean_latents.to(hidden_states)
859
- clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
860
- clean_latents = clean_latents.flatten(2).transpose(1, 2)
861
-
862
- clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
863
- clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
864
-
865
- hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
866
- rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
867
-
868
- if clean_latents_2x is not None and clean_latent_2x_indices is not None:
869
- clean_latents_2x = clean_latents_2x.to(hidden_states)
870
- clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
871
- clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
872
- clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
873
-
874
- clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
875
- clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
876
- clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
877
- clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
878
-
879
- hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
880
- rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
881
-
882
- if clean_latents_4x is not None and clean_latent_4x_indices is not None:
883
- clean_latents_4x = clean_latents_4x.to(hidden_states)
884
- clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
885
- clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
886
- clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
887
-
888
- clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
889
- clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
890
- clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
891
- clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
892
-
893
- hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
894
- rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
895
-
896
- return hidden_states, rope_freqs
897
-
898
- def forward(
899
- self,
900
- hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
901
- latent_indices=None,
902
- clean_latents=None, clean_latent_indices=None,
903
- clean_latents_2x=None, clean_latent_2x_indices=None,
904
- clean_latents_4x=None, clean_latent_4x_indices=None,
905
- image_embeddings=None,
906
- attention_kwargs=None, return_dict=True
907
- ):
908
-
909
- if attention_kwargs is None:
910
- attention_kwargs = {}
911
-
912
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
913
- p, p_t = self.config['patch_size'], self.config['patch_size_t']
914
- post_patch_num_frames = num_frames // p_t
915
- post_patch_height = height // p
916
- post_patch_width = width // p
917
- original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
918
-
919
- hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
920
-
921
- temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
922
- encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
923
-
924
- if self.image_projection is not None:
925
- assert image_embeddings is not None, 'You must use image embeddings!'
926
- extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
927
- extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
928
-
929
- # must cat before (not after) encoder_hidden_states, due to attn masking
930
- encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
931
- encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
932
-
933
- if batch_size == 1:
934
- # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
935
- # If they are not same, then their impls are wrong. Ours are always the correct one.
936
- text_len = encoder_attention_mask.sum().item()
937
- encoder_hidden_states = encoder_hidden_states[:, :text_len]
938
- attention_mask = None, None, None, None
939
- else:
940
- img_seq_len = hidden_states.shape[1]
941
- txt_seq_len = encoder_hidden_states.shape[1]
942
-
943
- cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
944
- cu_seqlens_kv = cu_seqlens_q
945
- max_seqlen_q = img_seq_len + txt_seq_len
946
- max_seqlen_kv = max_seqlen_q
947
-
948
- attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
949
-
950
- if self.enable_teacache:
951
- modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
952
-
953
- if self.cnt == 0 or self.cnt == self.num_steps-1:
954
- should_calc = True
955
- self.accumulated_rel_l1_distance = 0
956
- else:
957
- curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
958
- self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
959
- should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
960
-
961
- if should_calc:
962
- self.accumulated_rel_l1_distance = 0
963
-
964
- self.previous_modulated_input = modulated_inp
965
- self.cnt += 1
966
-
967
- if self.cnt == self.num_steps:
968
- self.cnt = 0
969
-
970
- if not should_calc:
971
- hidden_states = hidden_states + self.previous_residual
972
- else:
973
- ori_hidden_states = hidden_states.clone()
974
-
975
- for block_id, block in enumerate(self.transformer_blocks):
976
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
977
- block,
978
- hidden_states,
979
- encoder_hidden_states,
980
- temb,
981
- attention_mask,
982
- rope_freqs
983
- )
984
-
985
- for block_id, block in enumerate(self.single_transformer_blocks):
986
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
987
- block,
988
- hidden_states,
989
- encoder_hidden_states,
990
- temb,
991
- attention_mask,
992
- rope_freqs
993
- )
994
-
995
- self.previous_residual = hidden_states - ori_hidden_states
996
- else:
997
- for block_id, block in enumerate(self.transformer_blocks):
998
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
999
- block,
1000
- hidden_states,
1001
- encoder_hidden_states,
1002
- temb,
1003
- attention_mask,
1004
- rope_freqs
1005
- )
1006
-
1007
- for block_id, block in enumerate(self.single_transformer_blocks):
1008
- hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1009
- block,
1010
- hidden_states,
1011
- encoder_hidden_states,
1012
- temb,
1013
- attention_mask,
1014
- rope_freqs
1015
- )
1016
-
1017
- hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1018
-
1019
- hidden_states = hidden_states[:, -original_context_length:, :]
1020
-
1021
- if self.high_quality_fp32_output_for_inference:
1022
- hidden_states = hidden_states.to(dtype=torch.float32)
1023
- if self.proj_out.weight.dtype != torch.float32:
1024
- self.proj_out.to(dtype=torch.float32)
1025
-
1026
- hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1027
-
1028
- hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1029
- t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1030
- pt=p_t, ph=p, pw=p)
1031
-
1032
- if return_dict:
1033
- return Transformer2DModelOutput(sample=hidden_states)
1034
-
1035
- return hidden_states,
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import einops
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ from diffusers.loaders import FromOriginalModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.loaders import PeftAdapterMixin
11
+ from diffusers.utils import logging
12
+ from diffusers.models.attention import FeedForward
13
+ from diffusers.models.attention_processor import Attention
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers_helper.dit_common import LayerNorm
18
+ from diffusers_helper.utils import zero_module
19
+
20
+
21
+ enabled_backends = []
22
+
23
+ if torch.backends.cuda.flash_sdp_enabled():
24
+ enabled_backends.append("flash")
25
+ if torch.backends.cuda.math_sdp_enabled():
26
+ enabled_backends.append("math")
27
+ if torch.backends.cuda.mem_efficient_sdp_enabled():
28
+ enabled_backends.append("mem_efficient")
29
+ if torch.backends.cuda.cudnn_sdp_enabled():
30
+ enabled_backends.append("cudnn")
31
+
32
+ print("Currently enabled native sdp backends:", enabled_backends)
33
+
34
+ try:
35
+ # raise NotImplementedError
36
+ from xformers.ops import memory_efficient_attention as xformers_attn_func
37
+ print('Xformers is installed!')
38
+ except:
39
+ print('Xformers is not installed!')
40
+ xformers_attn_func = None
41
+
42
+ try:
43
+ # raise NotImplementedError
44
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
45
+ print('Flash Attn is installed!')
46
+ except:
47
+ print('Flash Attn is not installed!')
48
+ flash_attn_varlen_func = None
49
+ flash_attn_func = None
50
+
51
+ try:
52
+ # raise NotImplementedError
53
+ from sageattention import sageattn_varlen, sageattn
54
+ print('Sage Attn is installed!')
55
+ except:
56
+ print('Sage Attn is not installed!')
57
+ sageattn_varlen = None
58
+ sageattn = None
59
+
60
+
61
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
62
+
63
+
64
+ def pad_for_3d_conv(x, kernel_size):
65
+ b, c, t, h, w = x.shape
66
+ pt, ph, pw = kernel_size
67
+ pad_t = (pt - (t % pt)) % pt
68
+ pad_h = (ph - (h % ph)) % ph
69
+ pad_w = (pw - (w % pw)) % pw
70
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
71
+
72
+
73
+ def center_down_sample_3d(x, kernel_size):
74
+ # pt, ph, pw = kernel_size
75
+ # cp = (pt * ph * pw) // 2
76
+ # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
77
+ # xc = xp[cp]
78
+ # return xc
79
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
80
+
81
+
82
+ def get_cu_seqlens(text_mask, img_len):
83
+ batch_size = text_mask.shape[0]
84
+ text_len = text_mask.sum(dim=1)
85
+ max_len = text_mask.shape[1] + img_len
86
+
87
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
88
+
89
+ for i in range(batch_size):
90
+ s = text_len[i] + img_len
91
+ s1 = i * max_len + s
92
+ s2 = (i + 1) * max_len
93
+ cu_seqlens[2 * i + 1] = s1
94
+ cu_seqlens[2 * i + 2] = s2
95
+
96
+ return cu_seqlens
97
+
98
+
99
+ def apply_rotary_emb_transposed(x, freqs_cis):
100
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
101
+ x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
102
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
103
+ out = x.float() * cos + x_rotated.float() * sin
104
+ out = out.to(x)
105
+ return out
106
+
107
+
108
+ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
109
+ if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
110
+ if sageattn is not None:
111
+ x = sageattn(q, k, v, tensor_layout='NHD')
112
+ return x
113
+
114
+ if flash_attn_func is not None:
115
+ x = flash_attn_func(q, k, v)
116
+ return x
117
+
118
+ if xformers_attn_func is not None:
119
+ x = xformers_attn_func(q, k, v)
120
+ return x
121
+
122
+ x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
123
+ return x
124
+
125
+ B, L, H, C = q.shape
126
+
127
+ q = q.flatten(0, 1)
128
+ k = k.flatten(0, 1)
129
+ v = v.flatten(0, 1)
130
+
131
+ if sageattn_varlen is not None:
132
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
133
+ elif flash_attn_varlen_func is not None:
134
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
135
+ else:
136
+ raise NotImplementedError('No Attn Installed!')
137
+
138
+ x = x.unflatten(0, (B, L))
139
+
140
+ return x
141
+
142
+
143
+ class HunyuanAttnProcessorFlashAttnDouble:
144
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
145
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
146
+
147
+ query = attn.to_q(hidden_states)
148
+ key = attn.to_k(hidden_states)
149
+ value = attn.to_v(hidden_states)
150
+
151
+ query = query.unflatten(2, (attn.heads, -1))
152
+ key = key.unflatten(2, (attn.heads, -1))
153
+ value = value.unflatten(2, (attn.heads, -1))
154
+
155
+ query = attn.norm_q(query)
156
+ key = attn.norm_k(key)
157
+
158
+ query = apply_rotary_emb_transposed(query, image_rotary_emb)
159
+ key = apply_rotary_emb_transposed(key, image_rotary_emb)
160
+
161
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
162
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
163
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
164
+
165
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
166
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
167
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
168
+
169
+ encoder_query = attn.norm_added_q(encoder_query)
170
+ encoder_key = attn.norm_added_k(encoder_key)
171
+
172
+ query = torch.cat([query, encoder_query], dim=1)
173
+ key = torch.cat([key, encoder_key], dim=1)
174
+ value = torch.cat([value, encoder_value], dim=1)
175
+
176
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
177
+ hidden_states = hidden_states.flatten(-2)
178
+
179
+ txt_length = encoder_hidden_states.shape[1]
180
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
181
+
182
+ hidden_states = attn.to_out[0](hidden_states)
183
+ hidden_states = attn.to_out[1](hidden_states)
184
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
185
+
186
+ return hidden_states, encoder_hidden_states
187
+
188
+
189
+ class HunyuanAttnProcessorFlashAttnSingle:
190
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
191
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
192
+
193
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
194
+
195
+ query = attn.to_q(hidden_states)
196
+ key = attn.to_k(hidden_states)
197
+ value = attn.to_v(hidden_states)
198
+
199
+ query = query.unflatten(2, (attn.heads, -1))
200
+ key = key.unflatten(2, (attn.heads, -1))
201
+ value = value.unflatten(2, (attn.heads, -1))
202
+
203
+ query = attn.norm_q(query)
204
+ key = attn.norm_k(key)
205
+
206
+ txt_length = encoder_hidden_states.shape[1]
207
+
208
+ query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
209
+ key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
210
+
211
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
212
+ hidden_states = hidden_states.flatten(-2)
213
+
214
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
215
+
216
+ return hidden_states, encoder_hidden_states
217
+
218
+
219
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
220
+ def __init__(self, embedding_dim, pooled_projection_dim):
221
+ super().__init__()
222
+
223
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
224
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
225
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
226
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
227
+
228
+ def forward(self, timestep, guidance, pooled_projection):
229
+ timesteps_proj = self.time_proj(timestep)
230
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
231
+
232
+ guidance_proj = self.time_proj(guidance)
233
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
234
+
235
+ time_guidance_emb = timesteps_emb + guidance_emb
236
+
237
+ pooled_projections = self.text_embedder(pooled_projection)
238
+ conditioning = time_guidance_emb + pooled_projections
239
+
240
+ return conditioning
241
+
242
+
243
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
244
+ def __init__(self, embedding_dim, pooled_projection_dim):
245
+ super().__init__()
246
+
247
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
248
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
249
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
250
+
251
+ def forward(self, timestep, pooled_projection):
252
+ timesteps_proj = self.time_proj(timestep)
253
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
254
+
255
+ pooled_projections = self.text_embedder(pooled_projection)
256
+
257
+ conditioning = timesteps_emb + pooled_projections
258
+
259
+ return conditioning
260
+
261
+
262
+ class HunyuanVideoAdaNorm(nn.Module):
263
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
264
+ super().__init__()
265
+
266
+ out_features = out_features or 2 * in_features
267
+ self.linear = nn.Linear(in_features, out_features)
268
+ self.nonlinearity = nn.SiLU()
269
+
270
+ def forward(
271
+ self, temb: torch.Tensor
272
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
273
+ temb = self.linear(self.nonlinearity(temb))
274
+ gate_msa, gate_mlp = temb.chunk(2, dim=-1)
275
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
276
+ return gate_msa, gate_mlp
277
+
278
+
279
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
280
+ def __init__(
281
+ self,
282
+ num_attention_heads: int,
283
+ attention_head_dim: int,
284
+ mlp_width_ratio: str = 4.0,
285
+ mlp_drop_rate: float = 0.0,
286
+ attention_bias: bool = True,
287
+ ) -> None:
288
+ super().__init__()
289
+
290
+ hidden_size = num_attention_heads * attention_head_dim
291
+
292
+ self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
293
+ self.attn = Attention(
294
+ query_dim=hidden_size,
295
+ cross_attention_dim=None,
296
+ heads=num_attention_heads,
297
+ dim_head=attention_head_dim,
298
+ bias=attention_bias,
299
+ )
300
+
301
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
302
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
303
+
304
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ temb: torch.Tensor,
310
+ attention_mask: Optional[torch.Tensor] = None,
311
+ ) -> torch.Tensor:
312
+ norm_hidden_states = self.norm1(hidden_states)
313
+
314
+ attn_output = self.attn(
315
+ hidden_states=norm_hidden_states,
316
+ encoder_hidden_states=None,
317
+ attention_mask=attention_mask,
318
+ )
319
+
320
+ gate_msa, gate_mlp = self.norm_out(temb)
321
+ hidden_states = hidden_states + attn_output * gate_msa
322
+
323
+ ff_output = self.ff(self.norm2(hidden_states))
324
+ hidden_states = hidden_states + ff_output * gate_mlp
325
+
326
+ return hidden_states
327
+
328
+
329
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
330
+ def __init__(
331
+ self,
332
+ num_attention_heads: int,
333
+ attention_head_dim: int,
334
+ num_layers: int,
335
+ mlp_width_ratio: float = 4.0,
336
+ mlp_drop_rate: float = 0.0,
337
+ attention_bias: bool = True,
338
+ ) -> None:
339
+ super().__init__()
340
+
341
+ self.refiner_blocks = nn.ModuleList(
342
+ [
343
+ HunyuanVideoIndividualTokenRefinerBlock(
344
+ num_attention_heads=num_attention_heads,
345
+ attention_head_dim=attention_head_dim,
346
+ mlp_width_ratio=mlp_width_ratio,
347
+ mlp_drop_rate=mlp_drop_rate,
348
+ attention_bias=attention_bias,
349
+ )
350
+ for _ in range(num_layers)
351
+ ]
352
+ )
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states: torch.Tensor,
357
+ temb: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ ) -> None:
360
+ self_attn_mask = None
361
+ if attention_mask is not None:
362
+ batch_size = attention_mask.shape[0]
363
+ seq_len = attention_mask.shape[1]
364
+ attention_mask = attention_mask.to(hidden_states.device).bool()
365
+ print ("self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).expand(1, 1, seq_len, 1)")
366
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).expand(1, 1, seq_len, 1)
367
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
368
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
369
+ self_attn_mask[:, :, :, 0] = True
370
+
371
+ for block in self.refiner_blocks:
372
+ hidden_states = block(hidden_states, temb, self_attn_mask)
373
+
374
+ return hidden_states
375
+
376
+
377
+ class HunyuanVideoTokenRefiner(nn.Module):
378
+ def __init__(
379
+ self,
380
+ in_channels: int,
381
+ num_attention_heads: int,
382
+ attention_head_dim: int,
383
+ num_layers: int,
384
+ mlp_ratio: float = 4.0,
385
+ mlp_drop_rate: float = 0.0,
386
+ attention_bias: bool = True,
387
+ ) -> None:
388
+ super().__init__()
389
+
390
+ hidden_size = num_attention_heads * attention_head_dim
391
+
392
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
393
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
394
+ )
395
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
396
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
397
+ num_attention_heads=num_attention_heads,
398
+ attention_head_dim=attention_head_dim,
399
+ num_layers=num_layers,
400
+ mlp_width_ratio=mlp_ratio,
401
+ mlp_drop_rate=mlp_drop_rate,
402
+ attention_bias=attention_bias,
403
+ )
404
+
405
+ def forward(
406
+ self,
407
+ hidden_states: torch.Tensor,
408
+ timestep: torch.LongTensor,
409
+ attention_mask: Optional[torch.LongTensor] = None,
410
+ ) -> torch.Tensor:
411
+ if attention_mask is None:
412
+ pooled_projections = hidden_states.mean(dim=1)
413
+ else:
414
+ original_dtype = hidden_states.dtype
415
+ mask_float = attention_mask.float().unsqueeze(-1)
416
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
417
+ pooled_projections = pooled_projections.to(original_dtype)
418
+
419
+ temb = self.time_text_embed(timestep, pooled_projections)
420
+ hidden_states = self.proj_in(hidden_states)
421
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
422
+
423
+ return hidden_states
424
+
425
+
426
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
427
+ def __init__(self, rope_dim, theta):
428
+ super().__init__()
429
+ self.DT, self.DY, self.DX = rope_dim
430
+ self.theta = theta
431
+
432
+ @torch.no_grad()
433
+ def get_frequency(self, dim, pos):
434
+ T, H, W = pos.shape
435
+ freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
436
+ freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
437
+ return freqs.cos(), freqs.sin()
438
+
439
+ @torch.no_grad()
440
+ def forward_inner(self, frame_indices, height, width, device):
441
+ GT, GY, GX = torch.meshgrid(
442
+ frame_indices.to(device=device, dtype=torch.float32),
443
+ torch.arange(0, height, device=device, dtype=torch.float32),
444
+ torch.arange(0, width, device=device, dtype=torch.float32),
445
+ indexing="ij"
446
+ )
447
+
448
+ FCT, FST = self.get_frequency(self.DT, GT)
449
+ FCY, FSY = self.get_frequency(self.DY, GY)
450
+ FCX, FSX = self.get_frequency(self.DX, GX)
451
+
452
+ result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
453
+
454
+ return result.to(device)
455
+
456
+ @torch.no_grad()
457
+ def forward(self, frame_indices, height, width, device):
458
+ frame_indices = frame_indices.unbind(0)
459
+ results = [self.forward_inner(f, height, width, device) for f in frame_indices]
460
+ results = torch.stack(results, dim=0)
461
+ return results
462
+
463
+
464
+ class AdaLayerNormZero(nn.Module):
465
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
466
+ super().__init__()
467
+ self.silu = nn.SiLU()
468
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
469
+ if norm_type == "layer_norm":
470
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
471
+ else:
472
+ raise ValueError(f"unknown norm_type {norm_type}")
473
+
474
+ def forward(
475
+ self,
476
+ x: torch.Tensor,
477
+ emb: Optional[torch.Tensor] = None,
478
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
479
+ emb = emb.unsqueeze(-2)
480
+ emb = self.linear(self.silu(emb))
481
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
482
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
483
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
484
+
485
+
486
+ class AdaLayerNormZeroSingle(nn.Module):
487
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
488
+ super().__init__()
489
+
490
+ self.silu = nn.SiLU()
491
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
492
+ if norm_type == "layer_norm":
493
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
494
+ else:
495
+ raise ValueError(f"unknown norm_type {norm_type}")
496
+
497
+ def forward(
498
+ self,
499
+ x: torch.Tensor,
500
+ emb: Optional[torch.Tensor] = None,
501
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
502
+ emb = emb.unsqueeze(-2)
503
+ emb = self.linear(self.silu(emb))
504
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
505
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
506
+ return x, gate_msa
507
+
508
+
509
+ class AdaLayerNormContinuous(nn.Module):
510
+ def __init__(
511
+ self,
512
+ embedding_dim: int,
513
+ conditioning_embedding_dim: int,
514
+ elementwise_affine=True,
515
+ eps=1e-5,
516
+ bias=True,
517
+ norm_type="layer_norm",
518
+ ):
519
+ super().__init__()
520
+ self.silu = nn.SiLU()
521
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
522
+ if norm_type == "layer_norm":
523
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
524
+ else:
525
+ raise ValueError(f"unknown norm_type {norm_type}")
526
+
527
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
528
+ emb = emb.unsqueeze(-2)
529
+ emb = self.linear(self.silu(emb))
530
+ scale, shift = emb.chunk(2, dim=-1)
531
+ x = self.norm(x) * (1 + scale) + shift
532
+ return x
533
+
534
+
535
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
536
+ def __init__(
537
+ self,
538
+ num_attention_heads: int,
539
+ attention_head_dim: int,
540
+ mlp_ratio: float = 4.0,
541
+ qk_norm: str = "rms_norm",
542
+ ) -> None:
543
+ super().__init__()
544
+
545
+ hidden_size = num_attention_heads * attention_head_dim
546
+ mlp_dim = int(hidden_size * mlp_ratio)
547
+
548
+ self.attn = Attention(
549
+ query_dim=hidden_size,
550
+ cross_attention_dim=None,
551
+ dim_head=attention_head_dim,
552
+ heads=num_attention_heads,
553
+ out_dim=hidden_size,
554
+ bias=True,
555
+ processor=HunyuanAttnProcessorFlashAttnSingle(),
556
+ qk_norm=qk_norm,
557
+ eps=1e-6,
558
+ pre_only=True,
559
+ )
560
+
561
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
562
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
563
+ self.act_mlp = nn.GELU(approximate="tanh")
564
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
565
+
566
+ def forward(
567
+ self,
568
+ hidden_states: torch.Tensor,
569
+ encoder_hidden_states: torch.Tensor,
570
+ temb: torch.Tensor,
571
+ attention_mask: Optional[torch.Tensor] = None,
572
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
573
+ ) -> torch.Tensor:
574
+ text_seq_length = encoder_hidden_states.shape[1]
575
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
576
+
577
+ residual = hidden_states
578
+
579
+ # 1. Input normalization
580
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
581
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
582
+
583
+ norm_hidden_states, norm_encoder_hidden_states = (
584
+ norm_hidden_states[:, :-text_seq_length, :],
585
+ norm_hidden_states[:, -text_seq_length:, :],
586
+ )
587
+
588
+ # 2. Attention
589
+ attn_output, context_attn_output = self.attn(
590
+ hidden_states=norm_hidden_states,
591
+ encoder_hidden_states=norm_encoder_hidden_states,
592
+ attention_mask=attention_mask,
593
+ image_rotary_emb=image_rotary_emb,
594
+ )
595
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
596
+
597
+ # 3. Modulation and residual connection
598
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
599
+ hidden_states = gate * self.proj_out(hidden_states)
600
+ hidden_states = hidden_states + residual
601
+
602
+ hidden_states, encoder_hidden_states = (
603
+ hidden_states[:, :-text_seq_length, :],
604
+ hidden_states[:, -text_seq_length:, :],
605
+ )
606
+ return hidden_states, encoder_hidden_states
607
+
608
+
609
+ class HunyuanVideoTransformerBlock(nn.Module):
610
+ def __init__(
611
+ self,
612
+ num_attention_heads: int,
613
+ attention_head_dim: int,
614
+ mlp_ratio: float,
615
+ qk_norm: str = "rms_norm",
616
+ ) -> None:
617
+ super().__init__()
618
+
619
+ hidden_size = num_attention_heads * attention_head_dim
620
+
621
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
622
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
623
+
624
+ self.attn = Attention(
625
+ query_dim=hidden_size,
626
+ cross_attention_dim=None,
627
+ added_kv_proj_dim=hidden_size,
628
+ dim_head=attention_head_dim,
629
+ heads=num_attention_heads,
630
+ out_dim=hidden_size,
631
+ context_pre_only=False,
632
+ bias=True,
633
+ processor=HunyuanAttnProcessorFlashAttnDouble(),
634
+ qk_norm=qk_norm,
635
+ eps=1e-6,
636
+ )
637
+
638
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
639
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
640
+
641
+ self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
642
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
643
+
644
+ def forward(
645
+ self,
646
+ hidden_states: torch.Tensor,
647
+ encoder_hidden_states: torch.Tensor,
648
+ temb: torch.Tensor,
649
+ attention_mask: Optional[torch.Tensor] = None,
650
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
651
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
652
+ # 1. Input normalization
653
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
654
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
655
+
656
+ # 2. Joint attention
657
+ attn_output, context_attn_output = self.attn(
658
+ hidden_states=norm_hidden_states,
659
+ encoder_hidden_states=norm_encoder_hidden_states,
660
+ attention_mask=attention_mask,
661
+ image_rotary_emb=freqs_cis,
662
+ )
663
+
664
+ # 3. Modulation and residual connection
665
+ hidden_states = hidden_states + attn_output * gate_msa
666
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
667
+
668
+ norm_hidden_states = self.norm2(hidden_states)
669
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
670
+
671
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
672
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
673
+
674
+ # 4. Feed-forward
675
+ ff_output = self.ff(norm_hidden_states)
676
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
677
+
678
+ hidden_states = hidden_states + gate_mlp * ff_output
679
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
680
+
681
+ return hidden_states, encoder_hidden_states
682
+
683
+
684
+ class ClipVisionProjection(nn.Module):
685
+ def __init__(self, in_channels, out_channels):
686
+ super().__init__()
687
+ self.up = nn.Linear(in_channels, out_channels * 3)
688
+ self.down = nn.Linear(out_channels * 3, out_channels)
689
+
690
+ def forward(self, x):
691
+ projected_x = self.down(nn.functional.silu(self.up(x)))
692
+ return projected_x
693
+
694
+
695
+ class HunyuanVideoPatchEmbed(nn.Module):
696
+ def __init__(self, patch_size, in_chans, embed_dim):
697
+ super().__init__()
698
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
699
+
700
+
701
+ class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
702
+ def __init__(self, inner_dim):
703
+ super().__init__()
704
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
705
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
706
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
707
+
708
+ @torch.no_grad()
709
+ def initialize_weight_from_another_conv3d(self, another_layer):
710
+ weight = another_layer.weight.detach().clone()
711
+ bias = another_layer.bias.detach().clone()
712
+
713
+ sd = {
714
+ 'proj.weight': weight.clone(),
715
+ 'proj.bias': bias.clone(),
716
+ 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
717
+ 'proj_2x.bias': bias.clone(),
718
+ 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
719
+ 'proj_4x.bias': bias.clone(),
720
+ }
721
+
722
+ sd = {k: v.clone() for k, v in sd.items()}
723
+
724
+ self.load_state_dict(sd)
725
+ return
726
+
727
+
728
+ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
729
+ @register_to_config
730
+ def __init__(
731
+ self,
732
+ in_channels: int = 16,
733
+ out_channels: int = 16,
734
+ num_attention_heads: int = 24,
735
+ attention_head_dim: int = 128,
736
+ num_layers: int = 20,
737
+ num_single_layers: int = 40,
738
+ num_refiner_layers: int = 2,
739
+ mlp_ratio: float = 4.0,
740
+ patch_size: int = 2,
741
+ patch_size_t: int = 1,
742
+ qk_norm: str = "rms_norm",
743
+ guidance_embeds: bool = True,
744
+ text_embed_dim: int = 4096,
745
+ pooled_projection_dim: int = 768,
746
+ rope_theta: float = 256.0,
747
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
748
+ has_image_proj=False,
749
+ image_proj_dim=1152,
750
+ has_clean_x_embedder=False,
751
+ ) -> None:
752
+ super().__init__()
753
+
754
+ inner_dim = num_attention_heads * attention_head_dim
755
+ out_channels = out_channels or in_channels
756
+
757
+ # 1. Latent and condition embedders
758
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
759
+ self.context_embedder = HunyuanVideoTokenRefiner(
760
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
761
+ )
762
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
763
+
764
+ self.clean_x_embedder = None
765
+ self.image_projection = None
766
+
767
+ # 2. RoPE
768
+ self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
769
+
770
+ # 3. Dual stream transformer blocks
771
+ self.transformer_blocks = nn.ModuleList(
772
+ [
773
+ HunyuanVideoTransformerBlock(
774
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
775
+ )
776
+ for _ in range(num_layers)
777
+ ]
778
+ )
779
+
780
+ # 4. Single stream transformer blocks
781
+ self.single_transformer_blocks = nn.ModuleList(
782
+ [
783
+ HunyuanVideoSingleTransformerBlock(
784
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
785
+ )
786
+ for _ in range(num_single_layers)
787
+ ]
788
+ )
789
+
790
+ # 5. Output projection
791
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
792
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
793
+
794
+ self.inner_dim = inner_dim
795
+ self.use_gradient_checkpointing = False
796
+ self.enable_teacache = False
797
+
798
+ if has_image_proj:
799
+ self.install_image_projection(image_proj_dim)
800
+
801
+ if has_clean_x_embedder:
802
+ self.install_clean_x_embedder()
803
+
804
+ self.high_quality_fp32_output_for_inference = False
805
+
806
+ def install_image_projection(self, in_channels):
807
+ self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
808
+ self.config['has_image_proj'] = True
809
+ self.config['image_proj_dim'] = in_channels
810
+
811
+ def install_clean_x_embedder(self):
812
+ self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
813
+ self.config['has_clean_x_embedder'] = True
814
+
815
+ def enable_gradient_checkpointing(self):
816
+ self.use_gradient_checkpointing = True
817
+ print('self.use_gradient_checkpointing = True')
818
+
819
+ def disable_gradient_checkpointing(self):
820
+ self.use_gradient_checkpointing = False
821
+ print('self.use_gradient_checkpointing = False')
822
+
823
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
824
+ self.enable_teacache = enable_teacache
825
+ self.cnt = 0
826
+ self.num_steps = num_steps
827
+ self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
828
+ self.accumulated_rel_l1_distance = 0
829
+ self.previous_modulated_input = None
830
+ self.previous_residual = None
831
+ self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
832
+
833
+ def gradient_checkpointing_method(self, block, *args):
834
+ if self.use_gradient_checkpointing:
835
+ result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
836
+ else:
837
+ result = block(*args)
838
+ return result
839
+
840
+ def process_input_hidden_states(
841
+ self,
842
+ latents, latent_indices=None,
843
+ clean_latents=None, clean_latent_indices=None,
844
+ clean_latents_2x=None, clean_latent_2x_indices=None,
845
+ clean_latents_4x=None, clean_latent_4x_indices=None
846
+ ):
847
+ hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
848
+ B, C, T, H, W = hidden_states.shape
849
+
850
+ if latent_indices is None:
851
+ latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
852
+
853
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
854
+
855
+ rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
856
+ rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
857
+
858
+ if clean_latents is not None and clean_latent_indices is not None:
859
+ clean_latents = clean_latents.to(hidden_states)
860
+ clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
861
+ clean_latents = clean_latents.flatten(2).transpose(1, 2)
862
+
863
+ clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
864
+ clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
865
+
866
+ hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
867
+ rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
868
+
869
+ if clean_latents_2x is not None and clean_latent_2x_indices is not None:
870
+ clean_latents_2x = clean_latents_2x.to(hidden_states)
871
+ clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
872
+ clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
873
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
874
+
875
+ clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
876
+ clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
877
+ clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
878
+ clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
879
+
880
+ hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
881
+ rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
882
+
883
+ if clean_latents_4x is not None and clean_latent_4x_indices is not None:
884
+ clean_latents_4x = clean_latents_4x.to(hidden_states)
885
+ clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
886
+ clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
887
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
888
+
889
+ clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
890
+ clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
891
+ clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
892
+ clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
893
+
894
+ hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
895
+ rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
896
+
897
+ return hidden_states, rope_freqs
898
+
899
+ def forward(
900
+ self,
901
+ hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
902
+ latent_indices=None,
903
+ clean_latents=None, clean_latent_indices=None,
904
+ clean_latents_2x=None, clean_latent_2x_indices=None,
905
+ clean_latents_4x=None, clean_latent_4x_indices=None,
906
+ image_embeddings=None,
907
+ attention_kwargs=None, return_dict=True
908
+ ):
909
+
910
+ if attention_kwargs is None:
911
+ attention_kwargs = {}
912
+
913
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
914
+ p, p_t = self.config['patch_size'], self.config['patch_size_t']
915
+ post_patch_num_frames = num_frames // p_t
916
+ post_patch_height = height // p
917
+ post_patch_width = width // p
918
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
919
+
920
+ hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
921
+
922
+ temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
923
+ encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
924
+
925
+ if self.image_projection is not None:
926
+ assert image_embeddings is not None, 'You must use image embeddings!'
927
+ extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
928
+ extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
929
+
930
+ # must cat before (not after) encoder_hidden_states, due to attn masking
931
+ encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
932
+ encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
933
+
934
+ if batch_size == 1:
935
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
936
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
937
+ text_len = encoder_attention_mask.sum().item()
938
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
939
+ attention_mask = None, None, None, None
940
+ else:
941
+ img_seq_len = hidden_states.shape[1]
942
+ txt_seq_len = encoder_hidden_states.shape[1]
943
+
944
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
945
+ cu_seqlens_kv = cu_seqlens_q
946
+ max_seqlen_q = img_seq_len + txt_seq_len
947
+ max_seqlen_kv = max_seqlen_q
948
+
949
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
950
+
951
+ if self.enable_teacache:
952
+ modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
953
+
954
+ if self.cnt == 0 or self.cnt == self.num_steps-1:
955
+ should_calc = True
956
+ self.accumulated_rel_l1_distance = 0
957
+ else:
958
+ curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
959
+ self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
960
+ should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
961
+
962
+ if should_calc:
963
+ self.accumulated_rel_l1_distance = 0
964
+
965
+ self.previous_modulated_input = modulated_inp
966
+ self.cnt += 1
967
+
968
+ if self.cnt == self.num_steps:
969
+ self.cnt = 0
970
+
971
+ if not should_calc:
972
+ hidden_states = hidden_states + self.previous_residual
973
+ else:
974
+ ori_hidden_states = hidden_states.clone()
975
+
976
+ for block_id, block in enumerate(self.transformer_blocks):
977
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
978
+ block,
979
+ hidden_states,
980
+ encoder_hidden_states,
981
+ temb,
982
+ attention_mask,
983
+ rope_freqs
984
+ )
985
+
986
+ for block_id, block in enumerate(self.single_transformer_blocks):
987
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
988
+ block,
989
+ hidden_states,
990
+ encoder_hidden_states,
991
+ temb,
992
+ attention_mask,
993
+ rope_freqs
994
+ )
995
+
996
+ self.previous_residual = hidden_states - ori_hidden_states
997
+ else:
998
+ for block_id, block in enumerate(self.transformer_blocks):
999
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1000
+ block,
1001
+ hidden_states,
1002
+ encoder_hidden_states,
1003
+ temb,
1004
+ attention_mask,
1005
+ rope_freqs
1006
+ )
1007
+
1008
+ for block_id, block in enumerate(self.single_transformer_blocks):
1009
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1010
+ block,
1011
+ hidden_states,
1012
+ encoder_hidden_states,
1013
+ temb,
1014
+ attention_mask,
1015
+ rope_freqs
1016
+ )
1017
+
1018
+ hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1019
+
1020
+ hidden_states = hidden_states[:, -original_context_length:, :]
1021
+
1022
+ if self.high_quality_fp32_output_for_inference:
1023
+ hidden_states = hidden_states.to(dtype=torch.float32)
1024
+ if self.proj_out.weight.dtype != torch.float32:
1025
+ self.proj_out.to(dtype=torch.float32)
1026
+
1027
+ hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1028
+
1029
+ hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1030
+ t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1031
+ pt=p_t, ph=p, pw=p)
1032
+
1033
+ if return_dict:
1034
+ return Transformer2DModelOutput(sample=hidden_states)
1035
+
1036
+ return hidden_states,