Spaces:
Runtime error
Runtime error
Update train_dreambooth_lora.py
Browse files- train_dreambooth_lora.py +2 -2
train_dreambooth_lora.py
CHANGED
|
@@ -940,11 +940,11 @@ def main(args):
|
|
| 940 |
torch_dtype=weight_dtype,
|
| 941 |
)
|
| 942 |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
| 943 |
-
pipeline = pipeline.to(
|
| 944 |
pipeline.set_progress_bar_config(disable=True)
|
| 945 |
|
| 946 |
# run inference
|
| 947 |
-
generator = torch.Generator(device=
|
| 948 |
prompt = args.num_validation_images * [args.validation_prompt]
|
| 949 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
| 950 |
|
|
|
|
| 940 |
torch_dtype=weight_dtype,
|
| 941 |
)
|
| 942 |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
| 943 |
+
pipeline = pipeline.to(accelerator.device)
|
| 944 |
pipeline.set_progress_bar_config(disable=True)
|
| 945 |
|
| 946 |
# run inference
|
| 947 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
| 948 |
prompt = args.num_validation_images * [args.validation_prompt]
|
| 949 |
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
| 950 |
|