Sutirtha commited on
Commit
9608f17
·
verified ·
1 Parent(s): c8f3a35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -71
app.py CHANGED
@@ -1,84 +1,106 @@
1
- # app.py
2
-
3
- # ── Monkey‐patch missing torchvision module ──
4
- import sys
5
- import torchvision.transforms.functional as F
6
- sys.modules['torchvision.transforms.functional_tensor'] = F
7
-
8
- import os
9
  import gradio as gr
 
 
 
10
  import torch
11
  import numpy as np
12
- import cv2
13
  from PIL import Image
14
- from diffusers import StableDiffusionInpaintPipeline
 
 
15
 
16
- # Import the RealESRGANer helper and architecture
17
- from basicsr.archs.rrdbnet_arch import RRDBNet # RRDB backbone :contentReference[oaicite:0]{index=0}
18
- from realesrgan.utils import RealESRGANer # RealESRGANer class :contentReference[oaicite:1]{index=1}
 
 
 
 
 
 
 
19
 
20
- # 1. Initialize Stable Diffusion InpaintPipeline on CPU
21
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
22
- "runwayml/stable-diffusion-inpainting",
23
- torch_dtype=torch.float32,
24
- )
25
- pipe.to("cpu")
 
 
 
26
 
27
- # 2. Build the RRDBNet model and RealESRGANer (4×) on CPU
28
- device = torch.device("cpu")
29
- rrdb = RRDBNet(
30
- num_in_ch=3, num_out_ch=3,
31
- num_feat=64, num_block=23,
32
- num_grow_ch=32, scale=4
 
33
  )
34
- # Pass a GitHub URL so it downloads under-the-hood
35
- esrgan = RealESRGANer(
36
- scale=4,
37
- model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
38
- model=rrdb,
39
- tile=0, tile_pad=10, pre_pad=10,
40
- half=False,
41
- device=device,
42
- )
43
-
44
- def fill_and_upscale(input_img: Image.Image,
45
- mask_img: Image.Image,
46
- prompt: str):
47
- # Inpaint
48
- init = input_img.convert("RGB")
49
- mask = mask_img.convert("RGB")
50
- filled: Image.Image = pipe(
51
- prompt=prompt, image=init, mask_image=mask
52
- ).images[0]
53
-
54
- # Prepare for Real-ESRGANer (expects BGR numpy)
55
- arr = np.array(filled)
56
- bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
57
 
58
- # Upscale
59
- out_bgr, _ = esrgan.enhance(bgr, outscale=None)
60
- out_rgb = cv2.cvtColor(out_bgr, cv2.COLOR_BGR2RGB)
61
- upscaled = Image.fromarray(out_rgb)
62
 
63
- return filled, upscaled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # 3. Gradio UI
66
  with gr.Blocks() as demo:
67
- gr.Markdown("## Inpaint + Upscale (CPU Only)")
68
  with gr.Row():
69
- with gr.Column():
70
- inp = gr.Image(type="pil", label="Input Image")
71
- msk = gr.Image(type="pil", label="Mask (white=fill)")
72
- prompt = gr.Textbox(
73
- label="Prompt",
74
- placeholder="e.g. A serene waterfall at dawn"
75
- )
76
- btn = gr.Button("Run")
77
- with gr.Column():
78
- out1 = gr.Image(type="pil", label="Inpainted")
79
- out2 = gr.Image(type="pil", label="Upscaled")
80
-
81
- btn.click(fill_and_upscale, [inp, msk, prompt], [out1, out2])
82
-
83
- if __name__ == "__main__":
84
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from diffusers import AutoencoderKL, LCMScheduler
3
+ from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
4
+ from controlnet import ControlNetModel
5
  import torch
6
  import numpy as np
 
7
  from PIL import Image
8
+ from io import BytesIO
9
+ from torchvision import transforms
10
+ import requests
11
 
12
+ # Utility functions
13
+ def resize_image_to_retain_ratio(image):
14
+ pixel_number = 1024 * 1024
15
+ granularity_val = 8
16
+ ratio = image.width / image.height
17
+ width = int((pixel_number * ratio) ** 0.5)
18
+ width -= width % granularity_val
19
+ height = int(pixel_number / width)
20
+ height -= height % granularity_val
21
+ return image.resize((width, height))
22
 
23
+ def get_masked_image(image, mask):
24
+ image = np.array(image).astype(np.float32) / 255.0
25
+ mask = np.array(mask.convert("L")).astype(np.float32) / 255.0
26
+ masked_vis = image.copy()
27
+ image[mask > 0.5] = 0.5
28
+ masked_vis[mask > 0.5] = 0.5
29
+ return (Image.fromarray((image * 255).astype(np.uint8)),
30
+ Image.fromarray((masked_vis * 255).astype(np.uint8)),
31
+ mask)
32
 
33
+ # Load model once
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ controlnet = ControlNetModel.from_pretrained(
36
+ "briaai/BRIA-2.3-ControlNet-Generative-Fill", torch_dtype=torch.float16
37
+ ).to(device)
38
+ vae = AutoencoderKL.from_pretrained(
39
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
40
  )
41
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
42
+ "briaai/BRIA-2.3",
43
+ controlnet=controlnet,
44
+ vae=vae,
45
+ torch_dtype=torch.float16
46
+ ).to(device)
47
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
48
+ pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA")
49
+ pipe.fuse_lora()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Image transforms
52
+ image_transforms = transforms.Compose([transforms.ToTensor()])
 
 
53
 
54
+ def inference(init_img, mask_img, prompt, neg_prompt,
55
+ steps, guidance, control_scale, seed):
56
+ # Resize and prepare
57
+ init_img = resize_image_to_retain_ratio(init_img)
58
+ masked_img, vis_img, mask_arr = get_masked_image(init_img, mask_img)
59
+
60
+ # Encode masked image
61
+ tensor = image_transforms(masked_img).unsqueeze(0).to(device)
62
+ latents = vae.encode(tensor.to(vae.dtype)).latent_dist.sample() * vae.config.scaling_factor
63
+
64
+ # Prepare mask tensor
65
+ mask_t = torch.tensor(mask_arr)[None, None, ...].to(device)
66
+ mask_resized = torch.nn.functional.interpolate(mask_t, size=(latents.shape[2], latents.shape[3]), mode='nearest')
67
+
68
+ # Control image
69
+ control = torch.cat([latents, mask_resized], dim=1)
70
+
71
+ generator = torch.Generator(device=device).manual_seed(int(seed))
72
+ output = pipe(
73
+ prompt=prompt,
74
+ negative_prompt=neg_prompt,
75
+ controlnet_conditioning_scale=control_scale,
76
+ num_inference_steps=steps,
77
+ guidance_scale=guidance,
78
+ image=control,
79
+ init_image=init_img,
80
+ mask_image=mask_t[:, 0],
81
+ generator=generator,
82
+ height=init_img.height,
83
+ width=init_img.width,
84
+ ).images[0]
85
+ return output
86
 
87
+ # Build Gradio interface
88
  with gr.Blocks() as demo:
89
+ gr.Markdown("## BRIA-2.3 ControlNet Inpainting Demo")
90
  with gr.Row():
91
+ inp = gr.Image(source="upload", type="pil", label="Input Image")
92
+ msk = gr.Image(source="upload", type="pil", label="Mask Image")
93
+ prompt = gr.Textbox(label="Prompt", placeholder="Describe the desired content")
94
+ neg = gr.Textbox(label="Negative Prompt", value="blurry")
95
+ steps = gr.Slider(1, 50, value=12, step=1, label="Inference Steps")
96
+ guidance = gr.Slider(0.0, 10.0, value=1.2, step=0.1, label="Guidance Scale")
97
+ scale = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="ControlNet Scale")
98
+ seed = gr.Number(label="Seed", value=123456)
99
+ btn = gr.Button("Generate")
100
+ out = gr.Image(type="pil", label="Output")
101
+ btn.click(
102
+ fn=inference,
103
+ inputs=[inp, msk, prompt, neg, steps, guidance, scale, seed],
104
+ outputs=out,
105
+ )
106
+ demo.launch(server_name="0.0.0.0", server_port=7860)