Spaces:
Runtime error
Runtime error
update fix
Browse files
utils/stable_diffusion_controlnet_inpaint.py
CHANGED
|
@@ -1046,7 +1046,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixi
|
|
| 1046 |
do_classifier_free_guidance,
|
| 1047 |
)
|
| 1048 |
if self.unet.config.in_channels==4:
|
| 1049 |
-
init_masked_image_latents
|
| 1050 |
image,
|
| 1051 |
batch_size * num_images_per_prompt,
|
| 1052 |
height,
|
|
@@ -1055,8 +1055,10 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixi
|
|
| 1055 |
device,
|
| 1056 |
generator,
|
| 1057 |
do_classifier_free_guidance,
|
| 1058 |
-
)
|
| 1059 |
-
|
|
|
|
|
|
|
| 1060 |
_, _, w, h = mask_image.shape
|
| 1061 |
mask_image = torch.nn.functional.interpolate(mask_image, ((w // 8, h // 8)), mode='nearest')
|
| 1062 |
mask_image = mask_image.to(latents.device).type_as(latents)
|
|
|
|
| 1046 |
do_classifier_free_guidance,
|
| 1047 |
)
|
| 1048 |
if self.unet.config.in_channels==4:
|
| 1049 |
+
init_masked_image_latents = self.prepare_masked_image_latents(
|
| 1050 |
image,
|
| 1051 |
batch_size * num_images_per_prompt,
|
| 1052 |
height,
|
|
|
|
| 1055 |
device,
|
| 1056 |
generator,
|
| 1057 |
do_classifier_free_guidance,
|
| 1058 |
+
)
|
| 1059 |
+
if do_classifier_free_guidance:
|
| 1060 |
+
init_masked_image_latents, _ = init_masked_image_latents.chunk(2)
|
| 1061 |
+
# print(type(mask_image), mask_image.shape)
|
| 1062 |
_, _, w, h = mask_image.shape
|
| 1063 |
mask_image = torch.nn.functional.interpolate(mask_image, ((w // 8, h // 8)), mode='nearest')
|
| 1064 |
mask_image = mask_image.to(latents.device).type_as(latents)
|