File size: 4,899 Bytes
7d6af6f
2daf4de
 
 
 
 
 
 
 
4ad1442
f424be6
4ad1442
2daf4de
7d6af6f
 
2daf4de
 
f424be6
 
7d6af6f
d83fa6c
 
 
7d6af6f
 
4ad1442
7d6af6f
d83fa6c
4ad1442
 
 
7d6af6f
4ad1442
 
 
 
 
 
 
2daf4de
 
4ad1442
 
 
 
 
7d6af6f
4ad1442
 
 
 
 
 
 
 
7d6af6f
 
4ad1442
 
 
 
 
 
 
 
 
 
 
2daf4de
 
4ad1442
 
 
 
 
 
 
7d6af6f
4ad1442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d6af6f
4ad1442
 
 
 
 
7d6af6f
 
d83fa6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ad1442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d6af6f
4ad1442
 
 
 
 
 
 
 
7d6af6f
4ad1442
 
 
 
 
 
 
 
7d6af6f
4ad1442
 
 
 
 
 
7d6af6f
 
 
 
 
 
 
 
 
2daf4de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

import gradio as gr
from load_image import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import os
import numpy as np

torch.set_float32_matmul_precision(["high", "highest"][0])

# load 2 models

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)


# RMBG2 = AutoModelForImageSegmentation.from_pretrained(
#     "briaai/RMBG-2.0", trust_remote_code=True
# )

# Keep them in a dict to switch easily
models_dict = {
    "BiRefNet": birefnet,
    # "RMBG-2.0": RMBG2
}

# Transform

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

@spaces.GPU
def process(image: Image.Image, model_choice: str):
    """
    Runs inference to remove the background (adds alpha) 
    with the chosen segmentation model.
    """
    # Select the model
    current_model = models_dict[model_choice]

    # Prepare image
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0)

    # Inference
    with torch.no_grad():
        # Each model returns a list of preds in its forward, 
        # so we take the last element, apply sigmoid, and move to CPU
        preds = current_model(input_images)[-1].sigmoid().cpu()

    # Convert single-channel pred to a PIL mask
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)

    # Resize the mask back to original image size
    mask = pred_pil.resize(image_size)

    # Add alpha channel to the original
    image.putalpha(mask)
    return image

def fn(source: str, model_choice: str):
    """
    Used by Tab 1 & Tab 2 to produce a processed image with alpha.
    - 'source' is either a file path (type="filepath") or 
      a URL string (textbox).
    - 'model_choice' is the user's selection from the radio.
    """
    # Load from local path or URL
    im = load_img(source, output_type="pil")
    im = im.convert("RGB")

    # Process
    processed_image = process(im, model_choice)
    return processed_image

def process_file(file_path: str, model_choice: str):
    """
    For Tab 3 (file output).
    - Accepts a local path, returns path to a new .png with alpha channel.
    - 'model_choice' is also passed in for selecting the model.
    """
    name_path = file_path.rsplit(".", 1)[0] + ".png"
    im = load_img(file_path, output_type="pil")
    im = im.convert("RGB")

    # Run the chosen model
    transparent = process(im, model_choice)
    transparent.save(name_path)
    return name_path


# GRadio UI

# model_selector_1 = gr.Radio(
#     choices=["BiRefNet","RMBG-2.0"],
#     value="BiRefNet",
#     label="Select Model"
# )
# model_selector_2 = gr.Radio(
#     choices=["BiRefNet","RMBG-2.0"],
#     value="BiRefNet",
#     label="Select Model"
# )
# model_selector_3 = gr.Radio(
#     choices=["BiRefNet", "RMBG-2.0"],
#     value="BiRefNet",
#     label="Select Model"
# )

radio_opts = ["BiRefNet"]            # single choice everywhere

model_selector_1 = gr.Radio(radio_opts, value="BiRefNet", label="Select Model")
model_selector_2 = gr.Radio(radio_opts, value="BiRefNet", label="Select Model")
model_selector_3 = gr.Radio(radio_opts, value="BiRefNet", label="Select Model")

# Outputs for tabs 1 & 2: single processed image
processed_img_upload = gr.Image(label="Processed Image (Upload)", type="pil")
processed_img_url = gr.Image(label="Processed Image (URL)", type="pil")

# For uploading local files
image_upload = gr.Image(label="Upload an image", type="filepath")
image_file_upload = gr.Image(label="Upload an image", type="filepath")

# For Tab 2 (URL input)
url_input = gr.Textbox(label="Paste an image URL")

# For Tab 3 (file output)
output_file = gr.File(label="Output PNG File")

# Tab 1: local image -> processed image
tab1 = gr.Interface(
    fn=fn,
    inputs=[image_upload, model_selector_1],
    outputs=processed_img_upload,
    api_name="image",
    description="Upload an image and choose your background removal model."
)

# Tab 2: URL input -> processed image
tab2 = gr.Interface(
    fn=fn,
    inputs=[url_input, model_selector_2],
    outputs=processed_img_url,
    api_name="text",
    description="Paste an image URL and choose your background removal model."
)

# Tab 3: file output -> returns path to .png
tab3 = gr.Interface(
    fn=process_file,
    inputs=[image_file_upload, model_selector_3],
    outputs=output_file,
    api_name="png",
    description="Upload an image, choose a model, and get a transparent PNG."
)

# Combine all tabs
demo = gr.TabbedInterface(
    [tab1, tab2, tab3],
    ["Image Upload", "URL Input", "File Output"],
    title="Background Removal Tool"
)

if __name__ == "__main__":
    demo.launch(show_error=True, share=True)