jiuface's picture
Update src/app.py
8468be4 verified
raw
history blame contribute delete
No virus
10.3 kB
import spaces
import torch
from io import BytesIO
import PIL.Image
import pillow_heif
import numpy as np
from pathlib import Path
import random
import gradio as gr
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from PIL import Image
from refiners.fluxion.utils import manual_seed
from refiners.foundationals.latent_diffusion import Solver, solvers
import requests
from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
import time
import boto3
from datetime import datetime
import json
pillow_heif.register_heif_opener()
pillow_heif.register_avif_opener()
MAX_SEED = np.iinfo(np.int32).max
TITLE = """
Image Enhancer
"""
CHECKPOINTS = ESRGANUpscalerCheckpoints(
unet=Path(
hf_hub_download(
repo_id="refiners/juggernaut.reborn.sd1_5.unet",
filename="model.safetensors",
revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2",
)
),
clip_text_encoder=Path(
hf_hub_download(
repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder",
filename="model.safetensors",
revision="744ad6a5c0437ec02ad826df9f6ede102bb27481",
)
),
lda=Path(
hf_hub_download(
repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder",
filename="model.safetensors",
revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19",
)
),
controlnet_tile=Path(
hf_hub_download(
repo_id="refiners/controlnet.sd1_5.tile",
filename="model.safetensors",
revision="48ced6ff8bfa873a8976fa467c3629a240643387",
)
),
esrgan=Path(
hf_hub_download(
repo_id="philz1337x/upscaler",
filename="4x-UltraSharp.pth",
revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
)
),
negative_embedding=Path(
hf_hub_download(
repo_id="philz1337x/embeddings",
filename="JuggernautNegative-neg.pt",
revision="203caa7e9cc2bc225031a4021f6ab1ded283454a",
)
),
negative_embedding_key="string_to_param.*",
loras={
"more_details": Path(
hf_hub_download(
repo_id="philz1337x/loras",
filename="more_details.safetensors",
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
)
),
"sdxl_render": Path(
hf_hub_download(
repo_id="philz1337x/loras",
filename="SDXLrender_v2.0.safetensors",
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
)
),
},
)
LORA_SCALES = {
"more_details": 0.5,
"sdxl_render": 1.0,
}
# initialize the enhancer, on the cpu
DEVICE_CPU = torch.device("cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE_CPU, dtype=DTYPE)
# "move" the enhancer to the gpu, this is handled by Zero GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enhancer.to(device=DEVICE, dtype=DTYPE)
def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name)
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
s3 = boto3.client(
's3',
endpoint_url=connectionUrl,
region_name='auto',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key
)
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
buffer = BytesIO()
image.save(buffer, "PNG")
buffer.seek(0)
s3.upload_fileobj(buffer, bucket_name, image_file)
print("upload finish", image_file)
return image_file
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
@spaces.GPU(duration=120)
def process(
input_image: Image.Image,
image_url:str,
prompt: str = "masterpiece, best quality, highres",
negative_prompt: str = "worst quality, low quality, normal quality",
seed: int = 42,
upscale_factor: int = 2,
controlnet_scale: float = 0.6,
controlnet_decay: float = 1.0,
condition_scale: int = 6,
tile_width: int = 112,
tile_height: int = 144,
denoise_strength: float = 0.35,
num_inference_steps: int = 18,
solver: str = "DDIM",
upload_to_r2: bool = True,
account_id: str = "",
access_key: str = "",
secret_key: str = "",
bucket_name: str = ""
) -> tuple[tuple[Image.Image, Image.Image], str]:
manual_seed(seed)
if image_url:
# fetch image from url
with calculateDuration("Download Image"):
print("start to fetch image from url", image_url)
response = requests.get(image_url)
response.raise_for_status()
input_image = PIL.Image.open(BytesIO(response.content))
print("fetch image success")
print("start", prompt, upscale_factor)
solver_type: type[Solver] = getattr(solvers, solver)
with calculateDuration("enhancer"):
enhanced_image = enhancer.upscale(
image=input_image,
prompt=prompt,
negative_prompt=negative_prompt,
upscale_factor=upscale_factor,
controlnet_scale=controlnet_scale,
controlnet_scale_decay=controlnet_decay,
condition_scale=condition_scale,
tile_size=(tile_height, tile_width),
denoise_strength=denoise_strength,
num_inference_steps=num_inference_steps,
loras_scale=LORA_SCALES,
solver_type=solver_type,
)
print("enhancer finish")
if upload_to_r2:
url = upload_image_to_r2(enhanced_image, account_id, access_key, secret_key, bucket_name)
result = {"status": "success", "url": url}
else:
result = {"status": "success", "message": "Image generated but not uploaded"}
return [input_image, enhanced_image], json.dumps(result)
with gr.Blocks() as demo:
gr.HTML(TITLE)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)")
run_button = gr.ClearButton(components=None, value="Enhance Image")
with gr.Column():
output_slider = ImageSlider(label="Generate image", type="pil", slider_color="pink")
logs = gr.Textbox(label="logs")
run_button.add(output_slider)
with gr.Accordion("Advanced Options", open=False):
prompt = gr.Textbox(
label="Prompt",
placeholder="masterpiece, best quality, highres",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="worst quality, low quality, normal quality",
)
seed = gr.Slider(
minimum=0,
maximum=10_000,
value=42,
step=1,
label="Seed",
)
upscale_factor = gr.Slider(
minimum=1,
maximum=4,
value=2,
step=0.2,
label="Upscale Factor",
)
controlnet_scale = gr.Slider(
minimum=0,
maximum=1.5,
value=0.6,
step=0.1,
label="ControlNet Scale",
)
controlnet_decay = gr.Slider(
minimum=0.5,
maximum=1,
value=1.0,
step=0.025,
label="ControlNet Scale Decay",
)
condition_scale = gr.Slider(
minimum=2,
maximum=20,
value=6,
step=1,
label="Condition Scale",
)
tile_width = gr.Slider(
minimum=64,
maximum=200,
value=112,
step=1,
label="Latent Tile Width",
)
tile_height = gr.Slider(
minimum=64,
maximum=200,
value=144,
step=1,
label="Latent Tile Height",
)
denoise_strength = gr.Slider(
minimum=0,
maximum=1,
value=0.35,
step=0.1,
label="Denoise Strength",
)
num_inference_steps = gr.Slider(
minimum=1,
maximum=30,
value=18,
step=1,
label="Number of Inference Steps",
)
solver = gr.Radio(
choices=["DDIM", "DPMSolver"],
value="DDIM",
label="Solver",
)
upload_to_r2 = gr.Checkbox(label="Upload generated image to R2", value=False)
account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
run_button.click(
fn=process,
inputs=[
input_image,
image_url,
prompt,
negative_prompt,
seed,
upscale_factor,
controlnet_scale,
controlnet_decay,
condition_scale,
tile_width,
tile_height,
denoise_strength,
num_inference_steps,
solver,
upload_to_r2,
account_id,
access_key,
secret_key,
bucket
],
outputs=[output_slider, logs]
)
demo.launch(share=False)