1inkusFace commited on
Commit
a19b88a
·
verified ·
1 Parent(s): 4ad821a

Update skyreels_v2_infer/modules/transformer.py

Browse files
skyreels_v2_infer/modules/transformer.py CHANGED
@@ -233,9 +233,9 @@ class WanI2VCrossAttention(WanSelfAttention):
233
  v = self.v(context).view(b, -1, n, d)
234
  k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
235
  v_img = self.v_img(context_img).view(b, -1, n, d)
236
- img_x = flash_attention(q, k_img, v_img)
237
  # compute attention
238
- x = flash_attention(q, k, v)
239
 
240
  # output
241
  x = x.flatten(2)
 
233
  v = self.v(context).view(b, -1, n, d)
234
  k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
235
  v_img = self.v_img(context_img).view(b, -1, n, d)
236
+ img_x = attention(q, k_img, v_img)
237
  # compute attention
238
+ x = attention(q, k, v)
239
 
240
  # output
241
  x = x.flatten(2)