Spaces:
Running
on
Zero
Running
on
Zero
Update skyreels_v2_infer/modules/transformer.py
Browse files
skyreels_v2_infer/modules/transformer.py
CHANGED
@@ -13,7 +13,7 @@ from torch.nn.attention.flex_attention import BlockMask
|
|
13 |
from torch.nn.attention.flex_attention import create_block_mask
|
14 |
from torch.nn.attention.flex_attention import flex_attention
|
15 |
|
16 |
-
from .attention import flash_attention
|
17 |
|
18 |
|
19 |
flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
|
@@ -160,7 +160,7 @@ class WanSelfAttention(nn.Module):
|
|
160 |
if not self._flag_ar_attention:
|
161 |
q = rope_apply(q, grid_sizes, freqs)
|
162 |
k = rope_apply(k, grid_sizes, freqs)
|
163 |
-
x =
|
164 |
else:
|
165 |
q = rope_apply(q, grid_sizes, freqs)
|
166 |
k = rope_apply(k, grid_sizes, freqs)
|
@@ -199,7 +199,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|
199 |
v = self.v(context).view(b, -1, n, d)
|
200 |
|
201 |
# compute attention
|
202 |
-
x =
|
203 |
|
204 |
# output
|
205 |
x = x.flatten(2)
|
|
|
13 |
from torch.nn.attention.flex_attention import create_block_mask
|
14 |
from torch.nn.attention.flex_attention import flex_attention
|
15 |
|
16 |
+
from .attention import flash_attention, attention
|
17 |
|
18 |
|
19 |
flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
|
|
|
160 |
if not self._flag_ar_attention:
|
161 |
q = rope_apply(q, grid_sizes, freqs)
|
162 |
k = rope_apply(k, grid_sizes, freqs)
|
163 |
+
x = attention(q=q, k=k, v=v, window_size=self.window_size)
|
164 |
else:
|
165 |
q = rope_apply(q, grid_sizes, freqs)
|
166 |
k = rope_apply(k, grid_sizes, freqs)
|
|
|
199 |
v = self.v(context).view(b, -1, n, d)
|
200 |
|
201 |
# compute attention
|
202 |
+
x = attention(q, k, v)
|
203 |
|
204 |
# output
|
205 |
x = x.flatten(2)
|