Spaces:
Configuration error
Configuration error
| # Adapted from https://github.com/MichalGeyer/pnp-diffusers/blob/main/preprocess.py | |
| from transformers import CLIPTextModel, CLIPTokenizer, logging | |
| from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler | |
| # suppress partial model loading warning | |
| logging.set_verbosity_error() | |
| import os | |
| from PIL import Image | |
| from tqdm import tqdm, trange | |
| import torch | |
| import torch.nn as nn | |
| import argparse | |
| from pathlib import Path | |
| from pnp_utils import * | |
| import torchvision.transforms as T | |
| def get_timesteps(scheduler, num_inference_steps, strength, device): | |
| # get the original timestep using init_timestep | |
| init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | |
| t_start = max(num_inference_steps - init_timestep, 0) | |
| timesteps = scheduler.timesteps[t_start:] | |
| return timesteps, num_inference_steps - t_start | |
| class Preprocess(nn.Module): | |
| def __init__(self, device, sd_version='2.0', hf_key=None): | |
| super().__init__() | |
| self.device = device | |
| self.sd_version = sd_version | |
| self.use_depth = False | |
| print(f'[INFO] loading stable diffusion...') | |
| if hf_key is not None: | |
| print(f'[INFO] using hugging face custom model key: {hf_key}') | |
| model_key = hf_key | |
| elif self.sd_version == '2.1': | |
| model_key = "stabilityai/stable-diffusion-2-1-base" | |
| elif self.sd_version == '2.0': | |
| model_key = "stabilityai/stable-diffusion-2-base" | |
| elif self.sd_version == '1.5': | |
| model_key = "runwayml/stable-diffusion-v1-5" | |
| elif self.sd_version == 'depth': | |
| model_key = "stabilityai/stable-diffusion-2-depth" | |
| self.use_depth = True | |
| elif self.sd_version == '1.4': | |
| model_key = "CompVis/stable-diffusion-v1-4" | |
| else: | |
| raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') | |
| # Create model | |
| self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", revision="fp16", | |
| torch_dtype=torch.float16).to(self.device) | |
| self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") | |
| self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder", revision="fp16", | |
| torch_dtype=torch.float16).to(self.device) | |
| self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", revision="fp16", | |
| torch_dtype=torch.float16).to(self.device) | |
| self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") | |
| print(f'[INFO] loaded stable diffusion!') | |
| self.inversion_func = self.ddim_inversion | |
| def get_text_embeds(self, prompt, negative_prompt, device="cuda"): | |
| text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
| truncation=True, return_tensors='pt') | |
| text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0] | |
| uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, | |
| return_tensors='pt') | |
| uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| return text_embeddings | |
| def decode_latents(self, latents): | |
| with torch.autocast(device_type='cuda', dtype=torch.float32): | |
| latents = 1 / 0.18215 * latents | |
| imgs = self.vae.decode(latents).sample | |
| imgs = (imgs / 2 + 0.5).clamp(0, 1) | |
| return imgs | |
| def load_img(self, image_path): | |
| image_pil = T.Resize(512)(Image.open(image_path).convert("RGB")) | |
| image = T.ToTensor()(image_pil).unsqueeze(0).to(self.device) | |
| return image | |
| def encode_imgs(self, imgs): | |
| with torch.autocast(device_type='cuda', dtype=torch.float32): | |
| imgs = 2 * imgs - 1 | |
| posterior = self.vae.encode(imgs).latent_dist | |
| latents = posterior.mean * 0.18215 | |
| return latents | |
| def ddim_inversion(self, cond, latent, save_path, save_latents=True, | |
| timesteps_to_save=None): | |
| timesteps = reversed(self.scheduler.timesteps) | |
| with torch.autocast(device_type='cuda', dtype=torch.float32): | |
| for i, t in enumerate(tqdm(timesteps)): | |
| cond_batch = cond.repeat(latent.shape[0], 1, 1) | |
| alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
| alpha_prod_t_prev = ( | |
| self.scheduler.alphas_cumprod[timesteps[i - 1]] | |
| if i > 0 else self.scheduler.final_alpha_cumprod | |
| ) | |
| mu = alpha_prod_t ** 0.5 | |
| mu_prev = alpha_prod_t_prev ** 0.5 | |
| sigma = (1 - alpha_prod_t) ** 0.5 | |
| sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 | |
| eps = self.unet(latent, t, encoder_hidden_states=cond_batch).sample | |
| pred_x0 = (latent - sigma_prev * eps) / mu_prev | |
| latent = mu * pred_x0 + sigma * eps | |
| if save_latents: | |
| torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) | |
| torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) | |
| return latent | |
| def ddim_sample(self, x, cond, save_path, save_latents=False, timesteps_to_save=None): | |
| timesteps = self.scheduler.timesteps | |
| with torch.autocast(device_type='cuda', dtype=torch.float32): | |
| for i, t in enumerate(tqdm(timesteps)): | |
| cond_batch = cond.repeat(x.shape[0], 1, 1) | |
| alpha_prod_t = self.scheduler.alphas_cumprod[t] | |
| alpha_prod_t_prev = ( | |
| self.scheduler.alphas_cumprod[timesteps[i + 1]] | |
| if i < len(timesteps) - 1 | |
| else self.scheduler.final_alpha_cumprod | |
| ) | |
| mu = alpha_prod_t ** 0.5 | |
| sigma = (1 - alpha_prod_t) ** 0.5 | |
| mu_prev = alpha_prod_t_prev ** 0.5 | |
| sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 | |
| eps = self.unet(x, t, encoder_hidden_states=cond_batch).sample | |
| pred_x0 = (x - sigma * eps) / mu | |
| x = mu_prev * pred_x0 + sigma_prev * eps | |
| if save_latents: | |
| torch.save(x, os.path.join(save_path, f'noisy_latents_{t}.pt')) | |
| return x | |
| def extract_latents(self, num_steps, data_path, save_path, timesteps_to_save, | |
| inversion_prompt='', extract_reverse=False): | |
| self.scheduler.set_timesteps(num_steps) | |
| cond = self.get_text_embeds(inversion_prompt, "")[1].unsqueeze(0) | |
| image = self.load_img(data_path) | |
| latent = self.encode_imgs(image) | |
| inverted_x = self.inversion_func(cond, latent, save_path, save_latents=not extract_reverse, | |
| timesteps_to_save=timesteps_to_save) | |
| latent_reconstruction = self.ddim_sample(inverted_x, cond, save_path, save_latents=extract_reverse, | |
| timesteps_to_save=timesteps_to_save) | |
| rgb_reconstruction = self.decode_latents(latent_reconstruction) | |
| return rgb_reconstruction # , latent_reconstruction | |
| def run(opt): | |
| device = 'cuda' | |
| # timesteps to save | |
| if opt.sd_version == '2.1': | |
| model_key = "stabilityai/stable-diffusion-2-1-base" | |
| elif opt.sd_version == '2.0': | |
| model_key = "stabilityai/stable-diffusion-2-base" | |
| elif opt.sd_version == '1.5': | |
| model_key = "runwayml/stable-diffusion-v1-5" | |
| elif opt.sd_version == 'depth': | |
| model_key = "stabilityai/stable-diffusion-2-depth" | |
| elif opt.sd_version == '1.4': | |
| model_key = "CompVis/stable-diffusion-v1-4" | |
| toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") | |
| toy_scheduler.set_timesteps(opt.save_steps) | |
| timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=opt.save_steps, | |
| strength=1.0, | |
| device=device) | |
| seed_everything(opt.seed) | |
| extraction_path_prefix = "_reverse" if opt.extract_reverse else "_forward" | |
| save_path = os.path.join(opt.save_dir + extraction_path_prefix, os.path.splitext(os.path.basename(opt.data_path))[0]) | |
| os.makedirs(save_path, exist_ok=True) | |
| model = Preprocess(device, sd_version=opt.sd_version, hf_key=None) | |
| recon_image = model.extract_latents(data_path=opt.data_path, | |
| num_steps=opt.steps, | |
| save_path=save_path, | |
| timesteps_to_save=timesteps_to_save, | |
| inversion_prompt=opt.inversion_prompt, | |
| extract_reverse=opt.extract_reverse) | |
| T.ToPILImage()(recon_image[0]).save(os.path.join(save_path, f'recon.jpg')) | |
| if __name__ == "__main__": | |
| device = 'cuda' | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--data_path', type=str, | |
| default='data/source_2.png') | |
| parser.add_argument('--save_dir', type=str, default='latents') | |
| parser.add_argument('--sd_version', type=str, default='1.4', choices=['1.5', '2.0', '2.1', '1.4'], | |
| help="stable diffusion version") | |
| parser.add_argument('--seed', type=int, default=1) | |
| parser.add_argument('--steps', type=int, default=50) | |
| parser.add_argument('--save-steps', type=int, default=1000) | |
| parser.add_argument('--inversion_prompt', type=str, default='') | |
| parser.add_argument('--extract-reverse', default=False, action='store_true', help="extract features during the denoising process") | |
| opt = parser.parse_args() | |
| run(opt) | |