1inkusFace commited on
Commit
73d5421
·
verified ·
1 Parent(s): f22e445

Upload 18 files

Browse files
skyreels_v2_infer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipelines import DiffusionForcingPipeline
skyreels_v2_infer/distributed/__init__.py ADDED
File without changes
skyreels_v2_infer/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.amp as amp
4
+ from torch.backends.cuda import sdp_kernel
5
+ from xfuser.core.distributed import get_sequence_parallel_rank
6
+ from xfuser.core.distributed import get_sequence_parallel_world_size
7
+ from xfuser.core.distributed import get_sp_group
8
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
9
+
10
+ from ..modules.transformer import sinusoidal_embedding_1d
11
+
12
+
13
+ def pad_freqs(original_tensor, target_len):
14
+ seq_len, s1, s2 = original_tensor.shape
15
+ pad_size = target_len - seq_len
16
+ padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
17
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
18
+ return padded_tensor
19
+
20
+
21
+ @amp.autocast("cuda", enabled=False)
22
+ def rope_apply(x, grid_sizes, freqs):
23
+ """
24
+ x: [B, L, N, C].
25
+ grid_sizes: [B, 3].
26
+ freqs: [M, C // 2].
27
+ """
28
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
29
+ # split freqs
30
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
31
+
32
+ # loop over samples
33
+ output = []
34
+ grid = [grid_sizes.tolist()] * x.size(0)
35
+ for i, (f, h, w) in enumerate(grid):
36
+ seq_len = f * h * w
37
+
38
+ # precompute multipliers
39
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
40
+ freqs_i = torch.cat(
41
+ [
42
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
43
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
44
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
45
+ ],
46
+ dim=-1,
47
+ ).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_sequence_parallel_world_size()
51
+ sp_rank = get_sequence_parallel_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :]
55
+ x_i = torch.view_as_real(x_i * freqs_i_rank.cuda()).flatten(2)
56
+ x_i = torch.cat([x_i, x[i, s:]])
57
+
58
+ # append to collection
59
+ output.append(x_i)
60
+ return torch.stack(output).float()
61
+
62
+
63
+ def broadcast_should_calc(should_calc: bool) -> bool:
64
+ import torch.distributed as dist
65
+
66
+ device = torch.cuda.current_device()
67
+ int_should_calc = 1 if should_calc else 0
68
+ tensor = torch.tensor([int_should_calc], device=device, dtype=torch.int8)
69
+ dist.broadcast(tensor, src=0)
70
+ should_calc = tensor.item() == 1
71
+ return should_calc
72
+
73
+
74
+ def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
75
+ """
76
+ x: A list of videos each with shape [C, T, H, W].
77
+ t: [B].
78
+ context: A list of text embeddings each with shape [L, C].
79
+ """
80
+ if self.model_type == "i2v":
81
+ assert clip_fea is not None and y is not None
82
+ # params
83
+ device = self.patch_embedding.weight.device
84
+ if self.freqs.device != device:
85
+ self.freqs = self.freqs.to(device)
86
+
87
+ if y is not None:
88
+ x = torch.cat([x, y], dim=1)
89
+
90
+ # embeddings
91
+ x = self.patch_embedding(x)
92
+ grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
93
+ x = x.flatten(2).transpose(1, 2)
94
+
95
+ if self.flag_causal_attention:
96
+ frame_num = grid_sizes[0]
97
+ height = grid_sizes[1]
98
+ width = grid_sizes[2]
99
+ block_num = frame_num // self.num_frame_per_block
100
+ range_tensor = torch.arange(block_num).view(-1, 1)
101
+ range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
102
+ casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
103
+ casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
104
+ casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
105
+ casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
106
+ self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
107
+
108
+ # time embeddings
109
+ with amp.autocast("cuda", dtype=torch.float32):
110
+ if t.dim() == 2:
111
+ b, f = t.shape
112
+ _flag_df = True
113
+ else:
114
+ _flag_df = False
115
+ e = self.time_embedding(
116
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
117
+ ) # b, dim
118
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
119
+
120
+ if self.inject_sample_info:
121
+ fps = torch.tensor(fps, dtype=torch.long, device=device)
122
+
123
+ fps_emb = self.fps_embedding(fps).float()
124
+ if _flag_df:
125
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
126
+ else:
127
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
128
+
129
+ if _flag_df:
130
+ e = e.view(b, f, 1, 1, self.dim)
131
+ e0 = e0.view(b, f, 1, 1, 6, self.dim)
132
+ e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
133
+ e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
134
+ e0 = e0.transpose(1, 2).contiguous()
135
+
136
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
137
+
138
+ # context
139
+ context = self.text_embedding(context)
140
+
141
+ if clip_fea is not None:
142
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
143
+ context = torch.concat([context_clip, context], dim=1)
144
+
145
+ # arguments
146
+ if e0.ndim == 4:
147
+ e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()]
148
+ kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
149
+
150
+ if self.enable_teacache:
151
+ modulated_inp = e0 if self.use_ref_steps else e
152
+ # teacache
153
+ if self.cnt % 2 == 0: # even -> conditon
154
+ self.is_even = True
155
+ if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
156
+ should_calc_even = True
157
+ self.accumulated_rel_l1_distance_even = 0
158
+ else:
159
+ rescale_func = np.poly1d(self.coefficients)
160
+ self.accumulated_rel_l1_distance_even += rescale_func(
161
+ ((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean())
162
+ .cpu()
163
+ .item()
164
+ )
165
+ if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
166
+ should_calc_even = False
167
+ else:
168
+ should_calc_even = True
169
+ self.accumulated_rel_l1_distance_even = 0
170
+ self.previous_e0_even = modulated_inp.clone()
171
+ else: # odd -> unconditon
172
+ self.is_even = False
173
+ if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
174
+ should_calc_odd = True
175
+ self.accumulated_rel_l1_distance_odd = 0
176
+ else:
177
+ rescale_func = np.poly1d(self.coefficients)
178
+ self.accumulated_rel_l1_distance_odd += rescale_func(
179
+ ((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean())
180
+ .cpu()
181
+ .item()
182
+ )
183
+ if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
184
+ should_calc_odd = False
185
+ else:
186
+ should_calc_odd = True
187
+ self.accumulated_rel_l1_distance_odd = 0
188
+ self.previous_e0_odd = modulated_inp.clone()
189
+
190
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
191
+ if self.enable_teacache:
192
+ if self.is_even:
193
+ should_calc_even = broadcast_should_calc(should_calc_even)
194
+ if not should_calc_even:
195
+ x += self.previous_residual_even
196
+ else:
197
+ ori_x = x.clone()
198
+ for block in self.blocks:
199
+ x = block(x, **kwargs)
200
+ ori_x.mul_(-1)
201
+ ori_x.add_(x)
202
+ self.previous_residual_even = ori_x
203
+ else:
204
+ should_calc_odd = broadcast_should_calc(should_calc_odd)
205
+ if not should_calc_odd:
206
+ x += self.previous_residual_odd
207
+ else:
208
+ ori_x = x.clone()
209
+ for block in self.blocks:
210
+ x = block(x, **kwargs)
211
+ ori_x.mul_(-1)
212
+ ori_x.add_(x)
213
+ self.previous_residual_odd = ori_x
214
+ self.cnt += 1
215
+ if self.cnt >= self.num_steps:
216
+ self.cnt = 0
217
+ else:
218
+ # Context Parallel
219
+ for block in self.blocks:
220
+ x = block(x, **kwargs)
221
+
222
+ # head
223
+ if e.ndim == 3:
224
+ e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
225
+ x = self.head(x, e)
226
+ # Context Parallel
227
+ x = get_sp_group().all_gather(x, dim=1)
228
+ # unpatchify
229
+ x = self.unpatchify(x, grid_sizes)
230
+ return x.float()
231
+
232
+
233
+ def usp_attn_forward(self, x, grid_sizes, freqs, block_mask):
234
+
235
+ r"""
236
+ Args:
237
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
238
+ seq_lens(Tensor): Shape [B]
239
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
240
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
241
+ """
242
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
243
+ half_dtypes = (torch.float16, torch.bfloat16)
244
+
245
+ def half(x):
246
+ return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
247
+
248
+ # query, key, value function
249
+ def qkv_fn(x):
250
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
251
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
252
+ v = self.v(x).view(b, s, n, d)
253
+ return q, k, v
254
+
255
+ x = x.to(self.q.weight.dtype)
256
+ q, k, v = qkv_fn(x)
257
+
258
+ if not self._flag_ar_attention:
259
+ q = rope_apply(q, grid_sizes, freqs)
260
+ k = rope_apply(k, grid_sizes, freqs)
261
+ else:
262
+
263
+ q = rope_apply(q, grid_sizes, freqs)
264
+ k = rope_apply(k, grid_sizes, freqs)
265
+ q = q.to(torch.bfloat16)
266
+ k = k.to(torch.bfloat16)
267
+ v = v.to(torch.bfloat16)
268
+ # x = torch.nn.functional.scaled_dot_product_attention(
269
+ # q.transpose(1, 2),
270
+ # k.transpose(1, 2),
271
+ # v.transpose(1, 2),
272
+ # ).transpose(1, 2).contiguous()
273
+ with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
274
+ x = (
275
+ torch.nn.functional.scaled_dot_product_attention(
276
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
277
+ )
278
+ .transpose(1, 2)
279
+ .contiguous()
280
+ )
281
+ x = xFuserLongContextAttention()(None, query=half(q), key=half(k), value=half(v), window_size=self.window_size)
282
+
283
+ # output
284
+ x = x.flatten(2)
285
+ x = self.o(x)
286
+ return x
skyreels_v2_infer/modules/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+
4
+ import torch
5
+ from safetensors.torch import load_file
6
+
7
+ from .clip import CLIPModel
8
+ from .t5 import T5EncoderModel
9
+ from .transformer import WanModel
10
+ from .vae import WanVAE
11
+
12
+
13
+ def download_model(model_id):
14
+ if not os.path.exists(model_id):
15
+ from huggingface_hub import snapshot_download
16
+
17
+ model_id = snapshot_download(repo_id=model_id)
18
+ return model_id
19
+
20
+
21
+ def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE:
22
+ vae = WanVAE(model_path).to(device).to(weight_dtype)
23
+ vae.vae.requires_grad_(False)
24
+ vae.vae.eval()
25
+ gc.collect()
26
+ torch.cuda.empty_cache()
27
+ return vae
28
+
29
+
30
+ def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel:
31
+ config_path = os.path.join(model_path, "config.json")
32
+ transformer = WanModel.from_config(config_path).to(weight_dtype).to(device)
33
+
34
+ for file in os.listdir(model_path):
35
+ if file.endswith(".safetensors"):
36
+ file_path = os.path.join(model_path, file)
37
+ state_dict = load_file(file_path)
38
+ transformer.load_state_dict(state_dict, strict=False)
39
+ del state_dict
40
+ gc.collect()
41
+ torch.cuda.empty_cache()
42
+
43
+ transformer.requires_grad_(False)
44
+ transformer.eval()
45
+ gc.collect()
46
+ torch.cuda.empty_cache()
47
+ return transformer
48
+
49
+
50
+ def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel:
51
+ t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")
52
+ tokenizer_path = os.path.join(model_path, "google", "umt5-xxl")
53
+ text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype)
54
+ text_encoder.requires_grad_(False)
55
+ text_encoder.eval()
56
+ gc.collect()
57
+ torch.cuda.empty_cache()
58
+ return text_encoder
59
+
60
+
61
+ def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel:
62
+ checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
63
+ tokenizer_path = os.path.join(model_path, "xlm-roberta-large")
64
+ image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device)
65
+ image_enc.requires_grad_(False)
66
+ image_enc.eval()
67
+ gc.collect()
68
+ torch.cuda.empty_cache()
69
+ return image_enc
skyreels_v2_infer/modules/attention.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+
7
+ FLASH_ATTN_3_AVAILABLE = True
8
+ except ModuleNotFoundError:
9
+ FLASH_ATTN_3_AVAILABLE = False
10
+
11
+ try:
12
+ import flash_attn
13
+
14
+ FLASH_ATTN_2_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_2_AVAILABLE = False
17
+
18
+ import warnings
19
+
20
+ __all__ = [
21
+ "flash_attention",
22
+ "attention",
23
+ ]
24
+
25
+
26
+ def flash_attention(
27
+ q,
28
+ k,
29
+ v,
30
+ q_lens=None,
31
+ k_lens=None,
32
+ dropout_p=0.0,
33
+ softmax_scale=None,
34
+ q_scale=None,
35
+ causal=False,
36
+ window_size=(-1, -1),
37
+ deterministic=False,
38
+ dtype=torch.bfloat16,
39
+ version=None,
40
+ ):
41
+ """
42
+ q: [B, Lq, Nq, C1].
43
+ k: [B, Lk, Nk, C1].
44
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
45
+ q_lens: [B].
46
+ k_lens: [B].
47
+ dropout_p: float. Dropout probability.
48
+ softmax_scale: float. The scaling of QK^T before applying softmax.
49
+ causal: bool. Whether to apply causal attention mask.
50
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
51
+ deterministic: bool. If True, slightly slower and uses more memory.
52
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
53
+ """
54
+ half_dtypes = (torch.float16, torch.bfloat16)
55
+ assert dtype in half_dtypes
56
+ assert q.device.type == "cuda" and q.size(-1) <= 256
57
+
58
+ # params
59
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
60
+
61
+ def half(x):
62
+ return x if x.dtype in half_dtypes else x.to(dtype)
63
+
64
+ # preprocess query
65
+
66
+ q = half(q.flatten(0, 1))
67
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
68
+
69
+ # preprocess key, value
70
+
71
+ k = half(k.flatten(0, 1))
72
+ v = half(v.flatten(0, 1))
73
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
74
+
75
+ q = q.to(v.dtype)
76
+ k = k.to(v.dtype)
77
+
78
+ if q_scale is not None:
79
+ q = q * q_scale
80
+
81
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
82
+ warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.")
83
+
84
+ torch.cuda.nvtx.range_push(f"{list(q.shape)}-{list(k.shape)}-{list(v.shape)}-{q.dtype}-{k.dtype}-{v.dtype}")
85
+ # apply attention
86
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
87
+ # Note: dropout_p, window_size are not supported in FA3 now.
88
+ x = flash_attn_interface.flash_attn_varlen_func(
89
+ q=q,
90
+ k=k,
91
+ v=v,
92
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
93
+ .cumsum(0, dtype=torch.int32)
94
+ .to(q.device, non_blocking=True),
95
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
96
+ .cumsum(0, dtype=torch.int32)
97
+ .to(q.device, non_blocking=True),
98
+ seqused_q=None,
99
+ seqused_k=None,
100
+ max_seqlen_q=lq,
101
+ max_seqlen_k=lk,
102
+ softmax_scale=softmax_scale,
103
+ causal=causal,
104
+ deterministic=deterministic,
105
+ )[0].unflatten(0, (b, lq))
106
+ else:
107
+ assert FLASH_ATTN_2_AVAILABLE
108
+ x = flash_attn.flash_attn_varlen_func(
109
+ q=q,
110
+ k=k,
111
+ v=v,
112
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
113
+ .cumsum(0, dtype=torch.int32)
114
+ .to(q.device, non_blocking=True),
115
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
116
+ .cumsum(0, dtype=torch.int32)
117
+ .to(q.device, non_blocking=True),
118
+ max_seqlen_q=lq,
119
+ max_seqlen_k=lk,
120
+ dropout_p=dropout_p,
121
+ softmax_scale=softmax_scale,
122
+ causal=causal,
123
+ window_size=window_size,
124
+ deterministic=deterministic,
125
+ ).unflatten(0, (b, lq))
126
+ torch.cuda.nvtx.range_pop()
127
+
128
+ # output
129
+ return x
130
+
131
+
132
+ def attention(
133
+ q,
134
+ k,
135
+ v,
136
+ q_lens=None,
137
+ k_lens=None,
138
+ dropout_p=0.0,
139
+ softmax_scale=None,
140
+ q_scale=None,
141
+ causal=False,
142
+ window_size=(-1, -1),
143
+ deterministic=False,
144
+ dtype=torch.bfloat16,
145
+ fa_version=None,
146
+ ):
147
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
148
+ return flash_attention(
149
+ q=q,
150
+ k=k,
151
+ v=v,
152
+ q_lens=q_lens,
153
+ k_lens=k_lens,
154
+ dropout_p=dropout_p,
155
+ softmax_scale=softmax_scale,
156
+ q_scale=q_scale,
157
+ causal=causal,
158
+ window_size=window_size,
159
+ deterministic=deterministic,
160
+ dtype=dtype,
161
+ version=fa_version,
162
+ )
163
+ else:
164
+ if q_lens is not None or k_lens is not None:
165
+ warnings.warn(
166
+ "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
167
+ )
168
+ attn_mask = None
169
+
170
+ q = q.transpose(1, 2).to(dtype)
171
+ k = k.transpose(1, 2).to(dtype)
172
+ v = v.transpose(1, 2).to(dtype)
173
+
174
+ out = torch.nn.functional.scaled_dot_product_attention(
175
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
176
+ )
177
+
178
+ out = out.transpose(1, 2).contiguous()
179
+ return out
skyreels_v2_infer/modules/clip.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+ from diffusers.models import ModelMixin
11
+
12
+ from .attention import flash_attention
13
+ from .tokenizers import HuggingfaceTokenizer
14
+ from .xlm_roberta import XLMRoberta
15
+
16
+ __all__ = [
17
+ "XLMRobertaCLIP",
18
+ "clip_xlm_roberta_vit_h_14",
19
+ "CLIPModel",
20
+ ]
21
+
22
+
23
+ def pos_interpolate(pos, seq_len):
24
+ if pos.size(1) == seq_len:
25
+ return pos
26
+ else:
27
+ src_grid = int(math.sqrt(pos.size(1)))
28
+ tar_grid = int(math.sqrt(seq_len))
29
+ n = pos.size(1) - src_grid * src_grid
30
+ return torch.cat(
31
+ [
32
+ pos[:, :n],
33
+ F.interpolate(
34
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2),
35
+ size=(tar_grid, tar_grid),
36
+ mode="bicubic",
37
+ align_corners=False,
38
+ )
39
+ .flatten(2)
40
+ .transpose(1, 2),
41
+ ],
42
+ dim=1,
43
+ )
44
+
45
+
46
+ class QuickGELU(nn.Module):
47
+ def forward(self, x):
48
+ return x * torch.sigmoid(1.702 * x)
49
+
50
+
51
+ class LayerNorm(nn.LayerNorm):
52
+ def forward(self, x):
53
+ return super().forward(x.float()).type_as(x)
54
+
55
+
56
+ class SelfAttention(nn.Module):
57
+ def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
58
+ assert dim % num_heads == 0
59
+ super().__init__()
60
+ self.dim = dim
61
+ self.num_heads = num_heads
62
+ self.head_dim = dim // num_heads
63
+ self.causal = causal
64
+ self.attn_dropout = attn_dropout
65
+ self.proj_dropout = proj_dropout
66
+
67
+ # layers
68
+ self.to_qkv = nn.Linear(dim, dim * 3)
69
+ self.proj = nn.Linear(dim, dim)
70
+
71
+ def forward(self, x):
72
+ """
73
+ x: [B, L, C].
74
+ """
75
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
76
+
77
+ # compute query, key, value
78
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
79
+
80
+ # compute attention
81
+ p = self.attn_dropout if self.training else 0.0
82
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
83
+ x = x.reshape(b, s, c)
84
+
85
+ # output
86
+ x = self.proj(x)
87
+ x = F.dropout(x, self.proj_dropout, self.training)
88
+ return x
89
+
90
+
91
+ class SwiGLU(nn.Module):
92
+ def __init__(self, dim, mid_dim):
93
+ super().__init__()
94
+ self.dim = dim
95
+ self.mid_dim = mid_dim
96
+
97
+ # layers
98
+ self.fc1 = nn.Linear(dim, mid_dim)
99
+ self.fc2 = nn.Linear(dim, mid_dim)
100
+ self.fc3 = nn.Linear(mid_dim, dim)
101
+
102
+ def forward(self, x):
103
+ x = F.silu(self.fc1(x)) * self.fc2(x)
104
+ x = self.fc3(x)
105
+ return x
106
+
107
+
108
+ class AttentionBlock(nn.Module):
109
+ def __init__(
110
+ self,
111
+ dim,
112
+ mlp_ratio,
113
+ num_heads,
114
+ post_norm=False,
115
+ causal=False,
116
+ activation="quick_gelu",
117
+ attn_dropout=0.0,
118
+ proj_dropout=0.0,
119
+ norm_eps=1e-5,
120
+ ):
121
+ assert activation in ["quick_gelu", "gelu", "swi_glu"]
122
+ super().__init__()
123
+ self.dim = dim
124
+ self.mlp_ratio = mlp_ratio
125
+ self.num_heads = num_heads
126
+ self.post_norm = post_norm
127
+ self.causal = causal
128
+ self.norm_eps = norm_eps
129
+
130
+ # layers
131
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
132
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
133
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
134
+ if activation == "swi_glu":
135
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
136
+ else:
137
+ self.mlp = nn.Sequential(
138
+ nn.Linear(dim, int(dim * mlp_ratio)),
139
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
140
+ nn.Linear(int(dim * mlp_ratio), dim),
141
+ nn.Dropout(proj_dropout),
142
+ )
143
+
144
+ def forward(self, x):
145
+ if self.post_norm:
146
+ x = x + self.norm1(self.attn(x))
147
+ x = x + self.norm2(self.mlp(x))
148
+ else:
149
+ x = x + self.attn(self.norm1(x))
150
+ x = x + self.mlp(self.norm2(x))
151
+ return x
152
+
153
+
154
+ class AttentionPool(nn.Module):
155
+ def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
156
+ assert dim % num_heads == 0
157
+ super().__init__()
158
+ self.dim = dim
159
+ self.mlp_ratio = mlp_ratio
160
+ self.num_heads = num_heads
161
+ self.head_dim = dim // num_heads
162
+ self.proj_dropout = proj_dropout
163
+ self.norm_eps = norm_eps
164
+
165
+ # layers
166
+ gain = 1.0 / math.sqrt(dim)
167
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
168
+ self.to_q = nn.Linear(dim, dim)
169
+ self.to_kv = nn.Linear(dim, dim * 2)
170
+ self.proj = nn.Linear(dim, dim)
171
+ self.norm = LayerNorm(dim, eps=norm_eps)
172
+ self.mlp = nn.Sequential(
173
+ nn.Linear(dim, int(dim * mlp_ratio)),
174
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
175
+ nn.Linear(int(dim * mlp_ratio), dim),
176
+ nn.Dropout(proj_dropout),
177
+ )
178
+
179
+ def forward(self, x):
180
+ """
181
+ x: [B, L, C].
182
+ """
183
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
184
+
185
+ # compute query, key, value
186
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
187
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
188
+
189
+ # compute attention
190
+ x = flash_attention(q, k, v, version=2)
191
+ x = x.reshape(b, 1, c)
192
+
193
+ # output
194
+ x = self.proj(x)
195
+ x = F.dropout(x, self.proj_dropout, self.training)
196
+
197
+ # mlp
198
+ x = x + self.mlp(self.norm(x))
199
+ return x[:, 0]
200
+
201
+
202
+ class VisionTransformer(nn.Module):
203
+ def __init__(
204
+ self,
205
+ image_size=224,
206
+ patch_size=16,
207
+ dim=768,
208
+ mlp_ratio=4,
209
+ out_dim=512,
210
+ num_heads=12,
211
+ num_layers=12,
212
+ pool_type="token",
213
+ pre_norm=True,
214
+ post_norm=False,
215
+ activation="quick_gelu",
216
+ attn_dropout=0.0,
217
+ proj_dropout=0.0,
218
+ embedding_dropout=0.0,
219
+ norm_eps=1e-5,
220
+ ):
221
+ if image_size % patch_size != 0:
222
+ print("[WARNING] image_size is not divisible by patch_size", flush=True)
223
+ assert pool_type in ("token", "token_fc", "attn_pool")
224
+ out_dim = out_dim or dim
225
+ super().__init__()
226
+ self.image_size = image_size
227
+ self.patch_size = patch_size
228
+ self.num_patches = (image_size // patch_size) ** 2
229
+ self.dim = dim
230
+ self.mlp_ratio = mlp_ratio
231
+ self.out_dim = out_dim
232
+ self.num_heads = num_heads
233
+ self.num_layers = num_layers
234
+ self.pool_type = pool_type
235
+ self.post_norm = post_norm
236
+ self.norm_eps = norm_eps
237
+
238
+ # embeddings
239
+ gain = 1.0 / math.sqrt(dim)
240
+ self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
241
+ if pool_type in ("token", "token_fc"):
242
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
243
+ self.pos_embedding = nn.Parameter(
244
+ gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim)
245
+ )
246
+ self.dropout = nn.Dropout(embedding_dropout)
247
+
248
+ # transformer
249
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
250
+ self.transformer = nn.Sequential(
251
+ *[
252
+ AttentionBlock(
253
+ dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps
254
+ )
255
+ for _ in range(num_layers)
256
+ ]
257
+ )
258
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
259
+
260
+ # head
261
+ if pool_type == "token":
262
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
263
+ elif pool_type == "token_fc":
264
+ self.head = nn.Linear(dim, out_dim)
265
+ elif pool_type == "attn_pool":
266
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
267
+
268
+ def forward(self, x, interpolation=False, use_31_block=False):
269
+ b = x.size(0)
270
+
271
+ # embeddings
272
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
273
+ if self.pool_type in ("token", "token_fc"):
274
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
275
+ if interpolation:
276
+ e = pos_interpolate(self.pos_embedding, x.size(1))
277
+ else:
278
+ e = self.pos_embedding
279
+ x = self.dropout(x + e)
280
+ if self.pre_norm is not None:
281
+ x = self.pre_norm(x)
282
+
283
+ # transformer
284
+ if use_31_block:
285
+ x = self.transformer[:-1](x)
286
+ return x
287
+ else:
288
+ x = self.transformer(x)
289
+ return x
290
+
291
+
292
+ class XLMRobertaWithHead(XLMRoberta):
293
+ def __init__(self, **kwargs):
294
+ self.out_dim = kwargs.pop("out_dim")
295
+ super().__init__(**kwargs)
296
+
297
+ # head
298
+ mid_dim = (self.dim + self.out_dim) // 2
299
+ self.head = nn.Sequential(
300
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)
301
+ )
302
+
303
+ def forward(self, ids):
304
+ # xlm-roberta
305
+ x = super().forward(ids)
306
+
307
+ # average pooling
308
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
309
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
310
+
311
+ # head
312
+ x = self.head(x)
313
+ return x
314
+
315
+
316
+ class XLMRobertaCLIP(nn.Module):
317
+ def __init__(
318
+ self,
319
+ embed_dim=1024,
320
+ image_size=224,
321
+ patch_size=14,
322
+ vision_dim=1280,
323
+ vision_mlp_ratio=4,
324
+ vision_heads=16,
325
+ vision_layers=32,
326
+ vision_pool="token",
327
+ vision_pre_norm=True,
328
+ vision_post_norm=False,
329
+ activation="gelu",
330
+ vocab_size=250002,
331
+ max_text_len=514,
332
+ type_size=1,
333
+ pad_id=1,
334
+ text_dim=1024,
335
+ text_heads=16,
336
+ text_layers=24,
337
+ text_post_norm=True,
338
+ text_dropout=0.1,
339
+ attn_dropout=0.0,
340
+ proj_dropout=0.0,
341
+ embedding_dropout=0.0,
342
+ norm_eps=1e-5,
343
+ ):
344
+ super().__init__()
345
+ self.embed_dim = embed_dim
346
+ self.image_size = image_size
347
+ self.patch_size = patch_size
348
+ self.vision_dim = vision_dim
349
+ self.vision_mlp_ratio = vision_mlp_ratio
350
+ self.vision_heads = vision_heads
351
+ self.vision_layers = vision_layers
352
+ self.vision_pre_norm = vision_pre_norm
353
+ self.vision_post_norm = vision_post_norm
354
+ self.activation = activation
355
+ self.vocab_size = vocab_size
356
+ self.max_text_len = max_text_len
357
+ self.type_size = type_size
358
+ self.pad_id = pad_id
359
+ self.text_dim = text_dim
360
+ self.text_heads = text_heads
361
+ self.text_layers = text_layers
362
+ self.text_post_norm = text_post_norm
363
+ self.norm_eps = norm_eps
364
+
365
+ # models
366
+ self.visual = VisionTransformer(
367
+ image_size=image_size,
368
+ patch_size=patch_size,
369
+ dim=vision_dim,
370
+ mlp_ratio=vision_mlp_ratio,
371
+ out_dim=embed_dim,
372
+ num_heads=vision_heads,
373
+ num_layers=vision_layers,
374
+ pool_type=vision_pool,
375
+ pre_norm=vision_pre_norm,
376
+ post_norm=vision_post_norm,
377
+ activation=activation,
378
+ attn_dropout=attn_dropout,
379
+ proj_dropout=proj_dropout,
380
+ embedding_dropout=embedding_dropout,
381
+ norm_eps=norm_eps,
382
+ )
383
+ self.textual = XLMRobertaWithHead(
384
+ vocab_size=vocab_size,
385
+ max_seq_len=max_text_len,
386
+ type_size=type_size,
387
+ pad_id=pad_id,
388
+ dim=text_dim,
389
+ out_dim=embed_dim,
390
+ num_heads=text_heads,
391
+ num_layers=text_layers,
392
+ post_norm=text_post_norm,
393
+ dropout=text_dropout,
394
+ )
395
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
396
+
397
+ def forward(self, imgs, txt_ids):
398
+ """
399
+ imgs: [B, 3, H, W] of torch.float32.
400
+ - mean: [0.48145466, 0.4578275, 0.40821073]
401
+ - std: [0.26862954, 0.26130258, 0.27577711]
402
+ txt_ids: [B, L] of torch.long.
403
+ Encoded by data.CLIPTokenizer.
404
+ """
405
+ xi = self.visual(imgs)
406
+ xt = self.textual(txt_ids)
407
+ return xi, xt
408
+
409
+ def param_groups(self):
410
+ groups = [
411
+ {
412
+ "params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")],
413
+ "weight_decay": 0.0,
414
+ },
415
+ {"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
416
+ ]
417
+ return groups
418
+
419
+
420
+ def _clip(
421
+ pretrained=False,
422
+ pretrained_name=None,
423
+ model_cls=XLMRobertaCLIP,
424
+ return_transforms=False,
425
+ return_tokenizer=False,
426
+ tokenizer_padding="eos",
427
+ dtype=torch.float32,
428
+ device="cpu",
429
+ **kwargs,
430
+ ):
431
+ # init a model on device
432
+ with torch.device(device):
433
+ model = model_cls(**kwargs)
434
+
435
+ # set device
436
+ model = model.to(dtype=dtype, device=device)
437
+ output = (model,)
438
+
439
+ # init transforms
440
+ if return_transforms:
441
+ # mean and std
442
+ if "siglip" in pretrained_name.lower():
443
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
444
+ else:
445
+ mean = [0.48145466, 0.4578275, 0.40821073]
446
+ std = [0.26862954, 0.26130258, 0.27577711]
447
+
448
+ # transforms
449
+ transforms = T.Compose(
450
+ [
451
+ T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC),
452
+ T.ToTensor(),
453
+ T.Normalize(mean=mean, std=std),
454
+ ]
455
+ )
456
+ output += (transforms,)
457
+ return output[0] if len(output) == 1 else output
458
+
459
+
460
+ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
461
+ cfg = dict(
462
+ embed_dim=1024,
463
+ image_size=224,
464
+ patch_size=14,
465
+ vision_dim=1280,
466
+ vision_mlp_ratio=4,
467
+ vision_heads=16,
468
+ vision_layers=32,
469
+ vision_pool="token",
470
+ activation="gelu",
471
+ vocab_size=250002,
472
+ max_text_len=514,
473
+ type_size=1,
474
+ pad_id=1,
475
+ text_dim=1024,
476
+ text_heads=16,
477
+ text_layers=24,
478
+ text_post_norm=True,
479
+ text_dropout=0.1,
480
+ attn_dropout=0.0,
481
+ proj_dropout=0.0,
482
+ embedding_dropout=0.0,
483
+ )
484
+ cfg.update(**kwargs)
485
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
486
+
487
+
488
+ class CLIPModel(ModelMixin):
489
+ def __init__(self, checkpoint_path, tokenizer_path):
490
+ self.checkpoint_path = checkpoint_path
491
+ self.tokenizer_path = tokenizer_path
492
+
493
+ super().__init__()
494
+ # init model
495
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
496
+ pretrained=False, return_transforms=True, return_tokenizer=False
497
+ )
498
+ self.model = self.model.eval().requires_grad_(False)
499
+ logging.info(f"loading {checkpoint_path}")
500
+ self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
501
+
502
+ # init tokenizer
503
+ self.tokenizer = HuggingfaceTokenizer(
504
+ name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace"
505
+ )
506
+
507
+ def encode_video(self, video):
508
+ # preprocess
509
+ b, c, t, h, w = video.shape
510
+ video = video.transpose(1, 2)
511
+ video = video.reshape(b * t, c, h, w)
512
+ size = (self.model.image_size,) * 2
513
+ video = F.interpolate(
514
+ video,
515
+ size=size,
516
+ mode='bicubic',
517
+ align_corners=False)
518
+
519
+ video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5))
520
+
521
+ # forward
522
+ with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type):
523
+ out = self.model.visual(video, use_31_block=True)
524
+
525
+ return out
skyreels_v2_infer/modules/t5.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from diffusers.models import ModelMixin
10
+
11
+ from .tokenizers import HuggingfaceTokenizer
12
+
13
+ __all__ = [
14
+ "T5Model",
15
+ "T5Encoder",
16
+ "T5Decoder",
17
+ "T5EncoderModel",
18
+ ]
19
+
20
+
21
+ def fp16_clamp(x):
22
+ if x.dtype == torch.float16 and torch.isinf(x).any():
23
+ clamp = torch.finfo(x.dtype).max - 1000
24
+ x = torch.clamp(x, min=-clamp, max=clamp)
25
+ return x
26
+
27
+
28
+ def init_weights(m):
29
+ if isinstance(m, T5LayerNorm):
30
+ nn.init.ones_(m.weight)
31
+ elif isinstance(m, T5Model):
32
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
33
+ elif isinstance(m, T5FeedForward):
34
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
36
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
37
+ elif isinstance(m, T5Attention):
38
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
39
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
41
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
42
+ elif isinstance(m, T5RelativeEmbedding):
43
+ nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
44
+
45
+
46
+ class GELU(nn.Module):
47
+ def forward(self, x):
48
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
49
+
50
+
51
+ class T5LayerNorm(nn.Module):
52
+ def __init__(self, dim, eps=1e-6):
53
+ super(T5LayerNorm, self).__init__()
54
+ self.dim = dim
55
+ self.eps = eps
56
+ self.weight = nn.Parameter(torch.ones(dim))
57
+
58
+ def forward(self, x):
59
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
60
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
61
+ x = x.type_as(self.weight)
62
+ return self.weight * x
63
+
64
+
65
+ class T5Attention(nn.Module):
66
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
67
+ assert dim_attn % num_heads == 0
68
+ super(T5Attention, self).__init__()
69
+ self.dim = dim
70
+ self.dim_attn = dim_attn
71
+ self.num_heads = num_heads
72
+ self.head_dim = dim_attn // num_heads
73
+
74
+ # layers
75
+ self.q = nn.Linear(dim, dim_attn, bias=False)
76
+ self.k = nn.Linear(dim, dim_attn, bias=False)
77
+ self.v = nn.Linear(dim, dim_attn, bias=False)
78
+ self.o = nn.Linear(dim_attn, dim, bias=False)
79
+ self.dropout = nn.Dropout(dropout)
80
+
81
+ def forward(self, x, context=None, mask=None, pos_bias=None):
82
+ """
83
+ x: [B, L1, C].
84
+ context: [B, L2, C] or None.
85
+ mask: [B, L2] or [B, L1, L2] or None.
86
+ """
87
+ # check inputs
88
+ context = x if context is None else context
89
+ b, n, c = x.size(0), self.num_heads, self.head_dim
90
+
91
+ # compute query, key, value
92
+ q = self.q(x).view(b, -1, n, c)
93
+ k = self.k(context).view(b, -1, n, c)
94
+ v = self.v(context).view(b, -1, n, c)
95
+
96
+ # attention bias
97
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
98
+ if pos_bias is not None:
99
+ attn_bias += pos_bias
100
+ if mask is not None:
101
+ assert mask.ndim in [2, 3]
102
+ mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
103
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
104
+
105
+ # compute attention (T5 does not use scaling)
106
+ attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
107
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
108
+ x = torch.einsum("bnij,bjnc->binc", attn, v)
109
+
110
+ # output
111
+ x = x.reshape(b, -1, n * c)
112
+ x = self.o(x)
113
+ x = self.dropout(x)
114
+ return x
115
+
116
+
117
+ class T5FeedForward(nn.Module):
118
+ def __init__(self, dim, dim_ffn, dropout=0.1):
119
+ super(T5FeedForward, self).__init__()
120
+ self.dim = dim
121
+ self.dim_ffn = dim_ffn
122
+
123
+ # layers
124
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
125
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
126
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
127
+ self.dropout = nn.Dropout(dropout)
128
+
129
+ def forward(self, x):
130
+ x = self.fc1(x) * self.gate(x)
131
+ x = self.dropout(x)
132
+ x = self.fc2(x)
133
+ x = self.dropout(x)
134
+ return x
135
+
136
+
137
+ class T5SelfAttention(nn.Module):
138
+ def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
139
+ super(T5SelfAttention, self).__init__()
140
+ self.dim = dim
141
+ self.dim_attn = dim_attn
142
+ self.dim_ffn = dim_ffn
143
+ self.num_heads = num_heads
144
+ self.num_buckets = num_buckets
145
+ self.shared_pos = shared_pos
146
+
147
+ # layers
148
+ self.norm1 = T5LayerNorm(dim)
149
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
150
+ self.norm2 = T5LayerNorm(dim)
151
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
152
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
153
+
154
+ def forward(self, x, mask=None, pos_bias=None):
155
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
156
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
157
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
158
+ return x
159
+
160
+
161
+ class T5CrossAttention(nn.Module):
162
+ def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
163
+ super(T5CrossAttention, self).__init__()
164
+ self.dim = dim
165
+ self.dim_attn = dim_attn
166
+ self.dim_ffn = dim_ffn
167
+ self.num_heads = num_heads
168
+ self.num_buckets = num_buckets
169
+ self.shared_pos = shared_pos
170
+
171
+ # layers
172
+ self.norm1 = T5LayerNorm(dim)
173
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
174
+ self.norm2 = T5LayerNorm(dim)
175
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
176
+ self.norm3 = T5LayerNorm(dim)
177
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
178
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
179
+
180
+ def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
181
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
182
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
183
+ x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
184
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
185
+ return x
186
+
187
+
188
+ class T5RelativeEmbedding(nn.Module):
189
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
190
+ super(T5RelativeEmbedding, self).__init__()
191
+ self.num_buckets = num_buckets
192
+ self.num_heads = num_heads
193
+ self.bidirectional = bidirectional
194
+ self.max_dist = max_dist
195
+
196
+ # layers
197
+ self.embedding = nn.Embedding(num_buckets, num_heads)
198
+
199
+ def forward(self, lq, lk):
200
+ device = self.embedding.weight.device
201
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
202
+ # torch.arange(lq).unsqueeze(1).to(device)
203
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
204
+ rel_pos = self._relative_position_bucket(rel_pos)
205
+ rel_pos_embeds = self.embedding(rel_pos)
206
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
207
+ return rel_pos_embeds.contiguous()
208
+
209
+ def _relative_position_bucket(self, rel_pos):
210
+ # preprocess
211
+ if self.bidirectional:
212
+ num_buckets = self.num_buckets // 2
213
+ rel_buckets = (rel_pos > 0).long() * num_buckets
214
+ rel_pos = torch.abs(rel_pos)
215
+ else:
216
+ num_buckets = self.num_buckets
217
+ rel_buckets = 0
218
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
219
+
220
+ # embeddings for small and large positions
221
+ max_exact = num_buckets // 2
222
+ rel_pos_large = (
223
+ max_exact
224
+ + (
225
+ torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)
226
+ ).long()
227
+ )
228
+ rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
229
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
230
+ return rel_buckets
231
+
232
+
233
+ class T5Encoder(nn.Module):
234
+ def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
235
+ super(T5Encoder, self).__init__()
236
+ self.dim = dim
237
+ self.dim_attn = dim_attn
238
+ self.dim_ffn = dim_ffn
239
+ self.num_heads = num_heads
240
+ self.num_layers = num_layers
241
+ self.num_buckets = num_buckets
242
+ self.shared_pos = shared_pos
243
+
244
+ # layers
245
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
246
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
247
+ self.dropout = nn.Dropout(dropout)
248
+ self.blocks = nn.ModuleList(
249
+ [
250
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
251
+ for _ in range(num_layers)
252
+ ]
253
+ )
254
+ self.norm = T5LayerNorm(dim)
255
+
256
+ # initialize weights
257
+ self.apply(init_weights)
258
+
259
+ def forward(self, ids, mask=None):
260
+ x = self.token_embedding(ids)
261
+ x = self.dropout(x)
262
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
263
+ for block in self.blocks:
264
+ x = block(x, mask, pos_bias=e)
265
+ x = self.norm(x)
266
+ x = self.dropout(x)
267
+ return x
268
+
269
+
270
+ class T5Decoder(nn.Module):
271
+ def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
272
+ super(T5Decoder, self).__init__()
273
+ self.dim = dim
274
+ self.dim_attn = dim_attn
275
+ self.dim_ffn = dim_ffn
276
+ self.num_heads = num_heads
277
+ self.num_layers = num_layers
278
+ self.num_buckets = num_buckets
279
+ self.shared_pos = shared_pos
280
+
281
+ # layers
282
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
283
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
284
+ self.dropout = nn.Dropout(dropout)
285
+ self.blocks = nn.ModuleList(
286
+ [
287
+ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
288
+ for _ in range(num_layers)
289
+ ]
290
+ )
291
+ self.norm = T5LayerNorm(dim)
292
+
293
+ # initialize weights
294
+ self.apply(init_weights)
295
+
296
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
297
+ b, s = ids.size()
298
+
299
+ # causal mask
300
+ if mask is None:
301
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
302
+ elif mask.ndim == 2:
303
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
304
+
305
+ # layers
306
+ x = self.token_embedding(ids)
307
+ x = self.dropout(x)
308
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
309
+ for block in self.blocks:
310
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
311
+ x = self.norm(x)
312
+ x = self.dropout(x)
313
+ return x
314
+
315
+
316
+ class T5Model(nn.Module):
317
+ def __init__(
318
+ self,
319
+ vocab_size,
320
+ dim,
321
+ dim_attn,
322
+ dim_ffn,
323
+ num_heads,
324
+ encoder_layers,
325
+ decoder_layers,
326
+ num_buckets,
327
+ shared_pos=True,
328
+ dropout=0.1,
329
+ ):
330
+ super(T5Model, self).__init__()
331
+ self.vocab_size = vocab_size
332
+ self.dim = dim
333
+ self.dim_attn = dim_attn
334
+ self.dim_ffn = dim_ffn
335
+ self.num_heads = num_heads
336
+ self.encoder_layers = encoder_layers
337
+ self.decoder_layers = decoder_layers
338
+ self.num_buckets = num_buckets
339
+
340
+ # layers
341
+ self.token_embedding = nn.Embedding(vocab_size, dim)
342
+ self.encoder = T5Encoder(
343
+ self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout
344
+ )
345
+ self.decoder = T5Decoder(
346
+ self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout
347
+ )
348
+ self.head = nn.Linear(dim, vocab_size, bias=False)
349
+
350
+ # initialize weights
351
+ self.apply(init_weights)
352
+
353
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
354
+ x = self.encoder(encoder_ids, encoder_mask)
355
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
356
+ x = self.head(x)
357
+ return x
358
+
359
+
360
+ def _t5(
361
+ name,
362
+ encoder_only=False,
363
+ decoder_only=False,
364
+ return_tokenizer=False,
365
+ tokenizer_kwargs={},
366
+ dtype=torch.float32,
367
+ device="cpu",
368
+ **kwargs,
369
+ ):
370
+ # sanity check
371
+ assert not (encoder_only and decoder_only)
372
+
373
+ # params
374
+ if encoder_only:
375
+ model_cls = T5Encoder
376
+ kwargs["vocab"] = kwargs.pop("vocab_size")
377
+ kwargs["num_layers"] = kwargs.pop("encoder_layers")
378
+ _ = kwargs.pop("decoder_layers")
379
+ elif decoder_only:
380
+ model_cls = T5Decoder
381
+ kwargs["vocab"] = kwargs.pop("vocab_size")
382
+ kwargs["num_layers"] = kwargs.pop("decoder_layers")
383
+ _ = kwargs.pop("encoder_layers")
384
+ else:
385
+ model_cls = T5Model
386
+
387
+ # init model
388
+ with torch.device(device):
389
+ model = model_cls(**kwargs)
390
+
391
+ # set device
392
+ model = model.to(dtype=dtype, device=device)
393
+
394
+ # init tokenizer
395
+ if return_tokenizer:
396
+ from .tokenizers import HuggingfaceTokenizer
397
+
398
+ tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
399
+ return model, tokenizer
400
+ else:
401
+ return model
402
+
403
+
404
+ def umt5_xxl(**kwargs):
405
+ cfg = dict(
406
+ vocab_size=256384,
407
+ dim=4096,
408
+ dim_attn=4096,
409
+ dim_ffn=10240,
410
+ num_heads=64,
411
+ encoder_layers=24,
412
+ decoder_layers=24,
413
+ num_buckets=32,
414
+ shared_pos=False,
415
+ dropout=0.1,
416
+ )
417
+ cfg.update(**kwargs)
418
+ return _t5("umt5-xxl", **cfg)
419
+
420
+
421
+ class T5EncoderModel(ModelMixin):
422
+ def __init__(
423
+ self,
424
+ checkpoint_path=None,
425
+ tokenizer_path=None,
426
+ text_len=512,
427
+ shard_fn=None,
428
+ ):
429
+ self.text_len = text_len
430
+ self.checkpoint_path = checkpoint_path
431
+ self.tokenizer_path = tokenizer_path
432
+
433
+ super().__init__()
434
+ # init model
435
+ model = umt5_xxl(encoder_only=True, return_tokenizer=False)
436
+ logging.info(f"loading {checkpoint_path}")
437
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
438
+ self.model = model
439
+ if shard_fn is not None:
440
+ self.model = shard_fn(self.model, sync_module_states=False)
441
+ else:
442
+ self.model.eval().requires_grad_(False)
443
+ # init tokenizer
444
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
445
+
446
+ def encode(self, texts):
447
+ ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
448
+ ids = ids.to(self.device)
449
+ mask = mask.to(self.device)
450
+ # seq_lens = mask.gt(0).sum(dim=1).long()
451
+ context = self.model(ids, mask)
452
+ context = context * mask.unsqueeze(-1).cuda()
453
+
454
+ return context
skyreels_v2_infer/modules/tokenizers.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import html
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ from transformers import AutoTokenizer
8
+
9
+ __all__ = ["HuggingfaceTokenizer"]
10
+
11
+
12
+ def basic_clean(text):
13
+ text = ftfy.fix_text(text)
14
+ text = html.unescape(html.unescape(text))
15
+ return text.strip()
16
+
17
+
18
+ def whitespace_clean(text):
19
+ text = re.sub(r"\s+", " ", text)
20
+ text = text.strip()
21
+ return text
22
+
23
+
24
+ def canonicalize(text, keep_punctuation_exact_string=None):
25
+ text = text.replace("_", " ")
26
+ if keep_punctuation_exact_string:
27
+ text = keep_punctuation_exact_string.join(
28
+ part.translate(str.maketrans("", "", string.punctuation))
29
+ for part in text.split(keep_punctuation_exact_string)
30
+ )
31
+ else:
32
+ text = text.translate(str.maketrans("", "", string.punctuation))
33
+ text = text.lower()
34
+ text = re.sub(r"\s+", " ", text)
35
+ return text.strip()
36
+
37
+
38
+ class HuggingfaceTokenizer:
39
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
+ assert clean in (None, "whitespace", "lower", "canonicalize")
41
+ self.name = name
42
+ self.seq_len = seq_len
43
+ self.clean = clean
44
+
45
+ # init tokenizer
46
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
+ self.vocab_size = self.tokenizer.vocab_size
48
+
49
+ def __call__(self, sequence, **kwargs):
50
+ return_mask = kwargs.pop("return_mask", False)
51
+
52
+ # arguments
53
+ _kwargs = {"return_tensors": "pt"}
54
+ if self.seq_len is not None:
55
+ _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
56
+ _kwargs.update(**kwargs)
57
+
58
+ # tokenization
59
+ if isinstance(sequence, str):
60
+ sequence = [sequence]
61
+ if self.clean:
62
+ sequence = [self._clean(u) for u in sequence]
63
+ ids = self.tokenizer(sequence, **_kwargs)
64
+
65
+ # output
66
+ if return_mask:
67
+ return ids.input_ids, ids.attention_mask
68
+ else:
69
+ return ids.input_ids
70
+
71
+ def _clean(self, text):
72
+ if self.clean == "whitespace":
73
+ text = whitespace_clean(basic_clean(text))
74
+ elif self.clean == "lower":
75
+ text = whitespace_clean(basic_clean(text)).lower()
76
+ elif self.clean == "canonicalize":
77
+ text = canonicalize(basic_clean(text))
78
+ return text
skyreels_v2_infer/modules/transformer.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torch.amp as amp
6
+ import torch.nn as nn
7
+ from diffusers.configuration_utils import ConfigMixin
8
+ from diffusers.configuration_utils import register_to_config
9
+ from diffusers.loaders import PeftAdapterMixin
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from torch.backends.cuda import sdp_kernel
12
+ from torch.nn.attention.flex_attention import BlockMask
13
+ from torch.nn.attention.flex_attention import create_block_mask
14
+ from torch.nn.attention.flex_attention import flex_attention
15
+
16
+ from .attention import flash_attention
17
+
18
+
19
+ flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
20
+
21
+ DISABLE_COMPILE = False # get os env
22
+
23
+ __all__ = ["WanModel"]
24
+
25
+
26
+ def sinusoidal_embedding_1d(dim, position):
27
+ # preprocess
28
+ assert dim % 2 == 0
29
+ half = dim // 2
30
+ position = position.type(torch.float64)
31
+
32
+ # calculation
33
+ sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
34
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
35
+ return x
36
+
37
+
38
+ @amp.autocast("cuda", enabled=False)
39
+ def rope_params(max_seq_len, dim, theta=10000):
40
+ assert dim % 2 == 0
41
+ freqs = torch.outer(
42
+ torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
43
+ )
44
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
45
+ return freqs
46
+
47
+
48
+ @amp.autocast("cuda", enabled=False)
49
+ def rope_apply(x, grid_sizes, freqs):
50
+ n, c = x.size(2), x.size(3) // 2
51
+ bs = x.size(0)
52
+
53
+ # split freqs
54
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
55
+
56
+ # loop over samples
57
+ f, h, w = grid_sizes.tolist()
58
+ seq_len = f * h * w
59
+
60
+ # precompute multipliers
61
+
62
+ x = torch.view_as_complex(x.to(torch.float32).reshape(bs, seq_len, n, -1, 2))
63
+ freqs_i = torch.cat(
64
+ [
65
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
66
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
67
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
68
+ ],
69
+ dim=-1,
70
+ ).reshape(seq_len, 1, -1)
71
+
72
+ # apply rotary embedding
73
+ x = torch.view_as_real(x * freqs_i).flatten(3)
74
+
75
+ return x
76
+
77
+
78
+ @torch.compile(dynamic=True, disable=DISABLE_COMPILE)
79
+ def fast_rms_norm(x, weight, eps):
80
+ x = x.float()
81
+ x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
82
+ x = x.type_as(x) * weight
83
+ return x
84
+
85
+
86
+ class WanRMSNorm(nn.Module):
87
+ def __init__(self, dim, eps=1e-5):
88
+ super().__init__()
89
+ self.dim = dim
90
+ self.eps = eps
91
+ self.weight = nn.Parameter(torch.ones(dim))
92
+
93
+ def forward(self, x):
94
+ r"""
95
+ Args:
96
+ x(Tensor): Shape [B, L, C]
97
+ """
98
+ return fast_rms_norm(x, self.weight, self.eps)
99
+
100
+ def _norm(self, x):
101
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
102
+
103
+
104
+ class WanLayerNorm(nn.LayerNorm):
105
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
106
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
107
+
108
+ def forward(self, x):
109
+ r"""
110
+ Args:
111
+ x(Tensor): Shape [B, L, C]
112
+ """
113
+ return super().forward(x)
114
+
115
+
116
+ class WanSelfAttention(nn.Module):
117
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
118
+ assert dim % num_heads == 0
119
+ super().__init__()
120
+ self.dim = dim
121
+ self.num_heads = num_heads
122
+ self.head_dim = dim // num_heads
123
+ self.window_size = window_size
124
+ self.qk_norm = qk_norm
125
+ self.eps = eps
126
+
127
+ # layers
128
+ self.q = nn.Linear(dim, dim)
129
+ self.k = nn.Linear(dim, dim)
130
+ self.v = nn.Linear(dim, dim)
131
+ self.o = nn.Linear(dim, dim)
132
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
133
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
134
+
135
+ self._flag_ar_attention = False
136
+
137
+ def set_ar_attention(self):
138
+ self._flag_ar_attention = True
139
+
140
+ def forward(self, x, grid_sizes, freqs, block_mask):
141
+ r"""
142
+ Args:
143
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
144
+ seq_lens(Tensor): Shape [B]
145
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
146
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
147
+ """
148
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149
+
150
+ # query, key, value function
151
+ def qkv_fn(x):
152
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
153
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
154
+ v = self.v(x).view(b, s, n, d)
155
+ return q, k, v
156
+
157
+ x = x.to(self.q.weight.dtype)
158
+ q, k, v = qkv_fn(x)
159
+
160
+ if not self._flag_ar_attention:
161
+ q = rope_apply(q, grid_sizes, freqs)
162
+ k = rope_apply(k, grid_sizes, freqs)
163
+ x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
164
+ else:
165
+ q = rope_apply(q, grid_sizes, freqs)
166
+ k = rope_apply(k, grid_sizes, freqs)
167
+ q = q.to(torch.bfloat16)
168
+ k = k.to(torch.bfloat16)
169
+ v = v.to(torch.bfloat16)
170
+
171
+ with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
172
+ x = (
173
+ torch.nn.functional.scaled_dot_product_attention(
174
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
175
+ )
176
+ .transpose(1, 2)
177
+ .contiguous()
178
+ )
179
+
180
+ # output
181
+ x = x.flatten(2)
182
+ x = self.o(x)
183
+ return x
184
+
185
+
186
+ class WanT2VCrossAttention(WanSelfAttention):
187
+ def forward(self, x, context):
188
+ r"""
189
+ Args:
190
+ x(Tensor): Shape [B, L1, C]
191
+ context(Tensor): Shape [B, L2, C]
192
+ context_lens(Tensor): Shape [B]
193
+ """
194
+ b, n, d = x.size(0), self.num_heads, self.head_dim
195
+
196
+ # compute query, key, value
197
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
198
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
199
+ v = self.v(context).view(b, -1, n, d)
200
+
201
+ # compute attention
202
+ x = flash_attention(q, k, v)
203
+
204
+ # output
205
+ x = x.flatten(2)
206
+ x = self.o(x)
207
+ return x
208
+
209
+
210
+ class WanI2VCrossAttention(WanSelfAttention):
211
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
212
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
213
+
214
+ self.k_img = nn.Linear(dim, dim)
215
+ self.v_img = nn.Linear(dim, dim)
216
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
217
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
218
+
219
+ def forward(self, x, context):
220
+ r"""
221
+ Args:
222
+ x(Tensor): Shape [B, L1, C]
223
+ context(Tensor): Shape [B, L2, C]
224
+ context_lens(Tensor): Shape [B]
225
+ """
226
+ context_img = context[:, :257]
227
+ context = context[:, 257:]
228
+ b, n, d = x.size(0), self.num_heads, self.head_dim
229
+
230
+ # compute query, key, value
231
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
232
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
233
+ v = self.v(context).view(b, -1, n, d)
234
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
235
+ v_img = self.v_img(context_img).view(b, -1, n, d)
236
+ img_x = flash_attention(q, k_img, v_img)
237
+ # compute attention
238
+ x = flash_attention(q, k, v)
239
+
240
+ # output
241
+ x = x.flatten(2)
242
+ img_x = img_x.flatten(2)
243
+ x = x + img_x
244
+ x = self.o(x)
245
+ return x
246
+
247
+
248
+ WAN_CROSSATTENTION_CLASSES = {
249
+ "t2v_cross_attn": WanT2VCrossAttention,
250
+ "i2v_cross_attn": WanI2VCrossAttention,
251
+ }
252
+
253
+
254
+ def mul_add(x, y, z):
255
+ return x.float() + y.float() * z.float()
256
+
257
+
258
+ def mul_add_add(x, y, z):
259
+ return x.float() * (1 + y) + z
260
+
261
+
262
+ mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE)
263
+ mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE)
264
+
265
+
266
+ class WanAttentionBlock(nn.Module):
267
+ def __init__(
268
+ self,
269
+ cross_attn_type,
270
+ dim,
271
+ ffn_dim,
272
+ num_heads,
273
+ window_size=(-1, -1),
274
+ qk_norm=True,
275
+ cross_attn_norm=False,
276
+ eps=1e-6,
277
+ ):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.ffn_dim = ffn_dim
281
+ self.num_heads = num_heads
282
+ self.window_size = window_size
283
+ self.qk_norm = qk_norm
284
+ self.cross_attn_norm = cross_attn_norm
285
+ self.eps = eps
286
+
287
+ # layers
288
+ self.norm1 = WanLayerNorm(dim, eps)
289
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
290
+ self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
291
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps)
292
+ self.norm2 = WanLayerNorm(dim, eps)
293
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim))
294
+
295
+ # modulation
296
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
297
+
298
+ def set_ar_attention(self):
299
+ self.self_attn.set_ar_attention()
300
+
301
+ def forward(
302
+ self,
303
+ x,
304
+ e,
305
+ grid_sizes,
306
+ freqs,
307
+ context,
308
+ block_mask,
309
+ ):
310
+ r"""
311
+ Args:
312
+ x(Tensor): Shape [B, L, C]
313
+ e(Tensor): Shape [B, 6, C]
314
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
315
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
316
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
317
+ """
318
+ if e.dim() == 3:
319
+ modulation = self.modulation # 1, 6, dim
320
+ with amp.autocast("cuda", dtype=torch.float32):
321
+ e = (modulation + e).chunk(6, dim=1)
322
+ elif e.dim() == 4:
323
+ modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim
324
+ with amp.autocast("cuda", dtype=torch.float32):
325
+ e = (modulation + e).chunk(6, dim=1)
326
+ e = [ei.squeeze(1) for ei in e]
327
+
328
+ # self-attention
329
+ out = mul_add_add_compile(self.norm1(x), e[1], e[0])
330
+ y = self.self_attn(out, grid_sizes, freqs, block_mask)
331
+ with amp.autocast("cuda", dtype=torch.float32):
332
+ x = mul_add_compile(x, y, e[2])
333
+
334
+ # cross-attention & ffn function
335
+ def cross_attn_ffn(x, context, e):
336
+ dtype = context.dtype
337
+ x = x + self.cross_attn(self.norm3(x.to(dtype)), context)
338
+ y = self.ffn(mul_add_add_compile(self.norm2(x), e[4], e[3]).to(dtype))
339
+ with amp.autocast("cuda", dtype=torch.float32):
340
+ x = mul_add_compile(x, y, e[5])
341
+ return x
342
+
343
+ x = cross_attn_ffn(x, context, e)
344
+ return x.to(torch.bfloat16)
345
+
346
+
347
+ class Head(nn.Module):
348
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
349
+ super().__init__()
350
+ self.dim = dim
351
+ self.out_dim = out_dim
352
+ self.patch_size = patch_size
353
+ self.eps = eps
354
+
355
+ # layers
356
+ out_dim = math.prod(patch_size) * out_dim
357
+ self.norm = WanLayerNorm(dim, eps)
358
+ self.head = nn.Linear(dim, out_dim)
359
+
360
+ # modulation
361
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
362
+
363
+ def forward(self, x, e):
364
+ r"""
365
+ Args:
366
+ x(Tensor): Shape [B, L1, C]
367
+ e(Tensor): Shape [B, C]
368
+ """
369
+ with amp.autocast("cuda", dtype=torch.float32):
370
+ if e.dim() == 2:
371
+ modulation = self.modulation # 1, 2, dim
372
+ e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
373
+
374
+ elif e.dim() == 3:
375
+ modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim
376
+ e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
377
+ e = [ei.squeeze(1) for ei in e]
378
+ x = self.head(self.norm(x) * (1 + e[1]) + e[0])
379
+ return x
380
+
381
+
382
+ class MLPProj(torch.nn.Module):
383
+ def __init__(self, in_dim, out_dim):
384
+ super().__init__()
385
+
386
+ self.proj = torch.nn.Sequential(
387
+ torch.nn.LayerNorm(in_dim),
388
+ torch.nn.Linear(in_dim, in_dim),
389
+ torch.nn.GELU(),
390
+ torch.nn.Linear(in_dim, out_dim),
391
+ torch.nn.LayerNorm(out_dim),
392
+ )
393
+
394
+ def forward(self, image_embeds):
395
+ clip_extra_context_tokens = self.proj(image_embeds)
396
+ return clip_extra_context_tokens
397
+
398
+
399
+ class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
400
+ r"""
401
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
402
+ """
403
+
404
+ ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
405
+ _no_split_modules = ["WanAttentionBlock"]
406
+
407
+ _supports_gradient_checkpointing = True
408
+
409
+ @register_to_config
410
+ def __init__(
411
+ self,
412
+ model_type="t2v",
413
+ patch_size=(1, 2, 2),
414
+ text_len=512,
415
+ in_dim=16,
416
+ dim=2048,
417
+ ffn_dim=8192,
418
+ freq_dim=256,
419
+ text_dim=4096,
420
+ out_dim=16,
421
+ num_heads=16,
422
+ num_layers=32,
423
+ window_size=(-1, -1),
424
+ qk_norm=True,
425
+ cross_attn_norm=True,
426
+ inject_sample_info=False,
427
+ eps=1e-6,
428
+ ):
429
+ r"""
430
+ Initialize the diffusion model backbone.
431
+
432
+ Args:
433
+ model_type (`str`, *optional*, defaults to 't2v'):
434
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
435
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
436
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
437
+ text_len (`int`, *optional*, defaults to 512):
438
+ Fixed length for text embeddings
439
+ in_dim (`int`, *optional*, defaults to 16):
440
+ Input video channels (C_in)
441
+ dim (`int`, *optional*, defaults to 2048):
442
+ Hidden dimension of the transformer
443
+ ffn_dim (`int`, *optional*, defaults to 8192):
444
+ Intermediate dimension in feed-forward network
445
+ freq_dim (`int`, *optional*, defaults to 256):
446
+ Dimension for sinusoidal time embeddings
447
+ text_dim (`int`, *optional*, defaults to 4096):
448
+ Input dimension for text embeddings
449
+ out_dim (`int`, *optional*, defaults to 16):
450
+ Output video channels (C_out)
451
+ num_heads (`int`, *optional*, defaults to 16):
452
+ Number of attention heads
453
+ num_layers (`int`, *optional*, defaults to 32):
454
+ Number of transformer blocks
455
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
456
+ Window size for local attention (-1 indicates global attention)
457
+ qk_norm (`bool`, *optional*, defaults to True):
458
+ Enable query/key normalization
459
+ cross_attn_norm (`bool`, *optional*, defaults to False):
460
+ Enable cross-attention normalization
461
+ eps (`float`, *optional*, defaults to 1e-6):
462
+ Epsilon value for normalization layers
463
+ """
464
+
465
+ super().__init__()
466
+
467
+ assert model_type in ["t2v", "i2v"]
468
+ self.model_type = model_type
469
+
470
+ self.patch_size = patch_size
471
+ self.text_len = text_len
472
+ self.in_dim = in_dim
473
+ self.dim = dim
474
+ self.ffn_dim = ffn_dim
475
+ self.freq_dim = freq_dim
476
+ self.text_dim = text_dim
477
+ self.out_dim = out_dim
478
+ self.num_heads = num_heads
479
+ self.num_layers = num_layers
480
+ self.window_size = window_size
481
+ self.qk_norm = qk_norm
482
+ self.cross_attn_norm = cross_attn_norm
483
+ self.eps = eps
484
+ self.num_frame_per_block = 1
485
+ self.flag_causal_attention = False
486
+ self.block_mask = None
487
+ self.enable_teacache = False
488
+
489
+ # embeddings
490
+ self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
491
+ self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim))
492
+
493
+ self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
494
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
495
+
496
+ if inject_sample_info:
497
+ self.fps_embedding = nn.Embedding(2, dim)
498
+ self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
499
+
500
+ # blocks
501
+ cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
502
+ self.blocks = nn.ModuleList(
503
+ [
504
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
505
+ for _ in range(num_layers)
506
+ ]
507
+ )
508
+
509
+ # head
510
+ self.head = Head(dim, out_dim, patch_size, eps)
511
+
512
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
513
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
514
+ d = dim // num_heads
515
+ self.freqs = torch.cat(
516
+ [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))],
517
+ dim=1,
518
+ )
519
+
520
+ if model_type == "i2v":
521
+ self.img_emb = MLPProj(1280, dim)
522
+
523
+ self.gradient_checkpointing = False
524
+
525
+ self.cpu_offloading = False
526
+
527
+ self.inject_sample_info = inject_sample_info
528
+ # initialize weights
529
+ self.init_weights()
530
+
531
+ def _set_gradient_checkpointing(self, module, value=False):
532
+ self.gradient_checkpointing = value
533
+
534
+ def zero_init_i2v_cross_attn(self):
535
+ print("zero init i2v cross attn")
536
+ for i in range(self.num_layers):
537
+ self.blocks[i].cross_attn.v_img.weight.data.zero_()
538
+ self.blocks[i].cross_attn.v_img.bias.data.zero_()
539
+
540
+ @staticmethod
541
+ def _prepare_blockwise_causal_attn_mask(
542
+ device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1
543
+ ) -> BlockMask:
544
+ """
545
+ we will divide the token sequence into the following format
546
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
547
+ We use flexattention to construct the attention mask
548
+ """
549
+ total_length = num_frames * frame_seqlen
550
+
551
+ # we do right padding to get to a multiple of 128
552
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
553
+
554
+ ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
555
+
556
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
557
+ frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device)
558
+
559
+ for tmp in frame_indices:
560
+ ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block
561
+
562
+ def attention_mask(b, h, q_idx, kv_idx):
563
+ return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
564
+ # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
565
+
566
+ block_mask = create_block_mask(
567
+ attention_mask,
568
+ B=None,
569
+ H=None,
570
+ Q_LEN=total_length + padded_length,
571
+ KV_LEN=total_length + padded_length,
572
+ _compile=False,
573
+ device=device,
574
+ )
575
+
576
+ return block_mask
577
+
578
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir=''):
579
+ self.enable_teacache = enable_teacache
580
+ print('using teacache')
581
+ self.cnt = 0
582
+ self.num_steps = num_steps
583
+ self.teacache_thresh = teacache_thresh
584
+ self.accumulated_rel_l1_distance_even = 0
585
+ self.accumulated_rel_l1_distance_odd = 0
586
+ self.previous_e0_even = None
587
+ self.previous_e0_odd = None
588
+ self.previous_residual_even = None
589
+ self.previous_residual_odd = None
590
+ self.use_ref_steps = use_ret_steps
591
+ if "I2V" in ckpt_dir:
592
+ if use_ret_steps:
593
+ if '540P' in ckpt_dir:
594
+ self.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
595
+ if '720P' in ckpt_dir:
596
+ self.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
597
+ self.ret_steps = 5*2
598
+ self.cutoff_steps = num_steps*2
599
+ else:
600
+ if '540P' in ckpt_dir:
601
+ self.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
602
+ if '720P' in ckpt_dir:
603
+ self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
604
+ self.ret_steps = 1*2
605
+ self.cutoff_steps = num_steps*2 - 2
606
+ else:
607
+ if use_ret_steps:
608
+ if '1.3B' in ckpt_dir:
609
+ self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
610
+ if '14B' in ckpt_dir:
611
+ self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
612
+ self.ret_steps = 5*2
613
+ self.cutoff_steps = num_steps*2
614
+ else:
615
+ if '1.3B' in ckpt_dir:
616
+ self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
617
+ if '14B' in ckpt_dir:
618
+ self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
619
+ self.ret_steps = 1*2
620
+ self.cutoff_steps = num_steps*2 - 2
621
+
622
+ def forward(self, x, t, context, clip_fea=None, y=None, fps=None):
623
+ r"""
624
+ Forward pass through the diffusion model
625
+
626
+ Args:
627
+ x (List[Tensor]):
628
+ List of input video tensors, each with shape [C_in, F, H, W]
629
+ t (Tensor):
630
+ Diffusion timesteps tensor of shape [B]
631
+ context (List[Tensor]):
632
+ List of text embeddings each with shape [L, C]
633
+ seq_len (`int`):
634
+ Maximum sequence length for positional encoding
635
+ clip_fea (Tensor, *optional*):
636
+ CLIP image features for image-to-video mode
637
+ y (List[Tensor], *optional*):
638
+ Conditional video inputs for image-to-video mode, same shape as x
639
+
640
+ Returns:
641
+ List[Tensor]:
642
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
643
+ """
644
+ if self.model_type == "i2v":
645
+ assert clip_fea is not None and y is not None
646
+ # params
647
+ device = self.patch_embedding.weight.device
648
+ if self.freqs.device != device:
649
+ self.freqs = self.freqs.to(device)
650
+
651
+ if y is not None:
652
+ x = torch.cat([x, y], dim=1)
653
+
654
+ # embeddings
655
+ x = self.patch_embedding(x)
656
+ grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
657
+ x = x.flatten(2).transpose(1, 2)
658
+
659
+ if self.flag_causal_attention:
660
+ frame_num = grid_sizes[0]
661
+ height = grid_sizes[1]
662
+ width = grid_sizes[2]
663
+ block_num = frame_num // self.num_frame_per_block
664
+ range_tensor = torch.arange(block_num).view(-1, 1)
665
+ range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
666
+ casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
667
+ casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
668
+ casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
669
+ casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
670
+ self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
671
+
672
+ # time embeddings
673
+ with amp.autocast("cuda", dtype=torch.float32):
674
+ if t.dim() == 2:
675
+ b, f = t.shape
676
+ _flag_df = True
677
+ else:
678
+ _flag_df = False
679
+
680
+ e = self.time_embedding(
681
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
682
+ ) # b, dim
683
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
684
+
685
+ if self.inject_sample_info:
686
+ fps = torch.tensor(fps, dtype=torch.long, device=device)
687
+
688
+ fps_emb = self.fps_embedding(fps).float()
689
+ if _flag_df:
690
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
691
+ else:
692
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
693
+
694
+ if _flag_df:
695
+ e = e.view(b, f, 1, 1, self.dim)
696
+ e0 = e0.view(b, f, 1, 1, 6, self.dim)
697
+ e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
698
+ e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
699
+ e0 = e0.transpose(1, 2).contiguous()
700
+
701
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
702
+
703
+ # context
704
+ context = self.text_embedding(context)
705
+
706
+ if clip_fea is not None:
707
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
708
+ context = torch.concat([context_clip, context], dim=1)
709
+
710
+ # arguments
711
+ kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
712
+ if self.enable_teacache:
713
+ modulated_inp = e0 if self.use_ref_steps else e
714
+ # teacache
715
+ if self.cnt%2==0: # even -> conditon
716
+ self.is_even = True
717
+ if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
718
+ should_calc_even = True
719
+ self.accumulated_rel_l1_distance_even = 0
720
+ else:
721
+ rescale_func = np.poly1d(self.coefficients)
722
+ self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
723
+ if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
724
+ should_calc_even = False
725
+ else:
726
+ should_calc_even = True
727
+ self.accumulated_rel_l1_distance_even = 0
728
+ self.previous_e0_even = modulated_inp.clone()
729
+
730
+ else: # odd -> unconditon
731
+ self.is_even = False
732
+ if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
733
+ should_calc_odd = True
734
+ self.accumulated_rel_l1_distance_odd = 0
735
+ else:
736
+ rescale_func = np.poly1d(self.coefficients)
737
+ self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
738
+ if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
739
+ should_calc_odd = False
740
+ else:
741
+ should_calc_odd = True
742
+ self.accumulated_rel_l1_distance_odd = 0
743
+ self.previous_e0_odd = modulated_inp.clone()
744
+
745
+ if self.enable_teacache:
746
+ if self.is_even:
747
+ if not should_calc_even:
748
+ x += self.previous_residual_even
749
+ else:
750
+ ori_x = x.clone()
751
+ for block in self.blocks:
752
+ x = block(x, **kwargs)
753
+ self.previous_residual_even = x - ori_x
754
+ else:
755
+ if not should_calc_odd:
756
+ x += self.previous_residual_odd
757
+ else:
758
+ ori_x = x.clone()
759
+ for block in self.blocks:
760
+ x = block(x, **kwargs)
761
+ self.previous_residual_odd = x - ori_x
762
+
763
+ self.cnt += 1
764
+ if self.cnt >= self.num_steps:
765
+ self.cnt = 0
766
+ else:
767
+ for block in self.blocks:
768
+ x = block(x, **kwargs)
769
+
770
+ x = self.head(x, e)
771
+
772
+ # unpatchify
773
+ x = self.unpatchify(x, grid_sizes)
774
+
775
+ return x.float()
776
+
777
+ def unpatchify(self, x, grid_sizes):
778
+ r"""
779
+ Reconstruct video tensors from patch embeddings.
780
+
781
+ Args:
782
+ x (List[Tensor]):
783
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
784
+ grid_sizes (Tensor):
785
+ Original spatial-temporal grid dimensions before patching,
786
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
787
+
788
+ Returns:
789
+ List[Tensor]:
790
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
791
+ """
792
+
793
+ c = self.out_dim
794
+ bs = x.shape[0]
795
+ x = x.view(bs, *grid_sizes, *self.patch_size, c)
796
+ x = torch.einsum("bfhwpqrc->bcfphqwr", x)
797
+ x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
798
+
799
+ return x
800
+
801
+ def set_ar_attention(self, causal_block_size):
802
+ self.num_frame_per_block = causal_block_size
803
+ self.flag_causal_attention = True
804
+ for block in self.blocks:
805
+ block.set_ar_attention()
806
+
807
+ def init_weights(self):
808
+ r"""
809
+ Initialize model parameters using Xavier initialization.
810
+ """
811
+
812
+ # basic init
813
+ for m in self.modules():
814
+ if isinstance(m, nn.Linear):
815
+ nn.init.xavier_uniform_(m.weight)
816
+ if m.bias is not None:
817
+ nn.init.zeros_(m.bias)
818
+
819
+ # init embeddings
820
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
821
+ for m in self.text_embedding.modules():
822
+ if isinstance(m, nn.Linear):
823
+ nn.init.normal_(m.weight, std=0.02)
824
+ for m in self.time_embedding.modules():
825
+ if isinstance(m, nn.Linear):
826
+ nn.init.normal_(m.weight, std=0.02)
827
+
828
+ if self.inject_sample_info:
829
+ nn.init.normal_(self.fps_embedding.weight, std=0.02)
830
+
831
+ for m in self.fps_projection.modules():
832
+ if isinstance(m, nn.Linear):
833
+ nn.init.normal_(m.weight, std=0.02)
834
+
835
+ nn.init.zeros_(self.fps_projection[-1].weight)
836
+ nn.init.zeros_(self.fps_projection[-1].bias)
837
+
838
+ # init output layer
839
+ nn.init.zeros_(self.head.head.weight)
skyreels_v2_infer/modules/vae.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+
10
+ __all__ = [
11
+ "WanVAE",
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
25
+ self.padding = (0, 0, 0)
26
+
27
+ def forward(self, x, cache_x=None):
28
+ padding = list(self._padding)
29
+ if cache_x is not None and self._padding[4] > 0:
30
+ cache_x = cache_x.to(x.device)
31
+ x = torch.cat([cache_x, x], dim=2)
32
+ padding[4] -= cache_x.shape[2]
33
+ x = F.pad(x, padding)
34
+
35
+ return super().forward(x)
36
+
37
+
38
+ class RMS_norm(nn.Module):
39
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
40
+ super().__init__()
41
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
42
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
43
+
44
+ self.channel_first = channel_first
45
+ self.scale = dim**0.5
46
+ self.gamma = nn.Parameter(torch.ones(shape))
47
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
48
+
49
+ def forward(self, x):
50
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
51
+
52
+
53
+ class Upsample(nn.Upsample):
54
+ def forward(self, x):
55
+ """
56
+ Fix bfloat16 support for nearest neighbor interpolation.
57
+ """
58
+ return super().forward(x.float()).type_as(x)
59
+
60
+
61
+ class Resample(nn.Module):
62
+ def __init__(self, dim, mode):
63
+ assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
64
+ super().__init__()
65
+ self.dim = dim
66
+ self.mode = mode
67
+
68
+ # layers
69
+ if mode == "upsample2d":
70
+ self.resample = nn.Sequential(
71
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
72
+ )
73
+ elif mode == "upsample3d":
74
+ self.resample = nn.Sequential(
75
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
76
+ )
77
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
78
+
79
+ elif mode == "downsample2d":
80
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
81
+ elif mode == "downsample3d":
82
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
83
+ self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
84
+
85
+ else:
86
+ self.resample = nn.Identity()
87
+
88
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
89
+ b, c, t, h, w = x.size()
90
+ if self.mode == "upsample3d":
91
+ if feat_cache is not None:
92
+ idx = feat_idx[0]
93
+ if feat_cache[idx] is None:
94
+ feat_cache[idx] = "Rep"
95
+ feat_idx[0] += 1
96
+ else:
97
+
98
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
99
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
100
+ # cache last frame of last two chunk
101
+ cache_x = torch.cat(
102
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
103
+ )
104
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
105
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
106
+ if feat_cache[idx] == "Rep":
107
+ x = self.time_conv(x)
108
+ else:
109
+ x = self.time_conv(x, feat_cache[idx])
110
+ feat_cache[idx] = cache_x
111
+ feat_idx[0] += 1
112
+
113
+ x = x.reshape(b, 2, c, t, h, w)
114
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
115
+ x = x.reshape(b, c, t * 2, h, w)
116
+ t = x.shape[2]
117
+ x = rearrange(x, "b c t h w -> (b t) c h w")
118
+ x = self.resample(x)
119
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
120
+
121
+ if self.mode == "downsample3d":
122
+ if feat_cache is not None:
123
+ idx = feat_idx[0]
124
+ if feat_cache[idx] is None:
125
+ feat_cache[idx] = x.clone()
126
+ feat_idx[0] += 1
127
+ else:
128
+
129
+ cache_x = x[:, :, -1:, :, :].clone()
130
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
131
+ # # cache last frame of last two chunk
132
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
133
+
134
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
135
+ feat_cache[idx] = cache_x
136
+ feat_idx[0] += 1
137
+ return x
138
+
139
+ def init_weight(self, conv):
140
+ conv_weight = conv.weight
141
+ nn.init.zeros_(conv_weight)
142
+ c1, c2, t, h, w = conv_weight.size()
143
+ one_matrix = torch.eye(c1, c2)
144
+ init_matrix = one_matrix
145
+ nn.init.zeros_(conv_weight)
146
+ # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
147
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
148
+ conv.weight.data.copy_(conv_weight)
149
+ nn.init.zeros_(conv.bias.data)
150
+
151
+ def init_weight2(self, conv):
152
+ conv_weight = conv.weight.data
153
+ nn.init.zeros_(conv_weight)
154
+ c1, c2, t, h, w = conv_weight.size()
155
+ init_matrix = torch.eye(c1 // 2, c2)
156
+ # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
157
+ conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
158
+ conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
159
+ conv.weight.data.copy_(conv_weight)
160
+ nn.init.zeros_(conv.bias.data)
161
+
162
+
163
+ class ResidualBlock(nn.Module):
164
+ def __init__(self, in_dim, out_dim, dropout=0.0):
165
+ super().__init__()
166
+ self.in_dim = in_dim
167
+ self.out_dim = out_dim
168
+
169
+ # layers
170
+ self.residual = nn.Sequential(
171
+ RMS_norm(in_dim, images=False),
172
+ nn.SiLU(),
173
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
174
+ RMS_norm(out_dim, images=False),
175
+ nn.SiLU(),
176
+ nn.Dropout(dropout),
177
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
178
+ )
179
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
180
+
181
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
182
+ h = self.shortcut(x)
183
+ for layer in self.residual:
184
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
185
+ idx = feat_idx[0]
186
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
187
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
188
+ # cache last frame of last two chunk
189
+ cache_x = torch.cat(
190
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
191
+ )
192
+ x = layer(x, feat_cache[idx])
193
+ feat_cache[idx] = cache_x
194
+ feat_idx[0] += 1
195
+ else:
196
+ x = layer(x)
197
+ return x + h
198
+
199
+
200
+ class AttentionBlock(nn.Module):
201
+ """
202
+ Causal self-attention with a single head.
203
+ """
204
+
205
+ def __init__(self, dim):
206
+ super().__init__()
207
+ self.dim = dim
208
+
209
+ # layers
210
+ self.norm = RMS_norm(dim)
211
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
212
+ self.proj = nn.Conv2d(dim, dim, 1)
213
+
214
+ # zero out the last layer params
215
+ nn.init.zeros_(self.proj.weight)
216
+
217
+ def forward(self, x):
218
+ identity = x
219
+ b, c, t, h, w = x.size()
220
+ x = rearrange(x, "b c t h w -> (b t) c h w")
221
+ x = self.norm(x)
222
+ # compute query, key, value
223
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
224
+
225
+ # apply attention
226
+ x = F.scaled_dot_product_attention(
227
+ q,
228
+ k,
229
+ v,
230
+ )
231
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
232
+
233
+ # output
234
+ x = self.proj(x)
235
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
236
+ return x + identity
237
+
238
+
239
+ class Encoder3d(nn.Module):
240
+ def __init__(
241
+ self,
242
+ dim=128,
243
+ z_dim=4,
244
+ dim_mult=[1, 2, 4, 4],
245
+ num_res_blocks=2,
246
+ attn_scales=[],
247
+ temperal_downsample=[True, True, False],
248
+ dropout=0.0,
249
+ ):
250
+ super().__init__()
251
+ self.dim = dim
252
+ self.z_dim = z_dim
253
+ self.dim_mult = dim_mult
254
+ self.num_res_blocks = num_res_blocks
255
+ self.attn_scales = attn_scales
256
+ self.temperal_downsample = temperal_downsample
257
+
258
+ # dimensions
259
+ dims = [dim * u for u in [1] + dim_mult]
260
+ scale = 1.0
261
+
262
+ # init block
263
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
264
+
265
+ # downsample blocks
266
+ downsamples = []
267
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
268
+ # residual (+attention) blocks
269
+ for _ in range(num_res_blocks):
270
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
271
+ if scale in attn_scales:
272
+ downsamples.append(AttentionBlock(out_dim))
273
+ in_dim = out_dim
274
+
275
+ # downsample block
276
+ if i != len(dim_mult) - 1:
277
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
278
+ downsamples.append(Resample(out_dim, mode=mode))
279
+ scale /= 2.0
280
+ self.downsamples = nn.Sequential(*downsamples)
281
+
282
+ # middle blocks
283
+ self.middle = nn.Sequential(
284
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
285
+ )
286
+
287
+ # output blocks
288
+ self.head = nn.Sequential(
289
+ RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)
290
+ )
291
+
292
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
293
+ if feat_cache is not None:
294
+ idx = feat_idx[0]
295
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
296
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
297
+ # cache last frame of last two chunk
298
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
299
+ x = self.conv1(x, feat_cache[idx])
300
+ feat_cache[idx] = cache_x
301
+ feat_idx[0] += 1
302
+ else:
303
+ x = self.conv1(x)
304
+
305
+ ## downsamples
306
+ for layer in self.downsamples:
307
+ if feat_cache is not None:
308
+ x = layer(x, feat_cache, feat_idx)
309
+ else:
310
+ x = layer(x)
311
+
312
+ ## middle
313
+ for layer in self.middle:
314
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
315
+ x = layer(x, feat_cache, feat_idx)
316
+ else:
317
+ x = layer(x)
318
+
319
+ ## head
320
+ for layer in self.head:
321
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
322
+ idx = feat_idx[0]
323
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
324
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
325
+ # cache last frame of last two chunk
326
+ cache_x = torch.cat(
327
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
328
+ )
329
+ x = layer(x, feat_cache[idx])
330
+ feat_cache[idx] = cache_x
331
+ feat_idx[0] += 1
332
+ else:
333
+ x = layer(x)
334
+ return x
335
+
336
+
337
+ class Decoder3d(nn.Module):
338
+ def __init__(
339
+ self,
340
+ dim=128,
341
+ z_dim=4,
342
+ dim_mult=[1, 2, 4, 4],
343
+ num_res_blocks=2,
344
+ attn_scales=[],
345
+ temperal_upsample=[False, True, True],
346
+ dropout=0.0,
347
+ ):
348
+ super().__init__()
349
+ self.dim = dim
350
+ self.z_dim = z_dim
351
+ self.dim_mult = dim_mult
352
+ self.num_res_blocks = num_res_blocks
353
+ self.attn_scales = attn_scales
354
+ self.temperal_upsample = temperal_upsample
355
+
356
+ # dimensions
357
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
358
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
359
+
360
+ # init block
361
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
362
+
363
+ # middle blocks
364
+ self.middle = nn.Sequential(
365
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
366
+ )
367
+
368
+ # upsample blocks
369
+ upsamples = []
370
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
371
+ # residual (+attention) blocks
372
+ if i == 1 or i == 2 or i == 3:
373
+ in_dim = in_dim // 2
374
+ for _ in range(num_res_blocks + 1):
375
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
376
+ if scale in attn_scales:
377
+ upsamples.append(AttentionBlock(out_dim))
378
+ in_dim = out_dim
379
+
380
+ # upsample block
381
+ if i != len(dim_mult) - 1:
382
+ mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
383
+ upsamples.append(Resample(out_dim, mode=mode))
384
+ scale *= 2.0
385
+ self.upsamples = nn.Sequential(*upsamples)
386
+
387
+ # output blocks
388
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
389
+
390
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
391
+ ## conv1
392
+ if feat_cache is not None:
393
+ idx = feat_idx[0]
394
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
395
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
396
+ # cache last frame of last two chunk
397
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
398
+ x = self.conv1(x, feat_cache[idx])
399
+ feat_cache[idx] = cache_x
400
+ feat_idx[0] += 1
401
+ else:
402
+ x = self.conv1(x)
403
+
404
+ ## middle
405
+ for layer in self.middle:
406
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
407
+ x = layer(x, feat_cache, feat_idx)
408
+ else:
409
+ x = layer(x)
410
+
411
+ ## upsamples
412
+ for layer in self.upsamples:
413
+ if feat_cache is not None:
414
+ x = layer(x, feat_cache, feat_idx)
415
+ else:
416
+ x = layer(x)
417
+
418
+ ## head
419
+ for layer in self.head:
420
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
421
+ idx = feat_idx[0]
422
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
423
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
424
+ # cache last frame of last two chunk
425
+ cache_x = torch.cat(
426
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
427
+ )
428
+ x = layer(x, feat_cache[idx])
429
+ feat_cache[idx] = cache_x
430
+ feat_idx[0] += 1
431
+ else:
432
+ x = layer(x)
433
+ return x
434
+
435
+
436
+ def count_conv3d(model):
437
+ count = 0
438
+ for m in model.modules():
439
+ if isinstance(m, CausalConv3d):
440
+ count += 1
441
+ return count
442
+
443
+
444
+ class WanVAE_(nn.Module):
445
+ def __init__(
446
+ self,
447
+ dim=128,
448
+ z_dim=4,
449
+ dim_mult=[1, 2, 4, 4],
450
+ num_res_blocks=2,
451
+ attn_scales=[],
452
+ temperal_downsample=[True, True, False],
453
+ dropout=0.0,
454
+ ):
455
+ super().__init__()
456
+ self.dim = dim
457
+ self.z_dim = z_dim
458
+ self.dim_mult = dim_mult
459
+ self.num_res_blocks = num_res_blocks
460
+ self.attn_scales = attn_scales
461
+ self.temperal_downsample = temperal_downsample
462
+ self.temperal_upsample = temperal_downsample[::-1]
463
+
464
+ # modules
465
+ self.encoder = Encoder3d(
466
+ dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
467
+ )
468
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
469
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
470
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
471
+
472
+ def forward(self, x):
473
+ mu, log_var = self.encode(x)
474
+ z = self.reparameterize(mu, log_var)
475
+ x_recon = self.decode(z)
476
+ return x_recon, mu, log_var
477
+
478
+ def encode(self, x, scale):
479
+ self.clear_cache()
480
+ ## cache
481
+ t = x.shape[2]
482
+ iter_ = 1 + (t - 1) // 4
483
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
484
+ for i in range(iter_):
485
+ self._enc_conv_idx = [0]
486
+ if i == 0:
487
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
488
+ else:
489
+ out_ = self.encoder(
490
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
491
+ feat_cache=self._enc_feat_map,
492
+ feat_idx=self._enc_conv_idx,
493
+ )
494
+ out = torch.cat([out, out_], 2)
495
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
496
+ if isinstance(scale[0], torch.Tensor):
497
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
498
+ else:
499
+ mu = (mu - scale[0]) * scale[1]
500
+ self.clear_cache()
501
+ return mu
502
+
503
+ def decode(self, z, scale):
504
+ self.clear_cache()
505
+ # z: [b,c,t,h,w]
506
+ if isinstance(scale[0], torch.Tensor):
507
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
508
+ else:
509
+ z = z / scale[1] + scale[0]
510
+ iter_ = z.shape[2]
511
+ x = self.conv2(z)
512
+ for i in range(iter_):
513
+ self._conv_idx = [0]
514
+ if i == 0:
515
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
516
+ else:
517
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
518
+ out = torch.cat([out, out_], 2)
519
+ self.clear_cache()
520
+ return out
521
+
522
+ def reparameterize(self, mu, log_var):
523
+ std = torch.exp(0.5 * log_var)
524
+ eps = torch.randn_like(std)
525
+ return eps * std + mu
526
+
527
+ def sample(self, imgs, deterministic=False):
528
+ mu, log_var = self.encode(imgs)
529
+ if deterministic:
530
+ return mu
531
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
532
+ return mu + std * torch.randn_like(std)
533
+
534
+ def clear_cache(self):
535
+ self._conv_num = count_conv3d(self.decoder)
536
+ self._conv_idx = [0]
537
+ self._feat_map = [None] * self._conv_num
538
+ # cache encode
539
+ self._enc_conv_num = count_conv3d(self.encoder)
540
+ self._enc_conv_idx = [0]
541
+ self._enc_feat_map = [None] * self._enc_conv_num
542
+
543
+
544
+ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
545
+ """
546
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
547
+ """
548
+ # params
549
+ cfg = dict(
550
+ dim=96,
551
+ z_dim=z_dim,
552
+ dim_mult=[1, 2, 4, 4],
553
+ num_res_blocks=2,
554
+ attn_scales=[],
555
+ temperal_downsample=[False, True, True],
556
+ dropout=0.0,
557
+ )
558
+ cfg.update(**kwargs)
559
+
560
+ # init model
561
+ with torch.device("meta"):
562
+ model = WanVAE_(**cfg)
563
+
564
+ # load checkpoint
565
+ logging.info(f"loading {pretrained_path}")
566
+ model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
567
+
568
+ return model
569
+
570
+
571
+ class WanVAE:
572
+ def __init__(self, vae_pth="cache/vae_step_411000.pth", z_dim=16):
573
+
574
+ mean = [
575
+ -0.7571,
576
+ -0.7089,
577
+ -0.9113,
578
+ 0.1075,
579
+ -0.1745,
580
+ 0.9653,
581
+ -0.1517,
582
+ 1.5508,
583
+ 0.4134,
584
+ -0.0715,
585
+ 0.5517,
586
+ -0.3632,
587
+ -0.1922,
588
+ -0.9497,
589
+ 0.2503,
590
+ -0.2921,
591
+ ]
592
+ std = [
593
+ 2.8184,
594
+ 1.4541,
595
+ 2.3275,
596
+ 2.6558,
597
+ 1.2196,
598
+ 1.7708,
599
+ 2.6052,
600
+ 2.0743,
601
+ 3.2687,
602
+ 2.1526,
603
+ 2.8652,
604
+ 1.5579,
605
+ 1.6382,
606
+ 1.1253,
607
+ 2.8251,
608
+ 1.9160,
609
+ ]
610
+ self.vae_stride = (4, 8, 8)
611
+ self.mean = torch.tensor(mean)
612
+ self.std = torch.tensor(std)
613
+ self.scale = [self.mean, 1.0 / self.std]
614
+
615
+ # init model
616
+ self.vae = (
617
+ _video_vae(
618
+ pretrained_path=vae_pth,
619
+ z_dim=z_dim,
620
+ )
621
+ .eval()
622
+ .requires_grad_(False)
623
+ )
624
+
625
+ def encode(self, video):
626
+ """
627
+ videos: A list of videos each with shape [C, T, H, W].
628
+ """
629
+ return self.vae.encode(video, self.scale).float()
630
+
631
+ def to(self, *args, **kwargs):
632
+ self.mean = self.mean.to(*args, **kwargs)
633
+ self.std = self.std.to(*args, **kwargs)
634
+ self.scale = [self.mean, 1.0 / self.std]
635
+ self.vae = self.vae.to(*args, **kwargs)
636
+ return self
637
+
638
+ def decode(self, z):
639
+ return self.vae.decode(z, self.scale).float().clamp_(-1, 1)
skyreels_v2_infer/modules/xlm_roberta.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ["XLMRoberta", "xlm_roberta_large"]
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
12
+ assert dim % num_heads == 0
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.num_heads = num_heads
16
+ self.head_dim = dim // num_heads
17
+ self.eps = eps
18
+
19
+ # layers
20
+ self.q = nn.Linear(dim, dim)
21
+ self.k = nn.Linear(dim, dim)
22
+ self.v = nn.Linear(dim, dim)
23
+ self.o = nn.Linear(dim, dim)
24
+ self.dropout = nn.Dropout(dropout)
25
+
26
+ def forward(self, x, mask):
27
+ """
28
+ x: [B, L, C].
29
+ """
30
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
31
+
32
+ # compute query, key, value
33
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
34
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+
37
+ # compute attention
38
+ p = self.dropout.p if self.training else 0.0
39
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
40
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
41
+
42
+ # output
43
+ x = self.o(x)
44
+ x = self.dropout(x)
45
+ return x
46
+
47
+
48
+ class AttentionBlock(nn.Module):
49
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
50
+ super().__init__()
51
+ self.dim = dim
52
+ self.num_heads = num_heads
53
+ self.post_norm = post_norm
54
+ self.eps = eps
55
+
56
+ # layers
57
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
58
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
59
+ self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
60
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
61
+
62
+ def forward(self, x, mask):
63
+ if self.post_norm:
64
+ x = self.norm1(x + self.attn(x, mask))
65
+ x = self.norm2(x + self.ffn(x))
66
+ else:
67
+ x = x + self.attn(self.norm1(x), mask)
68
+ x = x + self.ffn(self.norm2(x))
69
+ return x
70
+
71
+
72
+ class XLMRoberta(nn.Module):
73
+ """
74
+ XLMRobertaModel with no pooler and no LM head.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_size=250002,
80
+ max_seq_len=514,
81
+ type_size=1,
82
+ pad_id=1,
83
+ dim=1024,
84
+ num_heads=16,
85
+ num_layers=24,
86
+ post_norm=True,
87
+ dropout=0.1,
88
+ eps=1e-5,
89
+ ):
90
+ super().__init__()
91
+ self.vocab_size = vocab_size
92
+ self.max_seq_len = max_seq_len
93
+ self.type_size = type_size
94
+ self.pad_id = pad_id
95
+ self.dim = dim
96
+ self.num_heads = num_heads
97
+ self.num_layers = num_layers
98
+ self.post_norm = post_norm
99
+ self.eps = eps
100
+
101
+ # embeddings
102
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
103
+ self.type_embedding = nn.Embedding(type_size, dim)
104
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
105
+ self.dropout = nn.Dropout(dropout)
106
+
107
+ # blocks
108
+ self.blocks = nn.ModuleList(
109
+ [AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)]
110
+ )
111
+
112
+ # norm layer
113
+ self.norm = nn.LayerNorm(dim, eps=eps)
114
+
115
+ def forward(self, ids):
116
+ """
117
+ ids: [B, L] of torch.LongTensor.
118
+ """
119
+ b, s = ids.shape
120
+ mask = ids.ne(self.pad_id).long()
121
+
122
+ # embeddings
123
+ x = (
124
+ self.token_embedding(ids)
125
+ + self.type_embedding(torch.zeros_like(ids))
126
+ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
127
+ )
128
+ if self.post_norm:
129
+ x = self.norm(x)
130
+ x = self.dropout(x)
131
+
132
+ # blocks
133
+ mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
134
+ for block in self.blocks:
135
+ x = block(x, mask)
136
+
137
+ # output
138
+ if not self.post_norm:
139
+ x = self.norm(x)
140
+ return x
141
+
142
+
143
+ def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
144
+ """
145
+ XLMRobertaLarge adapted from Huggingface.
146
+ """
147
+ # params
148
+ cfg = dict(
149
+ vocab_size=250002,
150
+ max_seq_len=514,
151
+ type_size=1,
152
+ pad_id=1,
153
+ dim=1024,
154
+ num_heads=16,
155
+ num_layers=24,
156
+ post_norm=True,
157
+ dropout=0.1,
158
+ eps=1e-5,
159
+ )
160
+ cfg.update(**kwargs)
161
+
162
+ # init a model on device
163
+ with torch.device(device):
164
+ model = XLMRoberta(**cfg)
165
+ return model
skyreels_v2_infer/pipelines/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .diffusion_forcing_pipeline import DiffusionForcingPipeline
2
+ from .image2video_pipeline import Image2VideoPipeline
3
+ from .image2video_pipeline import resizecrop
4
+ from .prompt_enhancer import PromptEnhancer
5
+ from .text2video_pipeline import Text2VideoPipeline
skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from typing import List
4
+ from typing import Optional
5
+ from typing import Tuple
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers.image_processor import PipelineImageInput
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+ from diffusers.video_processor import VideoProcessor
13
+ from tqdm import tqdm
14
+ import decord
15
+ from decord import VideoReader
16
+
17
+ from ..modules import get_text_encoder
18
+ from ..modules import get_transformer
19
+ from ..modules import get_vae
20
+ from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
21
+
22
+
23
+
24
+
25
+ class DiffusionForcingPipeline:
26
+ """
27
+ A pipeline for diffusion-based video generation tasks.
28
+
29
+ This pipeline supports two main tasks:
30
+ - Image-to-Video (i2v): Generates a video sequence from a source image
31
+ - Text-to-Video (t2v): Generates a video sequence from a text description
32
+
33
+ The pipeline integrates multiple components including:
34
+ - A transformer model for diffusion
35
+ - A VAE for encoding/decoding
36
+ - A text encoder for processing text prompts
37
+ - An image encoder for processing image inputs (i2v mode only)
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ model_path: str,
43
+ dit_path: str,
44
+ device: str = "cuda",
45
+ weight_dtype=torch.bfloat16,
46
+ use_usp=False,
47
+ offload=False,
48
+ ):
49
+ """
50
+ Initialize the diffusion forcing pipeline class
51
+
52
+ Args:
53
+ model_path (str): Path to the model
54
+ dit_path (str): Path to the DIT model, containing model configuration file (config.json) and weight file (*.safetensor)
55
+ device (str): Device to run on, defaults to 'cuda'
56
+ weight_dtype: Weight data type, defaults to torch.bfloat16
57
+ """
58
+ load_device = "cpu" if offload else device
59
+ self.transformer = get_transformer(dit_path, load_device, weight_dtype)
60
+ vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
61
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
62
+ self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
63
+ self.video_processor = VideoProcessor(vae_scale_factor=16)
64
+ self.device = device
65
+ self.offload = offload
66
+
67
+ if use_usp:
68
+ from xfuser.core.distributed import get_sequence_parallel_world_size
69
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
70
+ import types
71
+
72
+ for block in self.transformer.blocks:
73
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
74
+ self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
75
+ self.sp_size = get_sequence_parallel_world_size()
76
+
77
+ self.scheduler = FlowUniPCMultistepScheduler()
78
+
79
+ @property
80
+ def do_classifier_free_guidance(self) -> bool:
81
+ return self._guidance_scale > 1
82
+
83
+ def encode_image(
84
+ self, image: PipelineImageInput, height: int, width: int, num_frames: int
85
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
86
+
87
+ # prefix_video
88
+ prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1)
89
+ prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
90
+ if prefix_video.dtype == torch.uint8:
91
+ prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
92
+ prefix_video = prefix_video.to(self.device)
93
+ prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
94
+ causal_block_size = self.transformer.num_frame_per_block
95
+ if prefix_video[0].shape[1] % causal_block_size != 0:
96
+ truncate_len = prefix_video[0].shape[1] % causal_block_size
97
+ print("the length of prefix video is truncated for the casual block size alignment.")
98
+ prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
99
+ predix_video_latent_length = prefix_video[0].shape[1]
100
+ return prefix_video, predix_video_latent_length
101
+
102
+ def prepare_latents(
103
+ self,
104
+ shape: Tuple[int],
105
+ dtype: Optional[torch.dtype] = None,
106
+ device: Optional[torch.device] = None,
107
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
108
+ ) -> torch.Tensor:
109
+ return randn_tensor(shape, generator, device=device, dtype=dtype)
110
+
111
+ def generate_timestep_matrix(
112
+ self,
113
+ num_frames,
114
+ step_template,
115
+ base_num_frames,
116
+ ar_step=5,
117
+ num_pre_ready=0,
118
+ casual_block_size=1,
119
+ shrink_interval_with_mask=False,
120
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
121
+ step_matrix, step_index = [], []
122
+ update_mask, valid_interval = [], []
123
+ num_iterations = len(step_template) + 1
124
+ num_frames_block = num_frames // casual_block_size
125
+ base_num_frames_block = base_num_frames // casual_block_size
126
+ if base_num_frames_block < num_frames_block:
127
+ infer_step_num = len(step_template)
128
+ gen_block = base_num_frames_block
129
+ min_ar_step = infer_step_num / gen_block
130
+ assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
131
+ # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
132
+ step_template = torch.cat(
133
+ [
134
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
135
+ step_template.long(),
136
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
137
+ ]
138
+ ) # to handle the counter in row works starting from 1
139
+ pre_row = torch.zeros(num_frames_block, dtype=torch.long)
140
+ if num_pre_ready > 0:
141
+ pre_row[: num_pre_ready // casual_block_size] = num_iterations
142
+
143
+ while torch.all(pre_row >= (num_iterations - 1)) == False:
144
+ new_row = torch.zeros(num_frames_block, dtype=torch.long)
145
+ for i in range(num_frames_block):
146
+ if i == 0 or pre_row[i - 1] >= (
147
+ num_iterations - 1
148
+ ): # the first frame or the last frame is completely denoised
149
+ new_row[i] = pre_row[i] + 1
150
+ else:
151
+ new_row[i] = new_row[i - 1] - ar_step
152
+ new_row = new_row.clamp(0, num_iterations)
153
+
154
+ update_mask.append(
155
+ (new_row != pre_row) & (new_row != num_iterations)
156
+ ) # False: no need to update, True: need to update
157
+ step_index.append(new_row)
158
+ step_matrix.append(step_template[new_row])
159
+ pre_row = new_row
160
+
161
+ # for long video we split into several sequences, base_num_frames is set to the model max length (for training)
162
+ terminal_flag = base_num_frames_block
163
+ if shrink_interval_with_mask:
164
+ idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
165
+ update_mask = update_mask[0]
166
+ update_mask_idx = idx_sequence[update_mask]
167
+ last_update_idx = update_mask_idx[-1].item()
168
+ terminal_flag = last_update_idx + 1
169
+ # for i in range(0, len(update_mask)):
170
+ for curr_mask in update_mask:
171
+ if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
172
+ terminal_flag += 1
173
+ valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
174
+
175
+ step_update_mask = torch.stack(update_mask, dim=0)
176
+ step_index = torch.stack(step_index, dim=0)
177
+ step_matrix = torch.stack(step_matrix, dim=0)
178
+
179
+ if casual_block_size > 1:
180
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
181
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
182
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
183
+ valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
184
+
185
+ return step_matrix, step_index, step_update_mask, valid_interval
186
+
187
+ def get_video_as_tensor(self, video_path, width, height):
188
+ """
189
+ Loads a video from the given path and returns it as a tensor with proper channel ordering.
190
+ Args:
191
+ video_path (str): Path to the video file
192
+ Returns:
193
+ torch.Tensor: Video tensor in [C, T, H, W] format (channels first)
194
+ """
195
+
196
+ # Set Decord to use CPU for video decoding
197
+ decord.bridge.set_bridge('torch')
198
+
199
+ # Load video
200
+ vr = VideoReader(video_path, width=width, height=height)
201
+ total_frames = len(vr)
202
+
203
+ # Read all frames
204
+ video_frames = vr.get_batch(list(range(total_frames)))
205
+
206
+ # Convert from [T, H, W, C] to [C, T, H, W] format
207
+ video_tensor = video_frames.permute(0, 3, 1, 2).float()
208
+
209
+ return video_tensor
210
+
211
+ @torch.no_grad()
212
+ def extend_video(
213
+ self,
214
+ prompt: Union[str, List[str]],
215
+ negative_prompt: Union[str, List[str]] = "",
216
+ prefix_video_path: List[torch.Tensor] = None,
217
+ height: int = 480,
218
+ width: int = 832,
219
+ num_frames: int = 97,
220
+ num_inference_steps: int = 50,
221
+ shift: float = 1.0,
222
+ guidance_scale: float = 5.0,
223
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
224
+ overlap_history: int = None,
225
+ addnoise_condition: int = 0,
226
+ base_num_frames: int = 97,
227
+ ar_step: int = 5,
228
+ causal_block_size: int = None,
229
+ fps: int = 24,
230
+ ):
231
+ latent_height = height // 8
232
+ latent_width = width // 8
233
+ latent_length = (num_frames - 1) // 4 + 1
234
+
235
+ self._guidance_scale = guidance_scale
236
+
237
+ i2v_extra_kwrags = {}
238
+ prefix_video = None
239
+ predix_video_latent_length = 0
240
+
241
+ self.text_encoder.to(self.device)
242
+ prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
243
+ if self.do_classifier_free_guidance:
244
+ negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype)
245
+ if self.offload:
246
+ self.text_encoder.cpu()
247
+ torch.cuda.empty_cache()
248
+
249
+ self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
250
+ init_timesteps = self.scheduler.timesteps
251
+ if causal_block_size is None:
252
+ causal_block_size = self.transformer.num_frame_per_block
253
+ fps_embeds = [fps] * prompt_embeds.shape[0]
254
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
255
+ transformer_dtype = self.transformer.dtype
256
+ # with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
257
+
258
+ prefix_video = self.get_video_as_tensor(prefix_video_path, width, height)
259
+ prefix_frame = torch.tensor(prefix_video, device=self.device)
260
+ start_video = (prefix_frame.float() / (255.0 / 2.0)) - 1.0
261
+ start_video = start_video.transpose(0, 1)
262
+
263
+ # long video generation
264
+ base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
265
+ overlap_history_frames = (overlap_history - 1) // 4 + 1
266
+ n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
267
+ print(f"n_iter:{n_iter}")
268
+ output_video = start_video.cpu()
269
+ for i in range(n_iter):
270
+ prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device)
271
+ prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
272
+ if prefix_video[0].shape[1] % causal_block_size != 0:
273
+ truncate_len = prefix_video[0].shape[1] % causal_block_size
274
+ print("the length of prefix video is truncated for the casual block size alignment.")
275
+ prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
276
+ predix_video_latent_length = prefix_video[0].shape[1]
277
+ finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
278
+ left_frame_num = latent_length - finished_frame_num
279
+ base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
280
+ if ar_step > 0 and self.transformer.enable_teacache:
281
+ num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
282
+ self.transformer.num_steps = num_steps
283
+
284
+ latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
285
+ latents = self.prepare_latents(
286
+ latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
287
+ )
288
+ latents = [latents]
289
+ if prefix_video is not None:
290
+ latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
291
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
292
+ base_num_frames_iter,
293
+ init_timesteps,
294
+ base_num_frames_iter,
295
+ ar_step,
296
+ predix_video_latent_length,
297
+ causal_block_size,
298
+ )
299
+ sample_schedulers = []
300
+ for _ in range(base_num_frames_iter):
301
+ sample_scheduler = FlowUniPCMultistepScheduler(
302
+ num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
303
+ )
304
+ sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
305
+ sample_schedulers.append(sample_scheduler)
306
+ sample_schedulers_counter = [0] * base_num_frames_iter
307
+ self.transformer.to(self.device)
308
+ for i, timestep_i in enumerate(tqdm(step_matrix)):
309
+ update_mask_i = step_update_mask[i]
310
+ valid_interval_i = valid_interval[i]
311
+ valid_interval_start, valid_interval_end = valid_interval_i
312
+ timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
313
+ latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
314
+ if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
315
+ noise_factor = 0.001 * addnoise_condition
316
+ timestep_for_noised_condition = addnoise_condition
317
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
318
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
319
+ * (1.0 - noise_factor)
320
+ + torch.randn_like(
321
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
322
+ )
323
+ * noise_factor
324
+ )
325
+ timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
326
+ if not self.do_classifier_free_guidance:
327
+ noise_pred = self.transformer(
328
+ torch.stack([latent_model_input[0]]),
329
+ t=timestep,
330
+ context=prompt_embeds,
331
+ fps=fps_embeds,
332
+ **i2v_extra_kwrags,
333
+ )[0]
334
+ else:
335
+ noise_pred_cond = self.transformer(
336
+ torch.stack([latent_model_input[0]]),
337
+ t=timestep,
338
+ context=prompt_embeds,
339
+ fps=fps_embeds,
340
+ **i2v_extra_kwrags,
341
+ )[0]
342
+ noise_pred_uncond = self.transformer(
343
+ torch.stack([latent_model_input[0]]),
344
+ t=timestep,
345
+ context=negative_prompt_embeds,
346
+ fps=fps_embeds,
347
+ **i2v_extra_kwrags,
348
+ )[0]
349
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
350
+ for idx in range(valid_interval_start, valid_interval_end):
351
+ if update_mask_i[idx].item():
352
+ latents[0][:, idx] = sample_schedulers[idx].step(
353
+ noise_pred[:, idx - valid_interval_start],
354
+ timestep_i[idx],
355
+ latents[0][:, idx],
356
+ return_dict=False,
357
+ generator=generator,
358
+ )[0]
359
+ sample_schedulers_counter[idx] += 1
360
+ if self.offload:
361
+ self.transformer.cpu()
362
+ torch.cuda.empty_cache()
363
+ x0 = latents[0].unsqueeze(0)
364
+ videos = [self.vae.decode(x0)[0]]
365
+ if output_video is None:
366
+ output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
367
+ else:
368
+ output_video = torch.cat(
369
+ [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
370
+ ) # c, f, h, w
371
+ output_video = [(output_video / 2 + 0.5).clamp(0, 1)]
372
+ output_video = [video for video in output_video]
373
+ output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video]
374
+ output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video]
375
+
376
+ return output_video
377
+
378
+
379
+ @torch.no_grad()
380
+ def __call__(
381
+ self,
382
+ prompt: Union[str, List[str]],
383
+ negative_prompt: Union[str, List[str]] = "",
384
+ image: PipelineImageInput = None,
385
+ end_image: PipelineImageInput = None,
386
+ height: int = 480,
387
+ width: int = 832,
388
+ num_frames: int = 97,
389
+ num_inference_steps: int = 50,
390
+ shift: float = 1.0,
391
+ guidance_scale: float = 5.0,
392
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
393
+ overlap_history: int = None,
394
+ addnoise_condition: int = 0,
395
+ base_num_frames: int = 97,
396
+ ar_step: int = 5,
397
+ causal_block_size: int = None,
398
+ fps: int = 24,
399
+ ):
400
+ latent_height = height // 8
401
+ latent_width = width // 8
402
+ latent_length = (num_frames - 1) // 4 + 1
403
+
404
+ self._guidance_scale = guidance_scale
405
+
406
+ i2v_extra_kwrags = {}
407
+ prefix_video = None
408
+ predix_video_latent_length = 0
409
+ end_video = None
410
+ end_video_latent_length = 0
411
+
412
+ if image:
413
+ prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames)
414
+
415
+ if end_image:
416
+ end_video, end_video_latent_length = self.encode_image(end_image, height, width, num_frames)
417
+
418
+ self.text_encoder.to(self.device)
419
+ prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
420
+ if self.do_classifier_free_guidance:
421
+ negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype)
422
+ if self.offload:
423
+ self.text_encoder.cpu()
424
+ torch.cuda.empty_cache()
425
+
426
+ self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
427
+ init_timesteps = self.scheduler.timesteps
428
+ if causal_block_size is None:
429
+ causal_block_size = self.transformer.num_frame_per_block
430
+ fps_embeds = [fps] * prompt_embeds.shape[0]
431
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
432
+ transformer_dtype = self.transformer.dtype
433
+ # with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
434
+ if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames:
435
+ # short video generation
436
+ latent_shape = [16, latent_length, latent_height, latent_width]
437
+ latents = self.prepare_latents(
438
+ latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
439
+ )
440
+ latents = [latents]
441
+ if prefix_video is not None:
442
+ latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
443
+
444
+ if end_video is not None:
445
+ latents[0] = torch.cat([latents[0], end_video[0].to(transformer_dtype)], dim=1)
446
+
447
+ base_num_frames = num_frames
448
+ base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
449
+ if end_video is not None:
450
+ base_num_frames += end_video_latent_length
451
+ latent_length += end_video_latent_length
452
+
453
+
454
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
455
+ latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size
456
+ )
457
+ if end_video is not None:
458
+ step_matrix[:, -end_video_latent_length:] = 0
459
+ step_update_mask[:, -end_video_latent_length:] = False
460
+
461
+ sample_schedulers = []
462
+ for _ in range(latent_length):
463
+ sample_scheduler = FlowUniPCMultistepScheduler(
464
+ num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
465
+ )
466
+ sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
467
+ sample_schedulers.append(sample_scheduler)
468
+ sample_schedulers_counter = [0] * latent_length
469
+ self.transformer.to(self.device)
470
+ for i, timestep_i in enumerate(tqdm(step_matrix)):
471
+ update_mask_i = step_update_mask[i]
472
+ valid_interval_i = valid_interval[i]
473
+ valid_interval_start, valid_interval_end = valid_interval_i
474
+ timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
475
+ latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
476
+ if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
477
+ noise_factor = 0.001 * addnoise_condition
478
+ timestep_for_noised_condition = addnoise_condition
479
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
480
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor)
481
+ + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length])
482
+ * noise_factor
483
+ )
484
+ timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
485
+ if not self.do_classifier_free_guidance:
486
+ noise_pred = self.transformer(
487
+ torch.stack([latent_model_input[0]]),
488
+ t=timestep,
489
+ context=prompt_embeds,
490
+ fps=fps_embeds,
491
+ **i2v_extra_kwrags,
492
+ )[0]
493
+ else:
494
+ noise_pred_cond = self.transformer(
495
+ torch.stack([latent_model_input[0]]),
496
+ t=timestep,
497
+ context=prompt_embeds,
498
+ fps=fps_embeds,
499
+ **i2v_extra_kwrags,
500
+ )[0]
501
+ noise_pred_uncond = self.transformer(
502
+ torch.stack([latent_model_input[0]]),
503
+ t=timestep,
504
+ context=negative_prompt_embeds,
505
+ fps=fps_embeds,
506
+ **i2v_extra_kwrags,
507
+ )[0]
508
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
509
+ for idx in range(valid_interval_start, valid_interval_end):
510
+ if update_mask_i[idx].item():
511
+ latents[0][:, idx] = sample_schedulers[idx].step(
512
+ noise_pred[:, idx - valid_interval_start],
513
+ timestep_i[idx],
514
+ latents[0][:, idx],
515
+ return_dict=False,
516
+ generator=generator,
517
+ )[0]
518
+ sample_schedulers_counter[idx] += 1
519
+ if self.offload:
520
+ self.transformer.cpu()
521
+ torch.cuda.empty_cache()
522
+ x0 = latents[0].unsqueeze(0)
523
+ if end_video is not None:
524
+ x0 = latents[0][:, :-end_video_latent_length].unsqueeze(0)
525
+
526
+ videos = self.vae.decode(x0)
527
+ videos = (videos / 2 + 0.5).clamp(0, 1)
528
+ videos = [video for video in videos]
529
+ videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
530
+ videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
531
+ return videos
532
+ else:
533
+ # long video generation
534
+ base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
535
+ overlap_history_frames = (overlap_history - 1) // 4 + 1
536
+ n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
537
+ print(f"n_iter:{n_iter}")
538
+ output_video = None
539
+ for i in range(n_iter):
540
+ if output_video is not None: # i !=0
541
+ prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device)
542
+ prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
543
+ if prefix_video[0].shape[1] % causal_block_size != 0:
544
+ truncate_len = prefix_video[0].shape[1] % causal_block_size
545
+ print("the length of prefix video is truncated for the casual block size alignment.")
546
+ prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
547
+ predix_video_latent_length = prefix_video[0].shape[1]
548
+ finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
549
+ left_frame_num = latent_length - finished_frame_num
550
+ base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
551
+ if ar_step > 0 and self.transformer.enable_teacache:
552
+ num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
553
+ self.transformer.num_steps = num_steps
554
+ else: # i == 0
555
+ base_num_frames_iter = base_num_frames
556
+ latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
557
+ latents = self.prepare_latents(
558
+ latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
559
+ )
560
+ latents = [latents]
561
+ if prefix_video is not None:
562
+ latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
563
+
564
+ if end_video is not None and i == n_iter - 1:
565
+ base_num_frames_iter += end_video_latent_length
566
+ latents[0] = torch.cat([latents[0], end_video[0].to(transformer_dtype)], dim=1)
567
+
568
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
569
+ base_num_frames_iter,
570
+ init_timesteps,
571
+ base_num_frames_iter,
572
+ ar_step,
573
+ predix_video_latent_length,
574
+ causal_block_size,
575
+ )
576
+ if end_video is not None and i == n_iter - 1:
577
+ step_matrix[:, -end_video_latent_length:] = 0
578
+ step_update_mask[:, -end_video_latent_length:] = False
579
+
580
+ sample_schedulers = []
581
+ for _ in range(base_num_frames_iter):
582
+ sample_scheduler = FlowUniPCMultistepScheduler(
583
+ num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
584
+ )
585
+ sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
586
+ sample_schedulers.append(sample_scheduler)
587
+ sample_schedulers_counter = [0] * base_num_frames_iter
588
+ self.transformer.to(self.device)
589
+ for i, timestep_i in enumerate(tqdm(step_matrix)):
590
+ update_mask_i = step_update_mask[i]
591
+ valid_interval_i = valid_interval[i]
592
+ valid_interval_start, valid_interval_end = valid_interval_i
593
+ timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
594
+ latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
595
+ if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
596
+ noise_factor = 0.001 * addnoise_condition
597
+ timestep_for_noised_condition = addnoise_condition
598
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
599
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
600
+ * (1.0 - noise_factor)
601
+ + torch.randn_like(
602
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
603
+ )
604
+ * noise_factor
605
+ )
606
+ timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
607
+ if not self.do_classifier_free_guidance:
608
+ noise_pred = self.transformer(
609
+ torch.stack([latent_model_input[0]]),
610
+ t=timestep,
611
+ context=prompt_embeds,
612
+ fps=fps_embeds,
613
+ **i2v_extra_kwrags,
614
+ )[0]
615
+ else:
616
+ noise_pred_cond = self.transformer(
617
+ torch.stack([latent_model_input[0]]),
618
+ t=timestep,
619
+ context=prompt_embeds,
620
+ fps=fps_embeds,
621
+ **i2v_extra_kwrags,
622
+ )[0]
623
+ noise_pred_uncond = self.transformer(
624
+ torch.stack([latent_model_input[0]]),
625
+ t=timestep,
626
+ context=negative_prompt_embeds,
627
+ fps=fps_embeds,
628
+ **i2v_extra_kwrags,
629
+ )[0]
630
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
631
+ for idx in range(valid_interval_start, valid_interval_end):
632
+ if update_mask_i[idx].item():
633
+ latents[0][:, idx] = sample_schedulers[idx].step(
634
+ noise_pred[:, idx - valid_interval_start],
635
+ timestep_i[idx],
636
+ latents[0][:, idx],
637
+ return_dict=False,
638
+ generator=generator,
639
+ )[0]
640
+ sample_schedulers_counter[idx] += 1
641
+ if self.offload:
642
+ self.transformer.cpu()
643
+ torch.cuda.empty_cache()
644
+ x0 = latents[0].unsqueeze(0)
645
+ if end_video is not None and i == n_iter - 1:
646
+ x0 = latents[0][:, :-end_video_latent_length].unsqueeze(0)
647
+
648
+ videos = [self.vae.decode(x0)[0]]
649
+ if output_video is None:
650
+ output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
651
+ else:
652
+ output_video = torch.cat(
653
+ [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
654
+ ) # c, f, h, w
655
+ output_video = [(output_video / 2 + 0.5).clamp(0, 1)]
656
+ output_video = [video for video in output_video]
657
+ output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video]
658
+ output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video]
659
+ return output_video
skyreels_v2_infer/pipelines/image2video_pipeline.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.image_processor import PipelineImageInput
9
+ from diffusers.video_processor import VideoProcessor
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+ from ..modules import get_image_encoder
14
+ from ..modules import get_text_encoder
15
+ from ..modules import get_transformer
16
+ from ..modules import get_vae
17
+ from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
18
+
19
+
20
+ def resizecrop(image: Image.Image, th, tw):
21
+ w, h = image.size
22
+ if w == tw and h == th:
23
+ return image
24
+ if h / w > th / tw:
25
+ new_w = int(w)
26
+ new_h = int(new_w * th / tw)
27
+ else:
28
+ new_h = int(h)
29
+ new_w = int(new_h * tw / th)
30
+ left = (w - new_w) / 2
31
+ top = (h - new_h) / 2
32
+ right = (w + new_w) / 2
33
+ bottom = (h + new_h) / 2
34
+ image = image.crop((left, top, right, bottom))
35
+ return image
36
+
37
+
38
+ class Image2VideoPipeline:
39
+ def __init__(
40
+ self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
41
+ ):
42
+ load_device = "cpu" if offload else device
43
+ self.transformer = get_transformer(dit_path, load_device, weight_dtype)
44
+ vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
45
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
46
+ self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
47
+ self.clip = get_image_encoder(model_path, load_device, weight_dtype)
48
+ self.sp_size = 1
49
+ self.device = device
50
+ self.offload = offload
51
+ self.video_processor = VideoProcessor(vae_scale_factor=16)
52
+ if use_usp:
53
+ from xfuser.core.distributed import get_sequence_parallel_world_size
54
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
55
+ import types
56
+
57
+ for block in self.transformer.blocks:
58
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
59
+ self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
60
+ self.sp_size = get_sequence_parallel_world_size()
61
+
62
+ self.scheduler = FlowUniPCMultistepScheduler()
63
+ self.vae_stride = (4, 8, 8)
64
+ self.patch_size = (1, 2, 2)
65
+
66
+ @torch.no_grad()
67
+ def __call__(
68
+ self,
69
+ image: PipelineImageInput,
70
+ prompt: Union[str, List[str]] = None,
71
+ negative_prompt: Union[str, List[str]] = None,
72
+ height: int = 544,
73
+ width: int = 960,
74
+ num_frames: int = 97,
75
+ num_inference_steps: int = 50,
76
+ guidance_scale: float = 5.0,
77
+ shift: float = 5.0,
78
+ generator: Optional[torch.Generator] = None,
79
+ ):
80
+ F = num_frames
81
+
82
+ latent_height = height // 8 // 2 * 2
83
+ latent_width = width // 8 // 2 * 2
84
+ latent_length = (F - 1) // 4 + 1
85
+
86
+ h = latent_height * 8
87
+ w = latent_width * 8
88
+
89
+ img = self.video_processor.preprocess(image, height=h, width=w)
90
+
91
+ img = img.to(device=self.device, dtype=self.transformer.dtype)
92
+
93
+ padding_video = torch.zeros(img.shape[0], 3, F - 1, h, w, device=self.device)
94
+
95
+ img = img.unsqueeze(2)
96
+ img_cond = torch.concat([img, padding_video], dim=2)
97
+ img_cond = self.vae.encode(img_cond)
98
+ mask = torch.ones_like(img_cond)
99
+ mask[:, :, 1:] = 0
100
+ y = torch.cat([mask[:, :4], img_cond], dim=1)
101
+ self.clip.to(self.device)
102
+ clip_context = self.clip.encode_video(img)
103
+ if self.offload:
104
+ self.clip.cpu()
105
+ torch.cuda.empty_cache()
106
+
107
+ # preprocess
108
+ self.text_encoder.to(self.device)
109
+ context = self.text_encoder.encode(prompt).to(self.device)
110
+ context_null = self.text_encoder.encode(negative_prompt).to(self.device)
111
+ if self.offload:
112
+ self.text_encoder.cpu()
113
+ torch.cuda.empty_cache()
114
+
115
+ latent = torch.randn(
116
+ 16, latent_length, latent_height, latent_width, dtype=torch.float32, generator=generator, device=self.device
117
+ )
118
+
119
+ self.transformer.to(self.device)
120
+ with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
121
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
122
+ timesteps = self.scheduler.timesteps
123
+
124
+ arg_c = {
125
+ "context": context,
126
+ "clip_fea": clip_context,
127
+ "y": y,
128
+ }
129
+
130
+ arg_null = {
131
+ "context": context_null,
132
+ "clip_fea": clip_context,
133
+ "y": y,
134
+ }
135
+
136
+ self.transformer.to(self.device)
137
+ for _, t in enumerate(tqdm(timesteps)):
138
+ latent_model_input = torch.stack([latent]).to(self.device)
139
+ timestep = torch.stack([t]).to(self.device)
140
+ noise_pred_cond = self.transformer(latent_model_input, t=timestep, **arg_c)[0].to(self.device)
141
+ noise_pred_uncond = self.transformer(latent_model_input, t=timestep, **arg_null)[0].to(self.device)
142
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
143
+
144
+ temp_x0 = self.scheduler.step(
145
+ noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator
146
+ )[0]
147
+ latent = temp_x0.squeeze(0)
148
+ if self.offload:
149
+ self.transformer.cpu()
150
+ torch.cuda.empty_cache()
151
+ videos = self.vae.decode(latent)
152
+ videos = (videos / 2 + 0.5).clamp(0, 1)
153
+ videos = [video for video in videos]
154
+ videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
155
+ videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
156
+ return videos
skyreels_v2_infer/pipelines/prompt_enhancer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ sys_prompt = """
5
+ Transform the short prompt into a detailed video-generation caption using this structure:
6
+ ​​Opening shot type​​ (long/medium/close-up/extreme close-up/full shot)
7
+ ​​Primary subject(s)​​ with vivid attributes (colors, textures, actions, interactions)
8
+ ​​Dynamic elements​​ (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...')
9
+ ​​Scene composition​​ (background, environment, spatial relationships)
10
+ ​​Lighting/atmosphere​​ (natural/artificial, time of day, mood)
11
+ ​​Camera motion​​ (zooms, pans, static/handheld shots) if applicable.
12
+
13
+ Pattern Summary from Examples:
14
+ [Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement]
15
+
16
+ ​One case:
17
+ Short prompt: a person is playing football
18
+ Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan.
19
+
20
+ Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic.
21
+
22
+ ​​Now expand this short prompt:​​ [{}]. Please only output the final long prompt in English.
23
+ """
24
+
25
+ class PromptEnhancer:
26
+ def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct"):
27
+ self.model = AutoModelForCausalLM.from_pretrained(
28
+ model_name,
29
+ torch_dtype="auto",
30
+ device_map="cuda:0",
31
+ )
32
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+
34
+ def __call__(self, prompt):
35
+ prompt = prompt.strip()
36
+ prompt = sys_prompt.format(prompt)
37
+ messages = [
38
+ {"role": "system", "content": "You are a helpful assistant."},
39
+ {"role": "user", "content": prompt}
40
+ ]
41
+ text = self.tokenizer.apply_chat_template(
42
+ messages,
43
+ tokenize=False,
44
+ add_generation_prompt=True
45
+ )
46
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
47
+ generated_ids = self.model.generate(
48
+ **model_inputs,
49
+ max_new_tokens=2048,
50
+ )
51
+ generated_ids = [
52
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
53
+ ]
54
+ rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
55
+ return rewritten_prompt
56
+
57
+ if __name__ == '__main__':
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument("--prompt", type=str, default="In a still frame, a stop sign")
60
+ args = parser.parse_args()
61
+
62
+ prompt_enhancer = PromptEnhancer()
63
+ enhanced_prompt = prompt_enhancer(args.prompt)
64
+ print(f'Original prompt: {args.prompt}')
65
+ print(f'Enhanced prompt: {enhanced_prompt}')
skyreels_v2_infer/pipelines/text2video_pipeline.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.video_processor import VideoProcessor
9
+ from tqdm import tqdm
10
+
11
+ from ..modules import get_text_encoder
12
+ from ..modules import get_transformer
13
+ from ..modules import get_vae
14
+ from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
15
+
16
+
17
+ class Text2VideoPipeline:
18
+ def __init__(
19
+ self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
20
+ ):
21
+ load_device = "cpu" if offload else device
22
+ self.transformer = get_transformer(dit_path, load_device, weight_dtype)
23
+ vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
24
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
25
+ self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
26
+ self.video_processor = VideoProcessor(vae_scale_factor=16)
27
+ self.sp_size = 1
28
+ self.device = device
29
+ self.offload = offload
30
+ if use_usp:
31
+ from xfuser.core.distributed import get_sequence_parallel_world_size
32
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
33
+ import types
34
+
35
+ for block in self.transformer.blocks:
36
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
37
+ self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
38
+ self.sp_size = get_sequence_parallel_world_size()
39
+
40
+ self.scheduler = FlowUniPCMultistepScheduler()
41
+ self.vae_stride = (4, 8, 8)
42
+ self.patch_size = (1, 2, 2)
43
+
44
+ @torch.no_grad()
45
+ def __call__(
46
+ self,
47
+ prompt: Union[str, List[str]] = None,
48
+ negative_prompt: Union[str, List[str]] = None,
49
+ width: int = 544,
50
+ height: int = 960,
51
+ num_frames: int = 97,
52
+ num_inference_steps: int = 50,
53
+ guidance_scale: float = 5.0,
54
+ shift: float = 5.0,
55
+ generator: Optional[torch.Generator] = None,
56
+ ):
57
+ # preprocess
58
+ F = num_frames
59
+ target_shape = (
60
+ self.vae.vae.z_dim,
61
+ (F - 1) // self.vae_stride[0] + 1,
62
+ height // self.vae_stride[1],
63
+ width // self.vae_stride[2],
64
+ )
65
+ self.text_encoder.to(self.device)
66
+ context = self.text_encoder.encode(prompt).to(self.device)
67
+ context_null = self.text_encoder.encode(negative_prompt).to(self.device)
68
+ if self.offload:
69
+ self.text_encoder.cpu()
70
+ torch.cuda.empty_cache()
71
+
72
+ latents = [
73
+ torch.randn(
74
+ target_shape[0],
75
+ target_shape[1],
76
+ target_shape[2],
77
+ target_shape[3],
78
+ dtype=torch.float32,
79
+ device=self.device,
80
+ generator=generator,
81
+ )
82
+ ]
83
+
84
+ # evaluation mode
85
+ self.transformer.to(self.device)
86
+ with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
87
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
88
+ timesteps = self.scheduler.timesteps
89
+
90
+ for _, t in enumerate(tqdm(timesteps)):
91
+ latent_model_input = torch.stack(latents)
92
+ timestep = torch.stack([t])
93
+ noise_pred_cond = self.transformer(latent_model_input, t=timestep, context=context)[0]
94
+ noise_pred_uncond = self.transformer(latent_model_input, t=timestep, context=context_null)[0]
95
+
96
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
97
+
98
+ temp_x0 = self.scheduler.step(
99
+ noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=generator
100
+ )[0]
101
+ latents = [temp_x0.squeeze(0)]
102
+ if self.offload:
103
+ self.transformer.cpu()
104
+ torch.cuda.empty_cache()
105
+ videos = self.vae.decode(latents[0])
106
+ videos = (videos / 2 + 0.5).clamp(0, 1)
107
+ videos = [video for video in videos]
108
+ videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
109
+ videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
110
+ return videos
skyreels_v2_infer/scheduler/__init__.py ADDED
File without changes
skyreels_v2_infer/scheduler/fm_solvers_unipc.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2
+ # Convert unipc for flow matching
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+ import math
5
+ from typing import List
6
+ from typing import Optional
7
+ from typing import Tuple
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from diffusers.configuration_utils import ConfigMixin
13
+ from diffusers.configuration_utils import register_to_config
14
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
15
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
16
+ from diffusers.schedulers.scheduling_utils import SchedulerOutput
17
+ from diffusers.utils import deprecate
18
+
19
+
20
+ class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
21
+ """
22
+ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
23
+
24
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
25
+ methods the library implements for all schedulers such as loading and saving.
26
+
27
+ Args:
28
+ num_train_timesteps (`int`, defaults to 1000):
29
+ The number of diffusion steps to train the model.
30
+ solver_order (`int`, default `2`):
31
+ The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
32
+ due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
33
+ unconditional sampling.
34
+ prediction_type (`str`, defaults to "flow_prediction"):
35
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
36
+ the flow of the diffusion process.
37
+ thresholding (`bool`, defaults to `False`):
38
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
39
+ as Stable Diffusion.
40
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
41
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
42
+ sample_max_value (`float`, defaults to 1.0):
43
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
44
+ predict_x0 (`bool`, defaults to `True`):
45
+ Whether to use the updating algorithm on the predicted x0.
46
+ solver_type (`str`, default `bh2`):
47
+ Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
48
+ otherwise.
49
+ lower_order_final (`bool`, default `True`):
50
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
51
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
52
+ disable_corrector (`list`, default `[]`):
53
+ Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
54
+ and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
55
+ usually disabled during the first few steps.
56
+ solver_p (`SchedulerMixin`, default `None`):
57
+ Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
58
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
59
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
60
+ the sigmas are determined according to a sequence of noise levels {σi}.
61
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
62
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
63
+ timestep_spacing (`str`, defaults to `"linspace"`):
64
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
65
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
66
+ steps_offset (`int`, defaults to 0):
67
+ An offset added to the inference steps, as required by some model families.
68
+ final_sigmas_type (`str`, defaults to `"zero"`):
69
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
70
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
71
+ """
72
+
73
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
74
+ order = 1
75
+
76
+ @register_to_config
77
+ def __init__(
78
+ self,
79
+ num_train_timesteps: int = 1000,
80
+ solver_order: int = 2,
81
+ prediction_type: str = "flow_prediction",
82
+ shift: Optional[float] = 1.0,
83
+ use_dynamic_shifting=False,
84
+ thresholding: bool = False,
85
+ dynamic_thresholding_ratio: float = 0.995,
86
+ sample_max_value: float = 1.0,
87
+ predict_x0: bool = True,
88
+ solver_type: str = "bh2",
89
+ lower_order_final: bool = True,
90
+ disable_corrector: List[int] = [],
91
+ solver_p: SchedulerMixin = None,
92
+ timestep_spacing: str = "linspace",
93
+ steps_offset: int = 0,
94
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
95
+ ):
96
+
97
+ if solver_type not in ["bh1", "bh2"]:
98
+ if solver_type in ["midpoint", "heun", "logrho"]:
99
+ self.register_to_config(solver_type="bh2")
100
+ else:
101
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
102
+
103
+ self.predict_x0 = predict_x0
104
+ # setable values
105
+ self.num_inference_steps = None
106
+ alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
107
+ sigmas = 1.0 - alphas
108
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
109
+
110
+ if not use_dynamic_shifting:
111
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
112
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
113
+
114
+ self.sigmas = sigmas
115
+ self.timesteps = sigmas * num_train_timesteps
116
+
117
+ self.model_outputs = [None] * solver_order
118
+ self.timestep_list = [None] * solver_order
119
+ self.lower_order_nums = 0
120
+ self.disable_corrector = disable_corrector
121
+ self.solver_p = solver_p
122
+ self.last_sample = None
123
+ self._step_index = None
124
+ self._begin_index = None
125
+
126
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
127
+ self.sigma_min = self.sigmas[-1].item()
128
+ self.sigma_max = self.sigmas[0].item()
129
+
130
+ @property
131
+ def step_index(self):
132
+ """
133
+ The index counter for current timestep. It will increase 1 after each scheduler step.
134
+ """
135
+ return self._step_index
136
+
137
+ @property
138
+ def begin_index(self):
139
+ """
140
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
141
+ """
142
+ return self._begin_index
143
+
144
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
145
+ def set_begin_index(self, begin_index: int = 0):
146
+ """
147
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
148
+
149
+ Args:
150
+ begin_index (`int`):
151
+ The begin index for the scheduler.
152
+ """
153
+ self._begin_index = begin_index
154
+
155
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
156
+ def set_timesteps(
157
+ self,
158
+ num_inference_steps: Union[int, None] = None,
159
+ device: Union[str, torch.device] = None,
160
+ sigmas: Optional[List[float]] = None,
161
+ mu: Optional[Union[float, None]] = None,
162
+ shift: Optional[Union[float, None]] = None,
163
+ ):
164
+ """
165
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
166
+ Args:
167
+ num_inference_steps (`int`):
168
+ Total number of the spacing of the time steps.
169
+ device (`str` or `torch.device`, *optional*):
170
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
171
+ """
172
+
173
+ if self.config.use_dynamic_shifting and mu is None:
174
+ raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
175
+
176
+ if sigmas is None:
177
+ sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore
178
+
179
+ if self.config.use_dynamic_shifting:
180
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
181
+ else:
182
+ if shift is None:
183
+ shift = self.config.shift
184
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
185
+
186
+ if self.config.final_sigmas_type == "sigma_min":
187
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
188
+ elif self.config.final_sigmas_type == "zero":
189
+ sigma_last = 0
190
+ else:
191
+ raise ValueError(
192
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
193
+ )
194
+
195
+ timesteps = sigmas * self.config.num_train_timesteps
196
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
197
+
198
+ self.sigmas = torch.from_numpy(sigmas)
199
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
200
+
201
+ self.num_inference_steps = len(timesteps)
202
+
203
+ self.model_outputs = [
204
+ None,
205
+ ] * self.config.solver_order
206
+ self.lower_order_nums = 0
207
+ self.last_sample = None
208
+ if self.solver_p:
209
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
210
+
211
+ # add an index counter for schedulers that allow duplicated timesteps
212
+ self._step_index = None
213
+ self._begin_index = None
214
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
215
+
216
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
217
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
218
+ """
219
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
220
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
221
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
222
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
223
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
224
+
225
+ https://arxiv.org/abs/2205.11487
226
+ """
227
+ dtype = sample.dtype
228
+ batch_size, channels, *remaining_dims = sample.shape
229
+
230
+ if dtype not in (torch.float32, torch.float64):
231
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
232
+
233
+ # Flatten sample for doing quantile calculation along each image
234
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
235
+
236
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
237
+
238
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
239
+ s = torch.clamp(
240
+ s, min=1, max=self.config.sample_max_value
241
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
242
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
243
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
244
+
245
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
246
+ sample = sample.to(dtype)
247
+
248
+ return sample
249
+
250
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
251
+ def _sigma_to_t(self, sigma):
252
+ return sigma * self.config.num_train_timesteps
253
+
254
+ def _sigma_to_alpha_sigma_t(self, sigma):
255
+ return 1 - sigma, sigma
256
+
257
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
258
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
259
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
260
+
261
+ def convert_model_output(
262
+ self,
263
+ model_output: torch.Tensor,
264
+ *args,
265
+ sample: torch.Tensor = None,
266
+ **kwargs,
267
+ ) -> torch.Tensor:
268
+ r"""
269
+ Convert the model output to the corresponding type the UniPC algorithm needs.
270
+
271
+ Args:
272
+ model_output (`torch.Tensor`):
273
+ The direct output from the learned diffusion model.
274
+ timestep (`int`):
275
+ The current discrete timestep in the diffusion chain.
276
+ sample (`torch.Tensor`):
277
+ A current instance of a sample created by the diffusion process.
278
+
279
+ Returns:
280
+ `torch.Tensor`:
281
+ The converted model output.
282
+ """
283
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
284
+ if sample is None:
285
+ if len(args) > 1:
286
+ sample = args[1]
287
+ else:
288
+ raise ValueError("missing `sample` as a required keyward argument")
289
+ if timestep is not None:
290
+ deprecate(
291
+ "timesteps",
292
+ "1.0.0",
293
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
294
+ )
295
+
296
+ sigma = self.sigmas[self.step_index]
297
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
298
+
299
+ if self.predict_x0:
300
+ if self.config.prediction_type == "flow_prediction":
301
+ sigma_t = self.sigmas[self.step_index]
302
+ x0_pred = sample - sigma_t * model_output
303
+ else:
304
+ raise ValueError(
305
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
306
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
307
+ )
308
+
309
+ if self.config.thresholding:
310
+ x0_pred = self._threshold_sample(x0_pred)
311
+
312
+ return x0_pred
313
+ else:
314
+ if self.config.prediction_type == "flow_prediction":
315
+ sigma_t = self.sigmas[self.step_index]
316
+ epsilon = sample - (1 - sigma_t) * model_output
317
+ else:
318
+ raise ValueError(
319
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
320
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
321
+ )
322
+
323
+ if self.config.thresholding:
324
+ sigma_t = self.sigmas[self.step_index]
325
+ x0_pred = sample - sigma_t * model_output
326
+ x0_pred = self._threshold_sample(x0_pred)
327
+ epsilon = model_output + x0_pred
328
+
329
+ return epsilon
330
+
331
+ def multistep_uni_p_bh_update(
332
+ self,
333
+ model_output: torch.Tensor,
334
+ *args,
335
+ sample: torch.Tensor = None,
336
+ order: int = None, # pyright: ignore
337
+ **kwargs,
338
+ ) -> torch.Tensor:
339
+ """
340
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
341
+
342
+ Args:
343
+ model_output (`torch.Tensor`):
344
+ The direct output from the learned diffusion model at the current timestep.
345
+ prev_timestep (`int`):
346
+ The previous discrete timestep in the diffusion chain.
347
+ sample (`torch.Tensor`):
348
+ A current instance of a sample created by the diffusion process.
349
+ order (`int`):
350
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
351
+
352
+ Returns:
353
+ `torch.Tensor`:
354
+ The sample tensor at the previous timestep.
355
+ """
356
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
357
+ if sample is None:
358
+ if len(args) > 1:
359
+ sample = args[1]
360
+ else:
361
+ raise ValueError(" missing `sample` as a required keyward argument")
362
+ if order is None:
363
+ if len(args) > 2:
364
+ order = args[2]
365
+ else:
366
+ raise ValueError(" missing `order` as a required keyward argument")
367
+ if prev_timestep is not None:
368
+ deprecate(
369
+ "prev_timestep",
370
+ "1.0.0",
371
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
372
+ )
373
+ model_output_list = self.model_outputs
374
+
375
+ s0 = self.timestep_list[-1]
376
+ m0 = model_output_list[-1]
377
+ x = sample
378
+
379
+ if self.solver_p:
380
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
381
+ return x_t
382
+
383
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore
384
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
385
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
386
+
387
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
388
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
389
+
390
+ h = lambda_t - lambda_s0
391
+ device = sample.device
392
+
393
+ rks = []
394
+ D1s = []
395
+ for i in range(1, order):
396
+ si = self.step_index - i # pyright: ignore
397
+ mi = model_output_list[-(i + 1)]
398
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
399
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
400
+ rk = (lambda_si - lambda_s0) / h
401
+ rks.append(rk)
402
+ D1s.append((mi - m0) / rk) # pyright: ignore
403
+
404
+ rks.append(1.0)
405
+ rks = torch.tensor(rks, device=device)
406
+
407
+ R = []
408
+ b = []
409
+
410
+ hh = -h if self.predict_x0 else h
411
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
412
+ h_phi_k = h_phi_1 / hh - 1
413
+
414
+ factorial_i = 1
415
+
416
+ if self.config.solver_type == "bh1":
417
+ B_h = hh
418
+ elif self.config.solver_type == "bh2":
419
+ B_h = torch.expm1(hh)
420
+ else:
421
+ raise NotImplementedError()
422
+
423
+ for i in range(1, order + 1):
424
+ R.append(torch.pow(rks, i - 1))
425
+ b.append(h_phi_k * factorial_i / B_h)
426
+ factorial_i *= i + 1
427
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
428
+
429
+ R = torch.stack(R)
430
+ b = torch.tensor(b, device=device)
431
+
432
+ if len(D1s) > 0:
433
+ D1s = torch.stack(D1s, dim=1) # (B, K)
434
+ # for order 2, we use a simplified version
435
+ if order == 2:
436
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
437
+ else:
438
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
439
+ else:
440
+ D1s = None
441
+
442
+ if self.predict_x0:
443
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
444
+ if D1s is not None:
445
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
446
+ else:
447
+ pred_res = 0
448
+ x_t = x_t_ - alpha_t * B_h * pred_res
449
+ else:
450
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
451
+ if D1s is not None:
452
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
453
+ else:
454
+ pred_res = 0
455
+ x_t = x_t_ - sigma_t * B_h * pred_res
456
+
457
+ x_t = x_t.to(x.dtype)
458
+ return x_t
459
+
460
+ def multistep_uni_c_bh_update(
461
+ self,
462
+ this_model_output: torch.Tensor,
463
+ *args,
464
+ last_sample: torch.Tensor = None,
465
+ this_sample: torch.Tensor = None,
466
+ order: int = None, # pyright: ignore
467
+ **kwargs,
468
+ ) -> torch.Tensor:
469
+ """
470
+ One step for the UniC (B(h) version).
471
+
472
+ Args:
473
+ this_model_output (`torch.Tensor`):
474
+ The model outputs at `x_t`.
475
+ this_timestep (`int`):
476
+ The current timestep `t`.
477
+ last_sample (`torch.Tensor`):
478
+ The generated sample before the last predictor `x_{t-1}`.
479
+ this_sample (`torch.Tensor`):
480
+ The generated sample after the last predictor `x_{t}`.
481
+ order (`int`):
482
+ The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
483
+
484
+ Returns:
485
+ `torch.Tensor`:
486
+ The corrected sample tensor at the current timestep.
487
+ """
488
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
489
+ if last_sample is None:
490
+ if len(args) > 1:
491
+ last_sample = args[1]
492
+ else:
493
+ raise ValueError(" missing`last_sample` as a required keyward argument")
494
+ if this_sample is None:
495
+ if len(args) > 2:
496
+ this_sample = args[2]
497
+ else:
498
+ raise ValueError(" missing`this_sample` as a required keyward argument")
499
+ if order is None:
500
+ if len(args) > 3:
501
+ order = args[3]
502
+ else:
503
+ raise ValueError(" missing`order` as a required keyward argument")
504
+ if this_timestep is not None:
505
+ deprecate(
506
+ "this_timestep",
507
+ "1.0.0",
508
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
509
+ )
510
+
511
+ model_output_list = self.model_outputs
512
+
513
+ m0 = model_output_list[-1]
514
+ x = last_sample
515
+ x_t = this_sample
516
+ model_t = this_model_output
517
+
518
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore
519
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
520
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
521
+
522
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
523
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
524
+
525
+ h = lambda_t - lambda_s0
526
+ device = this_sample.device
527
+
528
+ rks = []
529
+ D1s = []
530
+ for i in range(1, order):
531
+ si = self.step_index - (i + 1) # pyright: ignore
532
+ mi = model_output_list[-(i + 1)]
533
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
534
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
535
+ rk = (lambda_si - lambda_s0) / h
536
+ rks.append(rk)
537
+ D1s.append((mi - m0) / rk) # pyright: ignore
538
+
539
+ rks.append(1.0)
540
+ rks = torch.tensor(rks, device=device)
541
+
542
+ R = []
543
+ b = []
544
+
545
+ hh = -h if self.predict_x0 else h
546
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
547
+ h_phi_k = h_phi_1 / hh - 1
548
+
549
+ factorial_i = 1
550
+
551
+ if self.config.solver_type == "bh1":
552
+ B_h = hh
553
+ elif self.config.solver_type == "bh2":
554
+ B_h = torch.expm1(hh)
555
+ else:
556
+ raise NotImplementedError()
557
+
558
+ for i in range(1, order + 1):
559
+ R.append(torch.pow(rks, i - 1))
560
+ b.append(h_phi_k * factorial_i / B_h)
561
+ factorial_i *= i + 1
562
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
563
+
564
+ R = torch.stack(R)
565
+ b = torch.tensor(b, device=device)
566
+
567
+ if len(D1s) > 0:
568
+ D1s = torch.stack(D1s, dim=1)
569
+ else:
570
+ D1s = None
571
+
572
+ # for order 1, we use a simplified version
573
+ if order == 1:
574
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
575
+ else:
576
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
577
+
578
+ if self.predict_x0:
579
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
580
+ if D1s is not None:
581
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
582
+ else:
583
+ corr_res = 0
584
+ D1_t = model_t - m0
585
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
586
+ else:
587
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
588
+ if D1s is not None:
589
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
590
+ else:
591
+ corr_res = 0
592
+ D1_t = model_t - m0
593
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
594
+ x_t = x_t.to(x.dtype)
595
+ return x_t
596
+
597
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
598
+ if schedule_timesteps is None:
599
+ schedule_timesteps = self.timesteps
600
+
601
+ indices = (schedule_timesteps == timestep).nonzero()
602
+
603
+ # The sigma index that is taken for the **very** first `step`
604
+ # is always the second index (or the last index if there is only 1)
605
+ # This way we can ensure we don't accidentally skip a sigma in
606
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
607
+ pos = 1 if len(indices) > 1 else 0
608
+
609
+ return indices[pos].item()
610
+
611
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
612
+ def _init_step_index(self, timestep):
613
+ """
614
+ Initialize the step_index counter for the scheduler.
615
+ """
616
+
617
+ if self.begin_index is None:
618
+ if isinstance(timestep, torch.Tensor):
619
+ timestep = timestep.to(self.timesteps.device)
620
+ self._step_index = self.index_for_timestep(timestep)
621
+ else:
622
+ self._step_index = self._begin_index
623
+
624
+ def step(
625
+ self,
626
+ model_output: torch.Tensor,
627
+ timestep: Union[int, torch.Tensor],
628
+ sample: torch.Tensor,
629
+ return_dict: bool = True,
630
+ generator=None,
631
+ ) -> Union[SchedulerOutput, Tuple]:
632
+ """
633
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
634
+ the multistep UniPC.
635
+
636
+ Args:
637
+ model_output (`torch.Tensor`):
638
+ The direct output from learned diffusion model.
639
+ timestep (`int`):
640
+ The current discrete timestep in the diffusion chain.
641
+ sample (`torch.Tensor`):
642
+ A current instance of a sample created by the diffusion process.
643
+ return_dict (`bool`):
644
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
645
+
646
+ Returns:
647
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
648
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
649
+ tuple is returned where the first element is the sample tensor.
650
+
651
+ """
652
+ if self.num_inference_steps is None:
653
+ raise ValueError(
654
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
655
+ )
656
+
657
+ if self.step_index is None:
658
+ self._init_step_index(timestep)
659
+
660
+ use_corrector = (
661
+ self.step_index > 0
662
+ and self.step_index - 1 not in self.disable_corrector
663
+ and self.last_sample is not None # pyright: ignore
664
+ )
665
+
666
+ model_output_convert = self.convert_model_output(model_output, sample=sample)
667
+ if use_corrector:
668
+ sample = self.multistep_uni_c_bh_update(
669
+ this_model_output=model_output_convert,
670
+ last_sample=self.last_sample,
671
+ this_sample=sample,
672
+ order=self.this_order,
673
+ )
674
+
675
+ for i in range(self.config.solver_order - 1):
676
+ self.model_outputs[i] = self.model_outputs[i + 1]
677
+ self.timestep_list[i] = self.timestep_list[i + 1]
678
+
679
+ self.model_outputs[-1] = model_output_convert
680
+ self.timestep_list[-1] = timestep # pyright: ignore
681
+
682
+ if self.config.lower_order_final:
683
+ this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
684
+ else:
685
+ this_order = self.config.solver_order
686
+
687
+ self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
688
+ assert self.this_order > 0
689
+
690
+ self.last_sample = sample
691
+ prev_sample = self.multistep_uni_p_bh_update(
692
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
693
+ sample=sample,
694
+ order=self.this_order,
695
+ )
696
+
697
+ if self.lower_order_nums < self.config.solver_order:
698
+ self.lower_order_nums += 1
699
+
700
+ # upon completion increase step index by one
701
+ self._step_index += 1 # pyright: ignore
702
+
703
+ if not return_dict:
704
+ return (prev_sample,)
705
+
706
+ return SchedulerOutput(prev_sample=prev_sample)
707
+
708
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
709
+ """
710
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
711
+ current timestep.
712
+
713
+ Args:
714
+ sample (`torch.Tensor`):
715
+ The input sample.
716
+
717
+ Returns:
718
+ `torch.Tensor`:
719
+ A scaled input sample.
720
+ """
721
+ return sample
722
+
723
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
724
+ def add_noise(
725
+ self,
726
+ original_samples: torch.Tensor,
727
+ noise: torch.Tensor,
728
+ timesteps: torch.IntTensor,
729
+ ) -> torch.Tensor:
730
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
731
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
732
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
733
+ # mps does not support float64
734
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
735
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
736
+ else:
737
+ schedule_timesteps = self.timesteps.to(original_samples.device)
738
+ timesteps = timesteps.to(original_samples.device)
739
+
740
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
741
+ if self.begin_index is None:
742
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
743
+ elif self.step_index is not None:
744
+ # add_noise is called after first denoising step (for inpainting)
745
+ step_indices = [self.step_index] * timesteps.shape[0]
746
+ else:
747
+ # add noise is called before first denoising step to create initial latent(img2img)
748
+ step_indices = [self.begin_index] * timesteps.shape[0]
749
+
750
+ sigma = sigmas[step_indices].flatten()
751
+ while len(sigma.shape) < len(original_samples.shape):
752
+ sigma = sigma.unsqueeze(-1)
753
+
754
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
755
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
756
+ return noisy_samples
757
+
758
+ def __len__(self):
759
+ return self.config.num_train_timesteps