Sutirtha commited on
Commit
5c781ca
Β·
verified Β·
1 Parent(s): 4cfee16

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import requests
5
+ import numpy as np
6
+ import gradio as gr
7
+ from PIL import Image
8
+ import onnxruntime
9
+ import cv2
10
+
11
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
12
+ # Configuration
13
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
14
+
15
+ HF_TOKEN = os.environ["HF_TOKEN_API_DEMO"]
16
+ AUTH_HEADERS = {"api_token": HF_TOKEN}
17
+ BRIA_API_URL = "http://engine.prod.bria-api.com/v1/gen_fill"
18
+
19
+ # List your local ONNX upscaler model names (without .ort extension)
20
+ UPSCALE_MODELS = ["modelx2", "modelx4"]
21
+
22
+
23
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
24
+ # Helper Functions
25
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
26
+
27
+ def pil_to_base64(img: Image.Image) -> str:
28
+ """Convert a PIL image to a base64 string prefixed with a comma."""
29
+ buf = io.BytesIO()
30
+ img.save(buf, format="PNG")
31
+ b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
32
+ return f",{b64}"
33
+
34
+ def download_pil_image(url: str) -> Image.Image:
35
+ r = requests.get(url)
36
+ return Image.open(io.BytesIO(r.content)).convert("RGB")
37
+
38
+ def gen_fill(image: Image.Image, mask: Image.Image, prompt: str) -> Image.Image:
39
+ """Call the BRIA Generative Fill API."""
40
+ payload = {
41
+ "file": pil_to_base64(image),
42
+ "mask_file": pil_to_base64(mask),
43
+ "prompt": prompt,
44
+ "steps_num": 12,
45
+ "sync": True,
46
+ }
47
+ res = requests.post(BRIA_API_URL, json=payload, headers=AUTH_HEADERS).json()
48
+ return download_pil_image(res["urls"][0])
49
+
50
+ def to_onnx_input(img: np.ndarray) -> np.ndarray:
51
+ img = img[:, :, :3] # BGR or RGB first three channels
52
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # ensure RGB
53
+ img = img.astype(np.float32) / 255.0
54
+ img = np.transpose(img, (2, 0, 1))[None, ...]
55
+ return img
56
+
57
+ def from_onnx_output(arr: np.ndarray) -> np.ndarray:
58
+ arr = np.squeeze(arr, axis=0)
59
+ arr = np.clip(arr, 0, 1) * 255
60
+ arr = np.transpose(arr, (1, 2, 0)).astype(np.uint8)
61
+ return arr
62
+
63
+ def upscale_image(img: Image.Image, model_name: str) -> Image.Image:
64
+ """Run ONNX upscaler on a PIL image."""
65
+ model_path = f"models/{model_name}.ort"
66
+ sess = onnxruntime.InferenceSession(model_path, sess_options=onnxruntime.SessionOptions())
67
+ inp = to_onnx_input(np.array(img)[:, :, ::-1]) # PIL is RGB, convert to BGR
68
+ out = sess.run(None, {sess.get_inputs()[0].name: inp})[0]
69
+ arr = from_onnx_output(out)
70
+ # The ONNX model outputs BGR; convert back to RGB
71
+ rgb = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
72
+ return Image.fromarray(rgb)
73
+
74
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
75
+ # Gradio Interface
76
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
77
+
78
+ with gr.Blocks(css="""
79
+ .gradio-container {max-width: 900px;}
80
+ #run_button {width:100%; height:48px;}
81
+ #image_editor img {object-fit: contain; width:100%; height:auto;}
82
+ #output_col img {object-fit: contain; width:100%; height:auto;}
83
+ """) as demo:
84
+
85
+ gr.Markdown("## BRIA Generative Fill + ONNX Upscaler")
86
+ gr.Markdown("1. Upload your image and draw a mask. 2. Enter a prompt. 3. Choose an upscaler and click **Run**.")
87
+
88
+ with gr.Row():
89
+ with gr.Column(scale=1):
90
+ editor = gr.ImageEditor(
91
+ label="Input Image & Mask",
92
+ tool="editor", brush=gr.Brush(color_mode="binary"),
93
+ height=400
94
+ )
95
+ prompt = gr.Textbox(label="Prompt", placeholder="e.g. β€œAdd a sunset sky”")
96
+ upscaler = gr.Radio(
97
+ choices=UPSCALE_MODELS,
98
+ label="Select Upscaler Model",
99
+ value=UPSCALE_MODELS[0]
100
+ )
101
+ btn = gr.Button("Run", elem_id="run_button")
102
+
103
+ with gr.Column(scale=1, elem_id="output_col"):
104
+ output = gr.Image(label="High-Def Output", height=400)
105
+
106
+ def run_pipeline(ed_img, txt, model_name):
107
+ # ed_img is a RGBA numpy array: [:,:,0:3] = image, [:,:,3] = mask
108
+ pil_in = Image.fromarray(ed_img[:, :, :3], "RGB")
109
+ pil_mask = Image.fromarray(ed_img[:, :, 3], "L")
110
+ filled = gen_fill(pil_in, pil_mask, txt)
111
+ up_img = upscale_image(filled, model_name)
112
+ return up_img
113
+
114
+ btn.click(fn=run_pipeline, inputs=[editor, prompt, upscaler], outputs=[output])
115
+
116
+ demo.launch()