fffiloni commited on
Commit
46e54d8
·
verified ·
1 Parent(s): e681a74

Update wan/vace.py

Browse files
Files changed (1) hide show
  1. wan/vace.py +3 -3
wan/vace.py CHANGED
@@ -15,7 +15,7 @@ from PIL import Image
15
  import torchvision.transforms.functional as TF
16
  import torch
17
  import torch.nn.functional as F
18
- import torch.cuda.amp as amp
19
  import torch.distributed as dist
20
  import torch.multiprocessing as mp
21
  from tqdm import tqdm
@@ -362,7 +362,7 @@ class WanVace(WanT2V):
362
  no_sync = getattr(self.model, 'no_sync', noop_no_sync)
363
 
364
  # evaluation mode
365
- with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
366
 
367
  if sample_solver == 'unipc':
368
  sample_scheduler = FlowUniPCMultistepScheduler(
@@ -616,7 +616,7 @@ class WanVaceMP(WanVace):
616
  no_sync = getattr(model, 'no_sync', noop_no_sync)
617
 
618
  # evaluation mode
619
- with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
620
 
621
  if sample_solver == 'unipc':
622
  sample_scheduler = FlowUniPCMultistepScheduler(
 
15
  import torchvision.transforms.functional as TF
16
  import torch
17
  import torch.nn.functional as F
18
+ import torch.amp as amp
19
  import torch.distributed as dist
20
  import torch.multiprocessing as mp
21
  from tqdm import tqdm
 
362
  no_sync = getattr(self.model, 'no_sync', noop_no_sync)
363
 
364
  # evaluation mode
365
+ with amp.autocast("cuda", dtype=self.param_dtype), torch.no_grad(), no_sync():
366
 
367
  if sample_solver == 'unipc':
368
  sample_scheduler = FlowUniPCMultistepScheduler(
 
616
  no_sync = getattr(model, 'no_sync', noop_no_sync)
617
 
618
  # evaluation mode
619
+ with amp.autocast("cuda", dtype=param_dtype), torch.no_grad(), no_sync():
620
 
621
  if sample_solver == 'unipc':
622
  sample_scheduler = FlowUniPCMultistepScheduler(