ihabooe commited on
Commit
595ca84
·
verified ·
1 Parent(s): 6629ac2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -1
app.py CHANGED
@@ -26,7 +26,7 @@ os.makedirs(OUTPUT_DIR, exist_ok=True)
26
 
27
  def process(image, progress=gr.Progress()):
28
  if image is None:
29
- return None, None
30
  try:
31
  progress(0, desc="Starting processing...")
32
  orig_image = Image.fromarray(image)
@@ -73,6 +73,188 @@ def process(image, progress=gr.Progress()):
73
  # Convert to numpy array for display
74
  output_array = np.array(new_im)
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  progress(1.0, desc="Done!")
77
  return output_array, filepath
78
 
 
26
 
27
  def process(image, progress=gr.Progress()):
28
  if image is None:
29
+ return None, None, None
30
  try:
31
  progress(0, desc="Starting processing...")
32
  orig_image = Image.fromarray(image)
 
73
  # Convert to numpy array for display
74
  output_array = np.array(new_im)
75
 
76
+ progress(1.0, desc="Done!")
77
+ return (
78
+ output_array,
79
+ filepath,
80
+ gr.update(value=f"""
81
+ <script>
82
+ setTimeout(function() {{
83
+ const link = document.createElement('a');
84
+ link.href = '/file={filepath}';
85
+ link.download = '{filename}';
86
+ document.body.appendChild(link);
87
+ link.click();
88
+ document.body.removeChild(link);
89
+ }}, 1000);
90
+ </script>
91
+ """, visible=True)
92
+ )
93
+
94
+ except Exception as e:
95
+ print(f"Error processing image: {str(e)}")
96
+ return None, None, None
97
+
98
+ css = """
99
+ @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;500;700&display=swap');
100
+
101
+ .container { max-width: 850px; margin: 0 auto; padding: 20px; }
102
+
103
+ .title-text {
104
+ color: #ff00de;
105
+ font-family: 'Orbitron', sans-serif;
106
+ font-size: 2.5em;
107
+ text-align: center;
108
+ margin: 20px 0;
109
+ text-shadow: 0 0 10px rgba(255, 0, 222, 0.7);
110
+ animation: glow 2s ease-in-out infinite alternate;
111
+ }
112
+
113
+ .subtitle-text {
114
+ color: #00ffff;
115
+ text-align: center;
116
+ margin-bottom: 30px;
117
+ font-size: 1.2em;
118
+ text-shadow: 0 0 8px rgba(0, 255, 255, 0.7);
119
+ }
120
+
121
+ .image-container {
122
+ background: rgba(10, 10, 30, 0.3);
123
+ border-radius: 15px;
124
+ padding: 20px;
125
+ margin: 10px 0;
126
+ border: 2px solid #00ffff;
127
+ box-shadow: 0 0 15px rgba(0, 255, 255, 0.2);
128
+ transition: all 0.3s ease;
129
+ }
130
+
131
+ .image-container img {
132
+ max-width: 100%;
133
+ height: auto;
134
+ display: block;
135
+ margin: 0 auto;
136
+ }
137
+
138
+ .image-container:hover {
139
+ box-shadow: 0 0 20px rgba(0, 255, 255, 0.4);
140
+ transform: translateY(-2px);
141
+ }
142
+
143
+ .download-btn {
144
+ background: linear-gradient(45deg, #00ffff, #ff00de);
145
+ border: none;
146
+ padding: 12px 25px;
147
+ border-radius: 8px;
148
+ color: white;
149
+ font-family: 'Orbitron', sans-serif;
150
+ cursor: pointer;
151
+ transition: all 0.3s ease;
152
+ margin-top: 10px;
153
+ text-align: center;
154
+ text-transform: uppercase;
155
+ letter-spacing: 1px;
156
+ display: block;
157
+ width: 100%;
158
+ }
159
+
160
+ .download-btn:hover {
161
+ transform: translateY(-2px);
162
+ box-shadow: 0 5px 15px rgba(0, 255, 255, 0.4);
163
+ }
164
+
165
+ @keyframes glow {
166
+ from {
167
+ text-shadow: 0 0 5px #ff00de, 0 0 10px #ff00de, 0 0 15px #ff00de;
168
+ }
169
+ to {
170
+ text-shadow: 0 0 10px #ff00de, 0 0 20px #ff00de, 0 0 30px #ff00de;
171
+ }
172
+ }
173
+
174
+ @media (max-width: 768px) {
175
+ .title-text { font-size: 1.8em; }
176
+ .subtitle-text { font-size: 1em; }
177
+ .image-container { padding: 10px; }
178
+ .download-btn { padding: 10px 20px; }
179
+ }
180
+ """
181
+
182
+ with gr.Blocks(css=css) as demo:
183
+ gr.Markdown("""
184
+ <h1 class="title-text">AI Background Removal</h1>
185
+ <p class="subtitle-text">Remove backgrounds instantly using advanced AI technology</p>
186
+ """)
187
+
188
+ with gr.Row():
189
+ with gr.Column():
190
+ input_image = gr.Image(
191
+ label="Upload Image",
192
+ type="numpy",
193
+ elem_classes="image-container"
194
+ )
195
+
196
+ output_image = gr.Image(
197
+ label="Result",
198
+ type="numpy",
199
+ show_label=True,
200
+ elem_classes="image-container"
201
+ )
202
+
203
+ download_button = gr.File(
204
+ label="Download Result",
205
+ visible=True,
206
+ elem_classes="download-btn"
207
+ )
208
+
209
+ # Add HTML component for auto-download
210
+ auto_download = gr.HTML(visible=False)
211
+
212
+ input_image.change(
213
+ fn=process,
214
+ inputs=input_image,
215
+ outputs=[output_image, download_button, auto_download]
216
+ )
217
+
218
+ if __name__ == "__main__":
219
+ demo.launch() w, h = process_image.size
220
+
221
+ im_np = np.array(process_image)
222
+ im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
223
+ im_tensor = torch.unsqueeze(im_tensor, 0)
224
+ im_tensor = torch.divide(im_tensor, 255.0)
225
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
226
+
227
+ progress(0.4, desc="Processing with AI model...")
228
+ if torch.cuda.is_available():
229
+ im_tensor = im_tensor.cuda()
230
+
231
+ with torch.no_grad():
232
+ result = net(im_tensor)
233
+
234
+ progress(0.6, desc="Post-processing...")
235
+ result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
236
+ ma = torch.max(result)
237
+ mi = torch.min(result)
238
+ result = (result - mi) / (ma - mi)
239
+
240
+ result_array = (result * 255).cpu().data.numpy().astype(np.uint8)
241
+ pil_mask = Image.fromarray(np.squeeze(result_array))
242
+
243
+ if pil_mask.size != original_size:
244
+ pil_mask = pil_mask.resize(original_size, Image.LANCZOS)
245
+
246
+ new_im = orig_image.copy()
247
+ new_im.putalpha(pil_mask)
248
+
249
+ progress(0.8, desc="Saving result...")
250
+ unique_id = str(uuid.uuid4())[:8]
251
+ filename = f"background_removed_{unique_id}.png"
252
+ filepath = os.path.join(OUTPUT_DIR, filename)
253
+ new_im.save(filepath, format='PNG', quality=100)
254
+
255
+ # Convert to numpy array for display
256
+ output_array = np.array(new_im)
257
+
258
  progress(1.0, desc="Done!")
259
  return output_array, filepath
260