Spaces:
Running
on
L40S
Running
on
L40S
Update wan/vace.py
Browse files- wan/vace.py +3 -3
wan/vace.py
CHANGED
@@ -15,7 +15,7 @@ from PIL import Image
|
|
15 |
import torchvision.transforms.functional as TF
|
16 |
import torch
|
17 |
import torch.nn.functional as F
|
18 |
-
import torch.
|
19 |
import torch.distributed as dist
|
20 |
import torch.multiprocessing as mp
|
21 |
from tqdm import tqdm
|
@@ -362,7 +362,7 @@ class WanVace(WanT2V):
|
|
362 |
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
363 |
|
364 |
# evaluation mode
|
365 |
-
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
366 |
|
367 |
if sample_solver == 'unipc':
|
368 |
sample_scheduler = FlowUniPCMultistepScheduler(
|
@@ -616,7 +616,7 @@ class WanVaceMP(WanVace):
|
|
616 |
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
617 |
|
618 |
# evaluation mode
|
619 |
-
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
|
620 |
|
621 |
if sample_solver == 'unipc':
|
622 |
sample_scheduler = FlowUniPCMultistepScheduler(
|
|
|
15 |
import torchvision.transforms.functional as TF
|
16 |
import torch
|
17 |
import torch.nn.functional as F
|
18 |
+
import torch.amp as amp
|
19 |
import torch.distributed as dist
|
20 |
import torch.multiprocessing as mp
|
21 |
from tqdm import tqdm
|
|
|
362 |
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
363 |
|
364 |
# evaluation mode
|
365 |
+
with amp.autocast("cuda", dtype=self.param_dtype), torch.no_grad(), no_sync():
|
366 |
|
367 |
if sample_solver == 'unipc':
|
368 |
sample_scheduler = FlowUniPCMultistepScheduler(
|
|
|
616 |
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
617 |
|
618 |
# evaluation mode
|
619 |
+
with amp.autocast("cuda", dtype=param_dtype), torch.no_grad(), no_sync():
|
620 |
|
621 |
if sample_solver == 'unipc':
|
622 |
sample_scheduler = FlowUniPCMultistepScheduler(
|