Spaces:
Runtime error
Runtime error
| # https://github.com/GaParmar/img2img-turbo/blob/main/src/pix2pix_turbo.py | |
| import os | |
| import requests | |
| import sys | |
| import pdb | |
| import copy | |
| from tqdm import tqdm | |
| import torch | |
| from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel | |
| from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler | |
| from diffusers.utils.peft_utils import set_weights_and_activate_adapters | |
| from peft import LoraConfig | |
| from pipelines.pix2pix.model import ( | |
| make_1step_sched, | |
| my_vae_encoder_fwd, | |
| my_vae_decoder_fwd, | |
| ) | |
| class TwinConv(torch.nn.Module): | |
| def __init__(self, convin_pretrained, convin_curr): | |
| super(TwinConv, self).__init__() | |
| self.conv_in_pretrained = copy.deepcopy(convin_pretrained) | |
| self.conv_in_curr = copy.deepcopy(convin_curr) | |
| self.r = None | |
| def forward(self, x): | |
| x1 = self.conv_in_pretrained(x).detach() | |
| x2 = self.conv_in_curr(x) | |
| return x1 * (1 - self.r) + x2 * (self.r) | |
| class Pix2Pix_Turbo(torch.nn.Module): | |
| def __init__(self, name, ckpt_folder="checkpoints"): | |
| super().__init__() | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "stabilityai/sd-turbo", subfolder="tokenizer" | |
| ) | |
| self.text_encoder = CLIPTextModel.from_pretrained( | |
| "stabilityai/sd-turbo", subfolder="text_encoder" | |
| ).cuda() | |
| self.sched = make_1step_sched() | |
| vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae") | |
| unet = UNet2DConditionModel.from_pretrained( | |
| "stabilityai/sd-turbo", subfolder="unet" | |
| ) | |
| if name == "edge_to_image": | |
| url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl" | |
| os.makedirs(ckpt_folder, exist_ok=True) | |
| outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl") | |
| if not os.path.exists(outf): | |
| print(f"Downloading checkpoint to {outf}") | |
| response = requests.get(url, stream=True) | |
| total_size_in_bytes = int(response.headers.get("content-length", 0)) | |
| block_size = 1024 # 1 Kibibyte | |
| progress_bar = tqdm( | |
| total=total_size_in_bytes, unit="iB", unit_scale=True | |
| ) | |
| with open(outf, "wb") as file: | |
| for data in response.iter_content(block_size): | |
| progress_bar.update(len(data)) | |
| file.write(data) | |
| progress_bar.close() | |
| if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
| print("ERROR, something went wrong") | |
| print(f"Downloaded successfully to {outf}") | |
| p_ckpt = outf | |
| sd = torch.load(p_ckpt, map_location="cpu") | |
| unet_lora_config = LoraConfig( | |
| r=sd["rank_unet"], | |
| init_lora_weights="gaussian", | |
| target_modules=sd["unet_lora_target_modules"], | |
| ) | |
| if name == "sketch_to_image_stochastic": | |
| # download from url | |
| url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl" | |
| os.makedirs(ckpt_folder, exist_ok=True) | |
| outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl") | |
| if not os.path.exists(outf): | |
| print(f"Downloading checkpoint to {outf}") | |
| response = requests.get(url, stream=True) | |
| total_size_in_bytes = int(response.headers.get("content-length", 0)) | |
| block_size = 1024 # 1 Kibibyte | |
| progress_bar = tqdm( | |
| total=total_size_in_bytes, unit="iB", unit_scale=True | |
| ) | |
| with open(outf, "wb") as file: | |
| for data in response.iter_content(block_size): | |
| progress_bar.update(len(data)) | |
| file.write(data) | |
| progress_bar.close() | |
| if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
| print("ERROR, something went wrong") | |
| print(f"Downloaded successfully to {outf}") | |
| p_ckpt = outf | |
| sd = torch.load(p_ckpt, map_location="cpu") | |
| unet_lora_config = LoraConfig( | |
| r=sd["rank_unet"], | |
| init_lora_weights="gaussian", | |
| target_modules=sd["unet_lora_target_modules"], | |
| ) | |
| convin_pretrained = copy.deepcopy(unet.conv_in) | |
| unet.conv_in = TwinConv(convin_pretrained, unet.conv_in) | |
| vae.encoder.forward = my_vae_encoder_fwd.__get__( | |
| vae.encoder, vae.encoder.__class__ | |
| ) | |
| vae.decoder.forward = my_vae_decoder_fwd.__get__( | |
| vae.decoder, vae.decoder.__class__ | |
| ) | |
| # add the skip connection convs | |
| vae.decoder.skip_conv_1 = torch.nn.Conv2d( | |
| 512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False | |
| ).cuda() | |
| vae.decoder.skip_conv_2 = torch.nn.Conv2d( | |
| 256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False | |
| ).cuda() | |
| vae.decoder.skip_conv_3 = torch.nn.Conv2d( | |
| 128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False | |
| ).cuda() | |
| vae.decoder.skip_conv_4 = torch.nn.Conv2d( | |
| 128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False | |
| ).cuda() | |
| vae_lora_config = LoraConfig( | |
| r=sd["rank_vae"], | |
| init_lora_weights="gaussian", | |
| target_modules=sd["vae_lora_target_modules"], | |
| ) | |
| vae.decoder.ignore_skip = False | |
| vae.add_adapter(vae_lora_config, adapter_name="vae_skip") | |
| unet.add_adapter(unet_lora_config) | |
| _sd_unet = unet.state_dict() | |
| for k in sd["state_dict_unet"]: | |
| _sd_unet[k] = sd["state_dict_unet"][k] | |
| unet.load_state_dict(_sd_unet) | |
| unet.enable_xformers_memory_efficient_attention() | |
| _sd_vae = vae.state_dict() | |
| for k in sd["state_dict_vae"]: | |
| _sd_vae[k] = sd["state_dict_vae"][k] | |
| vae.load_state_dict(_sd_vae) | |
| unet.to("cuda") | |
| vae.to("cuda") | |
| unet.eval() | |
| vae.eval() | |
| self.unet, self.vae = unet, vae | |
| self.vae.decoder.gamma = 1 | |
| self.timesteps = torch.tensor([999], device="cuda").long() | |
| self.last_prompt = "" | |
| self.caption_enc = None | |
| self.device = "cuda" | |
| def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=1.0): | |
| # encode the text prompt | |
| if prompt != self.last_prompt: | |
| caption_tokens = self.tokenizer( | |
| prompt, | |
| max_length=self.tokenizer.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ).input_ids.cuda() | |
| caption_enc = self.text_encoder(caption_tokens)[0] | |
| self.caption_enc = caption_enc | |
| self.last_prompt = prompt | |
| if deterministic: | |
| encoded_control = ( | |
| self.vae.encode(c_t).latent_dist.sample() | |
| * self.vae.config.scaling_factor | |
| ) | |
| model_pred = self.unet( | |
| encoded_control, | |
| self.timesteps, | |
| encoder_hidden_states=self.caption_enc, | |
| ).sample | |
| x_denoised = self.sched.step( | |
| model_pred, self.timesteps, encoded_control, return_dict=True | |
| ).prev_sample | |
| self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks | |
| output_image = ( | |
| self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample | |
| ).clamp(-1, 1) | |
| else: | |
| # scale the lora weights based on the r value | |
| self.unet.set_adapters(["default"], weights=[r]) | |
| set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r]) | |
| encoded_control = ( | |
| self.vae.encode(c_t).latent_dist.sample() | |
| * self.vae.config.scaling_factor | |
| ) | |
| # combine the input and noise | |
| unet_input = encoded_control * r + noise_map * (1 - r) | |
| self.unet.conv_in.r = r | |
| unet_output = self.unet( | |
| unet_input, | |
| self.timesteps, | |
| encoder_hidden_states=self.caption_enc, | |
| ).sample | |
| self.unet.conv_in.r = None | |
| x_denoised = self.sched.step( | |
| unet_output, self.timesteps, unet_input, return_dict=True | |
| ).prev_sample | |
| self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks | |
| self.vae.decoder.gamma = r | |
| output_image = ( | |
| self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample | |
| ).clamp(-1, 1) | |
| return output_image | |