|
|
|
|
|
|
|
import sys |
|
import torchvision.transforms.functional as F |
|
sys.modules['torchvision.transforms.functional_tensor'] = F |
|
|
|
import os |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
from diffusers import StableDiffusionInpaintPipeline |
|
|
|
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
from realesrgan.utils import RealESRGANer |
|
|
|
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", |
|
torch_dtype=torch.float32, |
|
) |
|
pipe.to("cpu") |
|
|
|
|
|
device = torch.device("cpu") |
|
rrdb = RRDBNet( |
|
num_in_ch=3, num_out_ch=3, |
|
num_feat=64, num_block=23, |
|
num_grow_ch=32, scale=4 |
|
) |
|
|
|
esrgan = RealESRGANer( |
|
scale=4, |
|
model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", |
|
model=rrdb, |
|
tile=0, tile_pad=10, pre_pad=10, |
|
half=False, |
|
device=device, |
|
) |
|
|
|
def fill_and_upscale(input_img: Image.Image, |
|
mask_img: Image.Image, |
|
prompt: str): |
|
|
|
init = input_img.convert("RGB") |
|
mask = mask_img.convert("RGB") |
|
filled: Image.Image = pipe( |
|
prompt=prompt, image=init, mask_image=mask |
|
).images[0] |
|
|
|
|
|
arr = np.array(filled) |
|
bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
out_bgr, _ = esrgan.enhance(bgr, outscale=None) |
|
out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB) |
|
upscaled = Image.fromarray(out_rgb) |
|
|
|
return filled, upscaled |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Inpaint + 4× Upscale (CPU Only)") |
|
with gr.Row(): |
|
with gr.Column(): |
|
inp = gr.Image(type="pil", label="Input Image") |
|
msk = gr.Image(type="pil", label="Mask (white=fill)") |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="e.g. A serene waterfall at dawn" |
|
) |
|
btn = gr.Button("Run") |
|
with gr.Column(): |
|
out1 = gr.Image(type="pil", label="Inpainted") |
|
out2 = gr.Image(type="pil", label="Upscaled") |
|
|
|
btn.click(fill_and_upscale, [inp, msk, prompt], [out1, out2]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|