dadai / app.py
Sutirtha's picture
uppdated app.py
c63316c verified
raw
history blame
2.62 kB
# app.py
# ── Monkey‐patch missing torchvision module ──
import sys
import torchvision.transforms.functional as F
sys.modules['torchvision.transforms.functional_tensor'] = F
import os
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from diffusers import StableDiffusionInpaintPipeline
# Import the RealESRGANer helper and architecture
from basicsr.archs.rrdbnet_arch import RRDBNet # RRDB backbone :contentReference[oaicite:0]{index=0}
from realesrgan.utils import RealESRGANer # RealESRGANer class :contentReference[oaicite:1]{index=1}
# 1. Initialize Stable Diffusion InpaintPipeline on CPU
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
torch_dtype=torch.float32,
)
pipe.to("cpu")
# 2. Build the RRDBNet model and RealESRGANer (4×) on CPU
device = torch.device("cpu")
rrdb = RRDBNet(
num_in_ch=3, num_out_ch=3,
num_feat=64, num_block=23,
num_grow_ch=32, scale=4
)
# Pass a GitHub URL so it downloads under-the-hood
esrgan = RealESRGANer(
scale=4,
model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
model=rrdb,
tile=0, tile_pad=10, pre_pad=10,
half=False,
device=device,
)
def fill_and_upscale(input_img: Image.Image,
mask_img: Image.Image,
prompt: str):
# Inpaint
init = input_img.convert("RGB")
mask = mask_img.convert("RGB")
filled: Image.Image = pipe(
prompt=prompt, image=init, mask_image=mask
).images[0]
# Prepare for Real-ESRGANer (expects BGR numpy)
arr = np.array(filled)
bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
# Upscale
out_bgr, _ = esrgan.enhance(bgr, outscale=None)
out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB)
upscaled = Image.fromarray(out_rgb)
return filled, upscaled
# 3. Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Inpaint + 4× Upscale (CPU Only)")
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Input Image")
msk = gr.Image(type="pil", label="Mask (white=fill)")
prompt = gr.Textbox(
label="Prompt",
placeholder="e.g. A serene waterfall at dawn"
)
btn = gr.Button("Run")
with gr.Column():
out1 = gr.Image(type="pil", label="Inpainted")
out2 = gr.Image(type="pil", label="Upscaled")
btn.click(fill_and_upscale, [inp, msk, prompt], [out1, out2])
if __name__ == "__main__":
demo.launch()