Spaces:
Running
on
L40S
Running
on
L40S
Update wan/modules/vace_model.py
Browse files
wan/modules/vace_model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
import torch
|
3 |
-
import torch.
|
4 |
import torch.nn as nn
|
5 |
from diffusers.configuration_utils import register_to_config
|
6 |
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
|
@@ -190,7 +190,7 @@ class VaceWanModel(WanModel):
|
|
190 |
])
|
191 |
|
192 |
# time embeddings
|
193 |
-
with amp.autocast(dtype=torch.float32):
|
194 |
e = self.time_embedding(
|
195 |
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
196 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
|
|
1 |
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
import torch
|
3 |
+
import torch.amp as amp
|
4 |
import torch.nn as nn
|
5 |
from diffusers.configuration_utils import register_to_config
|
6 |
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
|
|
|
190 |
])
|
191 |
|
192 |
# time embeddings
|
193 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
194 |
e = self.time_embedding(
|
195 |
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
196 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|