shemayons commited on
Commit
7d6af6f
·
verified ·
1 Parent(s): 206d6ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -38
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from load_image import load_img
3
  import spaces
@@ -10,17 +11,25 @@ import numpy as np
10
 
11
  torch.set_float32_matmul_precision(["high", "highest"][0])
12
 
13
- # load model
 
14
  birefnet = AutoModelForImageSegmentation.from_pretrained(
15
  "ZhengPeng7/BiRefNet", trust_remote_code=True
16
  )
17
 
18
- # Keep model in a dict for easy switching
 
 
 
 
 
19
  models_dict = {
20
- "BiRefNet": birefnet
 
21
  }
22
 
23
  # Transform
 
24
  transform_image = transforms.Compose(
25
  [
26
  transforms.Resize((1024, 1024)),
@@ -35,6 +44,7 @@ def process(image: Image.Image, model_choice: str):
35
  Runs inference to remove the background (adds alpha)
36
  with the chosen segmentation model.
37
  """
 
38
  current_model = models_dict[model_choice]
39
 
40
  # Prepare image
@@ -43,6 +53,8 @@ def process(image: Image.Image, model_choice: str):
43
 
44
  # Inference
45
  with torch.no_grad():
 
 
46
  preds = current_model(input_images)[-1].sigmoid().cpu()
47
 
48
  # Convert single-channel pred to a PIL mask
@@ -56,7 +68,6 @@ def process(image: Image.Image, model_choice: str):
56
  image.putalpha(mask)
57
  return image
58
 
59
-
60
  def fn(source: str, model_choice: str):
61
  """
62
  Used by Tab 1 & Tab 2 to produce a processed image with alpha.
@@ -64,6 +75,7 @@ def fn(source: str, model_choice: str):
64
  a URL string (textbox).
65
  - 'model_choice' is the user's selection from the radio.
66
  """
 
67
  im = load_img(source, output_type="pil")
68
  im = im.convert("RGB")
69
 
@@ -71,7 +83,6 @@ def fn(source: str, model_choice: str):
71
  processed_image = process(im, model_choice)
72
  return processed_image
73
 
74
-
75
  def process_file(file_path: str, model_choice: str):
76
  """
77
  For Tab 3 (file output).
@@ -82,16 +93,29 @@ def process_file(file_path: str, model_choice: str):
82
  im = load_img(file_path, output_type="pil")
83
  im = im.convert("RGB")
84
 
 
85
  transparent = process(im, model_choice)
86
  transparent.save(name_path)
87
  return name_path
88
 
89
 
90
- # Gradio UI
91
- model_selector = gr.Radio(
92
- choices=["BiRefNet"],
93
- value="BiRefNet",
94
- label="Select Model")
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  # Outputs for tabs 1 & 2: single processed image
97
  processed_img_upload = gr.Image(label="Processed Image (Upload)", type="pil")
@@ -110,7 +134,7 @@ output_file = gr.File(label="Output PNG File")
110
  # Tab 1: local image -> processed image
111
  tab1 = gr.Interface(
112
  fn=fn,
113
- inputs=[image_upload, model_selector],
114
  outputs=processed_img_upload,
115
  api_name="image",
116
  description="Upload an image and choose your background removal model."
@@ -119,7 +143,7 @@ tab1 = gr.Interface(
119
  # Tab 2: URL input -> processed image
120
  tab2 = gr.Interface(
121
  fn=fn,
122
- inputs=[url_input, model_selector],
123
  outputs=processed_img_url,
124
  api_name="text",
125
  description="Paste an image URL and choose your background removal model."
@@ -128,36 +152,20 @@ tab2 = gr.Interface(
128
  # Tab 3: file output -> returns path to .png
129
  tab3 = gr.Interface(
130
  fn=process_file,
131
- inputs=[image_file_upload, model_selector],
132
  outputs=output_file,
133
  api_name="png",
134
  description="Upload an image, choose a model, and get a transparent PNG."
135
  )
136
 
137
  # Combine all tabs
138
- with gr.Blocks() as demo:
139
- tabs = gr.Tabs()
140
- with tabs:
141
- with gr.TabItem("Image Upload"):
142
- inp1 = gr.Image(type="filepath", label="Upload an image")
143
- mdl1 = gr.Radio(["BiRefNet"], value="BiRefNet", label="Select Model")
144
- out1 = gr.Image(type="pil", label="Processed Image (Upload)")
145
- btn1 = gr.Button("Run")
146
- btn1.click(fn=fn, inputs=[inp1,mdl1], outputs=out1)
147
-
148
- with gr.TabItem("URL Input"):
149
- inp2 = gr.Textbox(label="Paste an image URL")
150
- mdl2 = gr.Radio(["BiRefNet"], value="BiRefNet", label="Select Model")
151
- out2 = gr.Image(type="pil", label="Processed Image (URL)")
152
- btn2 = gr.Button("Run")
153
- btn2.click(fn=fn, inputs=[inp2,mdl2], outputs=out2)
154
-
155
- with gr.TabItem("File Output"):
156
- inp3 = gr.Image(type="filepath", label="Upload an image")
157
- mdl3 = gr.Radio(["BiRefNet"], value="BiRefNet", label="Select Model")
158
- out3 = gr.File(label="Output PNG File")
159
- btn3 = gr.Button("Run")
160
- btn3.click(fn=process_file, inputs=[inp3,mdl3], outputs=out3)
161
-
162
- demo.launch(show_error=True, share=True)
163
 
 
1
+
2
  import gradio as gr
3
  from load_image import load_img
4
  import spaces
 
11
 
12
  torch.set_float32_matmul_precision(["high", "highest"][0])
13
 
14
+ # load 2 models
15
+
16
  birefnet = AutoModelForImageSegmentation.from_pretrained(
17
  "ZhengPeng7/BiRefNet", trust_remote_code=True
18
  )
19
 
20
+
21
+ RMBG2 = AutoModelForImageSegmentation.from_pretrained(
22
+ "briaai/RMBG-2.0", trust_remote_code=True
23
+ )
24
+
25
+ # Keep them in a dict to switch easily
26
  models_dict = {
27
+ "BiRefNet": birefnet,
28
+ "RMBG-2.0": RMBG2
29
  }
30
 
31
  # Transform
32
+
33
  transform_image = transforms.Compose(
34
  [
35
  transforms.Resize((1024, 1024)),
 
44
  Runs inference to remove the background (adds alpha)
45
  with the chosen segmentation model.
46
  """
47
+ # Select the model
48
  current_model = models_dict[model_choice]
49
 
50
  # Prepare image
 
53
 
54
  # Inference
55
  with torch.no_grad():
56
+ # Each model returns a list of preds in its forward,
57
+ # so we take the last element, apply sigmoid, and move to CPU
58
  preds = current_model(input_images)[-1].sigmoid().cpu()
59
 
60
  # Convert single-channel pred to a PIL mask
 
68
  image.putalpha(mask)
69
  return image
70
 
 
71
  def fn(source: str, model_choice: str):
72
  """
73
  Used by Tab 1 & Tab 2 to produce a processed image with alpha.
 
75
  a URL string (textbox).
76
  - 'model_choice' is the user's selection from the radio.
77
  """
78
+ # Load from local path or URL
79
  im = load_img(source, output_type="pil")
80
  im = im.convert("RGB")
81
 
 
83
  processed_image = process(im, model_choice)
84
  return processed_image
85
 
 
86
  def process_file(file_path: str, model_choice: str):
87
  """
88
  For Tab 3 (file output).
 
93
  im = load_img(file_path, output_type="pil")
94
  im = im.convert("RGB")
95
 
96
+ # Run the chosen model
97
  transparent = process(im, model_choice)
98
  transparent.save(name_path)
99
  return name_path
100
 
101
 
102
+ # GRadio UI
103
+
104
+ model_selector_1 = gr.Radio(
105
+ choices=["BiRefNet","RMBG-2.0"]
106
+ value="BiRefNet",
107
+ label="Select Model"
108
+ )
109
+ model_selector_2 = gr.Radio(
110
+ choices=["BiRefNet","RMBG-2.0"],
111
+ value="BiRefNet",
112
+ label="Select Model"
113
+ )
114
+ model_selector_3 = gr.Radio(
115
+ choices=["BiRefNet", "RMBG-2.0"],
116
+ value="BiRefNet",
117
+ label="Select Model"
118
+ )
119
 
120
  # Outputs for tabs 1 & 2: single processed image
121
  processed_img_upload = gr.Image(label="Processed Image (Upload)", type="pil")
 
134
  # Tab 1: local image -> processed image
135
  tab1 = gr.Interface(
136
  fn=fn,
137
+ inputs=[image_upload, model_selector_1],
138
  outputs=processed_img_upload,
139
  api_name="image",
140
  description="Upload an image and choose your background removal model."
 
143
  # Tab 2: URL input -> processed image
144
  tab2 = gr.Interface(
145
  fn=fn,
146
+ inputs=[url_input, model_selector_2],
147
  outputs=processed_img_url,
148
  api_name="text",
149
  description="Paste an image URL and choose your background removal model."
 
152
  # Tab 3: file output -> returns path to .png
153
  tab3 = gr.Interface(
154
  fn=process_file,
155
+ inputs=[image_file_upload, model_selector_3],
156
  outputs=output_file,
157
  api_name="png",
158
  description="Upload an image, choose a model, and get a transparent PNG."
159
  )
160
 
161
  # Combine all tabs
162
+ demo = gr.TabbedInterface(
163
+ [tab1, tab2, tab3],
164
+ ["Image Upload", "URL Input", "File Output"],
165
+ title="Background Removal Tool"
166
+ )
167
+
168
+ if __name__ == "__main__":
169
+ demo.launch(show_error=True, share=True)
170
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171