Sutirtha commited on
Commit
f18b0fd
Β·
verified Β·
1 Parent(s): 8ded0c3

Updated Bria

Browse files
Files changed (1) hide show
  1. app.py +154 -113
app.py CHANGED
@@ -1,116 +1,157 @@
1
- import os
2
- import io
3
- import base64
4
- import requests
5
- import numpy as np
6
  import gradio as gr
 
 
 
7
  from PIL import Image
8
- import onnxruntime
9
- import cv2
10
-
11
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
12
- # Configuration
13
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
14
-
15
- HF_TOKEN = os.environ["HF_TOKEN_API_DEMO"]
16
- AUTH_HEADERS = {"api_token": HF_TOKEN}
17
- BRIA_API_URL = "http://engine.prod.bria-api.com/v1/gen_fill"
18
-
19
- # List your local ONNX upscaler model names (without .ort extension)
20
- UPSCALE_MODELS = ["modelx2", "modelx4"]
21
-
22
-
23
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
24
- # Helper Functions
25
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
26
-
27
- def pil_to_base64(img: Image.Image) -> str:
28
- """Convert a PIL image to a base64 string prefixed with a comma."""
29
- buf = io.BytesIO()
30
- img.save(buf, format="PNG")
31
- b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
32
- return f",{b64}"
33
-
34
- def download_pil_image(url: str) -> Image.Image:
35
- r = requests.get(url)
36
- return Image.open(io.BytesIO(r.content)).convert("RGB")
37
-
38
- def gen_fill(image: Image.Image, mask: Image.Image, prompt: str) -> Image.Image:
39
- """Call the BRIA Generative Fill API."""
40
- payload = {
41
- "file": pil_to_base64(image),
42
- "mask_file": pil_to_base64(mask),
43
- "prompt": prompt,
44
- "steps_num": 12,
45
- "sync": True,
46
- }
47
- res = requests.post(BRIA_API_URL, json=payload, headers=AUTH_HEADERS).json()
48
- return download_pil_image(res["urls"][0])
49
-
50
- def to_onnx_input(img: np.ndarray) -> np.ndarray:
51
- img = img[:, :, :3] # BGR or RGB first three channels
52
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # ensure RGB
53
- img = img.astype(np.float32) / 255.0
54
- img = np.transpose(img, (2, 0, 1))[None, ...]
55
- return img
56
-
57
- def from_onnx_output(arr: np.ndarray) -> np.ndarray:
58
- arr = np.squeeze(arr, axis=0)
59
- arr = np.clip(arr, 0, 1) * 255
60
- arr = np.transpose(arr, (1, 2, 0)).astype(np.uint8)
61
- return arr
62
-
63
- def upscale_image(img: Image.Image, model_name: str) -> Image.Image:
64
- """Run ONNX upscaler on a PIL image."""
65
- model_path = f"models/{model_name}.ort"
66
- sess = onnxruntime.InferenceSession(model_path, sess_options=onnxruntime.SessionOptions())
67
- inp = to_onnx_input(np.array(img)[:, :, ::-1]) # PIL is RGB, convert to BGR
68
- out = sess.run(None, {sess.get_inputs()[0].name: inp})[0]
69
- arr = from_onnx_output(out)
70
- # The ONNX model outputs BGR; convert back to RGB
71
- rgb = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
72
- return Image.fromarray(rgb)
73
-
74
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
75
- # Gradio Interface
76
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
77
-
78
- with gr.Blocks(css="""
79
- .gradio-container {max-width: 900px;}
80
- #run_button {width:100%; height:48px;}
81
- #image_editor img {object-fit: contain; width:100%; height:auto;}
82
- #output_col img {object-fit: contain; width:100%; height:auto;}
83
- """) as demo:
84
-
85
- gr.Markdown("## BRIA Generative Fill + ONNX Upscaler")
86
- gr.Markdown("1. Upload your image and draw a mask. 2. Enter a prompt. 3. Choose an upscaler and click **Run**.")
87
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  with gr.Row():
89
- with gr.Column(scale=1):
90
- editor = gr.ImageEditor(
91
- label="Input Image & Mask",
92
- tool="editor", brush=gr.Brush(color_mode="binary"),
93
- height=400
94
- )
95
- prompt = gr.Textbox(label="Prompt", placeholder="e.g. β€œAdd a sunset sky”")
96
- upscaler = gr.Radio(
97
- choices=UPSCALE_MODELS,
98
- label="Select Upscaler Model",
99
- value=UPSCALE_MODELS[0]
100
- )
101
- btn = gr.Button("Run", elem_id="run_button")
102
-
103
- with gr.Column(scale=1, elem_id="output_col"):
104
- output = gr.Image(label="High-Def Output", height=400)
105
-
106
- def run_pipeline(ed_img, txt, model_name):
107
- # ed_img is a RGBA numpy array: [:,:,0:3] = image, [:,:,3] = mask
108
- pil_in = Image.fromarray(ed_img[:, :, :3], "RGB")
109
- pil_mask = Image.fromarray(ed_img[:, :, 3], "L")
110
- filled = gen_fill(pil_in, pil_mask, txt)
111
- up_img = upscale_image(filled, model_name)
112
- return up_img
113
-
114
- btn.click(fn=run_pipeline, inputs=[editor, prompt, upscaler], outputs=[output])
115
-
116
- demo.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
  from PIL import Image
6
+ from io import BytesIO
7
+ import requests
8
+ from torchvision import transforms
9
+ from diffusers import AutoencoderKL, LCMScheduler
10
+ from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
11
+ from controlnet import ControlNetModel
12
+
13
+ # -- Utility Functions --
14
+ def resize_image_to_retain_ratio(image: Image.Image) -> Image.Image:
15
+ pixel_number = 1024 * 1024
16
+ granularity = 8
17
+ ratio = image.width / image.height
18
+ width = int((pixel_number * ratio) ** 0.5)
19
+ width -= width % granularity
20
+ height = int(pixel_number / width)
21
+ height -= height % granularity
22
+ return image.resize((width, height))
23
+
24
+
25
+ def prepare_mask(image: Image.Image, mask: Image.Image) -> Image.Image:
26
+ return mask.convert("L").resize(image.size)
27
+
28
+
29
+ def download_image(url: str) -> Image.Image:
30
+ resp = requests.get(url)
31
+ return Image.open(BytesIO(resp.content)).convert("RGB")
32
+
33
+ # -- Model & Pipeline Initialization --
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ # Load ControlNet model
38
+ controlnet = (
39
+ ControlNetModel.from_pretrained(
40
+ "briaai/BRIA-2.3-ControlNet-Generative-Fill", torch_dtype=torch.float16
41
+ )
42
+ .to(device)
43
+ )
44
+
45
+ # Load VAE\ nvae = (
46
+ AutoencoderKL.from_pretrained(
47
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
48
+ )
49
+ .to(device)
50
+ )
51
+
52
+ # Load Stable Diffusion XL with ControlNet
53
+ pipe = (
54
+ StableDiffusionXLControlNetPipeline.from_pretrained(
55
+ "briaai/BRIA-2.3",
56
+ controlnet=controlnet,
57
+ torch_dtype=torch.float16,
58
+ vae=vae,
59
+ )
60
+ .to(device)
61
+ )
62
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
63
+ pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA")
64
+ pipe.fuse_lora()
65
+ pipe.enable_xformers_memory_efficient_attention()
66
+
67
+ # Tensor transform
68
+ to_tensor = transforms.ToTensor()
69
+
70
+ # -- Inference Function --
71
+ def generative_fill(
72
+ image: Image.Image,
73
+ mask: Image.Image,
74
+ prompt: str,
75
+ negative_prompt: str = "blurry",
76
+ num_inference_steps: int = 12,
77
+ controlnet_conditioning_scale: float = 1.0,
78
+ guidance_scale: float = 1.2,
79
+ seed: int = 123456,
80
+ ) -> Image.Image:
81
+ # Preprocess image & mask
82
+ image = image.convert("RGB")
83
+ image = resize_image_to_retain_ratio(image)
84
+ width, height = image.size
85
+ mask = prepare_mask(image, mask)
86
+
87
+ # Create masked version
88
+ img_arr = np.array(image).astype(np.float32) / 255.0
89
+ mask_arr = (np.array(mask).astype(np.float32) / 255.0)
90
+ masked_arr = img_arr.copy()
91
+ masked_arr[mask_arr > 0.5] = 0.5
92
+ masked_pil = Image.fromarray((masked_arr * 255).astype(np.uint8))
93
+
94
+ # Encode latents
95
+ input_tensor = to_tensor(masked_pil)
96
+ input_tensor = (input_tensor - 0.5) / 0.5
97
+ input_tensor = input_tensor.unsqueeze(0).to(device)
98
+ latents = pipe.vae.encode(input_tensor[:, :3]).latent_dist.sample() * pipe.vae.config.scaling_factor
99
+
100
+ # Prepare mask tensor
101
+ mask_tensor = torch.tensor(mask_arr[None, None], dtype=torch.float32, device=device)
102
+ mask_resized = F.interpolate(mask_tensor, size=(latents.shape[2], latents.shape[3]), mode="nearest")
103
+
104
+ # Combine latents & mask for ControlNet
105
+ control_latents = latents
106
+ control_image = torch.cat([control_latents, mask_resized], dim=1)
107
+
108
+ # Generate
109
+ generator = torch.Generator(device=device).manual_seed(seed)
110
+ output = pipe(
111
+ prompt=prompt,
112
+ negative_prompt=negative_prompt,
113
+ num_inference_steps=num_inference_steps,
114
+ height=height,
115
+ width=width,
116
+ image=control_image,
117
+ init_image=image,
118
+ mask_image=mask_tensor,
119
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
120
+ guidance_scale=guidance_scale,
121
+ generator=generator,
122
+ ).images[0]
123
+ return output
124
+
125
+ # -- Gradio Interface --
126
+ with gr.Blocks() as demo:
127
+ gr.Markdown("## BRIA 2.3 ControlNet Generative Fill")
128
  with gr.Row():
129
+ inp_image = gr.Image(type="pil", label="Input Image")
130
+ inp_mask = gr.Image(type="pil", label="Mask (white = fill area)")
131
+
132
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Describe what to fill...")
133
+ neg_prompt_input = gr.Textbox(label="Negative Prompt", value="blurry")
134
+ steps = gr.Slider(1, 50, value=12, step=1, label="Inference Steps")
135
+ c_scale = gr.Slider(0.0, 2.0, value=1.0, step=0.1, label="ControlNet Scale")
136
+ g_scale = gr.Slider(0.0, 20.0, value=1.2, step=0.1, label="Guidance Scale")
137
+ seed_input = gr.Number(value=123456, label="Seed")
138
+ run_btn = gr.Button("Generate")
139
+ output_image = gr.Image(type="pil", label="Generated Image")
140
+
141
+ run_btn.click(
142
+ generative_fill,
143
+ inputs=[
144
+ inp_image,
145
+ inp_mask,
146
+ prompt_input,
147
+ neg_prompt_input,
148
+ steps,
149
+ c_scale,
150
+ g_scale,
151
+ seed_input,
152
+ ],
153
+ outputs=output_image,
154
+ )
155
+ gr.Markdown("Model by BRIA AI | [Hugging Face](https://huggingface.co/briaai/BRIA-2.3-ControlNet-Generative-Fill)")
156
+
157
+ demo.launch(server_name="0.0.0.0", share=True)