fffiloni commited on
Commit
4144f09
·
verified ·
1 Parent(s): 46e54d8

Update wan/modules/vace_model.py

Browse files
Files changed (1) hide show
  1. wan/modules/vace_model.py +2 -2
wan/modules/vace_model.py CHANGED
@@ -1,6 +1,6 @@
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import torch
3
- import torch.cuda.amp as amp
4
  import torch.nn as nn
5
  from diffusers.configuration_utils import register_to_config
6
  from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
@@ -190,7 +190,7 @@ class VaceWanModel(WanModel):
190
  ])
191
 
192
  # time embeddings
193
- with amp.autocast(dtype=torch.float32):
194
  e = self.time_embedding(
195
  sinusoidal_embedding_1d(self.freq_dim, t).float())
196
  e0 = self.time_projection(e).unflatten(1, (6, self.dim))
 
1
  # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
  import torch
3
+ import torch.amp as amp
4
  import torch.nn as nn
5
  from diffusers.configuration_utils import register_to_config
6
  from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
 
190
  ])
191
 
192
  # time embeddings
193
+ with amp.autocast("cuda", dtype=torch.float32):
194
  e = self.time_embedding(
195
  sinusoidal_embedding_1d(self.freq_dim, t).float())
196
  e0 = self.time_projection(e).unflatten(1, (6, self.dim))