import warnings warnings.filterwarnings("ignore") from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler import torch from typing import Optional from tqdm import tqdm from diffusers.models.attention_processor import Attention, AttnProcessor2_0 import torchvision import torch.nn as nn import torch.nn.functional as F import gc import gradio as gr import numpy as np import os import pickle from transformers import CLIPImageProcessor from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker import argparse weights = { 'down': { 4096: 0.0, 1024: 1.0, 256: 1.0, }, 'mid': { 64: 1.0, }, 'up': { 256: 1.0, 1024: 1.0, 4096: 0.0, } } num_inference_steps = 10 model_id = "stabilityai/stable-diffusion-2-1-base" pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda") inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler") scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda") feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") should_stop = False def save_state_to_file(state): filename = "state.pkl" with open(filename, 'wb') as f: pickle.dump(state, f) return filename def load_state_from_file(filename): with open(filename, 'rb') as f: state = pickle.load(f) return state def stop_reconstruct(): global should_stop should_stop = True def reconstruct(input_img, caption): img = input_img cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds]) transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((512, 512)), torchvision.transforms.ToTensor() ]) loaded_image = transform(img).to("cuda").unsqueeze(0) if loaded_image.shape[1] == 4: loaded_image = loaded_image[:,:3,:,:] with torch.no_grad(): encoded_image = pipe.vae.encode(loaded_image*2 - 1) real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample() guidance_scale = 1 inverse_scheduler.set_timesteps(num_inference_steps, device="cuda") timesteps = inverse_scheduler.timesteps latents = real_image_latents inversed_latents = [] with torch.no_grad(): replace_attention_processor(pipe.unet, True) for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"): inversed_latents.append(latents) latent_model_input = torch.cat([latents] * 2) noise_pred = pipe.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds_combined, cross_attention_kwargs=None, return_dict=False, )[0] noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0] # initial state real_image_initial_latents = latents W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1) QT = nn.Parameter(W_values.clone()) guidance_scale = 7.5 scheduler.set_timesteps(num_inference_steps, device="cuda") timesteps = scheduler.timesteps optimizer = torch.optim.AdamW([QT], lr=0.008) pipe.vae.eval() pipe.vae.requires_grad_(False) pipe.unet.eval() pipe.unet.requires_grad_(False) last_loss = 1 for epoch in range(50): gc.collect() torch.cuda.empty_cache() if last_loss < 0.02: break elif last_loss < 0.03: for param_group in optimizer.param_groups: param_group['lr'] = 0.003 elif last_loss < 0.035: for param_group in optimizer.param_groups: param_group['lr'] = 0.006 intermediate_values = real_image_initial_latents.clone() for i in range(num_inference_steps): latents = intermediate_values.detach().clone() t = timesteps[i] prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()]) latent_model_input = torch.cat([latents] * 2) noise_pred_model = pipe.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=None, return_dict=False, )[0] noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0] loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction="mean") last_loss = loss optimizer.zero_grad() loss.backward() optimizer.step() global should_stop if should_stop: should_stop = False break image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] image = (image / 2.0 + 0.5).clamp(0.0, 1.0) safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda") image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0] image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy() image_np = (image_np * 255).astype(np.uint8) yield image_np, caption, [caption, real_image_initial_latents, QT] image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] image = (image / 2.0 + 0.5).clamp(0.0, 1.0) safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda") image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0] image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy() image_np = (image_np * 255).astype(np.uint8) yield image_np, caption, [caption, real_image_initial_latents, QT] class AttnReplaceProcessor(AttnProcessor2_0): def __init__(self, replace_all, weight): super().__init__() self.replace_all = replace_all self.weight = weight def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, *args, **kwargs, ) -> torch.FloatTensor: residual = hidden_states is_cross = not encoder_hidden_states is None input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, _, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2)) dimension_squared = hidden_states.shape[1] if not is_cross and (self.replace_all): ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4) attn_scores_dst.copy_(self.weight[dimension_squared] * attn_scores_src + (1.0 - self.weight[dimension_squared]) * attn_scores_dst) ucond_attn_scores_dst.copy_(self.weight[dimension_squared] * ucond_attn_scores_src + (1.0 - self.weight[dimension_squared]) * ucond_attn_scores_dst) attention_probs = attention_scores.softmax(dim=-1) del attention_scores hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) del attention_probs hidden_states = attn.to_out[0](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def replace_attention_processor(unet, clear = False): for name, module in unet.named_modules(): if 'attn1' in name and 'to' not in name: layer_type = name.split('.')[0].split('_')[0] if not clear: if layer_type == 'down': module.processor = AttnReplaceProcessor(True, weights['down']) elif layer_type == 'mid': module.processor = AttnReplaceProcessor(True, weights['mid']) elif layer_type == 'up': module.processor = AttnReplaceProcessor(True, weights['up']) else: module.processor = AttnReplaceProcessor(False, 0.0) def apply_prompt(meta_data, new_prompt): caption, real_image_initial_latents, QT = meta_data inference_steps = len(QT) cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] # uncond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0] guidance_scale = 7.5 scheduler.set_timesteps(inference_steps, device="cuda") timesteps = scheduler.timesteps latents = torch.cat([real_image_initial_latents] * 2) with torch.no_grad(): replace_attention_processor(pipe.unet) for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"): modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds]) latent_model_input = torch.cat([latents] * 2) noise_pred = pipe.unet( latent_model_input, t, encoder_hidden_states=modified_prompt_embeds, cross_attention_kwargs=None, return_dict=False, )[0] noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] replace_attention_processor(pipe.unet, True) image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0] image = (image / 2.0 + 0.5).clamp(0.0, 1.0) safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda") image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0] image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy() image_np = (image_np * 255).astype(np.uint8) return image_np def on_image_change(filepath): # Extract the filename without extension filename = os.path.splitext(os.path.basename(filepath))[0] # Check if the filename is "example1" or "example2" if filename in ["example1", "example2", "example3", "example4"]: meta_data_raw = load_state_from_file(f"assets/{filename}.pkl") _, _, QT_raw = meta_data_raw global num_inference_steps num_inference_steps = len(QT_raw) scale_value = 7 new_prompt = "" if filename == "example1": scale_value = 7 new_prompt = "a photo of a tree, summer, colourful" elif filename == "example2": scale_value = 8 new_prompt = "a photo of a panda, two ears, white background" elif filename == "example3": scale_value = 7 new_prompt = "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds" elif filename == "example4": scale_value = 7 new_prompt = "a photo of plastic bottle on some sand, beach background, sky background" update_scale(scale_value) img = apply_prompt(meta_data_raw, new_prompt) return filepath, img, meta_data_raw, num_inference_steps, scale_value, scale_value def update_value(value, key, res): global weights weights[key][res] = value def update_step(value): global num_inference_steps num_inference_steps = value def update_scale(scale): values = [1.0] * 7 if scale == 9: return values reduction_steps = (9 - scale) * 0.5 for i in range(4): # There are 4 positions to reduce symmetrically if reduction_steps >= 1: values[i] = 0.0 values[-(i + 1)] = 0.0 reduction_steps -= 1 elif reduction_steps > 0: values[i] = 0.5 values[-(i + 1)] = 0.5 break global weights index = 0 for outer_key, inner_dict in weights.items(): for inner_key in inner_dict: inner_dict[inner_key] = values[index] index += 1 return weights['down'][4096], weights['down'][1024], weights['down'][256], weights['mid'][64], weights['up'][256], weights['up'][1024], weights['up'][4096] with gr.Blocks() as demo: gr.Markdown( '''
Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process
Specific Cross-Attention Influence weights can be manually modified for given resolutions (1.0 = Fully Source Attn 0.0 = Fully Target Attn)