fffiloni commited on
Commit
54bc97a
·
verified ·
1 Parent(s): 4144f09

Update wan/modules/model.py

Browse files
Files changed (1) hide show
  1. wan/modules/model.py +1 -1
wan/modules/model.py CHANGED
@@ -302,7 +302,7 @@ class WanAttentionBlock(nn.Module):
302
  y = self.self_attn(
303
  self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
304
  freqs)
305
- with amp.autocast(dtype=torch.float32):
306
  x = x + y * e[2]
307
 
308
  # cross-attention & ffn function
 
302
  y = self.self_attn(
303
  self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
304
  freqs)
305
+ with amp.autocast("cuda", dtype=torch.float32):
306
  x = x + y * e[2]
307
 
308
  # cross-attention & ffn function