1inkusFace commited on
Commit
4ad821a
·
verified ·
1 Parent(s): f2e32cd

Update skyreels_v2_infer/modules/transformer.py

Browse files
skyreels_v2_infer/modules/transformer.py CHANGED
@@ -13,7 +13,7 @@ from torch.nn.attention.flex_attention import BlockMask
13
  from torch.nn.attention.flex_attention import create_block_mask
14
  from torch.nn.attention.flex_attention import flex_attention
15
 
16
- from .attention import flash_attention
17
 
18
 
19
  flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
@@ -160,7 +160,7 @@ class WanSelfAttention(nn.Module):
160
  if not self._flag_ar_attention:
161
  q = rope_apply(q, grid_sizes, freqs)
162
  k = rope_apply(k, grid_sizes, freqs)
163
- x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
164
  else:
165
  q = rope_apply(q, grid_sizes, freqs)
166
  k = rope_apply(k, grid_sizes, freqs)
@@ -199,7 +199,7 @@ class WanT2VCrossAttention(WanSelfAttention):
199
  v = self.v(context).view(b, -1, n, d)
200
 
201
  # compute attention
202
- x = flash_attention(q, k, v)
203
 
204
  # output
205
  x = x.flatten(2)
 
13
  from torch.nn.attention.flex_attention import create_block_mask
14
  from torch.nn.attention.flex_attention import flex_attention
15
 
16
+ from .attention import flash_attention, attention
17
 
18
 
19
  flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
 
160
  if not self._flag_ar_attention:
161
  q = rope_apply(q, grid_sizes, freqs)
162
  k = rope_apply(k, grid_sizes, freqs)
163
+ x = attention(q=q, k=k, v=v, window_size=self.window_size)
164
  else:
165
  q = rope_apply(q, grid_sizes, freqs)
166
  k = rope_apply(k, grid_sizes, freqs)
 
199
  v = self.v(context).view(b, -1, n, d)
200
 
201
  # compute attention
202
+ x = attention(q, k, v)
203
 
204
  # output
205
  x = x.flatten(2)