|
import gradio as gr |
|
from diffusers import AutoencoderKL, LCMScheduler |
|
from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline |
|
from controlnet import ControlNetModel |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from io import BytesIO |
|
from torchvision import transforms |
|
import requests |
|
|
|
|
|
def resize_image_to_retain_ratio(image): |
|
pixel_number = 1024 * 1024 |
|
granularity_val = 8 |
|
ratio = image.width / image.height |
|
width = int((pixel_number * ratio) ** 0.5) |
|
width -= width % granularity_val |
|
height = int(pixel_number / width) |
|
height -= height % granularity_val |
|
return image.resize((width, height)) |
|
|
|
def get_masked_image(image, mask): |
|
image = np.array(image).astype(np.float32) / 255.0 |
|
mask = np.array(mask.convert("L")).astype(np.float32) / 255.0 |
|
masked_vis = image.copy() |
|
image[mask > 0.5] = 0.5 |
|
masked_vis[mask > 0.5] = 0.5 |
|
return (Image.fromarray((image * 255).astype(np.uint8)), |
|
Image.fromarray((masked_vis * 255).astype(np.uint8)), |
|
mask) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
controlnet = ControlNetModel.from_pretrained( |
|
"briaai/BRIA-2.3-ControlNet-Generative-Fill", torch_dtype=torch.float16 |
|
).to(device) |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 |
|
) |
|
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
"briaai/BRIA-2.3", |
|
controlnet=controlnet, |
|
vae=vae, |
|
torch_dtype=torch.float16 |
|
).to(device) |
|
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) |
|
pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA") |
|
pipe.fuse_lora() |
|
|
|
|
|
image_transforms = transforms.Compose([transforms.ToTensor()]) |
|
|
|
def inference(init_img, mask_img, prompt, neg_prompt, |
|
steps, guidance, control_scale, seed): |
|
|
|
init_img = resize_image_to_retain_ratio(init_img) |
|
masked_img, vis_img, mask_arr = get_masked_image(init_img, mask_img) |
|
|
|
|
|
tensor = image_transforms(masked_img).unsqueeze(0).to(device) |
|
latents = vae.encode(tensor.to(vae.dtype)).latent_dist.sample() * vae.config.scaling_factor |
|
|
|
|
|
mask_t = torch.tensor(mask_arr)[None, None, ...].to(device) |
|
mask_resized = torch.nn.functional.interpolate(mask_t, size=(latents.shape[2], latents.shape[3]), mode='nearest') |
|
|
|
|
|
control = torch.cat([latents, mask_resized], dim=1) |
|
|
|
generator = torch.Generator(device=device).manual_seed(int(seed)) |
|
output = pipe( |
|
prompt=prompt, |
|
negative_prompt=neg_prompt, |
|
controlnet_conditioning_scale=control_scale, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance, |
|
image=control, |
|
init_image=init_img, |
|
mask_image=mask_t[:, 0], |
|
generator=generator, |
|
height=init_img.height, |
|
width=init_img.width, |
|
).images[0] |
|
return output |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## BRIA-2.3 ControlNet Inpainting Demo") |
|
with gr.Row(): |
|
inp = gr.Image(source="upload", type="pil", label="Input Image") |
|
msk = gr.Image(source="upload", type="pil", label="Mask Image") |
|
prompt = gr.Textbox(label="Prompt", placeholder="Describe the desired content") |
|
neg = gr.Textbox(label="Negative Prompt", value="blurry") |
|
steps = gr.Slider(1, 50, value=12, step=1, label="Inference Steps") |
|
guidance = gr.Slider(0.0, 10.0, value=1.2, step=0.1, label="Guidance Scale") |
|
scale = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="ControlNet Scale") |
|
seed = gr.Number(label="Seed", value=123456) |
|
btn = gr.Button("Generate") |
|
out = gr.Image(type="pil", label="Output") |
|
btn.click( |
|
fn=inference, |
|
inputs=[inp, msk, prompt, neg, steps, guidance, scale, seed], |
|
outputs=out, |
|
) |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|