import spaces import torch from io import BytesIO import PIL.Image import pillow_heif import numpy as np from pathlib import Path import random import gradio as gr from gradio_imageslider import ImageSlider from huggingface_hub import hf_hub_download from PIL import Image from refiners.fluxion.utils import manual_seed from refiners.foundationals.latent_diffusion import Solver, solvers import requests from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints import time import boto3 from datetime import datetime import json pillow_heif.register_heif_opener() pillow_heif.register_avif_opener() MAX_SEED = np.iinfo(np.int32).max TITLE = """ Image Enhancer """ CHECKPOINTS = ESRGANUpscalerCheckpoints( unet=Path( hf_hub_download( repo_id="refiners/juggernaut.reborn.sd1_5.unet", filename="model.safetensors", revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2", ) ), clip_text_encoder=Path( hf_hub_download( repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder", filename="model.safetensors", revision="744ad6a5c0437ec02ad826df9f6ede102bb27481", ) ), lda=Path( hf_hub_download( repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder", filename="model.safetensors", revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19", ) ), controlnet_tile=Path( hf_hub_download( repo_id="refiners/controlnet.sd1_5.tile", filename="model.safetensors", revision="48ced6ff8bfa873a8976fa467c3629a240643387", ) ), esrgan=Path( hf_hub_download( repo_id="philz1337x/upscaler", filename="4x-UltraSharp.pth", revision="011deacac8270114eb7d2eeff4fe6fa9a837be70", ) ), negative_embedding=Path( hf_hub_download( repo_id="philz1337x/embeddings", filename="JuggernautNegative-neg.pt", revision="203caa7e9cc2bc225031a4021f6ab1ded283454a", ) ), negative_embedding_key="string_to_param.*", loras={ "more_details": Path( hf_hub_download( repo_id="philz1337x/loras", filename="more_details.safetensors", revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", ) ), "sdxl_render": Path( hf_hub_download( repo_id="philz1337x/loras", filename="SDXLrender_v2.0.safetensors", revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", ) ), }, ) LORA_SCALES = { "more_details": 0.5, "sdxl_render": 1.0, } # initialize the enhancer, on the cpu DEVICE_CPU = torch.device("cpu") DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE_CPU, dtype=DTYPE) # "move" the enhancer to the gpu, this is handled by Zero GPU DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") enhancer.to(device=DEVICE, dtype=DTYPE) def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name): print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name) connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com" s3 = boto3.client( 's3', endpoint_url=connectionUrl, region_name='auto', aws_access_key_id=access_key, aws_secret_access_key=secret_key ) current_time = datetime.now().strftime("%Y%m%d_%H%M%S") image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png" buffer = BytesIO() image.save(buffer, "PNG") buffer.seek(0) s3.upload_fileobj(buffer, bucket_name, image_file) print("upload finish", image_file) return image_file class calculateDuration: def __init__(self, activity_name=""): self.activity_name = activity_name def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_value, traceback): self.end_time = time.time() self.elapsed_time = self.end_time - self.start_time if self.activity_name: print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") else: print(f"Elapsed time: {self.elapsed_time:.6f} seconds") @spaces.GPU(duration=120) def process( input_image: Image.Image, image_url:str, prompt: str = "masterpiece, best quality, highres", negative_prompt: str = "worst quality, low quality, normal quality", seed: int = 42, upscale_factor: int = 2, controlnet_scale: float = 0.6, controlnet_decay: float = 1.0, condition_scale: int = 6, tile_width: int = 112, tile_height: int = 144, denoise_strength: float = 0.35, num_inference_steps: int = 18, solver: str = "DDIM", upload_to_r2: bool = True, account_id: str = "", access_key: str = "", secret_key: str = "", bucket_name: str = "" ) -> tuple[tuple[Image.Image, Image.Image], str]: manual_seed(seed) if image_url: # fetch image from url with calculateDuration("Download Image"): print("start to fetch image from url", image_url) response = requests.get(image_url) response.raise_for_status() input_image = PIL.Image.open(BytesIO(response.content)) print("fetch image success") print("start", prompt, upscale_factor) solver_type: type[Solver] = getattr(solvers, solver) with calculateDuration("enhancer"): enhanced_image = enhancer.upscale( image=input_image, prompt=prompt, negative_prompt=negative_prompt, upscale_factor=upscale_factor, controlnet_scale=controlnet_scale, controlnet_scale_decay=controlnet_decay, condition_scale=condition_scale, tile_size=(tile_height, tile_width), denoise_strength=denoise_strength, num_inference_steps=num_inference_steps, loras_scale=LORA_SCALES, solver_type=solver_type, ) print("enhancer finish") if upload_to_r2: url = upload_image_to_r2(enhanced_image, account_id, access_key, secret_key, bucket_name) result = {"status": "success", "url": url} else: result = {"status": "success", "message": "Image generated but not uploaded"} return [input_image, enhanced_image], json.dumps(result) with gr.Blocks() as demo: gr.HTML(TITLE) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)") run_button = gr.ClearButton(components=None, value="Enhance Image") with gr.Column(): output_slider = ImageSlider(label="Generate image", type="pil", slider_color="pink") logs = gr.Textbox(label="logs") run_button.add(output_slider) with gr.Accordion("Advanced Options", open=False): prompt = gr.Textbox( label="Prompt", placeholder="masterpiece, best quality, highres", ) negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="worst quality, low quality, normal quality", ) seed = gr.Slider( minimum=0, maximum=10_000, value=42, step=1, label="Seed", ) upscale_factor = gr.Slider( minimum=1, maximum=4, value=2, step=0.2, label="Upscale Factor", ) controlnet_scale = gr.Slider( minimum=0, maximum=1.5, value=0.6, step=0.1, label="ControlNet Scale", ) controlnet_decay = gr.Slider( minimum=0.5, maximum=1, value=1.0, step=0.025, label="ControlNet Scale Decay", ) condition_scale = gr.Slider( minimum=2, maximum=20, value=6, step=1, label="Condition Scale", ) tile_width = gr.Slider( minimum=64, maximum=200, value=112, step=1, label="Latent Tile Width", ) tile_height = gr.Slider( minimum=64, maximum=200, value=144, step=1, label="Latent Tile Height", ) denoise_strength = gr.Slider( minimum=0, maximum=1, value=0.35, step=0.1, label="Denoise Strength", ) num_inference_steps = gr.Slider( minimum=1, maximum=30, value=18, step=1, label="Number of Inference Steps", ) solver = gr.Radio( choices=["DDIM", "DPMSolver"], value="DDIM", label="Solver", ) upload_to_r2 = gr.Checkbox(label="Upload generated image to R2", value=False) account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id") access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here") secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here") bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here") run_button.click( fn=process, inputs=[ input_image, image_url, prompt, negative_prompt, seed, upscale_factor, controlnet_scale, controlnet_decay, condition_scale, tile_width, tile_height, denoise_strength, num_inference_steps, solver, upload_to_r2, account_id, access_key, secret_key, bucket ], outputs=[output_slider, logs] ) demo.launch(share=False)