Spaces:
Runtime error
Runtime error
| import os | |
| import einops | |
| import gradio as gr | |
| from gradio_imageslider import ImageSlider | |
| import numpy as np | |
| import torch | |
| import random | |
| from PIL import Image | |
| from pathlib import Path | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| from pytorch_lightning import seed_everything | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor | |
| from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler | |
| from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline | |
| from myutils.misc import load_dreambooth_lora, rand_name | |
| from myutils.wavelet_color_fix import wavelet_color_fix | |
| from annotator.retinaface import RetinaFaceDetection | |
| use_pasd_light = False | |
| face_detector = RetinaFaceDetection() | |
| if use_pasd_light: | |
| from models.pasd_light.unet_2d_condition import UNet2DConditionModel | |
| from models.pasd_light.controlnet import ControlNetModel | |
| else: | |
| from models.pasd.unet_2d_condition import UNet2DConditionModel | |
| from models.pasd.controlnet import ControlNetModel | |
| pretrained_model_path = "checkpoints/stable-diffusion-v1-5" | |
| ckpt_path = "runs/pasd/checkpoint-100000" | |
| #dreambooth_lora_path = "checkpoints/personalized_models/toonyou_beta3.safetensors" | |
| dreambooth_lora_path = "checkpoints/personalized_models/majicmixRealistic_v6.safetensors" | |
| #dreambooth_lora_path = "checkpoints/personalized_models/Realistic_Vision_V5.1.safetensors" | |
| weight_dtype = torch.float16 | |
| device = "cuda" | |
| scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
| feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor") | |
| unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet") | |
| controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet") | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| unet.requires_grad_(False) | |
| controlnet.requires_grad_(False) | |
| unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path) | |
| text_encoder.to(device, dtype=weight_dtype) | |
| vae.to(device, dtype=weight_dtype) | |
| unet.to(device, dtype=weight_dtype) | |
| controlnet.to(device, dtype=weight_dtype) | |
| validation_pipeline = StableDiffusionControlNetPipeline( | |
| vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, | |
| unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, | |
| ) | |
| #validation_pipeline.enable_vae_tiling() | |
| validation_pipeline._init_tiled_vae(decoder_tile_size=224) | |
| weights = ResNet50_Weights.DEFAULT | |
| preprocess = weights.transforms() | |
| resnet = resnet50(weights=weights) | |
| resnet.eval() | |
| def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed): | |
| process_size = 768 | |
| resize_preproc = transforms.Compose([ | |
| transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), | |
| ]) | |
| with torch.no_grad(): | |
| seed_everything(seed) | |
| generator = torch.Generator(device=device) | |
| input_image = input_image.convert('RGB') | |
| batch = preprocess(input_image).unsqueeze(0) | |
| prediction = resnet(batch).squeeze(0).softmax(0) | |
| class_id = prediction.argmax().item() | |
| score = prediction[class_id].item() | |
| category_name = weights.meta["categories"][class_id] | |
| if score >= 0.1: | |
| prompt += f"{category_name}" if prompt=='' else f", {category_name}" | |
| prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}" | |
| ori_width, ori_height = input_image.size | |
| resize_flag = False | |
| rscale = upscale | |
| input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale)) | |
| #if min(validation_image.size) < process_size: | |
| # validation_image = resize_preproc(validation_image) | |
| input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8)) | |
| width, height = input_image.size | |
| resize_flag = True # | |
| try: | |
| image = validation_pipeline( | |
| None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg, | |
| negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0, | |
| ).images[0] | |
| if True: #alpha<1.0: | |
| image = wavelet_color_fix(image, input_image) | |
| if resize_flag: | |
| image = image.resize((ori_width*rscale, ori_height*rscale)) | |
| except Exception as e: | |
| print(e) | |
| image = Image.new(mode="RGB", size=(512, 512)) | |
| # Convert and save the image as JPEG | |
| image.save('result.jpg', 'JPEG') | |
| # Convert and save the image as JPEG | |
| input_image.save('input.jpg', 'JPEG') | |
| return ("input.jpg", "result.jpg"), "result.jpg" | |
| title = "Pixel-Aware Stable Diffusion for Real-ISR" | |
| description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them." | |
| article = "<p style='text-align: center'><a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a></p>" | |
| #examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']] | |
| css = """ | |
| #col-container{ | |
| margin: 0 auto; | |
| max-width: 720px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| with gr.HTML(f""" | |
| <h2 style="text-align: center;> | |
| {title} | |
| </h2> | |
| <p style="text-align: center;> | |
| {description} <br /> | |
| {article} | |
| </p> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", sources=["upload"]) | |
| prompt_in = gr.Textbox(label="Prompt", value="Asian") | |
| with gr.Accordion(label="Advanced settings", open=False): | |
| added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece') | |
| neg_promp = gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') | |
| denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1) | |
| upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1) | |
| condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1) | |
| classifier_free_guidance = gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)] | |
| submit_btn = gr.Button("Submit") | |
| with gr.Column(): | |
| b_a_slider = ImageSlider(label="B/A result", position=0.5) | |
| file_output = gr.File(label="Downloadable image result") | |
| submit_btn.click( | |
| fn = inference, | |
| inputs = [ | |
| input_image, prompt_in, | |
| added_prompt, neg_prompt, | |
| denoise_steps, | |
| upsample_scale, condition_scale, | |
| clasifier_free_guidance, seed | |
| ], | |
| outputs = [ | |
| b_a_slider, | |
| file_output | |
| ] | |
| ) | |
| demo.queue().launch() |