Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Changed validation device to cpu
Browse files- train_dreambooth_lora.py +2 -2
    	
        train_dreambooth_lora.py
    CHANGED
    
    | @@ -940,11 +940,11 @@ def main(args): | |
| 940 | 
             
                            torch_dtype=weight_dtype,
         | 
| 941 | 
             
                        )
         | 
| 942 | 
             
                        pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
         | 
| 943 | 
            -
                        pipeline = pipeline.to( | 
| 944 | 
             
                        pipeline.set_progress_bar_config(disable=True)
         | 
| 945 |  | 
| 946 | 
             
                        # run inference
         | 
| 947 | 
            -
                        generator = torch.Generator(device= | 
| 948 | 
             
                        prompt = args.num_validation_images * [args.validation_prompt]
         | 
| 949 | 
             
                        images = pipeline(prompt, num_inference_steps=25, generator=generator).images
         | 
| 950 |  | 
|  | |
| 940 | 
             
                            torch_dtype=weight_dtype,
         | 
| 941 | 
             
                        )
         | 
| 942 | 
             
                        pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
         | 
| 943 | 
            +
                        pipeline = pipeline.to('cpu')
         | 
| 944 | 
             
                        pipeline.set_progress_bar_config(disable=True)
         | 
| 945 |  | 
| 946 | 
             
                        # run inference
         | 
| 947 | 
            +
                        generator = torch.Generator(device='cpu').manual_seed(args.seed)
         | 
| 948 | 
             
                        prompt = args.num_validation_images * [args.validation_prompt]
         | 
| 949 | 
             
                        images = pipeline(prompt, num_inference_steps=25, generator=generator).images
         | 
| 950 |  | 
