Spaces:
Running
on
L40S
Running
on
L40S
Update wan/modules/model.py
Browse files- wan/modules/model.py +1 -1
wan/modules/model.py
CHANGED
@@ -302,7 +302,7 @@ class WanAttentionBlock(nn.Module):
|
|
302 |
y = self.self_attn(
|
303 |
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
304 |
freqs)
|
305 |
-
with amp.autocast(dtype=torch.float32):
|
306 |
x = x + y * e[2]
|
307 |
|
308 |
# cross-attention & ffn function
|
|
|
302 |
y = self.self_attn(
|
303 |
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
304 |
freqs)
|
305 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
306 |
x = x + y * e[2]
|
307 |
|
308 |
# cross-attention & ffn function
|