|  | from transformers import CLIPTextModel, CLIPTokenizer, logging | 
					
						
						|  | from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging.set_verbosity_error() | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torchvision.transforms as T | 
					
						
						|  | import argparse | 
					
						
						|  | import numpy as np | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def seed_everything(seed): | 
					
						
						|  | torch.manual_seed(seed) | 
					
						
						|  | torch.cuda.manual_seed(seed) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_views(panorama_height, panorama_width, window_size=64, stride=8): | 
					
						
						|  | panorama_height /= 8 | 
					
						
						|  | panorama_width /= 8 | 
					
						
						|  | num_blocks_height = (panorama_height - window_size) // stride + 1 | 
					
						
						|  | num_blocks_width = (panorama_width - window_size) // stride + 1 | 
					
						
						|  | total_num_blocks = int(num_blocks_height * num_blocks_width) | 
					
						
						|  | views = [] | 
					
						
						|  | for i in range(total_num_blocks): | 
					
						
						|  | h_start = int((i // num_blocks_width) * stride) | 
					
						
						|  | h_end = h_start + window_size | 
					
						
						|  | w_start = int((i % num_blocks_width) * stride) | 
					
						
						|  | w_end = w_start + window_size | 
					
						
						|  | views.append((h_start, h_end, w_start, w_end)) | 
					
						
						|  | return views | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MultiDiffusion(nn.Module): | 
					
						
						|  | def __init__(self, device, sd_version='2.0', hf_key=None): | 
					
						
						|  | super().__init__() | 
					
						
						|  |  | 
					
						
						|  | self.device = device | 
					
						
						|  | self.sd_version = sd_version | 
					
						
						|  |  | 
					
						
						|  | print(f'[INFO] loading stable diffusion...') | 
					
						
						|  | if hf_key is not None: | 
					
						
						|  | print(f'[INFO] using hugging face custom model key: {hf_key}') | 
					
						
						|  | model_key = hf_key | 
					
						
						|  | elif self.sd_version == '2.1': | 
					
						
						|  | model_key = "stabilityai/stable-diffusion-2-1-base" | 
					
						
						|  | elif self.sd_version == '2.0': | 
					
						
						|  | model_key = "stabilityai/stable-diffusion-2-base" | 
					
						
						|  | elif self.sd_version == '1.5': | 
					
						
						|  | model_key = "runwayml/stable-diffusion-v1-5" | 
					
						
						|  | else: | 
					
						
						|  | model_key = self.sd_version | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device) | 
					
						
						|  | self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") | 
					
						
						|  | self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device) | 
					
						
						|  | self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device) | 
					
						
						|  |  | 
					
						
						|  | self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") | 
					
						
						|  |  | 
					
						
						|  | print(f'[INFO] loaded stable diffusion!') | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def get_random_background(self, n_samples): | 
					
						
						|  |  | 
					
						
						|  | backgrounds = torch.rand(n_samples, 3, device=self.device)[:, :, None, None].repeat(1, 1, 512, 512) | 
					
						
						|  | return torch.cat([self.encode_imgs(bg.unsqueeze(0)) for bg in backgrounds]) | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def get_text_embeds(self, prompt, negative_prompt): | 
					
						
						|  |  | 
					
						
						|  | text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | 
					
						
						|  | truncation=True, return_tensors='pt') | 
					
						
						|  | text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | 
					
						
						|  | return_tensors='pt') | 
					
						
						|  |  | 
					
						
						|  | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | 
					
						
						|  | return text_embeddings | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def encode_imgs(self, imgs): | 
					
						
						|  | imgs = 2 * imgs - 1 | 
					
						
						|  | posterior = self.vae.encode(imgs).latent_dist | 
					
						
						|  | latents = posterior.sample() * 0.18215 | 
					
						
						|  | return latents | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def decode_latents(self, latents): | 
					
						
						|  | latents = 1 / 0.18215 * latents | 
					
						
						|  | imgs = self.vae.decode(latents).sample | 
					
						
						|  | imgs = (imgs / 2 + 0.5).clamp(0, 1) | 
					
						
						|  | return imgs | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def generate(self, masks, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50, | 
					
						
						|  | guidance_scale=7.5, bootstrapping=20): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bootstrapping_backgrounds = self.get_random_background(bootstrapping) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | text_embeds = self.get_text_embeds(prompts, negative_prompts) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device) | 
					
						
						|  | noise = latent.clone().repeat(len(prompts) - 1, 1, 1, 1) | 
					
						
						|  | views = get_views(height, width) | 
					
						
						|  | count = torch.zeros_like(latent) | 
					
						
						|  | value = torch.zeros_like(latent) | 
					
						
						|  |  | 
					
						
						|  | self.scheduler.set_timesteps(num_inference_steps) | 
					
						
						|  |  | 
					
						
						|  | with torch.autocast('cuda'): | 
					
						
						|  | for i, t in enumerate(self.scheduler.timesteps): | 
					
						
						|  | count.zero_() | 
					
						
						|  | value.zero_() | 
					
						
						|  |  | 
					
						
						|  | for h_start, h_end, w_start, w_end in views: | 
					
						
						|  | masks_view = masks[:, :, h_start:h_end, w_start:w_end] | 
					
						
						|  | latent_view = latent[:, :, h_start:h_end, w_start:w_end].repeat(len(prompts), 1, 1, 1) | 
					
						
						|  | if i < bootstrapping: | 
					
						
						|  | bg = bootstrapping_backgrounds[torch.randint(0, bootstrapping, (len(prompts) - 1,))] | 
					
						
						|  | bg = self.scheduler.add_noise(bg, noise[:, :, h_start:h_end, w_start:w_end], t) | 
					
						
						|  | latent_view[1:] = latent_view[1:] * masks_view[1:] + bg * (1 - masks_view[1:]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latent_model_input = torch.cat([latent_view] * 2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 
					
						
						|  | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample'] | 
					
						
						|  |  | 
					
						
						|  | value[:, :, h_start:h_end, w_start:w_end] += (latents_view_denoised * masks_view).sum(dim=0, | 
					
						
						|  | keepdims=True) | 
					
						
						|  | count[:, :, h_start:h_end, w_start:w_end] += masks_view.sum(dim=0, keepdims=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | latent = torch.where(count > 0, value / count, value) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | imgs = self.decode_latents(latent) | 
					
						
						|  | img = T.ToPILImage()(imgs[0].cpu()) | 
					
						
						|  | return img | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def preprocess_mask(mask_path, h, w, device): | 
					
						
						|  | mask = np.array(Image.open(mask_path).convert("L")) | 
					
						
						|  | mask = mask.astype(np.float32) / 255.0 | 
					
						
						|  | mask = mask[None, None] | 
					
						
						|  | mask[mask < 0.5] = 0 | 
					
						
						|  | mask[mask >= 0.5] = 1 | 
					
						
						|  | mask = torch.from_numpy(mask).to(device) | 
					
						
						|  | mask = torch.nn.functional.interpolate(mask, size=(h, w), mode='nearest') | 
					
						
						|  | return mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  | parser.add_argument('--mask_paths', type=list) | 
					
						
						|  | parser.add_argument('--bg_prompt', type=str) | 
					
						
						|  | parser.add_argument('--bg_negative', type=str) | 
					
						
						|  | parser.add_argument('--fg_prompts', type=list) | 
					
						
						|  | parser.add_argument('--fg_negative', type=list) | 
					
						
						|  | parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'], | 
					
						
						|  | help="stable diffusion version") | 
					
						
						|  | parser.add_argument('--H', type=int, default=768) | 
					
						
						|  | parser.add_argument('--W', type=int, default=512) | 
					
						
						|  | parser.add_argument('--seed', type=int, default=0) | 
					
						
						|  | parser.add_argument('--steps', type=int, default=50) | 
					
						
						|  | parser.add_argument('--bootstrapping', type=int, default=20) | 
					
						
						|  | opt = parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  | seed_everything(opt.seed) | 
					
						
						|  |  | 
					
						
						|  | device = torch.device('cuda') | 
					
						
						|  |  | 
					
						
						|  | sd = MultiDiffusion(device, opt.sd_version) | 
					
						
						|  |  | 
					
						
						|  | fg_masks = torch.cat([preprocess_mask(mask_path, opt.H // 8, opt.W // 8, device) for mask_path in opt.mask_paths]) | 
					
						
						|  | bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True) | 
					
						
						|  | bg_mask[bg_mask < 0] = 0 | 
					
						
						|  | masks = torch.cat([bg_mask, fg_masks]) | 
					
						
						|  |  | 
					
						
						|  | prompts = [opt.bg_prompt] + opt.fg_prompts | 
					
						
						|  | neg_prompts = [opt.bg_negative] + opt.fg_negative | 
					
						
						|  |  | 
					
						
						|  | img = sd.generate(masks, prompts, neg_prompts, opt.H, opt.W, opt.steps, bootstrapping=opt.bootstrapping) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | img.save('out.png') | 
					
						
						|  |  |