import gradio as gr from diffusers import AutoencoderKL, LCMScheduler from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline from controlnet import ControlNetModel import torch import numpy as np from PIL import Image from io import BytesIO from torchvision import transforms import requests # Utility functions def resize_image_to_retain_ratio(image): pixel_number = 1024 * 1024 granularity_val = 8 ratio = image.width / image.height width = int((pixel_number * ratio) ** 0.5) width -= width % granularity_val height = int(pixel_number / width) height -= height % granularity_val return image.resize((width, height)) def get_masked_image(image, mask): image = np.array(image).astype(np.float32) / 255.0 mask = np.array(mask.convert("L")).astype(np.float32) / 255.0 masked_vis = image.copy() image[mask > 0.5] = 0.5 masked_vis[mask > 0.5] = 0.5 return (Image.fromarray((image * 255).astype(np.uint8)), Image.fromarray((masked_vis * 255).astype(np.uint8)), mask) # Load model once device = "cuda" if torch.cuda.is_available() else "cpu" controlnet = ControlNetModel.from_pretrained( "briaai/BRIA-2.3-ControlNet-Generative-Fill", torch_dtype=torch.float16 ).to(device) vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 ) pipe = StableDiffusionXLControlNetPipeline.from_pretrained( "briaai/BRIA-2.3", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 ).to(device) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA") pipe.fuse_lora() # Image transforms image_transforms = transforms.Compose([transforms.ToTensor()]) def inference(init_img, mask_img, prompt, neg_prompt, steps, guidance, control_scale, seed): # Resize and prepare init_img = resize_image_to_retain_ratio(init_img) masked_img, vis_img, mask_arr = get_masked_image(init_img, mask_img) # Encode masked image tensor = image_transforms(masked_img).unsqueeze(0).to(device) latents = vae.encode(tensor.to(vae.dtype)).latent_dist.sample() * vae.config.scaling_factor # Prepare mask tensor mask_t = torch.tensor(mask_arr)[None, None, ...].to(device) mask_resized = torch.nn.functional.interpolate(mask_t, size=(latents.shape[2], latents.shape[3]), mode='nearest') # Control image control = torch.cat([latents, mask_resized], dim=1) generator = torch.Generator(device=device).manual_seed(int(seed)) output = pipe( prompt=prompt, negative_prompt=neg_prompt, controlnet_conditioning_scale=control_scale, num_inference_steps=steps, guidance_scale=guidance, image=control, init_image=init_img, mask_image=mask_t[:, 0], generator=generator, height=init_img.height, width=init_img.width, ).images[0] return output # Build Gradio interface with gr.Blocks() as demo: gr.Markdown("## BRIA-2.3 ControlNet Inpainting Demo") with gr.Row(): inp = gr.Image(source="upload", type="pil", label="Input Image") msk = gr.Image(source="upload", type="pil", label="Mask Image") prompt = gr.Textbox(label="Prompt", placeholder="Describe the desired content") neg = gr.Textbox(label="Negative Prompt", value="blurry") steps = gr.Slider(1, 50, value=12, step=1, label="Inference Steps") guidance = gr.Slider(0.0, 10.0, value=1.2, step=0.1, label="Guidance Scale") scale = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="ControlNet Scale") seed = gr.Number(label="Seed", value=123456) btn = gr.Button("Generate") out = gr.Image(type="pil", label="Output") btn.click( fn=inference, inputs=[inp, msk, prompt, neg, steps, guidance, scale, seed], outputs=out, ) demo.launch(server_name="0.0.0.0", server_port=7860)