Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Update wan/vace.py
Browse files- wan/vace.py +3 -3
    	
        wan/vace.py
    CHANGED
    
    | @@ -15,7 +15,7 @@ from PIL import Image | |
| 15 | 
             
            import torchvision.transforms.functional as TF
         | 
| 16 | 
             
            import torch
         | 
| 17 | 
             
            import torch.nn.functional as F
         | 
| 18 | 
            -
            import torch. | 
| 19 | 
             
            import torch.distributed as dist
         | 
| 20 | 
             
            import torch.multiprocessing as mp
         | 
| 21 | 
             
            from tqdm import tqdm
         | 
| @@ -362,7 +362,7 @@ class WanVace(WanT2V): | |
| 362 | 
             
                    no_sync = getattr(self.model, 'no_sync', noop_no_sync)
         | 
| 363 |  | 
| 364 | 
             
                    # evaluation mode
         | 
| 365 | 
            -
                    with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
         | 
| 366 |  | 
| 367 | 
             
                        if sample_solver == 'unipc':
         | 
| 368 | 
             
                            sample_scheduler = FlowUniPCMultistepScheduler(
         | 
| @@ -616,7 +616,7 @@ class WanVaceMP(WanVace): | |
| 616 | 
             
                            no_sync = getattr(model, 'no_sync', noop_no_sync)
         | 
| 617 |  | 
| 618 | 
             
                            # evaluation mode
         | 
| 619 | 
            -
                            with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
         | 
| 620 |  | 
| 621 | 
             
                                if sample_solver == 'unipc':
         | 
| 622 | 
             
                                    sample_scheduler = FlowUniPCMultistepScheduler(
         | 
|  | |
| 15 | 
             
            import torchvision.transforms.functional as TF
         | 
| 16 | 
             
            import torch
         | 
| 17 | 
             
            import torch.nn.functional as F
         | 
| 18 | 
            +
            import torch.amp as amp
         | 
| 19 | 
             
            import torch.distributed as dist
         | 
| 20 | 
             
            import torch.multiprocessing as mp
         | 
| 21 | 
             
            from tqdm import tqdm
         | 
|  | |
| 362 | 
             
                    no_sync = getattr(self.model, 'no_sync', noop_no_sync)
         | 
| 363 |  | 
| 364 | 
             
                    # evaluation mode
         | 
| 365 | 
            +
                    with amp.autocast("cuda", dtype=self.param_dtype), torch.no_grad(), no_sync():
         | 
| 366 |  | 
| 367 | 
             
                        if sample_solver == 'unipc':
         | 
| 368 | 
             
                            sample_scheduler = FlowUniPCMultistepScheduler(
         | 
|  | |
| 616 | 
             
                            no_sync = getattr(model, 'no_sync', noop_no_sync)
         | 
| 617 |  | 
| 618 | 
             
                            # evaluation mode
         | 
| 619 | 
            +
                            with amp.autocast("cuda", dtype=param_dtype), torch.no_grad(), no_sync():
         | 
| 620 |  | 
| 621 | 
             
                                if sample_solver == 'unipc':
         | 
| 622 | 
             
                                    sample_scheduler = FlowUniPCMultistepScheduler(
         | 
