Update app.py
Browse files
app.py
CHANGED
@@ -42,7 +42,17 @@ try:
|
|
42 |
print("Successfully imported albumentations")
|
43 |
except ImportError:
|
44 |
albumentations_available = False
|
45 |
-
print("Warning: albumentations not available, will
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
# Set device
|
48 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -123,8 +133,10 @@ def read_tiff_image(tiff_path):
|
|
123 |
This matches your training data loading approach
|
124 |
"""
|
125 |
try:
|
|
|
126 |
# Read the image using rasterio (get RGB channels)
|
127 |
with rasterio.open(tiff_path) as src:
|
|
|
128 |
# Check if we have enough bands
|
129 |
if src.count >= 3:
|
130 |
red = src.read(1)
|
@@ -133,6 +145,7 @@ def read_tiff_image(tiff_path):
|
|
133 |
|
134 |
# Stack to create RGB image
|
135 |
image = np.dstack((red, green, blue)).astype(np.float32)
|
|
|
136 |
|
137 |
# Normalize to [0, 1]
|
138 |
if image.max() > 0:
|
@@ -159,6 +172,8 @@ def read_tiff_image(tiff_path):
|
|
159 |
return image
|
160 |
except Exception as e:
|
161 |
print(f"Error reading TIFF file: {e}")
|
|
|
|
|
162 |
return None
|
163 |
|
164 |
def read_tiff_mask(mask_path):
|
@@ -167,35 +182,59 @@ def read_tiff_mask(mask_path):
|
|
167 |
This matches your training data loading approach
|
168 |
"""
|
169 |
try:
|
|
|
170 |
# Read mask
|
171 |
with rasterio.open(mask_path) as src:
|
|
|
172 |
mask = src.read(1).astype(np.uint8)
|
|
|
173 |
return mask
|
174 |
except Exception as e:
|
175 |
print(f"Error reading mask file: {e}")
|
|
|
|
|
176 |
return None
|
177 |
|
178 |
def preprocess_image(image, target_size=(128, 128)):
|
179 |
"""
|
180 |
Preprocess an image for inference
|
181 |
"""
|
182 |
-
# If image is a
|
183 |
-
if isinstance(image,
|
184 |
-
|
185 |
-
if image
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
elif isinstance(image, Image.Image):
|
|
|
189 |
image = np.array(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
-
|
192 |
-
if len(image.shape) == 2: # Grayscale
|
193 |
-
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
194 |
-
elif image.shape[2] == 4: # RGBA
|
195 |
-
image = image[:, :, :3]
|
196 |
-
|
197 |
-
# Make a copy for display
|
198 |
-
display_image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.copy()
|
199 |
|
200 |
# Resize image to the target size
|
201 |
if albumentations_available:
|
@@ -216,141 +255,146 @@ def preprocess_image(image, target_size=(128, 128)):
|
|
216 |
|
217 |
return image_tensor, display_image
|
218 |
|
219 |
-
def
|
220 |
"""
|
221 |
-
|
222 |
"""
|
223 |
try:
|
224 |
-
|
225 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as temp_file:
|
226 |
temp_path = temp_file.name
|
227 |
-
# Write the content to the temp file
|
228 |
-
if hasattr(file_obj, 'read'):
|
229 |
-
# If it's a file-like object
|
230 |
-
temp_file.write(file_obj.read())
|
231 |
-
elif isinstance(file_obj, (str, bytes)):
|
232 |
-
# If it's a string path or bytes
|
233 |
-
if isinstance(file_obj, str):
|
234 |
-
with open(file_obj, 'rb') as f:
|
235 |
-
temp_file.write(f.read())
|
236 |
-
else:
|
237 |
-
temp_file.write(file_obj)
|
238 |
-
|
239 |
-
# Check if it's a TIFF file
|
240 |
-
if temp_path.lower().endswith(('.tif', '.tiff')):
|
241 |
-
# Read as TIFF
|
242 |
-
image = read_tiff_image(temp_path)
|
243 |
-
if image is None:
|
244 |
-
os.unlink(temp_path)
|
245 |
-
return None, None
|
246 |
-
|
247 |
-
# Make a copy for display
|
248 |
-
display_image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.copy()
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
else:
|
260 |
-
|
261 |
-
|
262 |
-
# Convert to tensor
|
263 |
-
image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
|
264 |
-
else:
|
265 |
-
# Try to open as a regular image
|
266 |
-
try:
|
267 |
-
pil_image = Image.open(temp_path)
|
268 |
-
image_tensor, display_image = preprocess_image(pil_image, target_size)
|
269 |
-
except Exception as e:
|
270 |
-
print(f"Error opening as regular image: {e}")
|
271 |
os.unlink(temp_path)
|
272 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
# Clean up
|
275 |
os.unlink(temp_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
return image_tensor, display_image
|
277 |
|
278 |
except Exception as e:
|
279 |
-
print(f"Error processing uploaded
|
280 |
import traceback
|
281 |
traceback.print_exc()
|
282 |
return None, None
|
283 |
|
284 |
-
def
|
285 |
"""
|
286 |
-
|
287 |
"""
|
288 |
try:
|
289 |
-
|
290 |
-
if isinstance(mask_input, str) or hasattr(mask_input, 'read'):
|
291 |
-
# Save to temp file if it's a file-like object
|
292 |
-
if hasattr(mask_input, 'read'):
|
293 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as temp_file:
|
294 |
-
temp_path = temp_file.name
|
295 |
-
temp_file.write(mask_input.read())
|
296 |
-
mask_path = temp_path
|
297 |
-
else:
|
298 |
-
mask_path = mask_input
|
299 |
-
|
300 |
-
# Check if it's a TIFF
|
301 |
-
if mask_path.lower().endswith(('.tif', '.tiff')):
|
302 |
-
mask = read_tiff_mask(mask_path)
|
303 |
-
if mask is None:
|
304 |
-
if hasattr(mask_input, 'read'):
|
305 |
-
os.unlink(temp_path)
|
306 |
-
return None
|
307 |
-
else:
|
308 |
-
# Try as regular image
|
309 |
-
try:
|
310 |
-
mask = np.array(Image.open(mask_path))
|
311 |
-
if len(mask.shape) == 3:
|
312 |
-
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
313 |
-
except Exception as e:
|
314 |
-
print(f"Error opening mask as image: {e}")
|
315 |
-
if hasattr(mask_input, 'read'):
|
316 |
-
os.unlink(temp_path)
|
317 |
-
return None
|
318 |
-
|
319 |
-
# Clean up if temp file
|
320 |
-
if hasattr(mask_input, 'read'):
|
321 |
-
os.unlink(temp_path)
|
322 |
|
323 |
-
#
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
328 |
|
329 |
-
#
|
|
|
|
|
330 |
else:
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
-
# Resize mask to
|
336 |
if albumentations_available:
|
337 |
aug = A.Compose([
|
338 |
-
A.PadIfNeeded(min_height=
|
339 |
border_mode=cv2.BORDER_CONSTANT, value=0),
|
340 |
-
A.CenterCrop(height=
|
341 |
])
|
342 |
augmented = aug(image=mask)
|
343 |
mask_resized = augmented['image']
|
344 |
else:
|
345 |
-
mask_resized = cv2.resize(mask,
|
346 |
|
347 |
# Binarize the mask (0: background, 1: wetland)
|
348 |
-
mask_binary = (mask_resized >
|
349 |
|
350 |
return mask_binary
|
351 |
|
352 |
except Exception as e:
|
353 |
-
print(f"Error
|
354 |
import traceback
|
355 |
traceback.print_exc()
|
356 |
return None
|
@@ -415,27 +459,36 @@ def calculate_metrics(pred_mask, gt_mask):
|
|
415 |
|
416 |
return metrics
|
417 |
|
418 |
-
def process_images(input_image=None, input_tiff=None,
|
419 |
"""
|
420 |
Process input images and generate predictions
|
421 |
"""
|
422 |
try:
|
|
|
|
|
|
|
|
|
|
|
423 |
# Check if we have input
|
424 |
if input_image is None and input_tiff is None:
|
425 |
return None, "Please upload an image or TIFF file."
|
426 |
|
427 |
-
# Process the input
|
428 |
-
if input_tiff is not None:
|
429 |
# Process uploaded TIFF file
|
430 |
-
image_tensor, display_image =
|
431 |
-
|
|
|
|
|
432 |
# Process regular image
|
433 |
image_tensor, display_image = preprocess_image(input_image)
|
434 |
-
|
435 |
-
|
436 |
-
|
|
|
437 |
|
438 |
# Get prediction
|
|
|
439 |
pred_mask = predict_segmentation(image_tensor)
|
440 |
if pred_mask is None:
|
441 |
return None, "Failed to generate prediction."
|
@@ -444,14 +497,20 @@ def process_images(input_image=None, input_tiff=None, gt_mask=None):
|
|
444 |
gt_mask_processed = None
|
445 |
metrics_text = ""
|
446 |
|
447 |
-
if
|
448 |
-
|
|
|
449 |
|
450 |
if gt_mask_processed is not None:
|
|
|
451 |
metrics = calculate_metrics(pred_mask, gt_mask_processed)
|
452 |
metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
|
|
|
|
|
|
453 |
|
454 |
# Create visualization
|
|
|
455 |
fig = plt.figure(figsize=(12, 6))
|
456 |
|
457 |
if gt_mask_processed is not None:
|
@@ -498,6 +557,7 @@ def process_images(input_image=None, input_tiff=None, gt_mask=None):
|
|
498 |
result_image = Image.open(buf)
|
499 |
plt.close(fig)
|
500 |
|
|
|
501 |
return result_image, result_text
|
502 |
|
503 |
except Exception as e:
|
@@ -521,7 +581,8 @@ with gr.Blocks(title="Wetlands Segmentation from Satellite Imagery") as demo:
|
|
521 |
with gr.Tab("Upload TIFF"):
|
522 |
input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
|
523 |
|
524 |
-
|
|
|
525 |
|
526 |
process_btn = gr.Button("Analyze Image", variant="primary")
|
527 |
|
@@ -546,6 +607,7 @@ with gr.Blocks(title="Wetlands Segmentation from Satellite Imagery") as demo:
|
|
546 |
- The model works best with RGB satellite imagery
|
547 |
- For optimal results, use images with similar characteristics to those used in training
|
548 |
- The model focuses on identifying wetland regions in natural landscapes
|
|
|
549 |
|
550 |
**Repository:** [dcrey7/wetlands_segmentation_deeplabsv3plus](https://huggingface.co/dcrey7/wetlands_segmentation_deeplabsv3plus)
|
551 |
""")
|
@@ -553,7 +615,7 @@ with gr.Blocks(title="Wetlands Segmentation from Satellite Imagery") as demo:
|
|
553 |
# Set up event handlers
|
554 |
process_btn.click(
|
555 |
fn=process_images,
|
556 |
-
inputs=[input_image, input_tiff,
|
557 |
outputs=[output_image, output_text]
|
558 |
)
|
559 |
|
|
|
42 |
print("Successfully imported albumentations")
|
43 |
except ImportError:
|
44 |
albumentations_available = False
|
45 |
+
print("Warning: albumentations not available, will try to install it")
|
46 |
+
import subprocess
|
47 |
+
try:
|
48 |
+
subprocess.check_call([
|
49 |
+
"pip", "install", "albumentations"
|
50 |
+
])
|
51 |
+
import albumentations as A
|
52 |
+
albumentations_available = True
|
53 |
+
print("Successfully installed and imported albumentations")
|
54 |
+
except:
|
55 |
+
print("Failed to install albumentations, will use OpenCV for transforms")
|
56 |
|
57 |
# Set device
|
58 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
133 |
This matches your training data loading approach
|
134 |
"""
|
135 |
try:
|
136 |
+
print(f"Reading TIFF image from: {tiff_path}")
|
137 |
# Read the image using rasterio (get RGB channels)
|
138 |
with rasterio.open(tiff_path) as src:
|
139 |
+
print(f"TIFF opened successfully. Number of bands: {src.count}")
|
140 |
# Check if we have enough bands
|
141 |
if src.count >= 3:
|
142 |
red = src.read(1)
|
|
|
145 |
|
146 |
# Stack to create RGB image
|
147 |
image = np.dstack((red, green, blue)).astype(np.float32)
|
148 |
+
print(f"RGB image created, shape: {image.shape}, min: {image.min()}, max: {image.max()}")
|
149 |
|
150 |
# Normalize to [0, 1]
|
151 |
if image.max() > 0:
|
|
|
172 |
return image
|
173 |
except Exception as e:
|
174 |
print(f"Error reading TIFF file: {e}")
|
175 |
+
import traceback
|
176 |
+
traceback.print_exc()
|
177 |
return None
|
178 |
|
179 |
def read_tiff_mask(mask_path):
|
|
|
182 |
This matches your training data loading approach
|
183 |
"""
|
184 |
try:
|
185 |
+
print(f"Reading TIFF mask from: {mask_path}")
|
186 |
# Read mask
|
187 |
with rasterio.open(mask_path) as src:
|
188 |
+
print(f"Mask TIFF opened successfully. Number of bands: {src.count}")
|
189 |
mask = src.read(1).astype(np.uint8)
|
190 |
+
print(f"Mask shape: {mask.shape}, min: {mask.min()}, max: {mask.max()}, unique values: {np.unique(mask)}")
|
191 |
return mask
|
192 |
except Exception as e:
|
193 |
print(f"Error reading mask file: {e}")
|
194 |
+
import traceback
|
195 |
+
traceback.print_exc()
|
196 |
return None
|
197 |
|
198 |
def preprocess_image(image, target_size=(128, 128)):
|
199 |
"""
|
200 |
Preprocess an image for inference
|
201 |
"""
|
202 |
+
# If image is already a numpy array, use it directly
|
203 |
+
if isinstance(image, np.ndarray):
|
204 |
+
# Ensure RGB format
|
205 |
+
if len(image.shape) == 2: # Grayscale
|
206 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
207 |
+
elif image.shape[2] == 4: # RGBA
|
208 |
+
image = image[:, :, :3]
|
209 |
+
|
210 |
+
# Make a copy for display
|
211 |
+
display_image = image.copy()
|
212 |
+
|
213 |
+
# Normalize to [0, 1] if needed
|
214 |
+
if display_image.max() > 1.0:
|
215 |
+
image = image.astype(np.float32) / 255.0
|
216 |
+
|
217 |
+
# Convert PIL image to numpy
|
218 |
elif isinstance(image, Image.Image):
|
219 |
+
print("Converting PIL image to numpy array")
|
220 |
image = np.array(image)
|
221 |
+
|
222 |
+
# Ensure RGB format
|
223 |
+
if len(image.shape) == 2: # Grayscale
|
224 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
225 |
+
elif image.shape[2] == 4: # RGBA
|
226 |
+
image = image[:, :, :3]
|
227 |
+
|
228 |
+
# Make a copy for display
|
229 |
+
display_image = image.copy()
|
230 |
+
|
231 |
+
# Normalize to [0, 1]
|
232 |
+
image = image.astype(np.float32) / 255.0
|
233 |
+
else:
|
234 |
+
print(f"Unsupported image type: {type(image)}")
|
235 |
+
return None, None
|
236 |
|
237 |
+
print(f"Image shape: {image.shape}, min: {image.min()}, max: {image.max()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
# Resize image to the target size
|
240 |
if albumentations_available:
|
|
|
255 |
|
256 |
return image_tensor, display_image
|
257 |
|
258 |
+
def save_temp_file(file_obj, suffix='.tif'):
|
259 |
"""
|
260 |
+
Save uploaded file to a temporary file
|
261 |
"""
|
262 |
try:
|
263 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
|
|
264 |
temp_path = temp_file.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
+
if hasattr(file_obj, 'name'):
|
267 |
+
# If it has a name attribute, it's likely a FileStorage object
|
268 |
+
file_obj.save(temp_path)
|
269 |
+
print(f"Saved file from FileStorage object to {temp_path}")
|
270 |
+
elif hasattr(file_obj, 'read'):
|
271 |
+
# If it's a file-like object
|
272 |
+
content = file_obj.read()
|
273 |
+
if isinstance(content, str): # It's text, not binary
|
274 |
+
content = content.encode('utf-8')
|
275 |
+
temp_file.write(content)
|
276 |
+
print(f"Wrote {len(content)} bytes to {temp_path}")
|
277 |
+
elif isinstance(file_obj, bytes):
|
278 |
+
# If it's already bytes
|
279 |
+
temp_file.write(file_obj)
|
280 |
+
print(f"Wrote {len(file_obj)} bytes to {temp_path}")
|
281 |
+
elif isinstance(file_obj, str):
|
282 |
+
# If it's a path
|
283 |
+
with open(file_obj, 'rb') as f:
|
284 |
+
temp_file.write(f.read())
|
285 |
+
print(f"Copied file from {file_obj} to {temp_path}")
|
286 |
else:
|
287 |
+
print(f"Unsupported file object type: {type(file_obj)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
os.unlink(temp_path)
|
289 |
+
return None
|
290 |
+
|
291 |
+
return temp_path
|
292 |
+
except Exception as e:
|
293 |
+
print(f"Error saving temporary file: {e}")
|
294 |
+
import traceback
|
295 |
+
traceback.print_exc()
|
296 |
+
return None
|
297 |
+
|
298 |
+
def process_uploaded_tiff(file_obj):
|
299 |
+
"""
|
300 |
+
Process an uploaded TIFF file
|
301 |
+
"""
|
302 |
+
try:
|
303 |
+
print(f"Processing uploaded TIFF file: {type(file_obj)}")
|
304 |
+
|
305 |
+
# Save to a temporary file
|
306 |
+
temp_path = save_temp_file(file_obj)
|
307 |
+
if not temp_path:
|
308 |
+
return None, None
|
309 |
+
|
310 |
+
# Read the TIFF file
|
311 |
+
image = read_tiff_image(temp_path)
|
312 |
|
313 |
# Clean up
|
314 |
os.unlink(temp_path)
|
315 |
+
|
316 |
+
if image is None:
|
317 |
+
return None, None
|
318 |
+
|
319 |
+
# Make a copy for display
|
320 |
+
display_image = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.copy()
|
321 |
+
|
322 |
+
# Resize/preprocess
|
323 |
+
if albumentations_available:
|
324 |
+
aug = A.Compose([
|
325 |
+
A.PadIfNeeded(min_height=128, min_width=128,
|
326 |
+
border_mode=cv2.BORDER_CONSTANT, value=0),
|
327 |
+
A.CenterCrop(height=128, width=128)
|
328 |
+
])
|
329 |
+
augmented = aug(image=image)
|
330 |
+
image_resized = augmented['image']
|
331 |
+
else:
|
332 |
+
image_resized = cv2.resize(image, (128, 128), interpolation=cv2.INTER_LINEAR)
|
333 |
+
|
334 |
+
# Convert to tensor
|
335 |
+
image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
|
336 |
+
|
337 |
return image_tensor, display_image
|
338 |
|
339 |
except Exception as e:
|
340 |
+
print(f"Error processing uploaded TIFF: {e}")
|
341 |
import traceback
|
342 |
traceback.print_exc()
|
343 |
return None, None
|
344 |
|
345 |
+
def process_uploaded_mask(file_obj):
|
346 |
"""
|
347 |
+
Process an uploaded mask file
|
348 |
"""
|
349 |
try:
|
350 |
+
print(f"Processing uploaded mask file: {type(file_obj)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
|
352 |
+
# Save to a temporary file
|
353 |
+
temp_path = save_temp_file(file_obj)
|
354 |
+
if not temp_path:
|
355 |
+
return None
|
|
|
356 |
|
357 |
+
# Check if it's a TIFF file
|
358 |
+
if temp_path.lower().endswith(('.tif', '.tiff')):
|
359 |
+
mask = read_tiff_mask(temp_path)
|
360 |
else:
|
361 |
+
# Try to open as a regular image
|
362 |
+
try:
|
363 |
+
mask_img = Image.open(temp_path)
|
364 |
+
mask = np.array(mask_img)
|
365 |
+
if len(mask.shape) == 3:
|
366 |
+
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
367 |
+
print(f"Opened mask as regular image, shape: {mask.shape}")
|
368 |
+
except Exception as e:
|
369 |
+
print(f"Error opening mask as regular image: {e}")
|
370 |
+
os.unlink(temp_path)
|
371 |
+
return None
|
372 |
+
|
373 |
+
# Clean up
|
374 |
+
os.unlink(temp_path)
|
375 |
+
|
376 |
+
if mask is None:
|
377 |
+
return None
|
378 |
|
379 |
+
# Resize mask to 128x128
|
380 |
if albumentations_available:
|
381 |
aug = A.Compose([
|
382 |
+
A.PadIfNeeded(min_height=128, min_width=128,
|
383 |
border_mode=cv2.BORDER_CONSTANT, value=0),
|
384 |
+
A.CenterCrop(height=128, width=128)
|
385 |
])
|
386 |
augmented = aug(image=mask)
|
387 |
mask_resized = augmented['image']
|
388 |
else:
|
389 |
+
mask_resized = cv2.resize(mask, (128, 128), interpolation=cv2.INTER_NEAREST)
|
390 |
|
391 |
# Binarize the mask (0: background, 1: wetland)
|
392 |
+
mask_binary = (mask_resized > 0).astype(np.uint8)
|
393 |
|
394 |
return mask_binary
|
395 |
|
396 |
except Exception as e:
|
397 |
+
print(f"Error processing uploaded mask: {e}")
|
398 |
import traceback
|
399 |
traceback.print_exc()
|
400 |
return None
|
|
|
459 |
|
460 |
return metrics
|
461 |
|
462 |
+
def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
463 |
"""
|
464 |
Process input images and generate predictions
|
465 |
"""
|
466 |
try:
|
467 |
+
print("\n---- Starting new processing request ----")
|
468 |
+
print(f"Input image type: {type(input_image) if input_image is not None else None}")
|
469 |
+
print(f"Input TIFF type: {type(input_tiff) if input_tiff is not None else None}")
|
470 |
+
print(f"Ground truth mask type: {type(gt_mask_file) if gt_mask_file is not None else None}")
|
471 |
+
|
472 |
# Check if we have input
|
473 |
if input_image is None and input_tiff is None:
|
474 |
return None, "Please upload an image or TIFF file."
|
475 |
|
476 |
+
# Process the input
|
477 |
+
if input_tiff is not None and input_tiff:
|
478 |
# Process uploaded TIFF file
|
479 |
+
image_tensor, display_image = process_uploaded_tiff(input_tiff)
|
480 |
+
if image_tensor is None:
|
481 |
+
return None, "Failed to process the input TIFF file."
|
482 |
+
elif input_image is not None:
|
483 |
# Process regular image
|
484 |
image_tensor, display_image = preprocess_image(input_image)
|
485 |
+
if image_tensor is None:
|
486 |
+
return None, "Failed to process the input image."
|
487 |
+
else:
|
488 |
+
return None, "No valid input provided."
|
489 |
|
490 |
# Get prediction
|
491 |
+
print("Running model prediction...")
|
492 |
pred_mask = predict_segmentation(image_tensor)
|
493 |
if pred_mask is None:
|
494 |
return None, "Failed to generate prediction."
|
|
|
497 |
gt_mask_processed = None
|
498 |
metrics_text = ""
|
499 |
|
500 |
+
if gt_mask_file is not None and gt_mask_file:
|
501 |
+
print("Processing ground truth mask...")
|
502 |
+
gt_mask_processed = process_uploaded_mask(gt_mask_file)
|
503 |
|
504 |
if gt_mask_processed is not None:
|
505 |
+
print("Calculating metrics...")
|
506 |
metrics = calculate_metrics(pred_mask, gt_mask_processed)
|
507 |
metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
508 |
+
print(f"Metrics calculated: {metrics}")
|
509 |
+
else:
|
510 |
+
print("Failed to process ground truth mask.")
|
511 |
|
512 |
# Create visualization
|
513 |
+
print("Creating visualization...")
|
514 |
fig = plt.figure(figsize=(12, 6))
|
515 |
|
516 |
if gt_mask_processed is not None:
|
|
|
557 |
result_image = Image.open(buf)
|
558 |
plt.close(fig)
|
559 |
|
560 |
+
print("Processing completed successfully.")
|
561 |
return result_image, result_text
|
562 |
|
563 |
except Exception as e:
|
|
|
581 |
with gr.Tab("Upload TIFF"):
|
582 |
input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
|
583 |
|
584 |
+
# IMPORTANT CHANGE: Changed ground truth from Image to File for TIFF support
|
585 |
+
gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
|
586 |
|
587 |
process_btn = gr.Button("Analyze Image", variant="primary")
|
588 |
|
|
|
607 |
- The model works best with RGB satellite imagery
|
608 |
- For optimal results, use images with similar characteristics to those used in training
|
609 |
- The model focuses on identifying wetland regions in natural landscapes
|
610 |
+
- For ground truth masks, both TIFF and standard image formats are supported
|
611 |
|
612 |
**Repository:** [dcrey7/wetlands_segmentation_deeplabsv3plus](https://huggingface.co/dcrey7/wetlands_segmentation_deeplabsv3plus)
|
613 |
""")
|
|
|
615 |
# Set up event handlers
|
616 |
process_btn.click(
|
617 |
fn=process_images,
|
618 |
+
inputs=[input_image, input_tiff, gt_mask_file],
|
619 |
outputs=[output_image, output_text]
|
620 |
)
|
621 |
|