ihabooe commited on
Commit
6629ac2
·
verified ·
1 Parent(s): 14ec6bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -12
app.py CHANGED
@@ -12,6 +12,7 @@ import time
12
  import uuid
13
  import shutil
14
 
 
15
  print("Loading model...")
16
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -19,22 +20,185 @@ net.to(device)
19
  net.eval()
20
  print(f"Model loaded on {device}")
21
 
 
22
  OUTPUT_DIR = "output_images"
23
  os.makedirs(OUTPUT_DIR, exist_ok=True)
24
 
25
- def resize_image(image, max_size=1024):
26
- width, height = image.size
27
- aspect_ratio = width / height
28
- if width > max_size or height > max_size:
29
- if width > height:
30
- new_width = max_size
31
- new_height = int(max_size / aspect_ratio)
32
- else:
33
- new_height = max_size
34
- new_width = int(max_size * aspect_ratio)
35
- image = resize_image.resize((new_width, new_height), Image.LANCZOS)
36
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def process(image, progress=gr.Progress()):
39
  if image is None:
40
  return None, None
 
12
  import uuid
13
  import shutil
14
 
15
+ # Load the pre-trained model
16
  print("Loading model...")
17
  net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
20
  net.eval()
21
  print(f"Model loaded on {device}")
22
 
23
+ # Create output directory if it doesn't exist
24
  OUTPUT_DIR = "output_images"
25
  os.makedirs(OUTPUT_DIR, exist_ok=True)
26
 
27
+ def process(image, progress=gr.Progress()):
28
+ if image is None:
29
+ return None, None
30
+ try:
31
+ progress(0, desc="Starting processing...")
32
+ orig_image = Image.fromarray(image)
33
+ original_size = orig_image.size
34
+
35
+ progress(0.2, desc="Preparing image...")
36
+ process_image = orig_image.resize(original_size, Image.LANCZOS)
37
+ w, h = process_image.size
38
+
39
+ im_np = np.array(process_image)
40
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
41
+ im_tensor = torch.unsqueeze(im_tensor, 0)
42
+ im_tensor = torch.divide(im_tensor, 255.0)
43
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
44
+
45
+ progress(0.4, desc="Processing with AI model...")
46
+ if torch.cuda.is_available():
47
+ im_tensor = im_tensor.cuda()
48
+
49
+ with torch.no_grad():
50
+ result = net(im_tensor)
51
+
52
+ progress(0.6, desc="Post-processing...")
53
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
54
+ ma = torch.max(result)
55
+ mi = torch.min(result)
56
+ result = (result - mi) / (ma - mi)
57
+
58
+ result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
59
+ pil_mask = Image.fromarray(np.squeeze(result_array))
60
+
61
+ if pil_mask.size != original_size:
62
+ pil_mask = pil_mask.resize(original_size, Image.LANCZOS)
63
+
64
+ new_im = orig_image.copy()
65
+ new_im.putalpha(pil_mask)
66
+
67
+ progress(0.8, desc="Saving result...")
68
+ unique_id = str(uuid.uuid4())[:8]
69
+ filename = f"background_removed_{unique_id}.png"
70
+ filepath = os.path.join(OUTPUT_DIR, filename)
71
+ new_im.save(filepath, format='PNG', quality=100)
72
+
73
+ # Convert to numpy array for display
74
+ output_array = np.array(new_im)
75
+
76
+ progress(1.0, desc="Done!")
77
+ return output_array, filepath
78
+
79
+ except Exception as e:
80
+ print(f"Error processing image: {str(e)}")
81
+ return None, None
82
+
83
+ css = """
84
+ @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
85
+
86
+ .container { max-width: 850px; margin: 0 auto; padding: 20px; }
87
 
88
+ .title-text {
89
+ color: #ff00de;
90
+ font-family: 'Orbitron', sans-serif;
91
+ font-size: 2.5em;
92
+ text-align: center;
93
+ margin: 20px 0;
94
+ text-shadow: 0 0 10px rgba(255, 0, 222, 0.7);
95
+ animation: glow 2s ease-in-out infinite alternate;
96
+ }
97
+
98
+ .subtitle-text {
99
+ color: #00ffff;
100
+ text-align: center;
101
+ margin-bottom: 30px;
102
+ font-size: 1.2em;
103
+ text-shadow: 0 0 8px rgba(0, 255, 255, 0.7);
104
+ }
105
+
106
+ .image-container {
107
+ background: rgba(10, 10, 30, 0.3);
108
+ border-radius: 15px;
109
+ padding: 20px;
110
+ margin: 10px 0;
111
+ border: 2px solid #00ffff;
112
+ box-shadow: 0 0 15px rgba(0, 255, 255, 0.2);
113
+ transition: all 0.3s ease;
114
+ }
115
+
116
+ .image-container img {
117
+ max-width: 100%;
118
+ height: auto;
119
+ display: block;
120
+ margin: 0 auto;
121
+ }
122
+
123
+ .image-container:hover {
124
+ box-shadow: 0 0 20px rgba(0, 255, 255, 0.4);
125
+ transform: translateY(-2px);
126
+ }
127
+
128
+ .download-btn {
129
+ background: linear-gradient(45deg, #00ffff, #ff00de);
130
+ border: none;
131
+ padding: 12px 25px;
132
+ border-radius: 8px;
133
+ color: white;
134
+ font-family: 'Orbitron', sans-serif;
135
+ cursor: pointer;
136
+ transition: all 0.3s ease;
137
+ margin-top: 10px;
138
+ text-align: center;
139
+ text-transform: uppercase;
140
+ letter-spacing: 1px;
141
+ display: block;
142
+ width: 100%;
143
+ }
144
+
145
+ .download-btn:hover {
146
+ transform: translateY(-2px);
147
+ box-shadow: 0 5px 15px rgba(0, 255, 255, 0.4);
148
+ }
149
+
150
+ @keyframes glow {
151
+ from {
152
+ text-shadow: 0 0 5px #ff00de, 0 0 10px #ff00de, 0 0 15px #ff00de;
153
+ }
154
+ to {
155
+ text-shadow: 0 0 10px #ff00de, 0 0 20px #ff00de, 0 0 30px #ff00de;
156
+ }
157
+ }
158
+
159
+ @media (max-width: 768px) {
160
+ .title-text { font-size: 1.8em; }
161
+ .subtitle-text { font-size: 1em; }
162
+ .image-container { padding: 10px; }
163
+ .download-btn { padding: 10px 20px; }
164
+ }
165
+ """
166
+
167
+ with gr.Blocks(css=css) as demo:
168
+ gr.Markdown("""
169
+ <h1 class="title-text">AI Background Removal</h1>
170
+ <p class="subtitle-text">Remove backgrounds instantly using advanced AI technology</p>
171
+ """)
172
+
173
+ with gr.Row():
174
+ with gr.Column():
175
+ input_image = gr.Image(
176
+ label="Upload Image",
177
+ type="numpy",
178
+ elem_classes="image-container"
179
+ )
180
+
181
+ output_image = gr.Image(
182
+ label="Result",
183
+ type="numpy",
184
+ show_label=True,
185
+ elem_classes="image-container"
186
+ )
187
+
188
+ download_button = gr.File(
189
+ label="Download Result",
190
+ visible=True,
191
+ elem_classes="download-btn"
192
+ )
193
+
194
+ input_image.change(
195
+ fn=process,
196
+ inputs=input_image,
197
+ outputs=[output_image, download_button]
198
+ )
199
+
200
+ if __name__ == "__main__":
201
+ demo.launch()
202
  def process(image, progress=gr.Progress()):
203
  if image is None:
204
  return None, None