dadai / app.py
Sutirtha's picture
Update app.py
9608f17 verified
import gradio as gr
from diffusers import AutoencoderKL, LCMScheduler
from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from controlnet import ControlNetModel
import torch
import numpy as np
from PIL import Image
from io import BytesIO
from torchvision import transforms
import requests
# Utility functions
def resize_image_to_retain_ratio(image):
pixel_number = 1024 * 1024
granularity_val = 8
ratio = image.width / image.height
width = int((pixel_number * ratio) ** 0.5)
width -= width % granularity_val
height = int(pixel_number / width)
height -= height % granularity_val
return image.resize((width, height))
def get_masked_image(image, mask):
image = np.array(image).astype(np.float32) / 255.0
mask = np.array(mask.convert("L")).astype(np.float32) / 255.0
masked_vis = image.copy()
image[mask > 0.5] = 0.5
masked_vis[mask > 0.5] = 0.5
return (Image.fromarray((image * 255).astype(np.uint8)),
Image.fromarray((masked_vis * 255).astype(np.uint8)),
mask)
# Load model once
device = "cuda" if torch.cuda.is_available() else "cpu"
controlnet = ControlNetModel.from_pretrained(
"briaai/BRIA-2.3-ControlNet-Generative-Fill", torch_dtype=torch.float16
).to(device)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"briaai/BRIA-2.3",
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16
).to(device)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA")
pipe.fuse_lora()
# Image transforms
image_transforms = transforms.Compose([transforms.ToTensor()])
def inference(init_img, mask_img, prompt, neg_prompt,
steps, guidance, control_scale, seed):
# Resize and prepare
init_img = resize_image_to_retain_ratio(init_img)
masked_img, vis_img, mask_arr = get_masked_image(init_img, mask_img)
# Encode masked image
tensor = image_transforms(masked_img).unsqueeze(0).to(device)
latents = vae.encode(tensor.to(vae.dtype)).latent_dist.sample() * vae.config.scaling_factor
# Prepare mask tensor
mask_t = torch.tensor(mask_arr)[None, None, ...].to(device)
mask_resized = torch.nn.functional.interpolate(mask_t, size=(latents.shape[2], latents.shape[3]), mode='nearest')
# Control image
control = torch.cat([latents, mask_resized], dim=1)
generator = torch.Generator(device=device).manual_seed(int(seed))
output = pipe(
prompt=prompt,
negative_prompt=neg_prompt,
controlnet_conditioning_scale=control_scale,
num_inference_steps=steps,
guidance_scale=guidance,
image=control,
init_image=init_img,
mask_image=mask_t[:, 0],
generator=generator,
height=init_img.height,
width=init_img.width,
).images[0]
return output
# Build Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## BRIA-2.3 ControlNet Inpainting Demo")
with gr.Row():
inp = gr.Image(source="upload", type="pil", label="Input Image")
msk = gr.Image(source="upload", type="pil", label="Mask Image")
prompt = gr.Textbox(label="Prompt", placeholder="Describe the desired content")
neg = gr.Textbox(label="Negative Prompt", value="blurry")
steps = gr.Slider(1, 50, value=12, step=1, label="Inference Steps")
guidance = gr.Slider(0.0, 10.0, value=1.2, step=0.1, label="Guidance Scale")
scale = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="ControlNet Scale")
seed = gr.Number(label="Seed", value=123456)
btn = gr.Button("Generate")
out = gr.Image(type="pil", label="Output")
btn.click(
fn=inference,
inputs=[inp, msk, prompt, neg, steps, guidance, scale, seed],
outputs=out,
)
demo.launch(server_name="0.0.0.0", server_port=7860)