Spaces:
Running
on
Zero
Running
on
Zero
Upload 18 files
Browse files- skyreels_v2_infer/__init__.py +1 -0
- skyreels_v2_infer/distributed/__init__.py +0 -0
- skyreels_v2_infer/distributed/xdit_context_parallel.py +286 -0
- skyreels_v2_infer/modules/__init__.py +69 -0
- skyreels_v2_infer/modules/attention.py +179 -0
- skyreels_v2_infer/modules/clip.py +525 -0
- skyreels_v2_infer/modules/t5.py +454 -0
- skyreels_v2_infer/modules/tokenizers.py +78 -0
- skyreels_v2_infer/modules/transformer.py +839 -0
- skyreels_v2_infer/modules/vae.py +639 -0
- skyreels_v2_infer/modules/xlm_roberta.py +165 -0
- skyreels_v2_infer/pipelines/__init__.py +5 -0
- skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py +659 -0
- skyreels_v2_infer/pipelines/image2video_pipeline.py +156 -0
- skyreels_v2_infer/pipelines/prompt_enhancer.py +65 -0
- skyreels_v2_infer/pipelines/text2video_pipeline.py +110 -0
- skyreels_v2_infer/scheduler/__init__.py +0 -0
- skyreels_v2_infer/scheduler/fm_solvers_unipc.py +759 -0
skyreels_v2_infer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .pipelines import DiffusionForcingPipeline
|
skyreels_v2_infer/distributed/__init__.py
ADDED
File without changes
|
skyreels_v2_infer/distributed/xdit_context_parallel.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.amp as amp
|
4 |
+
from torch.backends.cuda import sdp_kernel
|
5 |
+
from xfuser.core.distributed import get_sequence_parallel_rank
|
6 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
7 |
+
from xfuser.core.distributed import get_sp_group
|
8 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
9 |
+
|
10 |
+
from ..modules.transformer import sinusoidal_embedding_1d
|
11 |
+
|
12 |
+
|
13 |
+
def pad_freqs(original_tensor, target_len):
|
14 |
+
seq_len, s1, s2 = original_tensor.shape
|
15 |
+
pad_size = target_len - seq_len
|
16 |
+
padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
|
17 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
18 |
+
return padded_tensor
|
19 |
+
|
20 |
+
|
21 |
+
@amp.autocast("cuda", enabled=False)
|
22 |
+
def rope_apply(x, grid_sizes, freqs):
|
23 |
+
"""
|
24 |
+
x: [B, L, N, C].
|
25 |
+
grid_sizes: [B, 3].
|
26 |
+
freqs: [M, C // 2].
|
27 |
+
"""
|
28 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
29 |
+
# split freqs
|
30 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
31 |
+
|
32 |
+
# loop over samples
|
33 |
+
output = []
|
34 |
+
grid = [grid_sizes.tolist()] * x.size(0)
|
35 |
+
for i, (f, h, w) in enumerate(grid):
|
36 |
+
seq_len = f * h * w
|
37 |
+
|
38 |
+
# precompute multipliers
|
39 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
|
40 |
+
freqs_i = torch.cat(
|
41 |
+
[
|
42 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
43 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
44 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
45 |
+
],
|
46 |
+
dim=-1,
|
47 |
+
).reshape(seq_len, 1, -1)
|
48 |
+
|
49 |
+
# apply rotary embedding
|
50 |
+
sp_size = get_sequence_parallel_world_size()
|
51 |
+
sp_rank = get_sequence_parallel_rank()
|
52 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
53 |
+
s_per_rank = s
|
54 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :]
|
55 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank.cuda()).flatten(2)
|
56 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
57 |
+
|
58 |
+
# append to collection
|
59 |
+
output.append(x_i)
|
60 |
+
return torch.stack(output).float()
|
61 |
+
|
62 |
+
|
63 |
+
def broadcast_should_calc(should_calc: bool) -> bool:
|
64 |
+
import torch.distributed as dist
|
65 |
+
|
66 |
+
device = torch.cuda.current_device()
|
67 |
+
int_should_calc = 1 if should_calc else 0
|
68 |
+
tensor = torch.tensor([int_should_calc], device=device, dtype=torch.int8)
|
69 |
+
dist.broadcast(tensor, src=0)
|
70 |
+
should_calc = tensor.item() == 1
|
71 |
+
return should_calc
|
72 |
+
|
73 |
+
|
74 |
+
def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
|
75 |
+
"""
|
76 |
+
x: A list of videos each with shape [C, T, H, W].
|
77 |
+
t: [B].
|
78 |
+
context: A list of text embeddings each with shape [L, C].
|
79 |
+
"""
|
80 |
+
if self.model_type == "i2v":
|
81 |
+
assert clip_fea is not None and y is not None
|
82 |
+
# params
|
83 |
+
device = self.patch_embedding.weight.device
|
84 |
+
if self.freqs.device != device:
|
85 |
+
self.freqs = self.freqs.to(device)
|
86 |
+
|
87 |
+
if y is not None:
|
88 |
+
x = torch.cat([x, y], dim=1)
|
89 |
+
|
90 |
+
# embeddings
|
91 |
+
x = self.patch_embedding(x)
|
92 |
+
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
|
93 |
+
x = x.flatten(2).transpose(1, 2)
|
94 |
+
|
95 |
+
if self.flag_causal_attention:
|
96 |
+
frame_num = grid_sizes[0]
|
97 |
+
height = grid_sizes[1]
|
98 |
+
width = grid_sizes[2]
|
99 |
+
block_num = frame_num // self.num_frame_per_block
|
100 |
+
range_tensor = torch.arange(block_num).view(-1, 1)
|
101 |
+
range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
|
102 |
+
casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
|
103 |
+
casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
|
104 |
+
casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
|
105 |
+
casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
|
106 |
+
self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
|
107 |
+
|
108 |
+
# time embeddings
|
109 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
110 |
+
if t.dim() == 2:
|
111 |
+
b, f = t.shape
|
112 |
+
_flag_df = True
|
113 |
+
else:
|
114 |
+
_flag_df = False
|
115 |
+
e = self.time_embedding(
|
116 |
+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
|
117 |
+
) # b, dim
|
118 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
|
119 |
+
|
120 |
+
if self.inject_sample_info:
|
121 |
+
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
122 |
+
|
123 |
+
fps_emb = self.fps_embedding(fps).float()
|
124 |
+
if _flag_df:
|
125 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
126 |
+
else:
|
127 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
|
128 |
+
|
129 |
+
if _flag_df:
|
130 |
+
e = e.view(b, f, 1, 1, self.dim)
|
131 |
+
e0 = e0.view(b, f, 1, 1, 6, self.dim)
|
132 |
+
e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
|
133 |
+
e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
|
134 |
+
e0 = e0.transpose(1, 2).contiguous()
|
135 |
+
|
136 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
137 |
+
|
138 |
+
# context
|
139 |
+
context = self.text_embedding(context)
|
140 |
+
|
141 |
+
if clip_fea is not None:
|
142 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
143 |
+
context = torch.concat([context_clip, context], dim=1)
|
144 |
+
|
145 |
+
# arguments
|
146 |
+
if e0.ndim == 4:
|
147 |
+
e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()]
|
148 |
+
kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
|
149 |
+
|
150 |
+
if self.enable_teacache:
|
151 |
+
modulated_inp = e0 if self.use_ref_steps else e
|
152 |
+
# teacache
|
153 |
+
if self.cnt % 2 == 0: # even -> conditon
|
154 |
+
self.is_even = True
|
155 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
156 |
+
should_calc_even = True
|
157 |
+
self.accumulated_rel_l1_distance_even = 0
|
158 |
+
else:
|
159 |
+
rescale_func = np.poly1d(self.coefficients)
|
160 |
+
self.accumulated_rel_l1_distance_even += rescale_func(
|
161 |
+
((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean())
|
162 |
+
.cpu()
|
163 |
+
.item()
|
164 |
+
)
|
165 |
+
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
|
166 |
+
should_calc_even = False
|
167 |
+
else:
|
168 |
+
should_calc_even = True
|
169 |
+
self.accumulated_rel_l1_distance_even = 0
|
170 |
+
self.previous_e0_even = modulated_inp.clone()
|
171 |
+
else: # odd -> unconditon
|
172 |
+
self.is_even = False
|
173 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
174 |
+
should_calc_odd = True
|
175 |
+
self.accumulated_rel_l1_distance_odd = 0
|
176 |
+
else:
|
177 |
+
rescale_func = np.poly1d(self.coefficients)
|
178 |
+
self.accumulated_rel_l1_distance_odd += rescale_func(
|
179 |
+
((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean())
|
180 |
+
.cpu()
|
181 |
+
.item()
|
182 |
+
)
|
183 |
+
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
|
184 |
+
should_calc_odd = False
|
185 |
+
else:
|
186 |
+
should_calc_odd = True
|
187 |
+
self.accumulated_rel_l1_distance_odd = 0
|
188 |
+
self.previous_e0_odd = modulated_inp.clone()
|
189 |
+
|
190 |
+
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
191 |
+
if self.enable_teacache:
|
192 |
+
if self.is_even:
|
193 |
+
should_calc_even = broadcast_should_calc(should_calc_even)
|
194 |
+
if not should_calc_even:
|
195 |
+
x += self.previous_residual_even
|
196 |
+
else:
|
197 |
+
ori_x = x.clone()
|
198 |
+
for block in self.blocks:
|
199 |
+
x = block(x, **kwargs)
|
200 |
+
ori_x.mul_(-1)
|
201 |
+
ori_x.add_(x)
|
202 |
+
self.previous_residual_even = ori_x
|
203 |
+
else:
|
204 |
+
should_calc_odd = broadcast_should_calc(should_calc_odd)
|
205 |
+
if not should_calc_odd:
|
206 |
+
x += self.previous_residual_odd
|
207 |
+
else:
|
208 |
+
ori_x = x.clone()
|
209 |
+
for block in self.blocks:
|
210 |
+
x = block(x, **kwargs)
|
211 |
+
ori_x.mul_(-1)
|
212 |
+
ori_x.add_(x)
|
213 |
+
self.previous_residual_odd = ori_x
|
214 |
+
self.cnt += 1
|
215 |
+
if self.cnt >= self.num_steps:
|
216 |
+
self.cnt = 0
|
217 |
+
else:
|
218 |
+
# Context Parallel
|
219 |
+
for block in self.blocks:
|
220 |
+
x = block(x, **kwargs)
|
221 |
+
|
222 |
+
# head
|
223 |
+
if e.ndim == 3:
|
224 |
+
e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
225 |
+
x = self.head(x, e)
|
226 |
+
# Context Parallel
|
227 |
+
x = get_sp_group().all_gather(x, dim=1)
|
228 |
+
# unpatchify
|
229 |
+
x = self.unpatchify(x, grid_sizes)
|
230 |
+
return x.float()
|
231 |
+
|
232 |
+
|
233 |
+
def usp_attn_forward(self, x, grid_sizes, freqs, block_mask):
|
234 |
+
|
235 |
+
r"""
|
236 |
+
Args:
|
237 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
238 |
+
seq_lens(Tensor): Shape [B]
|
239 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
240 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
241 |
+
"""
|
242 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
243 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
244 |
+
|
245 |
+
def half(x):
|
246 |
+
return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
|
247 |
+
|
248 |
+
# query, key, value function
|
249 |
+
def qkv_fn(x):
|
250 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
251 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
252 |
+
v = self.v(x).view(b, s, n, d)
|
253 |
+
return q, k, v
|
254 |
+
|
255 |
+
x = x.to(self.q.weight.dtype)
|
256 |
+
q, k, v = qkv_fn(x)
|
257 |
+
|
258 |
+
if not self._flag_ar_attention:
|
259 |
+
q = rope_apply(q, grid_sizes, freqs)
|
260 |
+
k = rope_apply(k, grid_sizes, freqs)
|
261 |
+
else:
|
262 |
+
|
263 |
+
q = rope_apply(q, grid_sizes, freqs)
|
264 |
+
k = rope_apply(k, grid_sizes, freqs)
|
265 |
+
q = q.to(torch.bfloat16)
|
266 |
+
k = k.to(torch.bfloat16)
|
267 |
+
v = v.to(torch.bfloat16)
|
268 |
+
# x = torch.nn.functional.scaled_dot_product_attention(
|
269 |
+
# q.transpose(1, 2),
|
270 |
+
# k.transpose(1, 2),
|
271 |
+
# v.transpose(1, 2),
|
272 |
+
# ).transpose(1, 2).contiguous()
|
273 |
+
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
274 |
+
x = (
|
275 |
+
torch.nn.functional.scaled_dot_product_attention(
|
276 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
|
277 |
+
)
|
278 |
+
.transpose(1, 2)
|
279 |
+
.contiguous()
|
280 |
+
)
|
281 |
+
x = xFuserLongContextAttention()(None, query=half(q), key=half(k), value=half(v), window_size=self.window_size)
|
282 |
+
|
283 |
+
# output
|
284 |
+
x = x.flatten(2)
|
285 |
+
x = self.o(x)
|
286 |
+
return x
|
skyreels_v2_infer/modules/__init__.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from safetensors.torch import load_file
|
6 |
+
|
7 |
+
from .clip import CLIPModel
|
8 |
+
from .t5 import T5EncoderModel
|
9 |
+
from .transformer import WanModel
|
10 |
+
from .vae import WanVAE
|
11 |
+
|
12 |
+
|
13 |
+
def download_model(model_id):
|
14 |
+
if not os.path.exists(model_id):
|
15 |
+
from huggingface_hub import snapshot_download
|
16 |
+
|
17 |
+
model_id = snapshot_download(repo_id=model_id)
|
18 |
+
return model_id
|
19 |
+
|
20 |
+
|
21 |
+
def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE:
|
22 |
+
vae = WanVAE(model_path).to(device).to(weight_dtype)
|
23 |
+
vae.vae.requires_grad_(False)
|
24 |
+
vae.vae.eval()
|
25 |
+
gc.collect()
|
26 |
+
torch.cuda.empty_cache()
|
27 |
+
return vae
|
28 |
+
|
29 |
+
|
30 |
+
def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel:
|
31 |
+
config_path = os.path.join(model_path, "config.json")
|
32 |
+
transformer = WanModel.from_config(config_path).to(weight_dtype).to(device)
|
33 |
+
|
34 |
+
for file in os.listdir(model_path):
|
35 |
+
if file.endswith(".safetensors"):
|
36 |
+
file_path = os.path.join(model_path, file)
|
37 |
+
state_dict = load_file(file_path)
|
38 |
+
transformer.load_state_dict(state_dict, strict=False)
|
39 |
+
del state_dict
|
40 |
+
gc.collect()
|
41 |
+
torch.cuda.empty_cache()
|
42 |
+
|
43 |
+
transformer.requires_grad_(False)
|
44 |
+
transformer.eval()
|
45 |
+
gc.collect()
|
46 |
+
torch.cuda.empty_cache()
|
47 |
+
return transformer
|
48 |
+
|
49 |
+
|
50 |
+
def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel:
|
51 |
+
t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")
|
52 |
+
tokenizer_path = os.path.join(model_path, "google", "umt5-xxl")
|
53 |
+
text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype)
|
54 |
+
text_encoder.requires_grad_(False)
|
55 |
+
text_encoder.eval()
|
56 |
+
gc.collect()
|
57 |
+
torch.cuda.empty_cache()
|
58 |
+
return text_encoder
|
59 |
+
|
60 |
+
|
61 |
+
def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel:
|
62 |
+
checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
63 |
+
tokenizer_path = os.path.join(model_path, "xlm-roberta-large")
|
64 |
+
image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device)
|
65 |
+
image_enc.requires_grad_(False)
|
66 |
+
image_enc.eval()
|
67 |
+
gc.collect()
|
68 |
+
torch.cuda.empty_cache()
|
69 |
+
return image_enc
|
skyreels_v2_infer/modules/attention.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import torch
|
3 |
+
|
4 |
+
try:
|
5 |
+
import flash_attn_interface
|
6 |
+
|
7 |
+
FLASH_ATTN_3_AVAILABLE = True
|
8 |
+
except ModuleNotFoundError:
|
9 |
+
FLASH_ATTN_3_AVAILABLE = False
|
10 |
+
|
11 |
+
try:
|
12 |
+
import flash_attn
|
13 |
+
|
14 |
+
FLASH_ATTN_2_AVAILABLE = True
|
15 |
+
except ModuleNotFoundError:
|
16 |
+
FLASH_ATTN_2_AVAILABLE = False
|
17 |
+
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
__all__ = [
|
21 |
+
"flash_attention",
|
22 |
+
"attention",
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
def flash_attention(
|
27 |
+
q,
|
28 |
+
k,
|
29 |
+
v,
|
30 |
+
q_lens=None,
|
31 |
+
k_lens=None,
|
32 |
+
dropout_p=0.0,
|
33 |
+
softmax_scale=None,
|
34 |
+
q_scale=None,
|
35 |
+
causal=False,
|
36 |
+
window_size=(-1, -1),
|
37 |
+
deterministic=False,
|
38 |
+
dtype=torch.bfloat16,
|
39 |
+
version=None,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
q: [B, Lq, Nq, C1].
|
43 |
+
k: [B, Lk, Nk, C1].
|
44 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
45 |
+
q_lens: [B].
|
46 |
+
k_lens: [B].
|
47 |
+
dropout_p: float. Dropout probability.
|
48 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
49 |
+
causal: bool. Whether to apply causal attention mask.
|
50 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
51 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
52 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
53 |
+
"""
|
54 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
55 |
+
assert dtype in half_dtypes
|
56 |
+
assert q.device.type == "cuda" and q.size(-1) <= 256
|
57 |
+
|
58 |
+
# params
|
59 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
60 |
+
|
61 |
+
def half(x):
|
62 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
63 |
+
|
64 |
+
# preprocess query
|
65 |
+
|
66 |
+
q = half(q.flatten(0, 1))
|
67 |
+
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
|
68 |
+
|
69 |
+
# preprocess key, value
|
70 |
+
|
71 |
+
k = half(k.flatten(0, 1))
|
72 |
+
v = half(v.flatten(0, 1))
|
73 |
+
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
|
74 |
+
|
75 |
+
q = q.to(v.dtype)
|
76 |
+
k = k.to(v.dtype)
|
77 |
+
|
78 |
+
if q_scale is not None:
|
79 |
+
q = q * q_scale
|
80 |
+
|
81 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
82 |
+
warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.")
|
83 |
+
|
84 |
+
torch.cuda.nvtx.range_push(f"{list(q.shape)}-{list(k.shape)}-{list(v.shape)}-{q.dtype}-{k.dtype}-{v.dtype}")
|
85 |
+
# apply attention
|
86 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
87 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
88 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
89 |
+
q=q,
|
90 |
+
k=k,
|
91 |
+
v=v,
|
92 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
93 |
+
.cumsum(0, dtype=torch.int32)
|
94 |
+
.to(q.device, non_blocking=True),
|
95 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
96 |
+
.cumsum(0, dtype=torch.int32)
|
97 |
+
.to(q.device, non_blocking=True),
|
98 |
+
seqused_q=None,
|
99 |
+
seqused_k=None,
|
100 |
+
max_seqlen_q=lq,
|
101 |
+
max_seqlen_k=lk,
|
102 |
+
softmax_scale=softmax_scale,
|
103 |
+
causal=causal,
|
104 |
+
deterministic=deterministic,
|
105 |
+
)[0].unflatten(0, (b, lq))
|
106 |
+
else:
|
107 |
+
assert FLASH_ATTN_2_AVAILABLE
|
108 |
+
x = flash_attn.flash_attn_varlen_func(
|
109 |
+
q=q,
|
110 |
+
k=k,
|
111 |
+
v=v,
|
112 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
113 |
+
.cumsum(0, dtype=torch.int32)
|
114 |
+
.to(q.device, non_blocking=True),
|
115 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
116 |
+
.cumsum(0, dtype=torch.int32)
|
117 |
+
.to(q.device, non_blocking=True),
|
118 |
+
max_seqlen_q=lq,
|
119 |
+
max_seqlen_k=lk,
|
120 |
+
dropout_p=dropout_p,
|
121 |
+
softmax_scale=softmax_scale,
|
122 |
+
causal=causal,
|
123 |
+
window_size=window_size,
|
124 |
+
deterministic=deterministic,
|
125 |
+
).unflatten(0, (b, lq))
|
126 |
+
torch.cuda.nvtx.range_pop()
|
127 |
+
|
128 |
+
# output
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
def attention(
|
133 |
+
q,
|
134 |
+
k,
|
135 |
+
v,
|
136 |
+
q_lens=None,
|
137 |
+
k_lens=None,
|
138 |
+
dropout_p=0.0,
|
139 |
+
softmax_scale=None,
|
140 |
+
q_scale=None,
|
141 |
+
causal=False,
|
142 |
+
window_size=(-1, -1),
|
143 |
+
deterministic=False,
|
144 |
+
dtype=torch.bfloat16,
|
145 |
+
fa_version=None,
|
146 |
+
):
|
147 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
148 |
+
return flash_attention(
|
149 |
+
q=q,
|
150 |
+
k=k,
|
151 |
+
v=v,
|
152 |
+
q_lens=q_lens,
|
153 |
+
k_lens=k_lens,
|
154 |
+
dropout_p=dropout_p,
|
155 |
+
softmax_scale=softmax_scale,
|
156 |
+
q_scale=q_scale,
|
157 |
+
causal=causal,
|
158 |
+
window_size=window_size,
|
159 |
+
deterministic=deterministic,
|
160 |
+
dtype=dtype,
|
161 |
+
version=fa_version,
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
if q_lens is not None or k_lens is not None:
|
165 |
+
warnings.warn(
|
166 |
+
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
|
167 |
+
)
|
168 |
+
attn_mask = None
|
169 |
+
|
170 |
+
q = q.transpose(1, 2).to(dtype)
|
171 |
+
k = k.transpose(1, 2).to(dtype)
|
172 |
+
v = v.transpose(1, 2).to(dtype)
|
173 |
+
|
174 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
175 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
|
176 |
+
)
|
177 |
+
|
178 |
+
out = out.transpose(1, 2).contiguous()
|
179 |
+
return out
|
skyreels_v2_infer/modules/clip.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
|
2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from diffusers.models import ModelMixin
|
11 |
+
|
12 |
+
from .attention import flash_attention
|
13 |
+
from .tokenizers import HuggingfaceTokenizer
|
14 |
+
from .xlm_roberta import XLMRoberta
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
"XLMRobertaCLIP",
|
18 |
+
"clip_xlm_roberta_vit_h_14",
|
19 |
+
"CLIPModel",
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
def pos_interpolate(pos, seq_len):
|
24 |
+
if pos.size(1) == seq_len:
|
25 |
+
return pos
|
26 |
+
else:
|
27 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
28 |
+
tar_grid = int(math.sqrt(seq_len))
|
29 |
+
n = pos.size(1) - src_grid * src_grid
|
30 |
+
return torch.cat(
|
31 |
+
[
|
32 |
+
pos[:, :n],
|
33 |
+
F.interpolate(
|
34 |
+
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2),
|
35 |
+
size=(tar_grid, tar_grid),
|
36 |
+
mode="bicubic",
|
37 |
+
align_corners=False,
|
38 |
+
)
|
39 |
+
.flatten(2)
|
40 |
+
.transpose(1, 2),
|
41 |
+
],
|
42 |
+
dim=1,
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
class QuickGELU(nn.Module):
|
47 |
+
def forward(self, x):
|
48 |
+
return x * torch.sigmoid(1.702 * x)
|
49 |
+
|
50 |
+
|
51 |
+
class LayerNorm(nn.LayerNorm):
|
52 |
+
def forward(self, x):
|
53 |
+
return super().forward(x.float()).type_as(x)
|
54 |
+
|
55 |
+
|
56 |
+
class SelfAttention(nn.Module):
|
57 |
+
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
|
58 |
+
assert dim % num_heads == 0
|
59 |
+
super().__init__()
|
60 |
+
self.dim = dim
|
61 |
+
self.num_heads = num_heads
|
62 |
+
self.head_dim = dim // num_heads
|
63 |
+
self.causal = causal
|
64 |
+
self.attn_dropout = attn_dropout
|
65 |
+
self.proj_dropout = proj_dropout
|
66 |
+
|
67 |
+
# layers
|
68 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
69 |
+
self.proj = nn.Linear(dim, dim)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
"""
|
73 |
+
x: [B, L, C].
|
74 |
+
"""
|
75 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
76 |
+
|
77 |
+
# compute query, key, value
|
78 |
+
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
79 |
+
|
80 |
+
# compute attention
|
81 |
+
p = self.attn_dropout if self.training else 0.0
|
82 |
+
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
83 |
+
x = x.reshape(b, s, c)
|
84 |
+
|
85 |
+
# output
|
86 |
+
x = self.proj(x)
|
87 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class SwiGLU(nn.Module):
|
92 |
+
def __init__(self, dim, mid_dim):
|
93 |
+
super().__init__()
|
94 |
+
self.dim = dim
|
95 |
+
self.mid_dim = mid_dim
|
96 |
+
|
97 |
+
# layers
|
98 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
99 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
100 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
104 |
+
x = self.fc3(x)
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class AttentionBlock(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
dim,
|
112 |
+
mlp_ratio,
|
113 |
+
num_heads,
|
114 |
+
post_norm=False,
|
115 |
+
causal=False,
|
116 |
+
activation="quick_gelu",
|
117 |
+
attn_dropout=0.0,
|
118 |
+
proj_dropout=0.0,
|
119 |
+
norm_eps=1e-5,
|
120 |
+
):
|
121 |
+
assert activation in ["quick_gelu", "gelu", "swi_glu"]
|
122 |
+
super().__init__()
|
123 |
+
self.dim = dim
|
124 |
+
self.mlp_ratio = mlp_ratio
|
125 |
+
self.num_heads = num_heads
|
126 |
+
self.post_norm = post_norm
|
127 |
+
self.causal = causal
|
128 |
+
self.norm_eps = norm_eps
|
129 |
+
|
130 |
+
# layers
|
131 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
132 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
|
133 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
134 |
+
if activation == "swi_glu":
|
135 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
136 |
+
else:
|
137 |
+
self.mlp = nn.Sequential(
|
138 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
139 |
+
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
|
140 |
+
nn.Linear(int(dim * mlp_ratio), dim),
|
141 |
+
nn.Dropout(proj_dropout),
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
if self.post_norm:
|
146 |
+
x = x + self.norm1(self.attn(x))
|
147 |
+
x = x + self.norm2(self.mlp(x))
|
148 |
+
else:
|
149 |
+
x = x + self.attn(self.norm1(x))
|
150 |
+
x = x + self.mlp(self.norm2(x))
|
151 |
+
return x
|
152 |
+
|
153 |
+
|
154 |
+
class AttentionPool(nn.Module):
|
155 |
+
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
|
156 |
+
assert dim % num_heads == 0
|
157 |
+
super().__init__()
|
158 |
+
self.dim = dim
|
159 |
+
self.mlp_ratio = mlp_ratio
|
160 |
+
self.num_heads = num_heads
|
161 |
+
self.head_dim = dim // num_heads
|
162 |
+
self.proj_dropout = proj_dropout
|
163 |
+
self.norm_eps = norm_eps
|
164 |
+
|
165 |
+
# layers
|
166 |
+
gain = 1.0 / math.sqrt(dim)
|
167 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
168 |
+
self.to_q = nn.Linear(dim, dim)
|
169 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
170 |
+
self.proj = nn.Linear(dim, dim)
|
171 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
172 |
+
self.mlp = nn.Sequential(
|
173 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
174 |
+
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
|
175 |
+
nn.Linear(int(dim * mlp_ratio), dim),
|
176 |
+
nn.Dropout(proj_dropout),
|
177 |
+
)
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
"""
|
181 |
+
x: [B, L, C].
|
182 |
+
"""
|
183 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
184 |
+
|
185 |
+
# compute query, key, value
|
186 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
187 |
+
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
188 |
+
|
189 |
+
# compute attention
|
190 |
+
x = flash_attention(q, k, v, version=2)
|
191 |
+
x = x.reshape(b, 1, c)
|
192 |
+
|
193 |
+
# output
|
194 |
+
x = self.proj(x)
|
195 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
196 |
+
|
197 |
+
# mlp
|
198 |
+
x = x + self.mlp(self.norm(x))
|
199 |
+
return x[:, 0]
|
200 |
+
|
201 |
+
|
202 |
+
class VisionTransformer(nn.Module):
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
image_size=224,
|
206 |
+
patch_size=16,
|
207 |
+
dim=768,
|
208 |
+
mlp_ratio=4,
|
209 |
+
out_dim=512,
|
210 |
+
num_heads=12,
|
211 |
+
num_layers=12,
|
212 |
+
pool_type="token",
|
213 |
+
pre_norm=True,
|
214 |
+
post_norm=False,
|
215 |
+
activation="quick_gelu",
|
216 |
+
attn_dropout=0.0,
|
217 |
+
proj_dropout=0.0,
|
218 |
+
embedding_dropout=0.0,
|
219 |
+
norm_eps=1e-5,
|
220 |
+
):
|
221 |
+
if image_size % patch_size != 0:
|
222 |
+
print("[WARNING] image_size is not divisible by patch_size", flush=True)
|
223 |
+
assert pool_type in ("token", "token_fc", "attn_pool")
|
224 |
+
out_dim = out_dim or dim
|
225 |
+
super().__init__()
|
226 |
+
self.image_size = image_size
|
227 |
+
self.patch_size = patch_size
|
228 |
+
self.num_patches = (image_size // patch_size) ** 2
|
229 |
+
self.dim = dim
|
230 |
+
self.mlp_ratio = mlp_ratio
|
231 |
+
self.out_dim = out_dim
|
232 |
+
self.num_heads = num_heads
|
233 |
+
self.num_layers = num_layers
|
234 |
+
self.pool_type = pool_type
|
235 |
+
self.post_norm = post_norm
|
236 |
+
self.norm_eps = norm_eps
|
237 |
+
|
238 |
+
# embeddings
|
239 |
+
gain = 1.0 / math.sqrt(dim)
|
240 |
+
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
|
241 |
+
if pool_type in ("token", "token_fc"):
|
242 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
243 |
+
self.pos_embedding = nn.Parameter(
|
244 |
+
gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim)
|
245 |
+
)
|
246 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
247 |
+
|
248 |
+
# transformer
|
249 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
250 |
+
self.transformer = nn.Sequential(
|
251 |
+
*[
|
252 |
+
AttentionBlock(
|
253 |
+
dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps
|
254 |
+
)
|
255 |
+
for _ in range(num_layers)
|
256 |
+
]
|
257 |
+
)
|
258 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
259 |
+
|
260 |
+
# head
|
261 |
+
if pool_type == "token":
|
262 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
263 |
+
elif pool_type == "token_fc":
|
264 |
+
self.head = nn.Linear(dim, out_dim)
|
265 |
+
elif pool_type == "attn_pool":
|
266 |
+
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
|
267 |
+
|
268 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
269 |
+
b = x.size(0)
|
270 |
+
|
271 |
+
# embeddings
|
272 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
273 |
+
if self.pool_type in ("token", "token_fc"):
|
274 |
+
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
|
275 |
+
if interpolation:
|
276 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
277 |
+
else:
|
278 |
+
e = self.pos_embedding
|
279 |
+
x = self.dropout(x + e)
|
280 |
+
if self.pre_norm is not None:
|
281 |
+
x = self.pre_norm(x)
|
282 |
+
|
283 |
+
# transformer
|
284 |
+
if use_31_block:
|
285 |
+
x = self.transformer[:-1](x)
|
286 |
+
return x
|
287 |
+
else:
|
288 |
+
x = self.transformer(x)
|
289 |
+
return x
|
290 |
+
|
291 |
+
|
292 |
+
class XLMRobertaWithHead(XLMRoberta):
|
293 |
+
def __init__(self, **kwargs):
|
294 |
+
self.out_dim = kwargs.pop("out_dim")
|
295 |
+
super().__init__(**kwargs)
|
296 |
+
|
297 |
+
# head
|
298 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
299 |
+
self.head = nn.Sequential(
|
300 |
+
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)
|
301 |
+
)
|
302 |
+
|
303 |
+
def forward(self, ids):
|
304 |
+
# xlm-roberta
|
305 |
+
x = super().forward(ids)
|
306 |
+
|
307 |
+
# average pooling
|
308 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
309 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
310 |
+
|
311 |
+
# head
|
312 |
+
x = self.head(x)
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
class XLMRobertaCLIP(nn.Module):
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
embed_dim=1024,
|
320 |
+
image_size=224,
|
321 |
+
patch_size=14,
|
322 |
+
vision_dim=1280,
|
323 |
+
vision_mlp_ratio=4,
|
324 |
+
vision_heads=16,
|
325 |
+
vision_layers=32,
|
326 |
+
vision_pool="token",
|
327 |
+
vision_pre_norm=True,
|
328 |
+
vision_post_norm=False,
|
329 |
+
activation="gelu",
|
330 |
+
vocab_size=250002,
|
331 |
+
max_text_len=514,
|
332 |
+
type_size=1,
|
333 |
+
pad_id=1,
|
334 |
+
text_dim=1024,
|
335 |
+
text_heads=16,
|
336 |
+
text_layers=24,
|
337 |
+
text_post_norm=True,
|
338 |
+
text_dropout=0.1,
|
339 |
+
attn_dropout=0.0,
|
340 |
+
proj_dropout=0.0,
|
341 |
+
embedding_dropout=0.0,
|
342 |
+
norm_eps=1e-5,
|
343 |
+
):
|
344 |
+
super().__init__()
|
345 |
+
self.embed_dim = embed_dim
|
346 |
+
self.image_size = image_size
|
347 |
+
self.patch_size = patch_size
|
348 |
+
self.vision_dim = vision_dim
|
349 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
350 |
+
self.vision_heads = vision_heads
|
351 |
+
self.vision_layers = vision_layers
|
352 |
+
self.vision_pre_norm = vision_pre_norm
|
353 |
+
self.vision_post_norm = vision_post_norm
|
354 |
+
self.activation = activation
|
355 |
+
self.vocab_size = vocab_size
|
356 |
+
self.max_text_len = max_text_len
|
357 |
+
self.type_size = type_size
|
358 |
+
self.pad_id = pad_id
|
359 |
+
self.text_dim = text_dim
|
360 |
+
self.text_heads = text_heads
|
361 |
+
self.text_layers = text_layers
|
362 |
+
self.text_post_norm = text_post_norm
|
363 |
+
self.norm_eps = norm_eps
|
364 |
+
|
365 |
+
# models
|
366 |
+
self.visual = VisionTransformer(
|
367 |
+
image_size=image_size,
|
368 |
+
patch_size=patch_size,
|
369 |
+
dim=vision_dim,
|
370 |
+
mlp_ratio=vision_mlp_ratio,
|
371 |
+
out_dim=embed_dim,
|
372 |
+
num_heads=vision_heads,
|
373 |
+
num_layers=vision_layers,
|
374 |
+
pool_type=vision_pool,
|
375 |
+
pre_norm=vision_pre_norm,
|
376 |
+
post_norm=vision_post_norm,
|
377 |
+
activation=activation,
|
378 |
+
attn_dropout=attn_dropout,
|
379 |
+
proj_dropout=proj_dropout,
|
380 |
+
embedding_dropout=embedding_dropout,
|
381 |
+
norm_eps=norm_eps,
|
382 |
+
)
|
383 |
+
self.textual = XLMRobertaWithHead(
|
384 |
+
vocab_size=vocab_size,
|
385 |
+
max_seq_len=max_text_len,
|
386 |
+
type_size=type_size,
|
387 |
+
pad_id=pad_id,
|
388 |
+
dim=text_dim,
|
389 |
+
out_dim=embed_dim,
|
390 |
+
num_heads=text_heads,
|
391 |
+
num_layers=text_layers,
|
392 |
+
post_norm=text_post_norm,
|
393 |
+
dropout=text_dropout,
|
394 |
+
)
|
395 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
396 |
+
|
397 |
+
def forward(self, imgs, txt_ids):
|
398 |
+
"""
|
399 |
+
imgs: [B, 3, H, W] of torch.float32.
|
400 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
401 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
402 |
+
txt_ids: [B, L] of torch.long.
|
403 |
+
Encoded by data.CLIPTokenizer.
|
404 |
+
"""
|
405 |
+
xi = self.visual(imgs)
|
406 |
+
xt = self.textual(txt_ids)
|
407 |
+
return xi, xt
|
408 |
+
|
409 |
+
def param_groups(self):
|
410 |
+
groups = [
|
411 |
+
{
|
412 |
+
"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")],
|
413 |
+
"weight_decay": 0.0,
|
414 |
+
},
|
415 |
+
{"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
|
416 |
+
]
|
417 |
+
return groups
|
418 |
+
|
419 |
+
|
420 |
+
def _clip(
|
421 |
+
pretrained=False,
|
422 |
+
pretrained_name=None,
|
423 |
+
model_cls=XLMRobertaCLIP,
|
424 |
+
return_transforms=False,
|
425 |
+
return_tokenizer=False,
|
426 |
+
tokenizer_padding="eos",
|
427 |
+
dtype=torch.float32,
|
428 |
+
device="cpu",
|
429 |
+
**kwargs,
|
430 |
+
):
|
431 |
+
# init a model on device
|
432 |
+
with torch.device(device):
|
433 |
+
model = model_cls(**kwargs)
|
434 |
+
|
435 |
+
# set device
|
436 |
+
model = model.to(dtype=dtype, device=device)
|
437 |
+
output = (model,)
|
438 |
+
|
439 |
+
# init transforms
|
440 |
+
if return_transforms:
|
441 |
+
# mean and std
|
442 |
+
if "siglip" in pretrained_name.lower():
|
443 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
444 |
+
else:
|
445 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
446 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
447 |
+
|
448 |
+
# transforms
|
449 |
+
transforms = T.Compose(
|
450 |
+
[
|
451 |
+
T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC),
|
452 |
+
T.ToTensor(),
|
453 |
+
T.Normalize(mean=mean, std=std),
|
454 |
+
]
|
455 |
+
)
|
456 |
+
output += (transforms,)
|
457 |
+
return output[0] if len(output) == 1 else output
|
458 |
+
|
459 |
+
|
460 |
+
def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
|
461 |
+
cfg = dict(
|
462 |
+
embed_dim=1024,
|
463 |
+
image_size=224,
|
464 |
+
patch_size=14,
|
465 |
+
vision_dim=1280,
|
466 |
+
vision_mlp_ratio=4,
|
467 |
+
vision_heads=16,
|
468 |
+
vision_layers=32,
|
469 |
+
vision_pool="token",
|
470 |
+
activation="gelu",
|
471 |
+
vocab_size=250002,
|
472 |
+
max_text_len=514,
|
473 |
+
type_size=1,
|
474 |
+
pad_id=1,
|
475 |
+
text_dim=1024,
|
476 |
+
text_heads=16,
|
477 |
+
text_layers=24,
|
478 |
+
text_post_norm=True,
|
479 |
+
text_dropout=0.1,
|
480 |
+
attn_dropout=0.0,
|
481 |
+
proj_dropout=0.0,
|
482 |
+
embedding_dropout=0.0,
|
483 |
+
)
|
484 |
+
cfg.update(**kwargs)
|
485 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
486 |
+
|
487 |
+
|
488 |
+
class CLIPModel(ModelMixin):
|
489 |
+
def __init__(self, checkpoint_path, tokenizer_path):
|
490 |
+
self.checkpoint_path = checkpoint_path
|
491 |
+
self.tokenizer_path = tokenizer_path
|
492 |
+
|
493 |
+
super().__init__()
|
494 |
+
# init model
|
495 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
496 |
+
pretrained=False, return_transforms=True, return_tokenizer=False
|
497 |
+
)
|
498 |
+
self.model = self.model.eval().requires_grad_(False)
|
499 |
+
logging.info(f"loading {checkpoint_path}")
|
500 |
+
self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
|
501 |
+
|
502 |
+
# init tokenizer
|
503 |
+
self.tokenizer = HuggingfaceTokenizer(
|
504 |
+
name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace"
|
505 |
+
)
|
506 |
+
|
507 |
+
def encode_video(self, video):
|
508 |
+
# preprocess
|
509 |
+
b, c, t, h, w = video.shape
|
510 |
+
video = video.transpose(1, 2)
|
511 |
+
video = video.reshape(b * t, c, h, w)
|
512 |
+
size = (self.model.image_size,) * 2
|
513 |
+
video = F.interpolate(
|
514 |
+
video,
|
515 |
+
size=size,
|
516 |
+
mode='bicubic',
|
517 |
+
align_corners=False)
|
518 |
+
|
519 |
+
video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5))
|
520 |
+
|
521 |
+
# forward
|
522 |
+
with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type):
|
523 |
+
out = self.model.visual(video, use_31_block=True)
|
524 |
+
|
525 |
+
return out
|
skyreels_v2_infer/modules/t5.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from transformers.models.t5.modeling_t5
|
2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from diffusers.models import ModelMixin
|
10 |
+
|
11 |
+
from .tokenizers import HuggingfaceTokenizer
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"T5Model",
|
15 |
+
"T5Encoder",
|
16 |
+
"T5Decoder",
|
17 |
+
"T5EncoderModel",
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def fp16_clamp(x):
|
22 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
23 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
24 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
def init_weights(m):
|
29 |
+
if isinstance(m, T5LayerNorm):
|
30 |
+
nn.init.ones_(m.weight)
|
31 |
+
elif isinstance(m, T5Model):
|
32 |
+
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
33 |
+
elif isinstance(m, T5FeedForward):
|
34 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
35 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
36 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
37 |
+
elif isinstance(m, T5Attention):
|
38 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
|
39 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
40 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
41 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
|
42 |
+
elif isinstance(m, T5RelativeEmbedding):
|
43 |
+
nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
|
44 |
+
|
45 |
+
|
46 |
+
class GELU(nn.Module):
|
47 |
+
def forward(self, x):
|
48 |
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
49 |
+
|
50 |
+
|
51 |
+
class T5LayerNorm(nn.Module):
|
52 |
+
def __init__(self, dim, eps=1e-6):
|
53 |
+
super(T5LayerNorm, self).__init__()
|
54 |
+
self.dim = dim
|
55 |
+
self.eps = eps
|
56 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
60 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
61 |
+
x = x.type_as(self.weight)
|
62 |
+
return self.weight * x
|
63 |
+
|
64 |
+
|
65 |
+
class T5Attention(nn.Module):
|
66 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
67 |
+
assert dim_attn % num_heads == 0
|
68 |
+
super(T5Attention, self).__init__()
|
69 |
+
self.dim = dim
|
70 |
+
self.dim_attn = dim_attn
|
71 |
+
self.num_heads = num_heads
|
72 |
+
self.head_dim = dim_attn // num_heads
|
73 |
+
|
74 |
+
# layers
|
75 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
76 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
77 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
78 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
79 |
+
self.dropout = nn.Dropout(dropout)
|
80 |
+
|
81 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
82 |
+
"""
|
83 |
+
x: [B, L1, C].
|
84 |
+
context: [B, L2, C] or None.
|
85 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
86 |
+
"""
|
87 |
+
# check inputs
|
88 |
+
context = x if context is None else context
|
89 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
90 |
+
|
91 |
+
# compute query, key, value
|
92 |
+
q = self.q(x).view(b, -1, n, c)
|
93 |
+
k = self.k(context).view(b, -1, n, c)
|
94 |
+
v = self.v(context).view(b, -1, n, c)
|
95 |
+
|
96 |
+
# attention bias
|
97 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
98 |
+
if pos_bias is not None:
|
99 |
+
attn_bias += pos_bias
|
100 |
+
if mask is not None:
|
101 |
+
assert mask.ndim in [2, 3]
|
102 |
+
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
|
103 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
104 |
+
|
105 |
+
# compute attention (T5 does not use scaling)
|
106 |
+
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
|
107 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
108 |
+
x = torch.einsum("bnij,bjnc->binc", attn, v)
|
109 |
+
|
110 |
+
# output
|
111 |
+
x = x.reshape(b, -1, n * c)
|
112 |
+
x = self.o(x)
|
113 |
+
x = self.dropout(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class T5FeedForward(nn.Module):
|
118 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
119 |
+
super(T5FeedForward, self).__init__()
|
120 |
+
self.dim = dim
|
121 |
+
self.dim_ffn = dim_ffn
|
122 |
+
|
123 |
+
# layers
|
124 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
125 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
126 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
127 |
+
self.dropout = nn.Dropout(dropout)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
x = self.fc1(x) * self.gate(x)
|
131 |
+
x = self.dropout(x)
|
132 |
+
x = self.fc2(x)
|
133 |
+
x = self.dropout(x)
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class T5SelfAttention(nn.Module):
|
138 |
+
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
|
139 |
+
super(T5SelfAttention, self).__init__()
|
140 |
+
self.dim = dim
|
141 |
+
self.dim_attn = dim_attn
|
142 |
+
self.dim_ffn = dim_ffn
|
143 |
+
self.num_heads = num_heads
|
144 |
+
self.num_buckets = num_buckets
|
145 |
+
self.shared_pos = shared_pos
|
146 |
+
|
147 |
+
# layers
|
148 |
+
self.norm1 = T5LayerNorm(dim)
|
149 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
150 |
+
self.norm2 = T5LayerNorm(dim)
|
151 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
152 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
153 |
+
|
154 |
+
def forward(self, x, mask=None, pos_bias=None):
|
155 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
156 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
157 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class T5CrossAttention(nn.Module):
|
162 |
+
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
|
163 |
+
super(T5CrossAttention, self).__init__()
|
164 |
+
self.dim = dim
|
165 |
+
self.dim_attn = dim_attn
|
166 |
+
self.dim_ffn = dim_ffn
|
167 |
+
self.num_heads = num_heads
|
168 |
+
self.num_buckets = num_buckets
|
169 |
+
self.shared_pos = shared_pos
|
170 |
+
|
171 |
+
# layers
|
172 |
+
self.norm1 = T5LayerNorm(dim)
|
173 |
+
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
174 |
+
self.norm2 = T5LayerNorm(dim)
|
175 |
+
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
176 |
+
self.norm3 = T5LayerNorm(dim)
|
177 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
178 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
|
179 |
+
|
180 |
+
def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
|
181 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
182 |
+
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
183 |
+
x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
|
184 |
+
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class T5RelativeEmbedding(nn.Module):
|
189 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
190 |
+
super(T5RelativeEmbedding, self).__init__()
|
191 |
+
self.num_buckets = num_buckets
|
192 |
+
self.num_heads = num_heads
|
193 |
+
self.bidirectional = bidirectional
|
194 |
+
self.max_dist = max_dist
|
195 |
+
|
196 |
+
# layers
|
197 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
198 |
+
|
199 |
+
def forward(self, lq, lk):
|
200 |
+
device = self.embedding.weight.device
|
201 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
202 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
203 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
|
204 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
205 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
206 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
|
207 |
+
return rel_pos_embeds.contiguous()
|
208 |
+
|
209 |
+
def _relative_position_bucket(self, rel_pos):
|
210 |
+
# preprocess
|
211 |
+
if self.bidirectional:
|
212 |
+
num_buckets = self.num_buckets // 2
|
213 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
214 |
+
rel_pos = torch.abs(rel_pos)
|
215 |
+
else:
|
216 |
+
num_buckets = self.num_buckets
|
217 |
+
rel_buckets = 0
|
218 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
219 |
+
|
220 |
+
# embeddings for small and large positions
|
221 |
+
max_exact = num_buckets // 2
|
222 |
+
rel_pos_large = (
|
223 |
+
max_exact
|
224 |
+
+ (
|
225 |
+
torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)
|
226 |
+
).long()
|
227 |
+
)
|
228 |
+
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
229 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
230 |
+
return rel_buckets
|
231 |
+
|
232 |
+
|
233 |
+
class T5Encoder(nn.Module):
|
234 |
+
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
|
235 |
+
super(T5Encoder, self).__init__()
|
236 |
+
self.dim = dim
|
237 |
+
self.dim_attn = dim_attn
|
238 |
+
self.dim_ffn = dim_ffn
|
239 |
+
self.num_heads = num_heads
|
240 |
+
self.num_layers = num_layers
|
241 |
+
self.num_buckets = num_buckets
|
242 |
+
self.shared_pos = shared_pos
|
243 |
+
|
244 |
+
# layers
|
245 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
246 |
+
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
247 |
+
self.dropout = nn.Dropout(dropout)
|
248 |
+
self.blocks = nn.ModuleList(
|
249 |
+
[
|
250 |
+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
|
251 |
+
for _ in range(num_layers)
|
252 |
+
]
|
253 |
+
)
|
254 |
+
self.norm = T5LayerNorm(dim)
|
255 |
+
|
256 |
+
# initialize weights
|
257 |
+
self.apply(init_weights)
|
258 |
+
|
259 |
+
def forward(self, ids, mask=None):
|
260 |
+
x = self.token_embedding(ids)
|
261 |
+
x = self.dropout(x)
|
262 |
+
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
263 |
+
for block in self.blocks:
|
264 |
+
x = block(x, mask, pos_bias=e)
|
265 |
+
x = self.norm(x)
|
266 |
+
x = self.dropout(x)
|
267 |
+
return x
|
268 |
+
|
269 |
+
|
270 |
+
class T5Decoder(nn.Module):
|
271 |
+
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
|
272 |
+
super(T5Decoder, self).__init__()
|
273 |
+
self.dim = dim
|
274 |
+
self.dim_attn = dim_attn
|
275 |
+
self.dim_ffn = dim_ffn
|
276 |
+
self.num_heads = num_heads
|
277 |
+
self.num_layers = num_layers
|
278 |
+
self.num_buckets = num_buckets
|
279 |
+
self.shared_pos = shared_pos
|
280 |
+
|
281 |
+
# layers
|
282 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
283 |
+
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
284 |
+
self.dropout = nn.Dropout(dropout)
|
285 |
+
self.blocks = nn.ModuleList(
|
286 |
+
[
|
287 |
+
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
|
288 |
+
for _ in range(num_layers)
|
289 |
+
]
|
290 |
+
)
|
291 |
+
self.norm = T5LayerNorm(dim)
|
292 |
+
|
293 |
+
# initialize weights
|
294 |
+
self.apply(init_weights)
|
295 |
+
|
296 |
+
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
297 |
+
b, s = ids.size()
|
298 |
+
|
299 |
+
# causal mask
|
300 |
+
if mask is None:
|
301 |
+
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
302 |
+
elif mask.ndim == 2:
|
303 |
+
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
304 |
+
|
305 |
+
# layers
|
306 |
+
x = self.token_embedding(ids)
|
307 |
+
x = self.dropout(x)
|
308 |
+
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
309 |
+
for block in self.blocks:
|
310 |
+
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
311 |
+
x = self.norm(x)
|
312 |
+
x = self.dropout(x)
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
class T5Model(nn.Module):
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
vocab_size,
|
320 |
+
dim,
|
321 |
+
dim_attn,
|
322 |
+
dim_ffn,
|
323 |
+
num_heads,
|
324 |
+
encoder_layers,
|
325 |
+
decoder_layers,
|
326 |
+
num_buckets,
|
327 |
+
shared_pos=True,
|
328 |
+
dropout=0.1,
|
329 |
+
):
|
330 |
+
super(T5Model, self).__init__()
|
331 |
+
self.vocab_size = vocab_size
|
332 |
+
self.dim = dim
|
333 |
+
self.dim_attn = dim_attn
|
334 |
+
self.dim_ffn = dim_ffn
|
335 |
+
self.num_heads = num_heads
|
336 |
+
self.encoder_layers = encoder_layers
|
337 |
+
self.decoder_layers = decoder_layers
|
338 |
+
self.num_buckets = num_buckets
|
339 |
+
|
340 |
+
# layers
|
341 |
+
self.token_embedding = nn.Embedding(vocab_size, dim)
|
342 |
+
self.encoder = T5Encoder(
|
343 |
+
self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout
|
344 |
+
)
|
345 |
+
self.decoder = T5Decoder(
|
346 |
+
self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout
|
347 |
+
)
|
348 |
+
self.head = nn.Linear(dim, vocab_size, bias=False)
|
349 |
+
|
350 |
+
# initialize weights
|
351 |
+
self.apply(init_weights)
|
352 |
+
|
353 |
+
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
354 |
+
x = self.encoder(encoder_ids, encoder_mask)
|
355 |
+
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
356 |
+
x = self.head(x)
|
357 |
+
return x
|
358 |
+
|
359 |
+
|
360 |
+
def _t5(
|
361 |
+
name,
|
362 |
+
encoder_only=False,
|
363 |
+
decoder_only=False,
|
364 |
+
return_tokenizer=False,
|
365 |
+
tokenizer_kwargs={},
|
366 |
+
dtype=torch.float32,
|
367 |
+
device="cpu",
|
368 |
+
**kwargs,
|
369 |
+
):
|
370 |
+
# sanity check
|
371 |
+
assert not (encoder_only and decoder_only)
|
372 |
+
|
373 |
+
# params
|
374 |
+
if encoder_only:
|
375 |
+
model_cls = T5Encoder
|
376 |
+
kwargs["vocab"] = kwargs.pop("vocab_size")
|
377 |
+
kwargs["num_layers"] = kwargs.pop("encoder_layers")
|
378 |
+
_ = kwargs.pop("decoder_layers")
|
379 |
+
elif decoder_only:
|
380 |
+
model_cls = T5Decoder
|
381 |
+
kwargs["vocab"] = kwargs.pop("vocab_size")
|
382 |
+
kwargs["num_layers"] = kwargs.pop("decoder_layers")
|
383 |
+
_ = kwargs.pop("encoder_layers")
|
384 |
+
else:
|
385 |
+
model_cls = T5Model
|
386 |
+
|
387 |
+
# init model
|
388 |
+
with torch.device(device):
|
389 |
+
model = model_cls(**kwargs)
|
390 |
+
|
391 |
+
# set device
|
392 |
+
model = model.to(dtype=dtype, device=device)
|
393 |
+
|
394 |
+
# init tokenizer
|
395 |
+
if return_tokenizer:
|
396 |
+
from .tokenizers import HuggingfaceTokenizer
|
397 |
+
|
398 |
+
tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
|
399 |
+
return model, tokenizer
|
400 |
+
else:
|
401 |
+
return model
|
402 |
+
|
403 |
+
|
404 |
+
def umt5_xxl(**kwargs):
|
405 |
+
cfg = dict(
|
406 |
+
vocab_size=256384,
|
407 |
+
dim=4096,
|
408 |
+
dim_attn=4096,
|
409 |
+
dim_ffn=10240,
|
410 |
+
num_heads=64,
|
411 |
+
encoder_layers=24,
|
412 |
+
decoder_layers=24,
|
413 |
+
num_buckets=32,
|
414 |
+
shared_pos=False,
|
415 |
+
dropout=0.1,
|
416 |
+
)
|
417 |
+
cfg.update(**kwargs)
|
418 |
+
return _t5("umt5-xxl", **cfg)
|
419 |
+
|
420 |
+
|
421 |
+
class T5EncoderModel(ModelMixin):
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
checkpoint_path=None,
|
425 |
+
tokenizer_path=None,
|
426 |
+
text_len=512,
|
427 |
+
shard_fn=None,
|
428 |
+
):
|
429 |
+
self.text_len = text_len
|
430 |
+
self.checkpoint_path = checkpoint_path
|
431 |
+
self.tokenizer_path = tokenizer_path
|
432 |
+
|
433 |
+
super().__init__()
|
434 |
+
# init model
|
435 |
+
model = umt5_xxl(encoder_only=True, return_tokenizer=False)
|
436 |
+
logging.info(f"loading {checkpoint_path}")
|
437 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
|
438 |
+
self.model = model
|
439 |
+
if shard_fn is not None:
|
440 |
+
self.model = shard_fn(self.model, sync_module_states=False)
|
441 |
+
else:
|
442 |
+
self.model.eval().requires_grad_(False)
|
443 |
+
# init tokenizer
|
444 |
+
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
|
445 |
+
|
446 |
+
def encode(self, texts):
|
447 |
+
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
|
448 |
+
ids = ids.to(self.device)
|
449 |
+
mask = mask.to(self.device)
|
450 |
+
# seq_lens = mask.gt(0).sum(dim=1).long()
|
451 |
+
context = self.model(ids, mask)
|
452 |
+
context = context * mask.unsqueeze(-1).cuda()
|
453 |
+
|
454 |
+
return context
|
skyreels_v2_infer/modules/tokenizers.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import html
|
3 |
+
import string
|
4 |
+
|
5 |
+
import ftfy
|
6 |
+
import regex as re
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
__all__ = ["HuggingfaceTokenizer"]
|
10 |
+
|
11 |
+
|
12 |
+
def basic_clean(text):
|
13 |
+
text = ftfy.fix_text(text)
|
14 |
+
text = html.unescape(html.unescape(text))
|
15 |
+
return text.strip()
|
16 |
+
|
17 |
+
|
18 |
+
def whitespace_clean(text):
|
19 |
+
text = re.sub(r"\s+", " ", text)
|
20 |
+
text = text.strip()
|
21 |
+
return text
|
22 |
+
|
23 |
+
|
24 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
25 |
+
text = text.replace("_", " ")
|
26 |
+
if keep_punctuation_exact_string:
|
27 |
+
text = keep_punctuation_exact_string.join(
|
28 |
+
part.translate(str.maketrans("", "", string.punctuation))
|
29 |
+
for part in text.split(keep_punctuation_exact_string)
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
33 |
+
text = text.lower()
|
34 |
+
text = re.sub(r"\s+", " ", text)
|
35 |
+
return text.strip()
|
36 |
+
|
37 |
+
|
38 |
+
class HuggingfaceTokenizer:
|
39 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
40 |
+
assert clean in (None, "whitespace", "lower", "canonicalize")
|
41 |
+
self.name = name
|
42 |
+
self.seq_len = seq_len
|
43 |
+
self.clean = clean
|
44 |
+
|
45 |
+
# init tokenizer
|
46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
47 |
+
self.vocab_size = self.tokenizer.vocab_size
|
48 |
+
|
49 |
+
def __call__(self, sequence, **kwargs):
|
50 |
+
return_mask = kwargs.pop("return_mask", False)
|
51 |
+
|
52 |
+
# arguments
|
53 |
+
_kwargs = {"return_tensors": "pt"}
|
54 |
+
if self.seq_len is not None:
|
55 |
+
_kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
|
56 |
+
_kwargs.update(**kwargs)
|
57 |
+
|
58 |
+
# tokenization
|
59 |
+
if isinstance(sequence, str):
|
60 |
+
sequence = [sequence]
|
61 |
+
if self.clean:
|
62 |
+
sequence = [self._clean(u) for u in sequence]
|
63 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
64 |
+
|
65 |
+
# output
|
66 |
+
if return_mask:
|
67 |
+
return ids.input_ids, ids.attention_mask
|
68 |
+
else:
|
69 |
+
return ids.input_ids
|
70 |
+
|
71 |
+
def _clean(self, text):
|
72 |
+
if self.clean == "whitespace":
|
73 |
+
text = whitespace_clean(basic_clean(text))
|
74 |
+
elif self.clean == "lower":
|
75 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
76 |
+
elif self.clean == "canonicalize":
|
77 |
+
text = canonicalize(basic_clean(text))
|
78 |
+
return text
|
skyreels_v2_infer/modules/transformer.py
ADDED
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.amp as amp
|
6 |
+
import torch.nn as nn
|
7 |
+
from diffusers.configuration_utils import ConfigMixin
|
8 |
+
from diffusers.configuration_utils import register_to_config
|
9 |
+
from diffusers.loaders import PeftAdapterMixin
|
10 |
+
from diffusers.models.modeling_utils import ModelMixin
|
11 |
+
from torch.backends.cuda import sdp_kernel
|
12 |
+
from torch.nn.attention.flex_attention import BlockMask
|
13 |
+
from torch.nn.attention.flex_attention import create_block_mask
|
14 |
+
from torch.nn.attention.flex_attention import flex_attention
|
15 |
+
|
16 |
+
from .attention import flash_attention
|
17 |
+
|
18 |
+
|
19 |
+
flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
|
20 |
+
|
21 |
+
DISABLE_COMPILE = False # get os env
|
22 |
+
|
23 |
+
__all__ = ["WanModel"]
|
24 |
+
|
25 |
+
|
26 |
+
def sinusoidal_embedding_1d(dim, position):
|
27 |
+
# preprocess
|
28 |
+
assert dim % 2 == 0
|
29 |
+
half = dim // 2
|
30 |
+
position = position.type(torch.float64)
|
31 |
+
|
32 |
+
# calculation
|
33 |
+
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
34 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
@amp.autocast("cuda", enabled=False)
|
39 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
40 |
+
assert dim % 2 == 0
|
41 |
+
freqs = torch.outer(
|
42 |
+
torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
43 |
+
)
|
44 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
45 |
+
return freqs
|
46 |
+
|
47 |
+
|
48 |
+
@amp.autocast("cuda", enabled=False)
|
49 |
+
def rope_apply(x, grid_sizes, freqs):
|
50 |
+
n, c = x.size(2), x.size(3) // 2
|
51 |
+
bs = x.size(0)
|
52 |
+
|
53 |
+
# split freqs
|
54 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
55 |
+
|
56 |
+
# loop over samples
|
57 |
+
f, h, w = grid_sizes.tolist()
|
58 |
+
seq_len = f * h * w
|
59 |
+
|
60 |
+
# precompute multipliers
|
61 |
+
|
62 |
+
x = torch.view_as_complex(x.to(torch.float32).reshape(bs, seq_len, n, -1, 2))
|
63 |
+
freqs_i = torch.cat(
|
64 |
+
[
|
65 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
66 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
67 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
68 |
+
],
|
69 |
+
dim=-1,
|
70 |
+
).reshape(seq_len, 1, -1)
|
71 |
+
|
72 |
+
# apply rotary embedding
|
73 |
+
x = torch.view_as_real(x * freqs_i).flatten(3)
|
74 |
+
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
@torch.compile(dynamic=True, disable=DISABLE_COMPILE)
|
79 |
+
def fast_rms_norm(x, weight, eps):
|
80 |
+
x = x.float()
|
81 |
+
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
|
82 |
+
x = x.type_as(x) * weight
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class WanRMSNorm(nn.Module):
|
87 |
+
def __init__(self, dim, eps=1e-5):
|
88 |
+
super().__init__()
|
89 |
+
self.dim = dim
|
90 |
+
self.eps = eps
|
91 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
r"""
|
95 |
+
Args:
|
96 |
+
x(Tensor): Shape [B, L, C]
|
97 |
+
"""
|
98 |
+
return fast_rms_norm(x, self.weight, self.eps)
|
99 |
+
|
100 |
+
def _norm(self, x):
|
101 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
102 |
+
|
103 |
+
|
104 |
+
class WanLayerNorm(nn.LayerNorm):
|
105 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
106 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
r"""
|
110 |
+
Args:
|
111 |
+
x(Tensor): Shape [B, L, C]
|
112 |
+
"""
|
113 |
+
return super().forward(x)
|
114 |
+
|
115 |
+
|
116 |
+
class WanSelfAttention(nn.Module):
|
117 |
+
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
|
118 |
+
assert dim % num_heads == 0
|
119 |
+
super().__init__()
|
120 |
+
self.dim = dim
|
121 |
+
self.num_heads = num_heads
|
122 |
+
self.head_dim = dim // num_heads
|
123 |
+
self.window_size = window_size
|
124 |
+
self.qk_norm = qk_norm
|
125 |
+
self.eps = eps
|
126 |
+
|
127 |
+
# layers
|
128 |
+
self.q = nn.Linear(dim, dim)
|
129 |
+
self.k = nn.Linear(dim, dim)
|
130 |
+
self.v = nn.Linear(dim, dim)
|
131 |
+
self.o = nn.Linear(dim, dim)
|
132 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
133 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
134 |
+
|
135 |
+
self._flag_ar_attention = False
|
136 |
+
|
137 |
+
def set_ar_attention(self):
|
138 |
+
self._flag_ar_attention = True
|
139 |
+
|
140 |
+
def forward(self, x, grid_sizes, freqs, block_mask):
|
141 |
+
r"""
|
142 |
+
Args:
|
143 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
144 |
+
seq_lens(Tensor): Shape [B]
|
145 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
146 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
147 |
+
"""
|
148 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
149 |
+
|
150 |
+
# query, key, value function
|
151 |
+
def qkv_fn(x):
|
152 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
153 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
154 |
+
v = self.v(x).view(b, s, n, d)
|
155 |
+
return q, k, v
|
156 |
+
|
157 |
+
x = x.to(self.q.weight.dtype)
|
158 |
+
q, k, v = qkv_fn(x)
|
159 |
+
|
160 |
+
if not self._flag_ar_attention:
|
161 |
+
q = rope_apply(q, grid_sizes, freqs)
|
162 |
+
k = rope_apply(k, grid_sizes, freqs)
|
163 |
+
x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
|
164 |
+
else:
|
165 |
+
q = rope_apply(q, grid_sizes, freqs)
|
166 |
+
k = rope_apply(k, grid_sizes, freqs)
|
167 |
+
q = q.to(torch.bfloat16)
|
168 |
+
k = k.to(torch.bfloat16)
|
169 |
+
v = v.to(torch.bfloat16)
|
170 |
+
|
171 |
+
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
172 |
+
x = (
|
173 |
+
torch.nn.functional.scaled_dot_product_attention(
|
174 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
|
175 |
+
)
|
176 |
+
.transpose(1, 2)
|
177 |
+
.contiguous()
|
178 |
+
)
|
179 |
+
|
180 |
+
# output
|
181 |
+
x = x.flatten(2)
|
182 |
+
x = self.o(x)
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class WanT2VCrossAttention(WanSelfAttention):
|
187 |
+
def forward(self, x, context):
|
188 |
+
r"""
|
189 |
+
Args:
|
190 |
+
x(Tensor): Shape [B, L1, C]
|
191 |
+
context(Tensor): Shape [B, L2, C]
|
192 |
+
context_lens(Tensor): Shape [B]
|
193 |
+
"""
|
194 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
195 |
+
|
196 |
+
# compute query, key, value
|
197 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
198 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
199 |
+
v = self.v(context).view(b, -1, n, d)
|
200 |
+
|
201 |
+
# compute attention
|
202 |
+
x = flash_attention(q, k, v)
|
203 |
+
|
204 |
+
# output
|
205 |
+
x = x.flatten(2)
|
206 |
+
x = self.o(x)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class WanI2VCrossAttention(WanSelfAttention):
|
211 |
+
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
|
212 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
213 |
+
|
214 |
+
self.k_img = nn.Linear(dim, dim)
|
215 |
+
self.v_img = nn.Linear(dim, dim)
|
216 |
+
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
217 |
+
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
218 |
+
|
219 |
+
def forward(self, x, context):
|
220 |
+
r"""
|
221 |
+
Args:
|
222 |
+
x(Tensor): Shape [B, L1, C]
|
223 |
+
context(Tensor): Shape [B, L2, C]
|
224 |
+
context_lens(Tensor): Shape [B]
|
225 |
+
"""
|
226 |
+
context_img = context[:, :257]
|
227 |
+
context = context[:, 257:]
|
228 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
229 |
+
|
230 |
+
# compute query, key, value
|
231 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
232 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
233 |
+
v = self.v(context).view(b, -1, n, d)
|
234 |
+
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
235 |
+
v_img = self.v_img(context_img).view(b, -1, n, d)
|
236 |
+
img_x = flash_attention(q, k_img, v_img)
|
237 |
+
# compute attention
|
238 |
+
x = flash_attention(q, k, v)
|
239 |
+
|
240 |
+
# output
|
241 |
+
x = x.flatten(2)
|
242 |
+
img_x = img_x.flatten(2)
|
243 |
+
x = x + img_x
|
244 |
+
x = self.o(x)
|
245 |
+
return x
|
246 |
+
|
247 |
+
|
248 |
+
WAN_CROSSATTENTION_CLASSES = {
|
249 |
+
"t2v_cross_attn": WanT2VCrossAttention,
|
250 |
+
"i2v_cross_attn": WanI2VCrossAttention,
|
251 |
+
}
|
252 |
+
|
253 |
+
|
254 |
+
def mul_add(x, y, z):
|
255 |
+
return x.float() + y.float() * z.float()
|
256 |
+
|
257 |
+
|
258 |
+
def mul_add_add(x, y, z):
|
259 |
+
return x.float() * (1 + y) + z
|
260 |
+
|
261 |
+
|
262 |
+
mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE)
|
263 |
+
mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE)
|
264 |
+
|
265 |
+
|
266 |
+
class WanAttentionBlock(nn.Module):
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
cross_attn_type,
|
270 |
+
dim,
|
271 |
+
ffn_dim,
|
272 |
+
num_heads,
|
273 |
+
window_size=(-1, -1),
|
274 |
+
qk_norm=True,
|
275 |
+
cross_attn_norm=False,
|
276 |
+
eps=1e-6,
|
277 |
+
):
|
278 |
+
super().__init__()
|
279 |
+
self.dim = dim
|
280 |
+
self.ffn_dim = ffn_dim
|
281 |
+
self.num_heads = num_heads
|
282 |
+
self.window_size = window_size
|
283 |
+
self.qk_norm = qk_norm
|
284 |
+
self.cross_attn_norm = cross_attn_norm
|
285 |
+
self.eps = eps
|
286 |
+
|
287 |
+
# layers
|
288 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
289 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
290 |
+
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
291 |
+
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps)
|
292 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
293 |
+
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim))
|
294 |
+
|
295 |
+
# modulation
|
296 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
297 |
+
|
298 |
+
def set_ar_attention(self):
|
299 |
+
self.self_attn.set_ar_attention()
|
300 |
+
|
301 |
+
def forward(
|
302 |
+
self,
|
303 |
+
x,
|
304 |
+
e,
|
305 |
+
grid_sizes,
|
306 |
+
freqs,
|
307 |
+
context,
|
308 |
+
block_mask,
|
309 |
+
):
|
310 |
+
r"""
|
311 |
+
Args:
|
312 |
+
x(Tensor): Shape [B, L, C]
|
313 |
+
e(Tensor): Shape [B, 6, C]
|
314 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
315 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
316 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
317 |
+
"""
|
318 |
+
if e.dim() == 3:
|
319 |
+
modulation = self.modulation # 1, 6, dim
|
320 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
321 |
+
e = (modulation + e).chunk(6, dim=1)
|
322 |
+
elif e.dim() == 4:
|
323 |
+
modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim
|
324 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
325 |
+
e = (modulation + e).chunk(6, dim=1)
|
326 |
+
e = [ei.squeeze(1) for ei in e]
|
327 |
+
|
328 |
+
# self-attention
|
329 |
+
out = mul_add_add_compile(self.norm1(x), e[1], e[0])
|
330 |
+
y = self.self_attn(out, grid_sizes, freqs, block_mask)
|
331 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
332 |
+
x = mul_add_compile(x, y, e[2])
|
333 |
+
|
334 |
+
# cross-attention & ffn function
|
335 |
+
def cross_attn_ffn(x, context, e):
|
336 |
+
dtype = context.dtype
|
337 |
+
x = x + self.cross_attn(self.norm3(x.to(dtype)), context)
|
338 |
+
y = self.ffn(mul_add_add_compile(self.norm2(x), e[4], e[3]).to(dtype))
|
339 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
340 |
+
x = mul_add_compile(x, y, e[5])
|
341 |
+
return x
|
342 |
+
|
343 |
+
x = cross_attn_ffn(x, context, e)
|
344 |
+
return x.to(torch.bfloat16)
|
345 |
+
|
346 |
+
|
347 |
+
class Head(nn.Module):
|
348 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
349 |
+
super().__init__()
|
350 |
+
self.dim = dim
|
351 |
+
self.out_dim = out_dim
|
352 |
+
self.patch_size = patch_size
|
353 |
+
self.eps = eps
|
354 |
+
|
355 |
+
# layers
|
356 |
+
out_dim = math.prod(patch_size) * out_dim
|
357 |
+
self.norm = WanLayerNorm(dim, eps)
|
358 |
+
self.head = nn.Linear(dim, out_dim)
|
359 |
+
|
360 |
+
# modulation
|
361 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
362 |
+
|
363 |
+
def forward(self, x, e):
|
364 |
+
r"""
|
365 |
+
Args:
|
366 |
+
x(Tensor): Shape [B, L1, C]
|
367 |
+
e(Tensor): Shape [B, C]
|
368 |
+
"""
|
369 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
370 |
+
if e.dim() == 2:
|
371 |
+
modulation = self.modulation # 1, 2, dim
|
372 |
+
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
373 |
+
|
374 |
+
elif e.dim() == 3:
|
375 |
+
modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim
|
376 |
+
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
377 |
+
e = [ei.squeeze(1) for ei in e]
|
378 |
+
x = self.head(self.norm(x) * (1 + e[1]) + e[0])
|
379 |
+
return x
|
380 |
+
|
381 |
+
|
382 |
+
class MLPProj(torch.nn.Module):
|
383 |
+
def __init__(self, in_dim, out_dim):
|
384 |
+
super().__init__()
|
385 |
+
|
386 |
+
self.proj = torch.nn.Sequential(
|
387 |
+
torch.nn.LayerNorm(in_dim),
|
388 |
+
torch.nn.Linear(in_dim, in_dim),
|
389 |
+
torch.nn.GELU(),
|
390 |
+
torch.nn.Linear(in_dim, out_dim),
|
391 |
+
torch.nn.LayerNorm(out_dim),
|
392 |
+
)
|
393 |
+
|
394 |
+
def forward(self, image_embeds):
|
395 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
396 |
+
return clip_extra_context_tokens
|
397 |
+
|
398 |
+
|
399 |
+
class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
400 |
+
r"""
|
401 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
402 |
+
"""
|
403 |
+
|
404 |
+
ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
|
405 |
+
_no_split_modules = ["WanAttentionBlock"]
|
406 |
+
|
407 |
+
_supports_gradient_checkpointing = True
|
408 |
+
|
409 |
+
@register_to_config
|
410 |
+
def __init__(
|
411 |
+
self,
|
412 |
+
model_type="t2v",
|
413 |
+
patch_size=(1, 2, 2),
|
414 |
+
text_len=512,
|
415 |
+
in_dim=16,
|
416 |
+
dim=2048,
|
417 |
+
ffn_dim=8192,
|
418 |
+
freq_dim=256,
|
419 |
+
text_dim=4096,
|
420 |
+
out_dim=16,
|
421 |
+
num_heads=16,
|
422 |
+
num_layers=32,
|
423 |
+
window_size=(-1, -1),
|
424 |
+
qk_norm=True,
|
425 |
+
cross_attn_norm=True,
|
426 |
+
inject_sample_info=False,
|
427 |
+
eps=1e-6,
|
428 |
+
):
|
429 |
+
r"""
|
430 |
+
Initialize the diffusion model backbone.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
434 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
435 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
436 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
437 |
+
text_len (`int`, *optional*, defaults to 512):
|
438 |
+
Fixed length for text embeddings
|
439 |
+
in_dim (`int`, *optional*, defaults to 16):
|
440 |
+
Input video channels (C_in)
|
441 |
+
dim (`int`, *optional*, defaults to 2048):
|
442 |
+
Hidden dimension of the transformer
|
443 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
444 |
+
Intermediate dimension in feed-forward network
|
445 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
446 |
+
Dimension for sinusoidal time embeddings
|
447 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
448 |
+
Input dimension for text embeddings
|
449 |
+
out_dim (`int`, *optional*, defaults to 16):
|
450 |
+
Output video channels (C_out)
|
451 |
+
num_heads (`int`, *optional*, defaults to 16):
|
452 |
+
Number of attention heads
|
453 |
+
num_layers (`int`, *optional*, defaults to 32):
|
454 |
+
Number of transformer blocks
|
455 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
456 |
+
Window size for local attention (-1 indicates global attention)
|
457 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
458 |
+
Enable query/key normalization
|
459 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
460 |
+
Enable cross-attention normalization
|
461 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
462 |
+
Epsilon value for normalization layers
|
463 |
+
"""
|
464 |
+
|
465 |
+
super().__init__()
|
466 |
+
|
467 |
+
assert model_type in ["t2v", "i2v"]
|
468 |
+
self.model_type = model_type
|
469 |
+
|
470 |
+
self.patch_size = patch_size
|
471 |
+
self.text_len = text_len
|
472 |
+
self.in_dim = in_dim
|
473 |
+
self.dim = dim
|
474 |
+
self.ffn_dim = ffn_dim
|
475 |
+
self.freq_dim = freq_dim
|
476 |
+
self.text_dim = text_dim
|
477 |
+
self.out_dim = out_dim
|
478 |
+
self.num_heads = num_heads
|
479 |
+
self.num_layers = num_layers
|
480 |
+
self.window_size = window_size
|
481 |
+
self.qk_norm = qk_norm
|
482 |
+
self.cross_attn_norm = cross_attn_norm
|
483 |
+
self.eps = eps
|
484 |
+
self.num_frame_per_block = 1
|
485 |
+
self.flag_causal_attention = False
|
486 |
+
self.block_mask = None
|
487 |
+
self.enable_teacache = False
|
488 |
+
|
489 |
+
# embeddings
|
490 |
+
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
491 |
+
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim))
|
492 |
+
|
493 |
+
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
494 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
495 |
+
|
496 |
+
if inject_sample_info:
|
497 |
+
self.fps_embedding = nn.Embedding(2, dim)
|
498 |
+
self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
|
499 |
+
|
500 |
+
# blocks
|
501 |
+
cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
|
502 |
+
self.blocks = nn.ModuleList(
|
503 |
+
[
|
504 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
505 |
+
for _ in range(num_layers)
|
506 |
+
]
|
507 |
+
)
|
508 |
+
|
509 |
+
# head
|
510 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
511 |
+
|
512 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
513 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
514 |
+
d = dim // num_heads
|
515 |
+
self.freqs = torch.cat(
|
516 |
+
[rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))],
|
517 |
+
dim=1,
|
518 |
+
)
|
519 |
+
|
520 |
+
if model_type == "i2v":
|
521 |
+
self.img_emb = MLPProj(1280, dim)
|
522 |
+
|
523 |
+
self.gradient_checkpointing = False
|
524 |
+
|
525 |
+
self.cpu_offloading = False
|
526 |
+
|
527 |
+
self.inject_sample_info = inject_sample_info
|
528 |
+
# initialize weights
|
529 |
+
self.init_weights()
|
530 |
+
|
531 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
532 |
+
self.gradient_checkpointing = value
|
533 |
+
|
534 |
+
def zero_init_i2v_cross_attn(self):
|
535 |
+
print("zero init i2v cross attn")
|
536 |
+
for i in range(self.num_layers):
|
537 |
+
self.blocks[i].cross_attn.v_img.weight.data.zero_()
|
538 |
+
self.blocks[i].cross_attn.v_img.bias.data.zero_()
|
539 |
+
|
540 |
+
@staticmethod
|
541 |
+
def _prepare_blockwise_causal_attn_mask(
|
542 |
+
device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1
|
543 |
+
) -> BlockMask:
|
544 |
+
"""
|
545 |
+
we will divide the token sequence into the following format
|
546 |
+
[1 latent frame] [1 latent frame] ... [1 latent frame]
|
547 |
+
We use flexattention to construct the attention mask
|
548 |
+
"""
|
549 |
+
total_length = num_frames * frame_seqlen
|
550 |
+
|
551 |
+
# we do right padding to get to a multiple of 128
|
552 |
+
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
553 |
+
|
554 |
+
ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
|
555 |
+
|
556 |
+
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
557 |
+
frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device)
|
558 |
+
|
559 |
+
for tmp in frame_indices:
|
560 |
+
ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block
|
561 |
+
|
562 |
+
def attention_mask(b, h, q_idx, kv_idx):
|
563 |
+
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
|
564 |
+
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
|
565 |
+
|
566 |
+
block_mask = create_block_mask(
|
567 |
+
attention_mask,
|
568 |
+
B=None,
|
569 |
+
H=None,
|
570 |
+
Q_LEN=total_length + padded_length,
|
571 |
+
KV_LEN=total_length + padded_length,
|
572 |
+
_compile=False,
|
573 |
+
device=device,
|
574 |
+
)
|
575 |
+
|
576 |
+
return block_mask
|
577 |
+
|
578 |
+
def initialize_teacache(self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir=''):
|
579 |
+
self.enable_teacache = enable_teacache
|
580 |
+
print('using teacache')
|
581 |
+
self.cnt = 0
|
582 |
+
self.num_steps = num_steps
|
583 |
+
self.teacache_thresh = teacache_thresh
|
584 |
+
self.accumulated_rel_l1_distance_even = 0
|
585 |
+
self.accumulated_rel_l1_distance_odd = 0
|
586 |
+
self.previous_e0_even = None
|
587 |
+
self.previous_e0_odd = None
|
588 |
+
self.previous_residual_even = None
|
589 |
+
self.previous_residual_odd = None
|
590 |
+
self.use_ref_steps = use_ret_steps
|
591 |
+
if "I2V" in ckpt_dir:
|
592 |
+
if use_ret_steps:
|
593 |
+
if '540P' in ckpt_dir:
|
594 |
+
self.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
595 |
+
if '720P' in ckpt_dir:
|
596 |
+
self.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
597 |
+
self.ret_steps = 5*2
|
598 |
+
self.cutoff_steps = num_steps*2
|
599 |
+
else:
|
600 |
+
if '540P' in ckpt_dir:
|
601 |
+
self.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
602 |
+
if '720P' in ckpt_dir:
|
603 |
+
self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
604 |
+
self.ret_steps = 1*2
|
605 |
+
self.cutoff_steps = num_steps*2 - 2
|
606 |
+
else:
|
607 |
+
if use_ret_steps:
|
608 |
+
if '1.3B' in ckpt_dir:
|
609 |
+
self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
|
610 |
+
if '14B' in ckpt_dir:
|
611 |
+
self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
|
612 |
+
self.ret_steps = 5*2
|
613 |
+
self.cutoff_steps = num_steps*2
|
614 |
+
else:
|
615 |
+
if '1.3B' in ckpt_dir:
|
616 |
+
self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
617 |
+
if '14B' in ckpt_dir:
|
618 |
+
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
|
619 |
+
self.ret_steps = 1*2
|
620 |
+
self.cutoff_steps = num_steps*2 - 2
|
621 |
+
|
622 |
+
def forward(self, x, t, context, clip_fea=None, y=None, fps=None):
|
623 |
+
r"""
|
624 |
+
Forward pass through the diffusion model
|
625 |
+
|
626 |
+
Args:
|
627 |
+
x (List[Tensor]):
|
628 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
629 |
+
t (Tensor):
|
630 |
+
Diffusion timesteps tensor of shape [B]
|
631 |
+
context (List[Tensor]):
|
632 |
+
List of text embeddings each with shape [L, C]
|
633 |
+
seq_len (`int`):
|
634 |
+
Maximum sequence length for positional encoding
|
635 |
+
clip_fea (Tensor, *optional*):
|
636 |
+
CLIP image features for image-to-video mode
|
637 |
+
y (List[Tensor], *optional*):
|
638 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
639 |
+
|
640 |
+
Returns:
|
641 |
+
List[Tensor]:
|
642 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
643 |
+
"""
|
644 |
+
if self.model_type == "i2v":
|
645 |
+
assert clip_fea is not None and y is not None
|
646 |
+
# params
|
647 |
+
device = self.patch_embedding.weight.device
|
648 |
+
if self.freqs.device != device:
|
649 |
+
self.freqs = self.freqs.to(device)
|
650 |
+
|
651 |
+
if y is not None:
|
652 |
+
x = torch.cat([x, y], dim=1)
|
653 |
+
|
654 |
+
# embeddings
|
655 |
+
x = self.patch_embedding(x)
|
656 |
+
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
|
657 |
+
x = x.flatten(2).transpose(1, 2)
|
658 |
+
|
659 |
+
if self.flag_causal_attention:
|
660 |
+
frame_num = grid_sizes[0]
|
661 |
+
height = grid_sizes[1]
|
662 |
+
width = grid_sizes[2]
|
663 |
+
block_num = frame_num // self.num_frame_per_block
|
664 |
+
range_tensor = torch.arange(block_num).view(-1, 1)
|
665 |
+
range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
|
666 |
+
casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
|
667 |
+
casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
|
668 |
+
casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
|
669 |
+
casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
|
670 |
+
self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
|
671 |
+
|
672 |
+
# time embeddings
|
673 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
674 |
+
if t.dim() == 2:
|
675 |
+
b, f = t.shape
|
676 |
+
_flag_df = True
|
677 |
+
else:
|
678 |
+
_flag_df = False
|
679 |
+
|
680 |
+
e = self.time_embedding(
|
681 |
+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
|
682 |
+
) # b, dim
|
683 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
|
684 |
+
|
685 |
+
if self.inject_sample_info:
|
686 |
+
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
687 |
+
|
688 |
+
fps_emb = self.fps_embedding(fps).float()
|
689 |
+
if _flag_df:
|
690 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
691 |
+
else:
|
692 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
|
693 |
+
|
694 |
+
if _flag_df:
|
695 |
+
e = e.view(b, f, 1, 1, self.dim)
|
696 |
+
e0 = e0.view(b, f, 1, 1, 6, self.dim)
|
697 |
+
e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
|
698 |
+
e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
|
699 |
+
e0 = e0.transpose(1, 2).contiguous()
|
700 |
+
|
701 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
702 |
+
|
703 |
+
# context
|
704 |
+
context = self.text_embedding(context)
|
705 |
+
|
706 |
+
if clip_fea is not None:
|
707 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
708 |
+
context = torch.concat([context_clip, context], dim=1)
|
709 |
+
|
710 |
+
# arguments
|
711 |
+
kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
|
712 |
+
if self.enable_teacache:
|
713 |
+
modulated_inp = e0 if self.use_ref_steps else e
|
714 |
+
# teacache
|
715 |
+
if self.cnt%2==0: # even -> conditon
|
716 |
+
self.is_even = True
|
717 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
718 |
+
should_calc_even = True
|
719 |
+
self.accumulated_rel_l1_distance_even = 0
|
720 |
+
else:
|
721 |
+
rescale_func = np.poly1d(self.coefficients)
|
722 |
+
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
|
723 |
+
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
|
724 |
+
should_calc_even = False
|
725 |
+
else:
|
726 |
+
should_calc_even = True
|
727 |
+
self.accumulated_rel_l1_distance_even = 0
|
728 |
+
self.previous_e0_even = modulated_inp.clone()
|
729 |
+
|
730 |
+
else: # odd -> unconditon
|
731 |
+
self.is_even = False
|
732 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
733 |
+
should_calc_odd = True
|
734 |
+
self.accumulated_rel_l1_distance_odd = 0
|
735 |
+
else:
|
736 |
+
rescale_func = np.poly1d(self.coefficients)
|
737 |
+
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
|
738 |
+
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
|
739 |
+
should_calc_odd = False
|
740 |
+
else:
|
741 |
+
should_calc_odd = True
|
742 |
+
self.accumulated_rel_l1_distance_odd = 0
|
743 |
+
self.previous_e0_odd = modulated_inp.clone()
|
744 |
+
|
745 |
+
if self.enable_teacache:
|
746 |
+
if self.is_even:
|
747 |
+
if not should_calc_even:
|
748 |
+
x += self.previous_residual_even
|
749 |
+
else:
|
750 |
+
ori_x = x.clone()
|
751 |
+
for block in self.blocks:
|
752 |
+
x = block(x, **kwargs)
|
753 |
+
self.previous_residual_even = x - ori_x
|
754 |
+
else:
|
755 |
+
if not should_calc_odd:
|
756 |
+
x += self.previous_residual_odd
|
757 |
+
else:
|
758 |
+
ori_x = x.clone()
|
759 |
+
for block in self.blocks:
|
760 |
+
x = block(x, **kwargs)
|
761 |
+
self.previous_residual_odd = x - ori_x
|
762 |
+
|
763 |
+
self.cnt += 1
|
764 |
+
if self.cnt >= self.num_steps:
|
765 |
+
self.cnt = 0
|
766 |
+
else:
|
767 |
+
for block in self.blocks:
|
768 |
+
x = block(x, **kwargs)
|
769 |
+
|
770 |
+
x = self.head(x, e)
|
771 |
+
|
772 |
+
# unpatchify
|
773 |
+
x = self.unpatchify(x, grid_sizes)
|
774 |
+
|
775 |
+
return x.float()
|
776 |
+
|
777 |
+
def unpatchify(self, x, grid_sizes):
|
778 |
+
r"""
|
779 |
+
Reconstruct video tensors from patch embeddings.
|
780 |
+
|
781 |
+
Args:
|
782 |
+
x (List[Tensor]):
|
783 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
784 |
+
grid_sizes (Tensor):
|
785 |
+
Original spatial-temporal grid dimensions before patching,
|
786 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
787 |
+
|
788 |
+
Returns:
|
789 |
+
List[Tensor]:
|
790 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
791 |
+
"""
|
792 |
+
|
793 |
+
c = self.out_dim
|
794 |
+
bs = x.shape[0]
|
795 |
+
x = x.view(bs, *grid_sizes, *self.patch_size, c)
|
796 |
+
x = torch.einsum("bfhwpqrc->bcfphqwr", x)
|
797 |
+
x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
798 |
+
|
799 |
+
return x
|
800 |
+
|
801 |
+
def set_ar_attention(self, causal_block_size):
|
802 |
+
self.num_frame_per_block = causal_block_size
|
803 |
+
self.flag_causal_attention = True
|
804 |
+
for block in self.blocks:
|
805 |
+
block.set_ar_attention()
|
806 |
+
|
807 |
+
def init_weights(self):
|
808 |
+
r"""
|
809 |
+
Initialize model parameters using Xavier initialization.
|
810 |
+
"""
|
811 |
+
|
812 |
+
# basic init
|
813 |
+
for m in self.modules():
|
814 |
+
if isinstance(m, nn.Linear):
|
815 |
+
nn.init.xavier_uniform_(m.weight)
|
816 |
+
if m.bias is not None:
|
817 |
+
nn.init.zeros_(m.bias)
|
818 |
+
|
819 |
+
# init embeddings
|
820 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
821 |
+
for m in self.text_embedding.modules():
|
822 |
+
if isinstance(m, nn.Linear):
|
823 |
+
nn.init.normal_(m.weight, std=0.02)
|
824 |
+
for m in self.time_embedding.modules():
|
825 |
+
if isinstance(m, nn.Linear):
|
826 |
+
nn.init.normal_(m.weight, std=0.02)
|
827 |
+
|
828 |
+
if self.inject_sample_info:
|
829 |
+
nn.init.normal_(self.fps_embedding.weight, std=0.02)
|
830 |
+
|
831 |
+
for m in self.fps_projection.modules():
|
832 |
+
if isinstance(m, nn.Linear):
|
833 |
+
nn.init.normal_(m.weight, std=0.02)
|
834 |
+
|
835 |
+
nn.init.zeros_(self.fps_projection[-1].weight)
|
836 |
+
nn.init.zeros_(self.fps_projection[-1].bias)
|
837 |
+
|
838 |
+
# init output layer
|
839 |
+
nn.init.zeros_(self.head.head.weight)
|
skyreels_v2_infer/modules/vae.py
ADDED
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"WanVAE",
|
12 |
+
]
|
13 |
+
|
14 |
+
CACHE_T = 2
|
15 |
+
|
16 |
+
|
17 |
+
class CausalConv3d(nn.Conv3d):
|
18 |
+
"""
|
19 |
+
Causal 3d convolusion.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, *args, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
25 |
+
self.padding = (0, 0, 0)
|
26 |
+
|
27 |
+
def forward(self, x, cache_x=None):
|
28 |
+
padding = list(self._padding)
|
29 |
+
if cache_x is not None and self._padding[4] > 0:
|
30 |
+
cache_x = cache_x.to(x.device)
|
31 |
+
x = torch.cat([cache_x, x], dim=2)
|
32 |
+
padding[4] -= cache_x.shape[2]
|
33 |
+
x = F.pad(x, padding)
|
34 |
+
|
35 |
+
return super().forward(x)
|
36 |
+
|
37 |
+
|
38 |
+
class RMS_norm(nn.Module):
|
39 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
40 |
+
super().__init__()
|
41 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
42 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
43 |
+
|
44 |
+
self.channel_first = channel_first
|
45 |
+
self.scale = dim**0.5
|
46 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
47 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
51 |
+
|
52 |
+
|
53 |
+
class Upsample(nn.Upsample):
|
54 |
+
def forward(self, x):
|
55 |
+
"""
|
56 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
57 |
+
"""
|
58 |
+
return super().forward(x.float()).type_as(x)
|
59 |
+
|
60 |
+
|
61 |
+
class Resample(nn.Module):
|
62 |
+
def __init__(self, dim, mode):
|
63 |
+
assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
|
64 |
+
super().__init__()
|
65 |
+
self.dim = dim
|
66 |
+
self.mode = mode
|
67 |
+
|
68 |
+
# layers
|
69 |
+
if mode == "upsample2d":
|
70 |
+
self.resample = nn.Sequential(
|
71 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
72 |
+
)
|
73 |
+
elif mode == "upsample3d":
|
74 |
+
self.resample = nn.Sequential(
|
75 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
76 |
+
)
|
77 |
+
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
78 |
+
|
79 |
+
elif mode == "downsample2d":
|
80 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
81 |
+
elif mode == "downsample3d":
|
82 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
83 |
+
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
84 |
+
|
85 |
+
else:
|
86 |
+
self.resample = nn.Identity()
|
87 |
+
|
88 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
89 |
+
b, c, t, h, w = x.size()
|
90 |
+
if self.mode == "upsample3d":
|
91 |
+
if feat_cache is not None:
|
92 |
+
idx = feat_idx[0]
|
93 |
+
if feat_cache[idx] is None:
|
94 |
+
feat_cache[idx] = "Rep"
|
95 |
+
feat_idx[0] += 1
|
96 |
+
else:
|
97 |
+
|
98 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
99 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
100 |
+
# cache last frame of last two chunk
|
101 |
+
cache_x = torch.cat(
|
102 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
103 |
+
)
|
104 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
105 |
+
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
106 |
+
if feat_cache[idx] == "Rep":
|
107 |
+
x = self.time_conv(x)
|
108 |
+
else:
|
109 |
+
x = self.time_conv(x, feat_cache[idx])
|
110 |
+
feat_cache[idx] = cache_x
|
111 |
+
feat_idx[0] += 1
|
112 |
+
|
113 |
+
x = x.reshape(b, 2, c, t, h, w)
|
114 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
115 |
+
x = x.reshape(b, c, t * 2, h, w)
|
116 |
+
t = x.shape[2]
|
117 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
118 |
+
x = self.resample(x)
|
119 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
120 |
+
|
121 |
+
if self.mode == "downsample3d":
|
122 |
+
if feat_cache is not None:
|
123 |
+
idx = feat_idx[0]
|
124 |
+
if feat_cache[idx] is None:
|
125 |
+
feat_cache[idx] = x.clone()
|
126 |
+
feat_idx[0] += 1
|
127 |
+
else:
|
128 |
+
|
129 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
130 |
+
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
131 |
+
# # cache last frame of last two chunk
|
132 |
+
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
133 |
+
|
134 |
+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
135 |
+
feat_cache[idx] = cache_x
|
136 |
+
feat_idx[0] += 1
|
137 |
+
return x
|
138 |
+
|
139 |
+
def init_weight(self, conv):
|
140 |
+
conv_weight = conv.weight
|
141 |
+
nn.init.zeros_(conv_weight)
|
142 |
+
c1, c2, t, h, w = conv_weight.size()
|
143 |
+
one_matrix = torch.eye(c1, c2)
|
144 |
+
init_matrix = one_matrix
|
145 |
+
nn.init.zeros_(conv_weight)
|
146 |
+
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
147 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
148 |
+
conv.weight.data.copy_(conv_weight)
|
149 |
+
nn.init.zeros_(conv.bias.data)
|
150 |
+
|
151 |
+
def init_weight2(self, conv):
|
152 |
+
conv_weight = conv.weight.data
|
153 |
+
nn.init.zeros_(conv_weight)
|
154 |
+
c1, c2, t, h, w = conv_weight.size()
|
155 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
156 |
+
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
157 |
+
conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
|
158 |
+
conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
|
159 |
+
conv.weight.data.copy_(conv_weight)
|
160 |
+
nn.init.zeros_(conv.bias.data)
|
161 |
+
|
162 |
+
|
163 |
+
class ResidualBlock(nn.Module):
|
164 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
165 |
+
super().__init__()
|
166 |
+
self.in_dim = in_dim
|
167 |
+
self.out_dim = out_dim
|
168 |
+
|
169 |
+
# layers
|
170 |
+
self.residual = nn.Sequential(
|
171 |
+
RMS_norm(in_dim, images=False),
|
172 |
+
nn.SiLU(),
|
173 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
174 |
+
RMS_norm(out_dim, images=False),
|
175 |
+
nn.SiLU(),
|
176 |
+
nn.Dropout(dropout),
|
177 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
178 |
+
)
|
179 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
180 |
+
|
181 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
182 |
+
h = self.shortcut(x)
|
183 |
+
for layer in self.residual:
|
184 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
185 |
+
idx = feat_idx[0]
|
186 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
187 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
188 |
+
# cache last frame of last two chunk
|
189 |
+
cache_x = torch.cat(
|
190 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
191 |
+
)
|
192 |
+
x = layer(x, feat_cache[idx])
|
193 |
+
feat_cache[idx] = cache_x
|
194 |
+
feat_idx[0] += 1
|
195 |
+
else:
|
196 |
+
x = layer(x)
|
197 |
+
return x + h
|
198 |
+
|
199 |
+
|
200 |
+
class AttentionBlock(nn.Module):
|
201 |
+
"""
|
202 |
+
Causal self-attention with a single head.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self, dim):
|
206 |
+
super().__init__()
|
207 |
+
self.dim = dim
|
208 |
+
|
209 |
+
# layers
|
210 |
+
self.norm = RMS_norm(dim)
|
211 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
212 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
213 |
+
|
214 |
+
# zero out the last layer params
|
215 |
+
nn.init.zeros_(self.proj.weight)
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
identity = x
|
219 |
+
b, c, t, h, w = x.size()
|
220 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
221 |
+
x = self.norm(x)
|
222 |
+
# compute query, key, value
|
223 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
224 |
+
|
225 |
+
# apply attention
|
226 |
+
x = F.scaled_dot_product_attention(
|
227 |
+
q,
|
228 |
+
k,
|
229 |
+
v,
|
230 |
+
)
|
231 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
232 |
+
|
233 |
+
# output
|
234 |
+
x = self.proj(x)
|
235 |
+
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
|
236 |
+
return x + identity
|
237 |
+
|
238 |
+
|
239 |
+
class Encoder3d(nn.Module):
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
dim=128,
|
243 |
+
z_dim=4,
|
244 |
+
dim_mult=[1, 2, 4, 4],
|
245 |
+
num_res_blocks=2,
|
246 |
+
attn_scales=[],
|
247 |
+
temperal_downsample=[True, True, False],
|
248 |
+
dropout=0.0,
|
249 |
+
):
|
250 |
+
super().__init__()
|
251 |
+
self.dim = dim
|
252 |
+
self.z_dim = z_dim
|
253 |
+
self.dim_mult = dim_mult
|
254 |
+
self.num_res_blocks = num_res_blocks
|
255 |
+
self.attn_scales = attn_scales
|
256 |
+
self.temperal_downsample = temperal_downsample
|
257 |
+
|
258 |
+
# dimensions
|
259 |
+
dims = [dim * u for u in [1] + dim_mult]
|
260 |
+
scale = 1.0
|
261 |
+
|
262 |
+
# init block
|
263 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
264 |
+
|
265 |
+
# downsample blocks
|
266 |
+
downsamples = []
|
267 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
268 |
+
# residual (+attention) blocks
|
269 |
+
for _ in range(num_res_blocks):
|
270 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
271 |
+
if scale in attn_scales:
|
272 |
+
downsamples.append(AttentionBlock(out_dim))
|
273 |
+
in_dim = out_dim
|
274 |
+
|
275 |
+
# downsample block
|
276 |
+
if i != len(dim_mult) - 1:
|
277 |
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
278 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
279 |
+
scale /= 2.0
|
280 |
+
self.downsamples = nn.Sequential(*downsamples)
|
281 |
+
|
282 |
+
# middle blocks
|
283 |
+
self.middle = nn.Sequential(
|
284 |
+
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
|
285 |
+
)
|
286 |
+
|
287 |
+
# output blocks
|
288 |
+
self.head = nn.Sequential(
|
289 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
293 |
+
if feat_cache is not None:
|
294 |
+
idx = feat_idx[0]
|
295 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
296 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
297 |
+
# cache last frame of last two chunk
|
298 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
299 |
+
x = self.conv1(x, feat_cache[idx])
|
300 |
+
feat_cache[idx] = cache_x
|
301 |
+
feat_idx[0] += 1
|
302 |
+
else:
|
303 |
+
x = self.conv1(x)
|
304 |
+
|
305 |
+
## downsamples
|
306 |
+
for layer in self.downsamples:
|
307 |
+
if feat_cache is not None:
|
308 |
+
x = layer(x, feat_cache, feat_idx)
|
309 |
+
else:
|
310 |
+
x = layer(x)
|
311 |
+
|
312 |
+
## middle
|
313 |
+
for layer in self.middle:
|
314 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
315 |
+
x = layer(x, feat_cache, feat_idx)
|
316 |
+
else:
|
317 |
+
x = layer(x)
|
318 |
+
|
319 |
+
## head
|
320 |
+
for layer in self.head:
|
321 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
322 |
+
idx = feat_idx[0]
|
323 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
324 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
325 |
+
# cache last frame of last two chunk
|
326 |
+
cache_x = torch.cat(
|
327 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
328 |
+
)
|
329 |
+
x = layer(x, feat_cache[idx])
|
330 |
+
feat_cache[idx] = cache_x
|
331 |
+
feat_idx[0] += 1
|
332 |
+
else:
|
333 |
+
x = layer(x)
|
334 |
+
return x
|
335 |
+
|
336 |
+
|
337 |
+
class Decoder3d(nn.Module):
|
338 |
+
def __init__(
|
339 |
+
self,
|
340 |
+
dim=128,
|
341 |
+
z_dim=4,
|
342 |
+
dim_mult=[1, 2, 4, 4],
|
343 |
+
num_res_blocks=2,
|
344 |
+
attn_scales=[],
|
345 |
+
temperal_upsample=[False, True, True],
|
346 |
+
dropout=0.0,
|
347 |
+
):
|
348 |
+
super().__init__()
|
349 |
+
self.dim = dim
|
350 |
+
self.z_dim = z_dim
|
351 |
+
self.dim_mult = dim_mult
|
352 |
+
self.num_res_blocks = num_res_blocks
|
353 |
+
self.attn_scales = attn_scales
|
354 |
+
self.temperal_upsample = temperal_upsample
|
355 |
+
|
356 |
+
# dimensions
|
357 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
358 |
+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
359 |
+
|
360 |
+
# init block
|
361 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
362 |
+
|
363 |
+
# middle blocks
|
364 |
+
self.middle = nn.Sequential(
|
365 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
|
366 |
+
)
|
367 |
+
|
368 |
+
# upsample blocks
|
369 |
+
upsamples = []
|
370 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
371 |
+
# residual (+attention) blocks
|
372 |
+
if i == 1 or i == 2 or i == 3:
|
373 |
+
in_dim = in_dim // 2
|
374 |
+
for _ in range(num_res_blocks + 1):
|
375 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
376 |
+
if scale in attn_scales:
|
377 |
+
upsamples.append(AttentionBlock(out_dim))
|
378 |
+
in_dim = out_dim
|
379 |
+
|
380 |
+
# upsample block
|
381 |
+
if i != len(dim_mult) - 1:
|
382 |
+
mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
383 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
384 |
+
scale *= 2.0
|
385 |
+
self.upsamples = nn.Sequential(*upsamples)
|
386 |
+
|
387 |
+
# output blocks
|
388 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
|
389 |
+
|
390 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
391 |
+
## conv1
|
392 |
+
if feat_cache is not None:
|
393 |
+
idx = feat_idx[0]
|
394 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
395 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
396 |
+
# cache last frame of last two chunk
|
397 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
398 |
+
x = self.conv1(x, feat_cache[idx])
|
399 |
+
feat_cache[idx] = cache_x
|
400 |
+
feat_idx[0] += 1
|
401 |
+
else:
|
402 |
+
x = self.conv1(x)
|
403 |
+
|
404 |
+
## middle
|
405 |
+
for layer in self.middle:
|
406 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
407 |
+
x = layer(x, feat_cache, feat_idx)
|
408 |
+
else:
|
409 |
+
x = layer(x)
|
410 |
+
|
411 |
+
## upsamples
|
412 |
+
for layer in self.upsamples:
|
413 |
+
if feat_cache is not None:
|
414 |
+
x = layer(x, feat_cache, feat_idx)
|
415 |
+
else:
|
416 |
+
x = layer(x)
|
417 |
+
|
418 |
+
## head
|
419 |
+
for layer in self.head:
|
420 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
421 |
+
idx = feat_idx[0]
|
422 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
423 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
424 |
+
# cache last frame of last two chunk
|
425 |
+
cache_x = torch.cat(
|
426 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
427 |
+
)
|
428 |
+
x = layer(x, feat_cache[idx])
|
429 |
+
feat_cache[idx] = cache_x
|
430 |
+
feat_idx[0] += 1
|
431 |
+
else:
|
432 |
+
x = layer(x)
|
433 |
+
return x
|
434 |
+
|
435 |
+
|
436 |
+
def count_conv3d(model):
|
437 |
+
count = 0
|
438 |
+
for m in model.modules():
|
439 |
+
if isinstance(m, CausalConv3d):
|
440 |
+
count += 1
|
441 |
+
return count
|
442 |
+
|
443 |
+
|
444 |
+
class WanVAE_(nn.Module):
|
445 |
+
def __init__(
|
446 |
+
self,
|
447 |
+
dim=128,
|
448 |
+
z_dim=4,
|
449 |
+
dim_mult=[1, 2, 4, 4],
|
450 |
+
num_res_blocks=2,
|
451 |
+
attn_scales=[],
|
452 |
+
temperal_downsample=[True, True, False],
|
453 |
+
dropout=0.0,
|
454 |
+
):
|
455 |
+
super().__init__()
|
456 |
+
self.dim = dim
|
457 |
+
self.z_dim = z_dim
|
458 |
+
self.dim_mult = dim_mult
|
459 |
+
self.num_res_blocks = num_res_blocks
|
460 |
+
self.attn_scales = attn_scales
|
461 |
+
self.temperal_downsample = temperal_downsample
|
462 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
463 |
+
|
464 |
+
# modules
|
465 |
+
self.encoder = Encoder3d(
|
466 |
+
dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
467 |
+
)
|
468 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
469 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
470 |
+
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
|
471 |
+
|
472 |
+
def forward(self, x):
|
473 |
+
mu, log_var = self.encode(x)
|
474 |
+
z = self.reparameterize(mu, log_var)
|
475 |
+
x_recon = self.decode(z)
|
476 |
+
return x_recon, mu, log_var
|
477 |
+
|
478 |
+
def encode(self, x, scale):
|
479 |
+
self.clear_cache()
|
480 |
+
## cache
|
481 |
+
t = x.shape[2]
|
482 |
+
iter_ = 1 + (t - 1) // 4
|
483 |
+
## 对encode输入的x,按时间拆分为1、4、4、4....
|
484 |
+
for i in range(iter_):
|
485 |
+
self._enc_conv_idx = [0]
|
486 |
+
if i == 0:
|
487 |
+
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
488 |
+
else:
|
489 |
+
out_ = self.encoder(
|
490 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
491 |
+
feat_cache=self._enc_feat_map,
|
492 |
+
feat_idx=self._enc_conv_idx,
|
493 |
+
)
|
494 |
+
out = torch.cat([out, out_], 2)
|
495 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
496 |
+
if isinstance(scale[0], torch.Tensor):
|
497 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
498 |
+
else:
|
499 |
+
mu = (mu - scale[0]) * scale[1]
|
500 |
+
self.clear_cache()
|
501 |
+
return mu
|
502 |
+
|
503 |
+
def decode(self, z, scale):
|
504 |
+
self.clear_cache()
|
505 |
+
# z: [b,c,t,h,w]
|
506 |
+
if isinstance(scale[0], torch.Tensor):
|
507 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
508 |
+
else:
|
509 |
+
z = z / scale[1] + scale[0]
|
510 |
+
iter_ = z.shape[2]
|
511 |
+
x = self.conv2(z)
|
512 |
+
for i in range(iter_):
|
513 |
+
self._conv_idx = [0]
|
514 |
+
if i == 0:
|
515 |
+
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
516 |
+
else:
|
517 |
+
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
518 |
+
out = torch.cat([out, out_], 2)
|
519 |
+
self.clear_cache()
|
520 |
+
return out
|
521 |
+
|
522 |
+
def reparameterize(self, mu, log_var):
|
523 |
+
std = torch.exp(0.5 * log_var)
|
524 |
+
eps = torch.randn_like(std)
|
525 |
+
return eps * std + mu
|
526 |
+
|
527 |
+
def sample(self, imgs, deterministic=False):
|
528 |
+
mu, log_var = self.encode(imgs)
|
529 |
+
if deterministic:
|
530 |
+
return mu
|
531 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
532 |
+
return mu + std * torch.randn_like(std)
|
533 |
+
|
534 |
+
def clear_cache(self):
|
535 |
+
self._conv_num = count_conv3d(self.decoder)
|
536 |
+
self._conv_idx = [0]
|
537 |
+
self._feat_map = [None] * self._conv_num
|
538 |
+
# cache encode
|
539 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
540 |
+
self._enc_conv_idx = [0]
|
541 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
542 |
+
|
543 |
+
|
544 |
+
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
|
545 |
+
"""
|
546 |
+
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
547 |
+
"""
|
548 |
+
# params
|
549 |
+
cfg = dict(
|
550 |
+
dim=96,
|
551 |
+
z_dim=z_dim,
|
552 |
+
dim_mult=[1, 2, 4, 4],
|
553 |
+
num_res_blocks=2,
|
554 |
+
attn_scales=[],
|
555 |
+
temperal_downsample=[False, True, True],
|
556 |
+
dropout=0.0,
|
557 |
+
)
|
558 |
+
cfg.update(**kwargs)
|
559 |
+
|
560 |
+
# init model
|
561 |
+
with torch.device("meta"):
|
562 |
+
model = WanVAE_(**cfg)
|
563 |
+
|
564 |
+
# load checkpoint
|
565 |
+
logging.info(f"loading {pretrained_path}")
|
566 |
+
model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
|
567 |
+
|
568 |
+
return model
|
569 |
+
|
570 |
+
|
571 |
+
class WanVAE:
|
572 |
+
def __init__(self, vae_pth="cache/vae_step_411000.pth", z_dim=16):
|
573 |
+
|
574 |
+
mean = [
|
575 |
+
-0.7571,
|
576 |
+
-0.7089,
|
577 |
+
-0.9113,
|
578 |
+
0.1075,
|
579 |
+
-0.1745,
|
580 |
+
0.9653,
|
581 |
+
-0.1517,
|
582 |
+
1.5508,
|
583 |
+
0.4134,
|
584 |
+
-0.0715,
|
585 |
+
0.5517,
|
586 |
+
-0.3632,
|
587 |
+
-0.1922,
|
588 |
+
-0.9497,
|
589 |
+
0.2503,
|
590 |
+
-0.2921,
|
591 |
+
]
|
592 |
+
std = [
|
593 |
+
2.8184,
|
594 |
+
1.4541,
|
595 |
+
2.3275,
|
596 |
+
2.6558,
|
597 |
+
1.2196,
|
598 |
+
1.7708,
|
599 |
+
2.6052,
|
600 |
+
2.0743,
|
601 |
+
3.2687,
|
602 |
+
2.1526,
|
603 |
+
2.8652,
|
604 |
+
1.5579,
|
605 |
+
1.6382,
|
606 |
+
1.1253,
|
607 |
+
2.8251,
|
608 |
+
1.9160,
|
609 |
+
]
|
610 |
+
self.vae_stride = (4, 8, 8)
|
611 |
+
self.mean = torch.tensor(mean)
|
612 |
+
self.std = torch.tensor(std)
|
613 |
+
self.scale = [self.mean, 1.0 / self.std]
|
614 |
+
|
615 |
+
# init model
|
616 |
+
self.vae = (
|
617 |
+
_video_vae(
|
618 |
+
pretrained_path=vae_pth,
|
619 |
+
z_dim=z_dim,
|
620 |
+
)
|
621 |
+
.eval()
|
622 |
+
.requires_grad_(False)
|
623 |
+
)
|
624 |
+
|
625 |
+
def encode(self, video):
|
626 |
+
"""
|
627 |
+
videos: A list of videos each with shape [C, T, H, W].
|
628 |
+
"""
|
629 |
+
return self.vae.encode(video, self.scale).float()
|
630 |
+
|
631 |
+
def to(self, *args, **kwargs):
|
632 |
+
self.mean = self.mean.to(*args, **kwargs)
|
633 |
+
self.std = self.std.to(*args, **kwargs)
|
634 |
+
self.scale = [self.mean, 1.0 / self.std]
|
635 |
+
self.vae = self.vae.to(*args, **kwargs)
|
636 |
+
return self
|
637 |
+
|
638 |
+
def decode(self, z):
|
639 |
+
return self.vae.decode(z, self.scale).float().clamp_(-1, 1)
|
skyreels_v2_infer/modules/xlm_roberta.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
|
2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
__all__ = ["XLMRoberta", "xlm_roberta_large"]
|
8 |
+
|
9 |
+
|
10 |
+
class SelfAttention(nn.Module):
|
11 |
+
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
12 |
+
assert dim % num_heads == 0
|
13 |
+
super().__init__()
|
14 |
+
self.dim = dim
|
15 |
+
self.num_heads = num_heads
|
16 |
+
self.head_dim = dim // num_heads
|
17 |
+
self.eps = eps
|
18 |
+
|
19 |
+
# layers
|
20 |
+
self.q = nn.Linear(dim, dim)
|
21 |
+
self.k = nn.Linear(dim, dim)
|
22 |
+
self.v = nn.Linear(dim, dim)
|
23 |
+
self.o = nn.Linear(dim, dim)
|
24 |
+
self.dropout = nn.Dropout(dropout)
|
25 |
+
|
26 |
+
def forward(self, x, mask):
|
27 |
+
"""
|
28 |
+
x: [B, L, C].
|
29 |
+
"""
|
30 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
31 |
+
|
32 |
+
# compute query, key, value
|
33 |
+
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
34 |
+
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
35 |
+
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
36 |
+
|
37 |
+
# compute attention
|
38 |
+
p = self.dropout.p if self.training else 0.0
|
39 |
+
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
40 |
+
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
41 |
+
|
42 |
+
# output
|
43 |
+
x = self.o(x)
|
44 |
+
x = self.dropout(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class AttentionBlock(nn.Module):
|
49 |
+
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
50 |
+
super().__init__()
|
51 |
+
self.dim = dim
|
52 |
+
self.num_heads = num_heads
|
53 |
+
self.post_norm = post_norm
|
54 |
+
self.eps = eps
|
55 |
+
|
56 |
+
# layers
|
57 |
+
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
58 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
59 |
+
self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
|
60 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
61 |
+
|
62 |
+
def forward(self, x, mask):
|
63 |
+
if self.post_norm:
|
64 |
+
x = self.norm1(x + self.attn(x, mask))
|
65 |
+
x = self.norm2(x + self.ffn(x))
|
66 |
+
else:
|
67 |
+
x = x + self.attn(self.norm1(x), mask)
|
68 |
+
x = x + self.ffn(self.norm2(x))
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class XLMRoberta(nn.Module):
|
73 |
+
"""
|
74 |
+
XLMRobertaModel with no pooler and no LM head.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
vocab_size=250002,
|
80 |
+
max_seq_len=514,
|
81 |
+
type_size=1,
|
82 |
+
pad_id=1,
|
83 |
+
dim=1024,
|
84 |
+
num_heads=16,
|
85 |
+
num_layers=24,
|
86 |
+
post_norm=True,
|
87 |
+
dropout=0.1,
|
88 |
+
eps=1e-5,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
self.vocab_size = vocab_size
|
92 |
+
self.max_seq_len = max_seq_len
|
93 |
+
self.type_size = type_size
|
94 |
+
self.pad_id = pad_id
|
95 |
+
self.dim = dim
|
96 |
+
self.num_heads = num_heads
|
97 |
+
self.num_layers = num_layers
|
98 |
+
self.post_norm = post_norm
|
99 |
+
self.eps = eps
|
100 |
+
|
101 |
+
# embeddings
|
102 |
+
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
103 |
+
self.type_embedding = nn.Embedding(type_size, dim)
|
104 |
+
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
105 |
+
self.dropout = nn.Dropout(dropout)
|
106 |
+
|
107 |
+
# blocks
|
108 |
+
self.blocks = nn.ModuleList(
|
109 |
+
[AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)]
|
110 |
+
)
|
111 |
+
|
112 |
+
# norm layer
|
113 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
114 |
+
|
115 |
+
def forward(self, ids):
|
116 |
+
"""
|
117 |
+
ids: [B, L] of torch.LongTensor.
|
118 |
+
"""
|
119 |
+
b, s = ids.shape
|
120 |
+
mask = ids.ne(self.pad_id).long()
|
121 |
+
|
122 |
+
# embeddings
|
123 |
+
x = (
|
124 |
+
self.token_embedding(ids)
|
125 |
+
+ self.type_embedding(torch.zeros_like(ids))
|
126 |
+
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
127 |
+
)
|
128 |
+
if self.post_norm:
|
129 |
+
x = self.norm(x)
|
130 |
+
x = self.dropout(x)
|
131 |
+
|
132 |
+
# blocks
|
133 |
+
mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
|
134 |
+
for block in self.blocks:
|
135 |
+
x = block(x, mask)
|
136 |
+
|
137 |
+
# output
|
138 |
+
if not self.post_norm:
|
139 |
+
x = self.norm(x)
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
|
144 |
+
"""
|
145 |
+
XLMRobertaLarge adapted from Huggingface.
|
146 |
+
"""
|
147 |
+
# params
|
148 |
+
cfg = dict(
|
149 |
+
vocab_size=250002,
|
150 |
+
max_seq_len=514,
|
151 |
+
type_size=1,
|
152 |
+
pad_id=1,
|
153 |
+
dim=1024,
|
154 |
+
num_heads=16,
|
155 |
+
num_layers=24,
|
156 |
+
post_norm=True,
|
157 |
+
dropout=0.1,
|
158 |
+
eps=1e-5,
|
159 |
+
)
|
160 |
+
cfg.update(**kwargs)
|
161 |
+
|
162 |
+
# init a model on device
|
163 |
+
with torch.device(device):
|
164 |
+
model = XLMRoberta(**cfg)
|
165 |
+
return model
|
skyreels_v2_infer/pipelines/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .diffusion_forcing_pipeline import DiffusionForcingPipeline
|
2 |
+
from .image2video_pipeline import Image2VideoPipeline
|
3 |
+
from .image2video_pipeline import resizecrop
|
4 |
+
from .prompt_enhancer import PromptEnhancer
|
5 |
+
from .text2video_pipeline import Text2VideoPipeline
|
skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from typing import List
|
4 |
+
from typing import Optional
|
5 |
+
from typing import Tuple
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from diffusers.image_processor import PipelineImageInput
|
11 |
+
from diffusers.utils.torch_utils import randn_tensor
|
12 |
+
from diffusers.video_processor import VideoProcessor
|
13 |
+
from tqdm import tqdm
|
14 |
+
import decord
|
15 |
+
from decord import VideoReader
|
16 |
+
|
17 |
+
from ..modules import get_text_encoder
|
18 |
+
from ..modules import get_transformer
|
19 |
+
from ..modules import get_vae
|
20 |
+
from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class DiffusionForcingPipeline:
|
26 |
+
"""
|
27 |
+
A pipeline for diffusion-based video generation tasks.
|
28 |
+
|
29 |
+
This pipeline supports two main tasks:
|
30 |
+
- Image-to-Video (i2v): Generates a video sequence from a source image
|
31 |
+
- Text-to-Video (t2v): Generates a video sequence from a text description
|
32 |
+
|
33 |
+
The pipeline integrates multiple components including:
|
34 |
+
- A transformer model for diffusion
|
35 |
+
- A VAE for encoding/decoding
|
36 |
+
- A text encoder for processing text prompts
|
37 |
+
- An image encoder for processing image inputs (i2v mode only)
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
model_path: str,
|
43 |
+
dit_path: str,
|
44 |
+
device: str = "cuda",
|
45 |
+
weight_dtype=torch.bfloat16,
|
46 |
+
use_usp=False,
|
47 |
+
offload=False,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Initialize the diffusion forcing pipeline class
|
51 |
+
|
52 |
+
Args:
|
53 |
+
model_path (str): Path to the model
|
54 |
+
dit_path (str): Path to the DIT model, containing model configuration file (config.json) and weight file (*.safetensor)
|
55 |
+
device (str): Device to run on, defaults to 'cuda'
|
56 |
+
weight_dtype: Weight data type, defaults to torch.bfloat16
|
57 |
+
"""
|
58 |
+
load_device = "cpu" if offload else device
|
59 |
+
self.transformer = get_transformer(dit_path, load_device, weight_dtype)
|
60 |
+
vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
|
61 |
+
self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
|
62 |
+
self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
|
63 |
+
self.video_processor = VideoProcessor(vae_scale_factor=16)
|
64 |
+
self.device = device
|
65 |
+
self.offload = offload
|
66 |
+
|
67 |
+
if use_usp:
|
68 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
69 |
+
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
70 |
+
import types
|
71 |
+
|
72 |
+
for block in self.transformer.blocks:
|
73 |
+
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
74 |
+
self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
|
75 |
+
self.sp_size = get_sequence_parallel_world_size()
|
76 |
+
|
77 |
+
self.scheduler = FlowUniPCMultistepScheduler()
|
78 |
+
|
79 |
+
@property
|
80 |
+
def do_classifier_free_guidance(self) -> bool:
|
81 |
+
return self._guidance_scale > 1
|
82 |
+
|
83 |
+
def encode_image(
|
84 |
+
self, image: PipelineImageInput, height: int, width: int, num_frames: int
|
85 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
86 |
+
|
87 |
+
# prefix_video
|
88 |
+
prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1)
|
89 |
+
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
|
90 |
+
if prefix_video.dtype == torch.uint8:
|
91 |
+
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
|
92 |
+
prefix_video = prefix_video.to(self.device)
|
93 |
+
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
|
94 |
+
causal_block_size = self.transformer.num_frame_per_block
|
95 |
+
if prefix_video[0].shape[1] % causal_block_size != 0:
|
96 |
+
truncate_len = prefix_video[0].shape[1] % causal_block_size
|
97 |
+
print("the length of prefix video is truncated for the casual block size alignment.")
|
98 |
+
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
|
99 |
+
predix_video_latent_length = prefix_video[0].shape[1]
|
100 |
+
return prefix_video, predix_video_latent_length
|
101 |
+
|
102 |
+
def prepare_latents(
|
103 |
+
self,
|
104 |
+
shape: Tuple[int],
|
105 |
+
dtype: Optional[torch.dtype] = None,
|
106 |
+
device: Optional[torch.device] = None,
|
107 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
108 |
+
) -> torch.Tensor:
|
109 |
+
return randn_tensor(shape, generator, device=device, dtype=dtype)
|
110 |
+
|
111 |
+
def generate_timestep_matrix(
|
112 |
+
self,
|
113 |
+
num_frames,
|
114 |
+
step_template,
|
115 |
+
base_num_frames,
|
116 |
+
ar_step=5,
|
117 |
+
num_pre_ready=0,
|
118 |
+
casual_block_size=1,
|
119 |
+
shrink_interval_with_mask=False,
|
120 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
|
121 |
+
step_matrix, step_index = [], []
|
122 |
+
update_mask, valid_interval = [], []
|
123 |
+
num_iterations = len(step_template) + 1
|
124 |
+
num_frames_block = num_frames // casual_block_size
|
125 |
+
base_num_frames_block = base_num_frames // casual_block_size
|
126 |
+
if base_num_frames_block < num_frames_block:
|
127 |
+
infer_step_num = len(step_template)
|
128 |
+
gen_block = base_num_frames_block
|
129 |
+
min_ar_step = infer_step_num / gen_block
|
130 |
+
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
|
131 |
+
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
|
132 |
+
step_template = torch.cat(
|
133 |
+
[
|
134 |
+
torch.tensor([999], dtype=torch.int64, device=step_template.device),
|
135 |
+
step_template.long(),
|
136 |
+
torch.tensor([0], dtype=torch.int64, device=step_template.device),
|
137 |
+
]
|
138 |
+
) # to handle the counter in row works starting from 1
|
139 |
+
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
|
140 |
+
if num_pre_ready > 0:
|
141 |
+
pre_row[: num_pre_ready // casual_block_size] = num_iterations
|
142 |
+
|
143 |
+
while torch.all(pre_row >= (num_iterations - 1)) == False:
|
144 |
+
new_row = torch.zeros(num_frames_block, dtype=torch.long)
|
145 |
+
for i in range(num_frames_block):
|
146 |
+
if i == 0 or pre_row[i - 1] >= (
|
147 |
+
num_iterations - 1
|
148 |
+
): # the first frame or the last frame is completely denoised
|
149 |
+
new_row[i] = pre_row[i] + 1
|
150 |
+
else:
|
151 |
+
new_row[i] = new_row[i - 1] - ar_step
|
152 |
+
new_row = new_row.clamp(0, num_iterations)
|
153 |
+
|
154 |
+
update_mask.append(
|
155 |
+
(new_row != pre_row) & (new_row != num_iterations)
|
156 |
+
) # False: no need to update, True: need to update
|
157 |
+
step_index.append(new_row)
|
158 |
+
step_matrix.append(step_template[new_row])
|
159 |
+
pre_row = new_row
|
160 |
+
|
161 |
+
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
|
162 |
+
terminal_flag = base_num_frames_block
|
163 |
+
if shrink_interval_with_mask:
|
164 |
+
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
|
165 |
+
update_mask = update_mask[0]
|
166 |
+
update_mask_idx = idx_sequence[update_mask]
|
167 |
+
last_update_idx = update_mask_idx[-1].item()
|
168 |
+
terminal_flag = last_update_idx + 1
|
169 |
+
# for i in range(0, len(update_mask)):
|
170 |
+
for curr_mask in update_mask:
|
171 |
+
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
|
172 |
+
terminal_flag += 1
|
173 |
+
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
|
174 |
+
|
175 |
+
step_update_mask = torch.stack(update_mask, dim=0)
|
176 |
+
step_index = torch.stack(step_index, dim=0)
|
177 |
+
step_matrix = torch.stack(step_matrix, dim=0)
|
178 |
+
|
179 |
+
if casual_block_size > 1:
|
180 |
+
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
|
181 |
+
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
|
182 |
+
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
|
183 |
+
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
|
184 |
+
|
185 |
+
return step_matrix, step_index, step_update_mask, valid_interval
|
186 |
+
|
187 |
+
def get_video_as_tensor(self, video_path, width, height):
|
188 |
+
"""
|
189 |
+
Loads a video from the given path and returns it as a tensor with proper channel ordering.
|
190 |
+
Args:
|
191 |
+
video_path (str): Path to the video file
|
192 |
+
Returns:
|
193 |
+
torch.Tensor: Video tensor in [C, T, H, W] format (channels first)
|
194 |
+
"""
|
195 |
+
|
196 |
+
# Set Decord to use CPU for video decoding
|
197 |
+
decord.bridge.set_bridge('torch')
|
198 |
+
|
199 |
+
# Load video
|
200 |
+
vr = VideoReader(video_path, width=width, height=height)
|
201 |
+
total_frames = len(vr)
|
202 |
+
|
203 |
+
# Read all frames
|
204 |
+
video_frames = vr.get_batch(list(range(total_frames)))
|
205 |
+
|
206 |
+
# Convert from [T, H, W, C] to [C, T, H, W] format
|
207 |
+
video_tensor = video_frames.permute(0, 3, 1, 2).float()
|
208 |
+
|
209 |
+
return video_tensor
|
210 |
+
|
211 |
+
@torch.no_grad()
|
212 |
+
def extend_video(
|
213 |
+
self,
|
214 |
+
prompt: Union[str, List[str]],
|
215 |
+
negative_prompt: Union[str, List[str]] = "",
|
216 |
+
prefix_video_path: List[torch.Tensor] = None,
|
217 |
+
height: int = 480,
|
218 |
+
width: int = 832,
|
219 |
+
num_frames: int = 97,
|
220 |
+
num_inference_steps: int = 50,
|
221 |
+
shift: float = 1.0,
|
222 |
+
guidance_scale: float = 5.0,
|
223 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
224 |
+
overlap_history: int = None,
|
225 |
+
addnoise_condition: int = 0,
|
226 |
+
base_num_frames: int = 97,
|
227 |
+
ar_step: int = 5,
|
228 |
+
causal_block_size: int = None,
|
229 |
+
fps: int = 24,
|
230 |
+
):
|
231 |
+
latent_height = height // 8
|
232 |
+
latent_width = width // 8
|
233 |
+
latent_length = (num_frames - 1) // 4 + 1
|
234 |
+
|
235 |
+
self._guidance_scale = guidance_scale
|
236 |
+
|
237 |
+
i2v_extra_kwrags = {}
|
238 |
+
prefix_video = None
|
239 |
+
predix_video_latent_length = 0
|
240 |
+
|
241 |
+
self.text_encoder.to(self.device)
|
242 |
+
prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
|
243 |
+
if self.do_classifier_free_guidance:
|
244 |
+
negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype)
|
245 |
+
if self.offload:
|
246 |
+
self.text_encoder.cpu()
|
247 |
+
torch.cuda.empty_cache()
|
248 |
+
|
249 |
+
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
|
250 |
+
init_timesteps = self.scheduler.timesteps
|
251 |
+
if causal_block_size is None:
|
252 |
+
causal_block_size = self.transformer.num_frame_per_block
|
253 |
+
fps_embeds = [fps] * prompt_embeds.shape[0]
|
254 |
+
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
|
255 |
+
transformer_dtype = self.transformer.dtype
|
256 |
+
# with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
|
257 |
+
|
258 |
+
prefix_video = self.get_video_as_tensor(prefix_video_path, width, height)
|
259 |
+
prefix_frame = torch.tensor(prefix_video, device=self.device)
|
260 |
+
start_video = (prefix_frame.float() / (255.0 / 2.0)) - 1.0
|
261 |
+
start_video = start_video.transpose(0, 1)
|
262 |
+
|
263 |
+
# long video generation
|
264 |
+
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
|
265 |
+
overlap_history_frames = (overlap_history - 1) // 4 + 1
|
266 |
+
n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
|
267 |
+
print(f"n_iter:{n_iter}")
|
268 |
+
output_video = start_video.cpu()
|
269 |
+
for i in range(n_iter):
|
270 |
+
prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device)
|
271 |
+
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
|
272 |
+
if prefix_video[0].shape[1] % causal_block_size != 0:
|
273 |
+
truncate_len = prefix_video[0].shape[1] % causal_block_size
|
274 |
+
print("the length of prefix video is truncated for the casual block size alignment.")
|
275 |
+
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
|
276 |
+
predix_video_latent_length = prefix_video[0].shape[1]
|
277 |
+
finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
|
278 |
+
left_frame_num = latent_length - finished_frame_num
|
279 |
+
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
|
280 |
+
if ar_step > 0 and self.transformer.enable_teacache:
|
281 |
+
num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
|
282 |
+
self.transformer.num_steps = num_steps
|
283 |
+
|
284 |
+
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
|
285 |
+
latents = self.prepare_latents(
|
286 |
+
latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
|
287 |
+
)
|
288 |
+
latents = [latents]
|
289 |
+
if prefix_video is not None:
|
290 |
+
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
|
291 |
+
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
292 |
+
base_num_frames_iter,
|
293 |
+
init_timesteps,
|
294 |
+
base_num_frames_iter,
|
295 |
+
ar_step,
|
296 |
+
predix_video_latent_length,
|
297 |
+
causal_block_size,
|
298 |
+
)
|
299 |
+
sample_schedulers = []
|
300 |
+
for _ in range(base_num_frames_iter):
|
301 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
302 |
+
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
|
303 |
+
)
|
304 |
+
sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
|
305 |
+
sample_schedulers.append(sample_scheduler)
|
306 |
+
sample_schedulers_counter = [0] * base_num_frames_iter
|
307 |
+
self.transformer.to(self.device)
|
308 |
+
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
309 |
+
update_mask_i = step_update_mask[i]
|
310 |
+
valid_interval_i = valid_interval[i]
|
311 |
+
valid_interval_start, valid_interval_end = valid_interval_i
|
312 |
+
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
313 |
+
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
|
314 |
+
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
|
315 |
+
noise_factor = 0.001 * addnoise_condition
|
316 |
+
timestep_for_noised_condition = addnoise_condition
|
317 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
|
318 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
|
319 |
+
* (1.0 - noise_factor)
|
320 |
+
+ torch.randn_like(
|
321 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
|
322 |
+
)
|
323 |
+
* noise_factor
|
324 |
+
)
|
325 |
+
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
326 |
+
if not self.do_classifier_free_guidance:
|
327 |
+
noise_pred = self.transformer(
|
328 |
+
torch.stack([latent_model_input[0]]),
|
329 |
+
t=timestep,
|
330 |
+
context=prompt_embeds,
|
331 |
+
fps=fps_embeds,
|
332 |
+
**i2v_extra_kwrags,
|
333 |
+
)[0]
|
334 |
+
else:
|
335 |
+
noise_pred_cond = self.transformer(
|
336 |
+
torch.stack([latent_model_input[0]]),
|
337 |
+
t=timestep,
|
338 |
+
context=prompt_embeds,
|
339 |
+
fps=fps_embeds,
|
340 |
+
**i2v_extra_kwrags,
|
341 |
+
)[0]
|
342 |
+
noise_pred_uncond = self.transformer(
|
343 |
+
torch.stack([latent_model_input[0]]),
|
344 |
+
t=timestep,
|
345 |
+
context=negative_prompt_embeds,
|
346 |
+
fps=fps_embeds,
|
347 |
+
**i2v_extra_kwrags,
|
348 |
+
)[0]
|
349 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
350 |
+
for idx in range(valid_interval_start, valid_interval_end):
|
351 |
+
if update_mask_i[idx].item():
|
352 |
+
latents[0][:, idx] = sample_schedulers[idx].step(
|
353 |
+
noise_pred[:, idx - valid_interval_start],
|
354 |
+
timestep_i[idx],
|
355 |
+
latents[0][:, idx],
|
356 |
+
return_dict=False,
|
357 |
+
generator=generator,
|
358 |
+
)[0]
|
359 |
+
sample_schedulers_counter[idx] += 1
|
360 |
+
if self.offload:
|
361 |
+
self.transformer.cpu()
|
362 |
+
torch.cuda.empty_cache()
|
363 |
+
x0 = latents[0].unsqueeze(0)
|
364 |
+
videos = [self.vae.decode(x0)[0]]
|
365 |
+
if output_video is None:
|
366 |
+
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
|
367 |
+
else:
|
368 |
+
output_video = torch.cat(
|
369 |
+
[output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
|
370 |
+
) # c, f, h, w
|
371 |
+
output_video = [(output_video / 2 + 0.5).clamp(0, 1)]
|
372 |
+
output_video = [video for video in output_video]
|
373 |
+
output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video]
|
374 |
+
output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video]
|
375 |
+
|
376 |
+
return output_video
|
377 |
+
|
378 |
+
|
379 |
+
@torch.no_grad()
|
380 |
+
def __call__(
|
381 |
+
self,
|
382 |
+
prompt: Union[str, List[str]],
|
383 |
+
negative_prompt: Union[str, List[str]] = "",
|
384 |
+
image: PipelineImageInput = None,
|
385 |
+
end_image: PipelineImageInput = None,
|
386 |
+
height: int = 480,
|
387 |
+
width: int = 832,
|
388 |
+
num_frames: int = 97,
|
389 |
+
num_inference_steps: int = 50,
|
390 |
+
shift: float = 1.0,
|
391 |
+
guidance_scale: float = 5.0,
|
392 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
393 |
+
overlap_history: int = None,
|
394 |
+
addnoise_condition: int = 0,
|
395 |
+
base_num_frames: int = 97,
|
396 |
+
ar_step: int = 5,
|
397 |
+
causal_block_size: int = None,
|
398 |
+
fps: int = 24,
|
399 |
+
):
|
400 |
+
latent_height = height // 8
|
401 |
+
latent_width = width // 8
|
402 |
+
latent_length = (num_frames - 1) // 4 + 1
|
403 |
+
|
404 |
+
self._guidance_scale = guidance_scale
|
405 |
+
|
406 |
+
i2v_extra_kwrags = {}
|
407 |
+
prefix_video = None
|
408 |
+
predix_video_latent_length = 0
|
409 |
+
end_video = None
|
410 |
+
end_video_latent_length = 0
|
411 |
+
|
412 |
+
if image:
|
413 |
+
prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames)
|
414 |
+
|
415 |
+
if end_image:
|
416 |
+
end_video, end_video_latent_length = self.encode_image(end_image, height, width, num_frames)
|
417 |
+
|
418 |
+
self.text_encoder.to(self.device)
|
419 |
+
prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
|
420 |
+
if self.do_classifier_free_guidance:
|
421 |
+
negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype)
|
422 |
+
if self.offload:
|
423 |
+
self.text_encoder.cpu()
|
424 |
+
torch.cuda.empty_cache()
|
425 |
+
|
426 |
+
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
|
427 |
+
init_timesteps = self.scheduler.timesteps
|
428 |
+
if causal_block_size is None:
|
429 |
+
causal_block_size = self.transformer.num_frame_per_block
|
430 |
+
fps_embeds = [fps] * prompt_embeds.shape[0]
|
431 |
+
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
|
432 |
+
transformer_dtype = self.transformer.dtype
|
433 |
+
# with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
|
434 |
+
if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames:
|
435 |
+
# short video generation
|
436 |
+
latent_shape = [16, latent_length, latent_height, latent_width]
|
437 |
+
latents = self.prepare_latents(
|
438 |
+
latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
|
439 |
+
)
|
440 |
+
latents = [latents]
|
441 |
+
if prefix_video is not None:
|
442 |
+
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
|
443 |
+
|
444 |
+
if end_video is not None:
|
445 |
+
latents[0] = torch.cat([latents[0], end_video[0].to(transformer_dtype)], dim=1)
|
446 |
+
|
447 |
+
base_num_frames = num_frames
|
448 |
+
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
|
449 |
+
if end_video is not None:
|
450 |
+
base_num_frames += end_video_latent_length
|
451 |
+
latent_length += end_video_latent_length
|
452 |
+
|
453 |
+
|
454 |
+
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
455 |
+
latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size
|
456 |
+
)
|
457 |
+
if end_video is not None:
|
458 |
+
step_matrix[:, -end_video_latent_length:] = 0
|
459 |
+
step_update_mask[:, -end_video_latent_length:] = False
|
460 |
+
|
461 |
+
sample_schedulers = []
|
462 |
+
for _ in range(latent_length):
|
463 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
464 |
+
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
|
465 |
+
)
|
466 |
+
sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
|
467 |
+
sample_schedulers.append(sample_scheduler)
|
468 |
+
sample_schedulers_counter = [0] * latent_length
|
469 |
+
self.transformer.to(self.device)
|
470 |
+
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
471 |
+
update_mask_i = step_update_mask[i]
|
472 |
+
valid_interval_i = valid_interval[i]
|
473 |
+
valid_interval_start, valid_interval_end = valid_interval_i
|
474 |
+
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
475 |
+
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
|
476 |
+
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
|
477 |
+
noise_factor = 0.001 * addnoise_condition
|
478 |
+
timestep_for_noised_condition = addnoise_condition
|
479 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
|
480 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor)
|
481 |
+
+ torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length])
|
482 |
+
* noise_factor
|
483 |
+
)
|
484 |
+
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
485 |
+
if not self.do_classifier_free_guidance:
|
486 |
+
noise_pred = self.transformer(
|
487 |
+
torch.stack([latent_model_input[0]]),
|
488 |
+
t=timestep,
|
489 |
+
context=prompt_embeds,
|
490 |
+
fps=fps_embeds,
|
491 |
+
**i2v_extra_kwrags,
|
492 |
+
)[0]
|
493 |
+
else:
|
494 |
+
noise_pred_cond = self.transformer(
|
495 |
+
torch.stack([latent_model_input[0]]),
|
496 |
+
t=timestep,
|
497 |
+
context=prompt_embeds,
|
498 |
+
fps=fps_embeds,
|
499 |
+
**i2v_extra_kwrags,
|
500 |
+
)[0]
|
501 |
+
noise_pred_uncond = self.transformer(
|
502 |
+
torch.stack([latent_model_input[0]]),
|
503 |
+
t=timestep,
|
504 |
+
context=negative_prompt_embeds,
|
505 |
+
fps=fps_embeds,
|
506 |
+
**i2v_extra_kwrags,
|
507 |
+
)[0]
|
508 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
509 |
+
for idx in range(valid_interval_start, valid_interval_end):
|
510 |
+
if update_mask_i[idx].item():
|
511 |
+
latents[0][:, idx] = sample_schedulers[idx].step(
|
512 |
+
noise_pred[:, idx - valid_interval_start],
|
513 |
+
timestep_i[idx],
|
514 |
+
latents[0][:, idx],
|
515 |
+
return_dict=False,
|
516 |
+
generator=generator,
|
517 |
+
)[0]
|
518 |
+
sample_schedulers_counter[idx] += 1
|
519 |
+
if self.offload:
|
520 |
+
self.transformer.cpu()
|
521 |
+
torch.cuda.empty_cache()
|
522 |
+
x0 = latents[0].unsqueeze(0)
|
523 |
+
if end_video is not None:
|
524 |
+
x0 = latents[0][:, :-end_video_latent_length].unsqueeze(0)
|
525 |
+
|
526 |
+
videos = self.vae.decode(x0)
|
527 |
+
videos = (videos / 2 + 0.5).clamp(0, 1)
|
528 |
+
videos = [video for video in videos]
|
529 |
+
videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
|
530 |
+
videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
|
531 |
+
return videos
|
532 |
+
else:
|
533 |
+
# long video generation
|
534 |
+
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
|
535 |
+
overlap_history_frames = (overlap_history - 1) // 4 + 1
|
536 |
+
n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
|
537 |
+
print(f"n_iter:{n_iter}")
|
538 |
+
output_video = None
|
539 |
+
for i in range(n_iter):
|
540 |
+
if output_video is not None: # i !=0
|
541 |
+
prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device)
|
542 |
+
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
|
543 |
+
if prefix_video[0].shape[1] % causal_block_size != 0:
|
544 |
+
truncate_len = prefix_video[0].shape[1] % causal_block_size
|
545 |
+
print("the length of prefix video is truncated for the casual block size alignment.")
|
546 |
+
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
|
547 |
+
predix_video_latent_length = prefix_video[0].shape[1]
|
548 |
+
finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
|
549 |
+
left_frame_num = latent_length - finished_frame_num
|
550 |
+
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
|
551 |
+
if ar_step > 0 and self.transformer.enable_teacache:
|
552 |
+
num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
|
553 |
+
self.transformer.num_steps = num_steps
|
554 |
+
else: # i == 0
|
555 |
+
base_num_frames_iter = base_num_frames
|
556 |
+
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
|
557 |
+
latents = self.prepare_latents(
|
558 |
+
latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
|
559 |
+
)
|
560 |
+
latents = [latents]
|
561 |
+
if prefix_video is not None:
|
562 |
+
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
|
563 |
+
|
564 |
+
if end_video is not None and i == n_iter - 1:
|
565 |
+
base_num_frames_iter += end_video_latent_length
|
566 |
+
latents[0] = torch.cat([latents[0], end_video[0].to(transformer_dtype)], dim=1)
|
567 |
+
|
568 |
+
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
569 |
+
base_num_frames_iter,
|
570 |
+
init_timesteps,
|
571 |
+
base_num_frames_iter,
|
572 |
+
ar_step,
|
573 |
+
predix_video_latent_length,
|
574 |
+
causal_block_size,
|
575 |
+
)
|
576 |
+
if end_video is not None and i == n_iter - 1:
|
577 |
+
step_matrix[:, -end_video_latent_length:] = 0
|
578 |
+
step_update_mask[:, -end_video_latent_length:] = False
|
579 |
+
|
580 |
+
sample_schedulers = []
|
581 |
+
for _ in range(base_num_frames_iter):
|
582 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
583 |
+
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
|
584 |
+
)
|
585 |
+
sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
|
586 |
+
sample_schedulers.append(sample_scheduler)
|
587 |
+
sample_schedulers_counter = [0] * base_num_frames_iter
|
588 |
+
self.transformer.to(self.device)
|
589 |
+
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
590 |
+
update_mask_i = step_update_mask[i]
|
591 |
+
valid_interval_i = valid_interval[i]
|
592 |
+
valid_interval_start, valid_interval_end = valid_interval_i
|
593 |
+
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
594 |
+
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
|
595 |
+
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
|
596 |
+
noise_factor = 0.001 * addnoise_condition
|
597 |
+
timestep_for_noised_condition = addnoise_condition
|
598 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
|
599 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
|
600 |
+
* (1.0 - noise_factor)
|
601 |
+
+ torch.randn_like(
|
602 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
|
603 |
+
)
|
604 |
+
* noise_factor
|
605 |
+
)
|
606 |
+
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
607 |
+
if not self.do_classifier_free_guidance:
|
608 |
+
noise_pred = self.transformer(
|
609 |
+
torch.stack([latent_model_input[0]]),
|
610 |
+
t=timestep,
|
611 |
+
context=prompt_embeds,
|
612 |
+
fps=fps_embeds,
|
613 |
+
**i2v_extra_kwrags,
|
614 |
+
)[0]
|
615 |
+
else:
|
616 |
+
noise_pred_cond = self.transformer(
|
617 |
+
torch.stack([latent_model_input[0]]),
|
618 |
+
t=timestep,
|
619 |
+
context=prompt_embeds,
|
620 |
+
fps=fps_embeds,
|
621 |
+
**i2v_extra_kwrags,
|
622 |
+
)[0]
|
623 |
+
noise_pred_uncond = self.transformer(
|
624 |
+
torch.stack([latent_model_input[0]]),
|
625 |
+
t=timestep,
|
626 |
+
context=negative_prompt_embeds,
|
627 |
+
fps=fps_embeds,
|
628 |
+
**i2v_extra_kwrags,
|
629 |
+
)[0]
|
630 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
631 |
+
for idx in range(valid_interval_start, valid_interval_end):
|
632 |
+
if update_mask_i[idx].item():
|
633 |
+
latents[0][:, idx] = sample_schedulers[idx].step(
|
634 |
+
noise_pred[:, idx - valid_interval_start],
|
635 |
+
timestep_i[idx],
|
636 |
+
latents[0][:, idx],
|
637 |
+
return_dict=False,
|
638 |
+
generator=generator,
|
639 |
+
)[0]
|
640 |
+
sample_schedulers_counter[idx] += 1
|
641 |
+
if self.offload:
|
642 |
+
self.transformer.cpu()
|
643 |
+
torch.cuda.empty_cache()
|
644 |
+
x0 = latents[0].unsqueeze(0)
|
645 |
+
if end_video is not None and i == n_iter - 1:
|
646 |
+
x0 = latents[0][:, :-end_video_latent_length].unsqueeze(0)
|
647 |
+
|
648 |
+
videos = [self.vae.decode(x0)[0]]
|
649 |
+
if output_video is None:
|
650 |
+
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
|
651 |
+
else:
|
652 |
+
output_video = torch.cat(
|
653 |
+
[output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
|
654 |
+
) # c, f, h, w
|
655 |
+
output_video = [(output_video / 2 + 0.5).clamp(0, 1)]
|
656 |
+
output_video = [video for video in output_video]
|
657 |
+
output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video]
|
658 |
+
output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video]
|
659 |
+
return output_video
|
skyreels_v2_infer/pipelines/image2video_pipeline.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from typing import Optional
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers.image_processor import PipelineImageInput
|
9 |
+
from diffusers.video_processor import VideoProcessor
|
10 |
+
from PIL import Image
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from ..modules import get_image_encoder
|
14 |
+
from ..modules import get_text_encoder
|
15 |
+
from ..modules import get_transformer
|
16 |
+
from ..modules import get_vae
|
17 |
+
from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
18 |
+
|
19 |
+
|
20 |
+
def resizecrop(image: Image.Image, th, tw):
|
21 |
+
w, h = image.size
|
22 |
+
if w == tw and h == th:
|
23 |
+
return image
|
24 |
+
if h / w > th / tw:
|
25 |
+
new_w = int(w)
|
26 |
+
new_h = int(new_w * th / tw)
|
27 |
+
else:
|
28 |
+
new_h = int(h)
|
29 |
+
new_w = int(new_h * tw / th)
|
30 |
+
left = (w - new_w) / 2
|
31 |
+
top = (h - new_h) / 2
|
32 |
+
right = (w + new_w) / 2
|
33 |
+
bottom = (h + new_h) / 2
|
34 |
+
image = image.crop((left, top, right, bottom))
|
35 |
+
return image
|
36 |
+
|
37 |
+
|
38 |
+
class Image2VideoPipeline:
|
39 |
+
def __init__(
|
40 |
+
self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
|
41 |
+
):
|
42 |
+
load_device = "cpu" if offload else device
|
43 |
+
self.transformer = get_transformer(dit_path, load_device, weight_dtype)
|
44 |
+
vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
|
45 |
+
self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
|
46 |
+
self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
|
47 |
+
self.clip = get_image_encoder(model_path, load_device, weight_dtype)
|
48 |
+
self.sp_size = 1
|
49 |
+
self.device = device
|
50 |
+
self.offload = offload
|
51 |
+
self.video_processor = VideoProcessor(vae_scale_factor=16)
|
52 |
+
if use_usp:
|
53 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
54 |
+
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
55 |
+
import types
|
56 |
+
|
57 |
+
for block in self.transformer.blocks:
|
58 |
+
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
59 |
+
self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
|
60 |
+
self.sp_size = get_sequence_parallel_world_size()
|
61 |
+
|
62 |
+
self.scheduler = FlowUniPCMultistepScheduler()
|
63 |
+
self.vae_stride = (4, 8, 8)
|
64 |
+
self.patch_size = (1, 2, 2)
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def __call__(
|
68 |
+
self,
|
69 |
+
image: PipelineImageInput,
|
70 |
+
prompt: Union[str, List[str]] = None,
|
71 |
+
negative_prompt: Union[str, List[str]] = None,
|
72 |
+
height: int = 544,
|
73 |
+
width: int = 960,
|
74 |
+
num_frames: int = 97,
|
75 |
+
num_inference_steps: int = 50,
|
76 |
+
guidance_scale: float = 5.0,
|
77 |
+
shift: float = 5.0,
|
78 |
+
generator: Optional[torch.Generator] = None,
|
79 |
+
):
|
80 |
+
F = num_frames
|
81 |
+
|
82 |
+
latent_height = height // 8 // 2 * 2
|
83 |
+
latent_width = width // 8 // 2 * 2
|
84 |
+
latent_length = (F - 1) // 4 + 1
|
85 |
+
|
86 |
+
h = latent_height * 8
|
87 |
+
w = latent_width * 8
|
88 |
+
|
89 |
+
img = self.video_processor.preprocess(image, height=h, width=w)
|
90 |
+
|
91 |
+
img = img.to(device=self.device, dtype=self.transformer.dtype)
|
92 |
+
|
93 |
+
padding_video = torch.zeros(img.shape[0], 3, F - 1, h, w, device=self.device)
|
94 |
+
|
95 |
+
img = img.unsqueeze(2)
|
96 |
+
img_cond = torch.concat([img, padding_video], dim=2)
|
97 |
+
img_cond = self.vae.encode(img_cond)
|
98 |
+
mask = torch.ones_like(img_cond)
|
99 |
+
mask[:, :, 1:] = 0
|
100 |
+
y = torch.cat([mask[:, :4], img_cond], dim=1)
|
101 |
+
self.clip.to(self.device)
|
102 |
+
clip_context = self.clip.encode_video(img)
|
103 |
+
if self.offload:
|
104 |
+
self.clip.cpu()
|
105 |
+
torch.cuda.empty_cache()
|
106 |
+
|
107 |
+
# preprocess
|
108 |
+
self.text_encoder.to(self.device)
|
109 |
+
context = self.text_encoder.encode(prompt).to(self.device)
|
110 |
+
context_null = self.text_encoder.encode(negative_prompt).to(self.device)
|
111 |
+
if self.offload:
|
112 |
+
self.text_encoder.cpu()
|
113 |
+
torch.cuda.empty_cache()
|
114 |
+
|
115 |
+
latent = torch.randn(
|
116 |
+
16, latent_length, latent_height, latent_width, dtype=torch.float32, generator=generator, device=self.device
|
117 |
+
)
|
118 |
+
|
119 |
+
self.transformer.to(self.device)
|
120 |
+
with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
|
121 |
+
self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
|
122 |
+
timesteps = self.scheduler.timesteps
|
123 |
+
|
124 |
+
arg_c = {
|
125 |
+
"context": context,
|
126 |
+
"clip_fea": clip_context,
|
127 |
+
"y": y,
|
128 |
+
}
|
129 |
+
|
130 |
+
arg_null = {
|
131 |
+
"context": context_null,
|
132 |
+
"clip_fea": clip_context,
|
133 |
+
"y": y,
|
134 |
+
}
|
135 |
+
|
136 |
+
self.transformer.to(self.device)
|
137 |
+
for _, t in enumerate(tqdm(timesteps)):
|
138 |
+
latent_model_input = torch.stack([latent]).to(self.device)
|
139 |
+
timestep = torch.stack([t]).to(self.device)
|
140 |
+
noise_pred_cond = self.transformer(latent_model_input, t=timestep, **arg_c)[0].to(self.device)
|
141 |
+
noise_pred_uncond = self.transformer(latent_model_input, t=timestep, **arg_null)[0].to(self.device)
|
142 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
143 |
+
|
144 |
+
temp_x0 = self.scheduler.step(
|
145 |
+
noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator
|
146 |
+
)[0]
|
147 |
+
latent = temp_x0.squeeze(0)
|
148 |
+
if self.offload:
|
149 |
+
self.transformer.cpu()
|
150 |
+
torch.cuda.empty_cache()
|
151 |
+
videos = self.vae.decode(latent)
|
152 |
+
videos = (videos / 2 + 0.5).clamp(0, 1)
|
153 |
+
videos = [video for video in videos]
|
154 |
+
videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
|
155 |
+
videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
|
156 |
+
return videos
|
skyreels_v2_infer/pipelines/prompt_enhancer.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
sys_prompt = """
|
5 |
+
Transform the short prompt into a detailed video-generation caption using this structure:
|
6 |
+
Opening shot type (long/medium/close-up/extreme close-up/full shot)
|
7 |
+
Primary subject(s) with vivid attributes (colors, textures, actions, interactions)
|
8 |
+
Dynamic elements (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...')
|
9 |
+
Scene composition (background, environment, spatial relationships)
|
10 |
+
Lighting/atmosphere (natural/artificial, time of day, mood)
|
11 |
+
Camera motion (zooms, pans, static/handheld shots) if applicable.
|
12 |
+
|
13 |
+
Pattern Summary from Examples:
|
14 |
+
[Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement]
|
15 |
+
|
16 |
+
One case:
|
17 |
+
Short prompt: a person is playing football
|
18 |
+
Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan.
|
19 |
+
|
20 |
+
Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic.
|
21 |
+
|
22 |
+
Now expand this short prompt: [{}]. Please only output the final long prompt in English.
|
23 |
+
"""
|
24 |
+
|
25 |
+
class PromptEnhancer:
|
26 |
+
def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct"):
|
27 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
28 |
+
model_name,
|
29 |
+
torch_dtype="auto",
|
30 |
+
device_map="cuda:0",
|
31 |
+
)
|
32 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
+
|
34 |
+
def __call__(self, prompt):
|
35 |
+
prompt = prompt.strip()
|
36 |
+
prompt = sys_prompt.format(prompt)
|
37 |
+
messages = [
|
38 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
39 |
+
{"role": "user", "content": prompt}
|
40 |
+
]
|
41 |
+
text = self.tokenizer.apply_chat_template(
|
42 |
+
messages,
|
43 |
+
tokenize=False,
|
44 |
+
add_generation_prompt=True
|
45 |
+
)
|
46 |
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
47 |
+
generated_ids = self.model.generate(
|
48 |
+
**model_inputs,
|
49 |
+
max_new_tokens=2048,
|
50 |
+
)
|
51 |
+
generated_ids = [
|
52 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
53 |
+
]
|
54 |
+
rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
55 |
+
return rewritten_prompt
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
parser = argparse.ArgumentParser()
|
59 |
+
parser.add_argument("--prompt", type=str, default="In a still frame, a stop sign")
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
prompt_enhancer = PromptEnhancer()
|
63 |
+
enhanced_prompt = prompt_enhancer(args.prompt)
|
64 |
+
print(f'Original prompt: {args.prompt}')
|
65 |
+
print(f'Enhanced prompt: {enhanced_prompt}')
|
skyreels_v2_infer/pipelines/text2video_pipeline.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from typing import Optional
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers.video_processor import VideoProcessor
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from ..modules import get_text_encoder
|
12 |
+
from ..modules import get_transformer
|
13 |
+
from ..modules import get_vae
|
14 |
+
from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
15 |
+
|
16 |
+
|
17 |
+
class Text2VideoPipeline:
|
18 |
+
def __init__(
|
19 |
+
self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
|
20 |
+
):
|
21 |
+
load_device = "cpu" if offload else device
|
22 |
+
self.transformer = get_transformer(dit_path, load_device, weight_dtype)
|
23 |
+
vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
|
24 |
+
self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
|
25 |
+
self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
|
26 |
+
self.video_processor = VideoProcessor(vae_scale_factor=16)
|
27 |
+
self.sp_size = 1
|
28 |
+
self.device = device
|
29 |
+
self.offload = offload
|
30 |
+
if use_usp:
|
31 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
32 |
+
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
33 |
+
import types
|
34 |
+
|
35 |
+
for block in self.transformer.blocks:
|
36 |
+
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
37 |
+
self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
|
38 |
+
self.sp_size = get_sequence_parallel_world_size()
|
39 |
+
|
40 |
+
self.scheduler = FlowUniPCMultistepScheduler()
|
41 |
+
self.vae_stride = (4, 8, 8)
|
42 |
+
self.patch_size = (1, 2, 2)
|
43 |
+
|
44 |
+
@torch.no_grad()
|
45 |
+
def __call__(
|
46 |
+
self,
|
47 |
+
prompt: Union[str, List[str]] = None,
|
48 |
+
negative_prompt: Union[str, List[str]] = None,
|
49 |
+
width: int = 544,
|
50 |
+
height: int = 960,
|
51 |
+
num_frames: int = 97,
|
52 |
+
num_inference_steps: int = 50,
|
53 |
+
guidance_scale: float = 5.0,
|
54 |
+
shift: float = 5.0,
|
55 |
+
generator: Optional[torch.Generator] = None,
|
56 |
+
):
|
57 |
+
# preprocess
|
58 |
+
F = num_frames
|
59 |
+
target_shape = (
|
60 |
+
self.vae.vae.z_dim,
|
61 |
+
(F - 1) // self.vae_stride[0] + 1,
|
62 |
+
height // self.vae_stride[1],
|
63 |
+
width // self.vae_stride[2],
|
64 |
+
)
|
65 |
+
self.text_encoder.to(self.device)
|
66 |
+
context = self.text_encoder.encode(prompt).to(self.device)
|
67 |
+
context_null = self.text_encoder.encode(negative_prompt).to(self.device)
|
68 |
+
if self.offload:
|
69 |
+
self.text_encoder.cpu()
|
70 |
+
torch.cuda.empty_cache()
|
71 |
+
|
72 |
+
latents = [
|
73 |
+
torch.randn(
|
74 |
+
target_shape[0],
|
75 |
+
target_shape[1],
|
76 |
+
target_shape[2],
|
77 |
+
target_shape[3],
|
78 |
+
dtype=torch.float32,
|
79 |
+
device=self.device,
|
80 |
+
generator=generator,
|
81 |
+
)
|
82 |
+
]
|
83 |
+
|
84 |
+
# evaluation mode
|
85 |
+
self.transformer.to(self.device)
|
86 |
+
with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
|
87 |
+
self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
|
88 |
+
timesteps = self.scheduler.timesteps
|
89 |
+
|
90 |
+
for _, t in enumerate(tqdm(timesteps)):
|
91 |
+
latent_model_input = torch.stack(latents)
|
92 |
+
timestep = torch.stack([t])
|
93 |
+
noise_pred_cond = self.transformer(latent_model_input, t=timestep, context=context)[0]
|
94 |
+
noise_pred_uncond = self.transformer(latent_model_input, t=timestep, context=context_null)[0]
|
95 |
+
|
96 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
97 |
+
|
98 |
+
temp_x0 = self.scheduler.step(
|
99 |
+
noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=generator
|
100 |
+
)[0]
|
101 |
+
latents = [temp_x0.squeeze(0)]
|
102 |
+
if self.offload:
|
103 |
+
self.transformer.cpu()
|
104 |
+
torch.cuda.empty_cache()
|
105 |
+
videos = self.vae.decode(latents[0])
|
106 |
+
videos = (videos / 2 + 0.5).clamp(0, 1)
|
107 |
+
videos = [video for video in videos]
|
108 |
+
videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
|
109 |
+
videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
|
110 |
+
return videos
|
skyreels_v2_infer/scheduler/__init__.py
ADDED
File without changes
|
skyreels_v2_infer/scheduler/fm_solvers_unipc.py
ADDED
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
|
2 |
+
# Convert unipc for flow matching
|
3 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
4 |
+
import math
|
5 |
+
from typing import List
|
6 |
+
from typing import Optional
|
7 |
+
from typing import Tuple
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from diffusers.configuration_utils import ConfigMixin
|
13 |
+
from diffusers.configuration_utils import register_to_config
|
14 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
|
15 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
16 |
+
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
17 |
+
from diffusers.utils import deprecate
|
18 |
+
|
19 |
+
|
20 |
+
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
21 |
+
"""
|
22 |
+
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
|
23 |
+
|
24 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
25 |
+
methods the library implements for all schedulers such as loading and saving.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
num_train_timesteps (`int`, defaults to 1000):
|
29 |
+
The number of diffusion steps to train the model.
|
30 |
+
solver_order (`int`, default `2`):
|
31 |
+
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
|
32 |
+
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
|
33 |
+
unconditional sampling.
|
34 |
+
prediction_type (`str`, defaults to "flow_prediction"):
|
35 |
+
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
|
36 |
+
the flow of the diffusion process.
|
37 |
+
thresholding (`bool`, defaults to `False`):
|
38 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
39 |
+
as Stable Diffusion.
|
40 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
41 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
42 |
+
sample_max_value (`float`, defaults to 1.0):
|
43 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
|
44 |
+
predict_x0 (`bool`, defaults to `True`):
|
45 |
+
Whether to use the updating algorithm on the predicted x0.
|
46 |
+
solver_type (`str`, default `bh2`):
|
47 |
+
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
|
48 |
+
otherwise.
|
49 |
+
lower_order_final (`bool`, default `True`):
|
50 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
51 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
52 |
+
disable_corrector (`list`, default `[]`):
|
53 |
+
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
|
54 |
+
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
|
55 |
+
usually disabled during the first few steps.
|
56 |
+
solver_p (`SchedulerMixin`, default `None`):
|
57 |
+
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
|
58 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
59 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
60 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
61 |
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
62 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
63 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
64 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
65 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
66 |
+
steps_offset (`int`, defaults to 0):
|
67 |
+
An offset added to the inference steps, as required by some model families.
|
68 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
69 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
70 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
71 |
+
"""
|
72 |
+
|
73 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
74 |
+
order = 1
|
75 |
+
|
76 |
+
@register_to_config
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
num_train_timesteps: int = 1000,
|
80 |
+
solver_order: int = 2,
|
81 |
+
prediction_type: str = "flow_prediction",
|
82 |
+
shift: Optional[float] = 1.0,
|
83 |
+
use_dynamic_shifting=False,
|
84 |
+
thresholding: bool = False,
|
85 |
+
dynamic_thresholding_ratio: float = 0.995,
|
86 |
+
sample_max_value: float = 1.0,
|
87 |
+
predict_x0: bool = True,
|
88 |
+
solver_type: str = "bh2",
|
89 |
+
lower_order_final: bool = True,
|
90 |
+
disable_corrector: List[int] = [],
|
91 |
+
solver_p: SchedulerMixin = None,
|
92 |
+
timestep_spacing: str = "linspace",
|
93 |
+
steps_offset: int = 0,
|
94 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
95 |
+
):
|
96 |
+
|
97 |
+
if solver_type not in ["bh1", "bh2"]:
|
98 |
+
if solver_type in ["midpoint", "heun", "logrho"]:
|
99 |
+
self.register_to_config(solver_type="bh2")
|
100 |
+
else:
|
101 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
102 |
+
|
103 |
+
self.predict_x0 = predict_x0
|
104 |
+
# setable values
|
105 |
+
self.num_inference_steps = None
|
106 |
+
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
107 |
+
sigmas = 1.0 - alphas
|
108 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
|
109 |
+
|
110 |
+
if not use_dynamic_shifting:
|
111 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
112 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
|
113 |
+
|
114 |
+
self.sigmas = sigmas
|
115 |
+
self.timesteps = sigmas * num_train_timesteps
|
116 |
+
|
117 |
+
self.model_outputs = [None] * solver_order
|
118 |
+
self.timestep_list = [None] * solver_order
|
119 |
+
self.lower_order_nums = 0
|
120 |
+
self.disable_corrector = disable_corrector
|
121 |
+
self.solver_p = solver_p
|
122 |
+
self.last_sample = None
|
123 |
+
self._step_index = None
|
124 |
+
self._begin_index = None
|
125 |
+
|
126 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
127 |
+
self.sigma_min = self.sigmas[-1].item()
|
128 |
+
self.sigma_max = self.sigmas[0].item()
|
129 |
+
|
130 |
+
@property
|
131 |
+
def step_index(self):
|
132 |
+
"""
|
133 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
134 |
+
"""
|
135 |
+
return self._step_index
|
136 |
+
|
137 |
+
@property
|
138 |
+
def begin_index(self):
|
139 |
+
"""
|
140 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
141 |
+
"""
|
142 |
+
return self._begin_index
|
143 |
+
|
144 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
145 |
+
def set_begin_index(self, begin_index: int = 0):
|
146 |
+
"""
|
147 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
begin_index (`int`):
|
151 |
+
The begin index for the scheduler.
|
152 |
+
"""
|
153 |
+
self._begin_index = begin_index
|
154 |
+
|
155 |
+
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
|
156 |
+
def set_timesteps(
|
157 |
+
self,
|
158 |
+
num_inference_steps: Union[int, None] = None,
|
159 |
+
device: Union[str, torch.device] = None,
|
160 |
+
sigmas: Optional[List[float]] = None,
|
161 |
+
mu: Optional[Union[float, None]] = None,
|
162 |
+
shift: Optional[Union[float, None]] = None,
|
163 |
+
):
|
164 |
+
"""
|
165 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
166 |
+
Args:
|
167 |
+
num_inference_steps (`int`):
|
168 |
+
Total number of the spacing of the time steps.
|
169 |
+
device (`str` or `torch.device`, *optional*):
|
170 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
171 |
+
"""
|
172 |
+
|
173 |
+
if self.config.use_dynamic_shifting and mu is None:
|
174 |
+
raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
175 |
+
|
176 |
+
if sigmas is None:
|
177 |
+
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore
|
178 |
+
|
179 |
+
if self.config.use_dynamic_shifting:
|
180 |
+
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
|
181 |
+
else:
|
182 |
+
if shift is None:
|
183 |
+
shift = self.config.shift
|
184 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
|
185 |
+
|
186 |
+
if self.config.final_sigmas_type == "sigma_min":
|
187 |
+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
188 |
+
elif self.config.final_sigmas_type == "zero":
|
189 |
+
sigma_last = 0
|
190 |
+
else:
|
191 |
+
raise ValueError(
|
192 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
193 |
+
)
|
194 |
+
|
195 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
196 |
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
|
197 |
+
|
198 |
+
self.sigmas = torch.from_numpy(sigmas)
|
199 |
+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
200 |
+
|
201 |
+
self.num_inference_steps = len(timesteps)
|
202 |
+
|
203 |
+
self.model_outputs = [
|
204 |
+
None,
|
205 |
+
] * self.config.solver_order
|
206 |
+
self.lower_order_nums = 0
|
207 |
+
self.last_sample = None
|
208 |
+
if self.solver_p:
|
209 |
+
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
210 |
+
|
211 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
212 |
+
self._step_index = None
|
213 |
+
self._begin_index = None
|
214 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
215 |
+
|
216 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
217 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
218 |
+
"""
|
219 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
220 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
221 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
222 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
223 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
224 |
+
|
225 |
+
https://arxiv.org/abs/2205.11487
|
226 |
+
"""
|
227 |
+
dtype = sample.dtype
|
228 |
+
batch_size, channels, *remaining_dims = sample.shape
|
229 |
+
|
230 |
+
if dtype not in (torch.float32, torch.float64):
|
231 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
232 |
+
|
233 |
+
# Flatten sample for doing quantile calculation along each image
|
234 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
235 |
+
|
236 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
237 |
+
|
238 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
239 |
+
s = torch.clamp(
|
240 |
+
s, min=1, max=self.config.sample_max_value
|
241 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
242 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
243 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
244 |
+
|
245 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
246 |
+
sample = sample.to(dtype)
|
247 |
+
|
248 |
+
return sample
|
249 |
+
|
250 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
|
251 |
+
def _sigma_to_t(self, sigma):
|
252 |
+
return sigma * self.config.num_train_timesteps
|
253 |
+
|
254 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
255 |
+
return 1 - sigma, sigma
|
256 |
+
|
257 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
|
258 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
259 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
260 |
+
|
261 |
+
def convert_model_output(
|
262 |
+
self,
|
263 |
+
model_output: torch.Tensor,
|
264 |
+
*args,
|
265 |
+
sample: torch.Tensor = None,
|
266 |
+
**kwargs,
|
267 |
+
) -> torch.Tensor:
|
268 |
+
r"""
|
269 |
+
Convert the model output to the corresponding type the UniPC algorithm needs.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
model_output (`torch.Tensor`):
|
273 |
+
The direct output from the learned diffusion model.
|
274 |
+
timestep (`int`):
|
275 |
+
The current discrete timestep in the diffusion chain.
|
276 |
+
sample (`torch.Tensor`):
|
277 |
+
A current instance of a sample created by the diffusion process.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
`torch.Tensor`:
|
281 |
+
The converted model output.
|
282 |
+
"""
|
283 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
284 |
+
if sample is None:
|
285 |
+
if len(args) > 1:
|
286 |
+
sample = args[1]
|
287 |
+
else:
|
288 |
+
raise ValueError("missing `sample` as a required keyward argument")
|
289 |
+
if timestep is not None:
|
290 |
+
deprecate(
|
291 |
+
"timesteps",
|
292 |
+
"1.0.0",
|
293 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
294 |
+
)
|
295 |
+
|
296 |
+
sigma = self.sigmas[self.step_index]
|
297 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
298 |
+
|
299 |
+
if self.predict_x0:
|
300 |
+
if self.config.prediction_type == "flow_prediction":
|
301 |
+
sigma_t = self.sigmas[self.step_index]
|
302 |
+
x0_pred = sample - sigma_t * model_output
|
303 |
+
else:
|
304 |
+
raise ValueError(
|
305 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
306 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
307 |
+
)
|
308 |
+
|
309 |
+
if self.config.thresholding:
|
310 |
+
x0_pred = self._threshold_sample(x0_pred)
|
311 |
+
|
312 |
+
return x0_pred
|
313 |
+
else:
|
314 |
+
if self.config.prediction_type == "flow_prediction":
|
315 |
+
sigma_t = self.sigmas[self.step_index]
|
316 |
+
epsilon = sample - (1 - sigma_t) * model_output
|
317 |
+
else:
|
318 |
+
raise ValueError(
|
319 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
320 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
321 |
+
)
|
322 |
+
|
323 |
+
if self.config.thresholding:
|
324 |
+
sigma_t = self.sigmas[self.step_index]
|
325 |
+
x0_pred = sample - sigma_t * model_output
|
326 |
+
x0_pred = self._threshold_sample(x0_pred)
|
327 |
+
epsilon = model_output + x0_pred
|
328 |
+
|
329 |
+
return epsilon
|
330 |
+
|
331 |
+
def multistep_uni_p_bh_update(
|
332 |
+
self,
|
333 |
+
model_output: torch.Tensor,
|
334 |
+
*args,
|
335 |
+
sample: torch.Tensor = None,
|
336 |
+
order: int = None, # pyright: ignore
|
337 |
+
**kwargs,
|
338 |
+
) -> torch.Tensor:
|
339 |
+
"""
|
340 |
+
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
model_output (`torch.Tensor`):
|
344 |
+
The direct output from the learned diffusion model at the current timestep.
|
345 |
+
prev_timestep (`int`):
|
346 |
+
The previous discrete timestep in the diffusion chain.
|
347 |
+
sample (`torch.Tensor`):
|
348 |
+
A current instance of a sample created by the diffusion process.
|
349 |
+
order (`int`):
|
350 |
+
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
`torch.Tensor`:
|
354 |
+
The sample tensor at the previous timestep.
|
355 |
+
"""
|
356 |
+
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
357 |
+
if sample is None:
|
358 |
+
if len(args) > 1:
|
359 |
+
sample = args[1]
|
360 |
+
else:
|
361 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
362 |
+
if order is None:
|
363 |
+
if len(args) > 2:
|
364 |
+
order = args[2]
|
365 |
+
else:
|
366 |
+
raise ValueError(" missing `order` as a required keyward argument")
|
367 |
+
if prev_timestep is not None:
|
368 |
+
deprecate(
|
369 |
+
"prev_timestep",
|
370 |
+
"1.0.0",
|
371 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
372 |
+
)
|
373 |
+
model_output_list = self.model_outputs
|
374 |
+
|
375 |
+
s0 = self.timestep_list[-1]
|
376 |
+
m0 = model_output_list[-1]
|
377 |
+
x = sample
|
378 |
+
|
379 |
+
if self.solver_p:
|
380 |
+
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
381 |
+
return x_t
|
382 |
+
|
383 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore
|
384 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
385 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
386 |
+
|
387 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
388 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
389 |
+
|
390 |
+
h = lambda_t - lambda_s0
|
391 |
+
device = sample.device
|
392 |
+
|
393 |
+
rks = []
|
394 |
+
D1s = []
|
395 |
+
for i in range(1, order):
|
396 |
+
si = self.step_index - i # pyright: ignore
|
397 |
+
mi = model_output_list[-(i + 1)]
|
398 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
399 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
400 |
+
rk = (lambda_si - lambda_s0) / h
|
401 |
+
rks.append(rk)
|
402 |
+
D1s.append((mi - m0) / rk) # pyright: ignore
|
403 |
+
|
404 |
+
rks.append(1.0)
|
405 |
+
rks = torch.tensor(rks, device=device)
|
406 |
+
|
407 |
+
R = []
|
408 |
+
b = []
|
409 |
+
|
410 |
+
hh = -h if self.predict_x0 else h
|
411 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
412 |
+
h_phi_k = h_phi_1 / hh - 1
|
413 |
+
|
414 |
+
factorial_i = 1
|
415 |
+
|
416 |
+
if self.config.solver_type == "bh1":
|
417 |
+
B_h = hh
|
418 |
+
elif self.config.solver_type == "bh2":
|
419 |
+
B_h = torch.expm1(hh)
|
420 |
+
else:
|
421 |
+
raise NotImplementedError()
|
422 |
+
|
423 |
+
for i in range(1, order + 1):
|
424 |
+
R.append(torch.pow(rks, i - 1))
|
425 |
+
b.append(h_phi_k * factorial_i / B_h)
|
426 |
+
factorial_i *= i + 1
|
427 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
428 |
+
|
429 |
+
R = torch.stack(R)
|
430 |
+
b = torch.tensor(b, device=device)
|
431 |
+
|
432 |
+
if len(D1s) > 0:
|
433 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
434 |
+
# for order 2, we use a simplified version
|
435 |
+
if order == 2:
|
436 |
+
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
437 |
+
else:
|
438 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
|
439 |
+
else:
|
440 |
+
D1s = None
|
441 |
+
|
442 |
+
if self.predict_x0:
|
443 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
444 |
+
if D1s is not None:
|
445 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
446 |
+
else:
|
447 |
+
pred_res = 0
|
448 |
+
x_t = x_t_ - alpha_t * B_h * pred_res
|
449 |
+
else:
|
450 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
451 |
+
if D1s is not None:
|
452 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
453 |
+
else:
|
454 |
+
pred_res = 0
|
455 |
+
x_t = x_t_ - sigma_t * B_h * pred_res
|
456 |
+
|
457 |
+
x_t = x_t.to(x.dtype)
|
458 |
+
return x_t
|
459 |
+
|
460 |
+
def multistep_uni_c_bh_update(
|
461 |
+
self,
|
462 |
+
this_model_output: torch.Tensor,
|
463 |
+
*args,
|
464 |
+
last_sample: torch.Tensor = None,
|
465 |
+
this_sample: torch.Tensor = None,
|
466 |
+
order: int = None, # pyright: ignore
|
467 |
+
**kwargs,
|
468 |
+
) -> torch.Tensor:
|
469 |
+
"""
|
470 |
+
One step for the UniC (B(h) version).
|
471 |
+
|
472 |
+
Args:
|
473 |
+
this_model_output (`torch.Tensor`):
|
474 |
+
The model outputs at `x_t`.
|
475 |
+
this_timestep (`int`):
|
476 |
+
The current timestep `t`.
|
477 |
+
last_sample (`torch.Tensor`):
|
478 |
+
The generated sample before the last predictor `x_{t-1}`.
|
479 |
+
this_sample (`torch.Tensor`):
|
480 |
+
The generated sample after the last predictor `x_{t}`.
|
481 |
+
order (`int`):
|
482 |
+
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
`torch.Tensor`:
|
486 |
+
The corrected sample tensor at the current timestep.
|
487 |
+
"""
|
488 |
+
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
489 |
+
if last_sample is None:
|
490 |
+
if len(args) > 1:
|
491 |
+
last_sample = args[1]
|
492 |
+
else:
|
493 |
+
raise ValueError(" missing`last_sample` as a required keyward argument")
|
494 |
+
if this_sample is None:
|
495 |
+
if len(args) > 2:
|
496 |
+
this_sample = args[2]
|
497 |
+
else:
|
498 |
+
raise ValueError(" missing`this_sample` as a required keyward argument")
|
499 |
+
if order is None:
|
500 |
+
if len(args) > 3:
|
501 |
+
order = args[3]
|
502 |
+
else:
|
503 |
+
raise ValueError(" missing`order` as a required keyward argument")
|
504 |
+
if this_timestep is not None:
|
505 |
+
deprecate(
|
506 |
+
"this_timestep",
|
507 |
+
"1.0.0",
|
508 |
+
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
509 |
+
)
|
510 |
+
|
511 |
+
model_output_list = self.model_outputs
|
512 |
+
|
513 |
+
m0 = model_output_list[-1]
|
514 |
+
x = last_sample
|
515 |
+
x_t = this_sample
|
516 |
+
model_t = this_model_output
|
517 |
+
|
518 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore
|
519 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
520 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
521 |
+
|
522 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
523 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
524 |
+
|
525 |
+
h = lambda_t - lambda_s0
|
526 |
+
device = this_sample.device
|
527 |
+
|
528 |
+
rks = []
|
529 |
+
D1s = []
|
530 |
+
for i in range(1, order):
|
531 |
+
si = self.step_index - (i + 1) # pyright: ignore
|
532 |
+
mi = model_output_list[-(i + 1)]
|
533 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
534 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
535 |
+
rk = (lambda_si - lambda_s0) / h
|
536 |
+
rks.append(rk)
|
537 |
+
D1s.append((mi - m0) / rk) # pyright: ignore
|
538 |
+
|
539 |
+
rks.append(1.0)
|
540 |
+
rks = torch.tensor(rks, device=device)
|
541 |
+
|
542 |
+
R = []
|
543 |
+
b = []
|
544 |
+
|
545 |
+
hh = -h if self.predict_x0 else h
|
546 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
547 |
+
h_phi_k = h_phi_1 / hh - 1
|
548 |
+
|
549 |
+
factorial_i = 1
|
550 |
+
|
551 |
+
if self.config.solver_type == "bh1":
|
552 |
+
B_h = hh
|
553 |
+
elif self.config.solver_type == "bh2":
|
554 |
+
B_h = torch.expm1(hh)
|
555 |
+
else:
|
556 |
+
raise NotImplementedError()
|
557 |
+
|
558 |
+
for i in range(1, order + 1):
|
559 |
+
R.append(torch.pow(rks, i - 1))
|
560 |
+
b.append(h_phi_k * factorial_i / B_h)
|
561 |
+
factorial_i *= i + 1
|
562 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
563 |
+
|
564 |
+
R = torch.stack(R)
|
565 |
+
b = torch.tensor(b, device=device)
|
566 |
+
|
567 |
+
if len(D1s) > 0:
|
568 |
+
D1s = torch.stack(D1s, dim=1)
|
569 |
+
else:
|
570 |
+
D1s = None
|
571 |
+
|
572 |
+
# for order 1, we use a simplified version
|
573 |
+
if order == 1:
|
574 |
+
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
575 |
+
else:
|
576 |
+
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
577 |
+
|
578 |
+
if self.predict_x0:
|
579 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
580 |
+
if D1s is not None:
|
581 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
582 |
+
else:
|
583 |
+
corr_res = 0
|
584 |
+
D1_t = model_t - m0
|
585 |
+
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
586 |
+
else:
|
587 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
588 |
+
if D1s is not None:
|
589 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
590 |
+
else:
|
591 |
+
corr_res = 0
|
592 |
+
D1_t = model_t - m0
|
593 |
+
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
594 |
+
x_t = x_t.to(x.dtype)
|
595 |
+
return x_t
|
596 |
+
|
597 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
598 |
+
if schedule_timesteps is None:
|
599 |
+
schedule_timesteps = self.timesteps
|
600 |
+
|
601 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
602 |
+
|
603 |
+
# The sigma index that is taken for the **very** first `step`
|
604 |
+
# is always the second index (or the last index if there is only 1)
|
605 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
606 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
607 |
+
pos = 1 if len(indices) > 1 else 0
|
608 |
+
|
609 |
+
return indices[pos].item()
|
610 |
+
|
611 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
612 |
+
def _init_step_index(self, timestep):
|
613 |
+
"""
|
614 |
+
Initialize the step_index counter for the scheduler.
|
615 |
+
"""
|
616 |
+
|
617 |
+
if self.begin_index is None:
|
618 |
+
if isinstance(timestep, torch.Tensor):
|
619 |
+
timestep = timestep.to(self.timesteps.device)
|
620 |
+
self._step_index = self.index_for_timestep(timestep)
|
621 |
+
else:
|
622 |
+
self._step_index = self._begin_index
|
623 |
+
|
624 |
+
def step(
|
625 |
+
self,
|
626 |
+
model_output: torch.Tensor,
|
627 |
+
timestep: Union[int, torch.Tensor],
|
628 |
+
sample: torch.Tensor,
|
629 |
+
return_dict: bool = True,
|
630 |
+
generator=None,
|
631 |
+
) -> Union[SchedulerOutput, Tuple]:
|
632 |
+
"""
|
633 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
634 |
+
the multistep UniPC.
|
635 |
+
|
636 |
+
Args:
|
637 |
+
model_output (`torch.Tensor`):
|
638 |
+
The direct output from learned diffusion model.
|
639 |
+
timestep (`int`):
|
640 |
+
The current discrete timestep in the diffusion chain.
|
641 |
+
sample (`torch.Tensor`):
|
642 |
+
A current instance of a sample created by the diffusion process.
|
643 |
+
return_dict (`bool`):
|
644 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
645 |
+
|
646 |
+
Returns:
|
647 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
648 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
649 |
+
tuple is returned where the first element is the sample tensor.
|
650 |
+
|
651 |
+
"""
|
652 |
+
if self.num_inference_steps is None:
|
653 |
+
raise ValueError(
|
654 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
655 |
+
)
|
656 |
+
|
657 |
+
if self.step_index is None:
|
658 |
+
self._init_step_index(timestep)
|
659 |
+
|
660 |
+
use_corrector = (
|
661 |
+
self.step_index > 0
|
662 |
+
and self.step_index - 1 not in self.disable_corrector
|
663 |
+
and self.last_sample is not None # pyright: ignore
|
664 |
+
)
|
665 |
+
|
666 |
+
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
667 |
+
if use_corrector:
|
668 |
+
sample = self.multistep_uni_c_bh_update(
|
669 |
+
this_model_output=model_output_convert,
|
670 |
+
last_sample=self.last_sample,
|
671 |
+
this_sample=sample,
|
672 |
+
order=self.this_order,
|
673 |
+
)
|
674 |
+
|
675 |
+
for i in range(self.config.solver_order - 1):
|
676 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
677 |
+
self.timestep_list[i] = self.timestep_list[i + 1]
|
678 |
+
|
679 |
+
self.model_outputs[-1] = model_output_convert
|
680 |
+
self.timestep_list[-1] = timestep # pyright: ignore
|
681 |
+
|
682 |
+
if self.config.lower_order_final:
|
683 |
+
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
|
684 |
+
else:
|
685 |
+
this_order = self.config.solver_order
|
686 |
+
|
687 |
+
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
|
688 |
+
assert self.this_order > 0
|
689 |
+
|
690 |
+
self.last_sample = sample
|
691 |
+
prev_sample = self.multistep_uni_p_bh_update(
|
692 |
+
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
693 |
+
sample=sample,
|
694 |
+
order=self.this_order,
|
695 |
+
)
|
696 |
+
|
697 |
+
if self.lower_order_nums < self.config.solver_order:
|
698 |
+
self.lower_order_nums += 1
|
699 |
+
|
700 |
+
# upon completion increase step index by one
|
701 |
+
self._step_index += 1 # pyright: ignore
|
702 |
+
|
703 |
+
if not return_dict:
|
704 |
+
return (prev_sample,)
|
705 |
+
|
706 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
707 |
+
|
708 |
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
709 |
+
"""
|
710 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
711 |
+
current timestep.
|
712 |
+
|
713 |
+
Args:
|
714 |
+
sample (`torch.Tensor`):
|
715 |
+
The input sample.
|
716 |
+
|
717 |
+
Returns:
|
718 |
+
`torch.Tensor`:
|
719 |
+
A scaled input sample.
|
720 |
+
"""
|
721 |
+
return sample
|
722 |
+
|
723 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
724 |
+
def add_noise(
|
725 |
+
self,
|
726 |
+
original_samples: torch.Tensor,
|
727 |
+
noise: torch.Tensor,
|
728 |
+
timesteps: torch.IntTensor,
|
729 |
+
) -> torch.Tensor:
|
730 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
731 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
732 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
733 |
+
# mps does not support float64
|
734 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
735 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
736 |
+
else:
|
737 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
738 |
+
timesteps = timesteps.to(original_samples.device)
|
739 |
+
|
740 |
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
741 |
+
if self.begin_index is None:
|
742 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
743 |
+
elif self.step_index is not None:
|
744 |
+
# add_noise is called after first denoising step (for inpainting)
|
745 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
746 |
+
else:
|
747 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
748 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
749 |
+
|
750 |
+
sigma = sigmas[step_indices].flatten()
|
751 |
+
while len(sigma.shape) < len(original_samples.shape):
|
752 |
+
sigma = sigma.unsqueeze(-1)
|
753 |
+
|
754 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
755 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
756 |
+
return noisy_samples
|
757 |
+
|
758 |
+
def __len__(self):
|
759 |
+
return self.config.num_train_timesteps
|