Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers import DDIMScheduler | |
| from torchvision.io import read_image | |
| from pytorch_lightning import seed_everything | |
| from masactrl.diffuser_utils import MasaCtrlPipeline | |
| from masactrl.masactrl_utils import (AttentionBase, | |
| regiter_attention_editor_diffusers) | |
| from .app_utils import global_context | |
| torch.set_grad_enabled(False) | |
| # device = torch.device("cuda") if torch.cuda.is_available() else torch.device( | |
| # "cpu") | |
| # model_path = "CompVis/stable-diffusion-v1-4" | |
| # scheduler = DDIMScheduler(beta_start=0.00085, | |
| # beta_end=0.012, | |
| # beta_schedule="scaled_linear", | |
| # clip_sample=False, | |
| # set_alpha_to_one=False) | |
| # model = MasaCtrlPipeline.from_pretrained(model_path, | |
| # scheduler=scheduler).to(device) | |
| def load_image(image_path): | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| image = read_image(image_path) | |
| image = image[:3].unsqueeze_(0).float() / 127.5 - 1. # [-1, 1] | |
| image = F.interpolate(image, (512, 512)) | |
| image = image.to(device) | |
| def real_image_editing(source_image, target_prompt, | |
| starting_step, starting_layer, ddim_steps, scale, seed, | |
| appended_prompt, negative_prompt): | |
| from masactrl.masactrl import MutualSelfAttentionControl | |
| model = global_context["model"] | |
| device = global_context["device"] | |
| seed_everything(seed) | |
| with torch.no_grad(): | |
| if appended_prompt is not None: | |
| target_prompt += appended_prompt | |
| ref_prompt = "" | |
| prompts = [ref_prompt, target_prompt] | |
| # invert the image into noise map | |
| if isinstance(source_image, np.ndarray): | |
| source_image = torch.from_numpy(source_image).to(device) / 127.5 - 1. | |
| source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2) | |
| source_image = F.interpolate(source_image, (512, 512)) | |
| start_code, latents_list = model.invert(source_image, | |
| ref_prompt, | |
| guidance_scale=scale, | |
| num_inference_steps=ddim_steps, | |
| return_intermediates=True) | |
| start_code = start_code.expand(len(prompts), -1, -1, -1) | |
| # recontruct the image with inverted DDIM noise map | |
| editor = AttentionBase() | |
| regiter_attention_editor_diffusers(model, editor) | |
| image_fixed = model([target_prompt], | |
| latents=start_code[-1:], | |
| num_inference_steps=ddim_steps, | |
| guidance_scale=scale) | |
| image_fixed = image_fixed.cpu().permute(0, 2, 3, 1).numpy() | |
| # inference the synthesized image with MasaCtrl | |
| # hijack the attention module | |
| controller = MutualSelfAttentionControl(starting_step, starting_layer) | |
| regiter_attention_editor_diffusers(model, controller) | |
| # inference the synthesized image | |
| image_masactrl = model(prompts, | |
| latents=start_code, | |
| guidance_scale=scale) | |
| image_masactrl = image_masactrl.cpu().permute(0, 2, 3, 1).numpy() | |
| return [ | |
| image_masactrl[0], | |
| image_fixed[0], | |
| image_masactrl[1] | |
| ] # source, fixed seed, masactrl | |
| def create_demo_editing(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## **Input Settings**") | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_image = gr.Image(label="Source Image", value=os.path.join(os.path.dirname(__file__), "images/corgi.jpg"), interactive=True) | |
| target_prompt = gr.Textbox(label="Target Prompt", | |
| value='A photo of a running corgi', | |
| interactive=True) | |
| with gr.Row(): | |
| ddim_steps = gr.Slider(label="DDIM Steps", | |
| minimum=1, | |
| maximum=999, | |
| value=50, | |
| step=1) | |
| starting_step = gr.Slider(label="Step of MasaCtrl", | |
| minimum=0, | |
| maximum=999, | |
| value=4, | |
| step=1) | |
| starting_layer = gr.Slider(label="Layer of MasaCtrl", | |
| minimum=0, | |
| maximum=16, | |
| value=10, | |
| step=1) | |
| run_btn = gr.Button("Run") | |
| with gr.Column(): | |
| appended_prompt = gr.Textbox(label="Appended Prompt", value='') | |
| negative_prompt = gr.Textbox(label="Negative Prompt", value='') | |
| with gr.Row(): | |
| scale = gr.Slider(label="CFG Scale", | |
| minimum=0.1, | |
| maximum=30.0, | |
| value=7.5, | |
| step=0.1) | |
| seed = gr.Slider(label="Seed", | |
| minimum=-1, | |
| maximum=2147483647, | |
| value=42, | |
| step=1) | |
| gr.Markdown("## **Output**") | |
| with gr.Row(): | |
| image_recons = gr.Image(label="Source Image") | |
| image_fixed = gr.Image(label="Image with Fixed Seed") | |
| image_masactrl = gr.Image(label="Image with MasaCtrl") | |
| inputs = [ | |
| source_image, target_prompt, starting_step, starting_layer, ddim_steps, | |
| scale, seed, appended_prompt, negative_prompt | |
| ] | |
| run_btn.click(real_image_editing, inputs, | |
| [image_recons, image_fixed, image_masactrl]) | |
| gr.Examples( | |
| [[os.path.join(os.path.dirname(__file__), "images/corgi.jpg"), | |
| "A photo of a running corgi"], | |
| [os.path.join(os.path.dirname(__file__), "images/person.png"), | |
| "A photo of a person, black t-shirt, raising hand"], | |
| ], | |
| [source_image, target_prompt] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo_editing = create_demo_editing() | |
| demo_editing.launch() | |