Update app.py
Browse files
app.py
CHANGED
@@ -12,7 +12,6 @@ from io import BytesIO
|
|
12 |
import urllib.request
|
13 |
import tempfile
|
14 |
import rasterio
|
15 |
-
from rasterio.plot import reshape_as_image
|
16 |
import warnings
|
17 |
warnings.filterwarnings("ignore")
|
18 |
|
@@ -69,7 +68,7 @@ if smp_available:
|
|
69 |
)
|
70 |
else:
|
71 |
# Fallback to a simple model that won't actually work but allows the UI to load
|
72 |
-
print("Warning: Using a placeholder model that won't produce
|
73 |
from torch import nn
|
74 |
class PlaceholderModel(nn.Module):
|
75 |
def __init__(self):
|
@@ -133,10 +132,8 @@ def read_tiff_image(tiff_path):
|
|
133 |
This matches your training data loading approach
|
134 |
"""
|
135 |
try:
|
136 |
-
print(f"Reading TIFF image from: {tiff_path}")
|
137 |
# Read the image using rasterio (get RGB channels)
|
138 |
with rasterio.open(tiff_path) as src:
|
139 |
-
print(f"TIFF opened successfully. Number of bands: {src.count}")
|
140 |
# Check if we have enough bands
|
141 |
if src.count >= 3:
|
142 |
red = src.read(1)
|
@@ -145,7 +142,6 @@ def read_tiff_image(tiff_path):
|
|
145 |
|
146 |
# Stack to create RGB image
|
147 |
image = np.dstack((red, green, blue)).astype(np.float32)
|
148 |
-
print(f"RGB image created, shape: {image.shape}, min: {image.min()}, max: {image.max()}")
|
149 |
|
150 |
# Normalize to [0, 1]
|
151 |
if image.max() > 0:
|
@@ -154,7 +150,6 @@ def read_tiff_image(tiff_path):
|
|
154 |
return image
|
155 |
else:
|
156 |
# If less than 3 bands, handle accordingly
|
157 |
-
print(f"Warning: TIFF file has only {src.count} bands, RGB expected")
|
158 |
bands = [src.read(i+1) for i in range(src.count)]
|
159 |
# If only one band, duplicate to create RGB
|
160 |
if len(bands) == 1:
|
@@ -172,8 +167,6 @@ def read_tiff_image(tiff_path):
|
|
172 |
return image
|
173 |
except Exception as e:
|
174 |
print(f"Error reading TIFF file: {e}")
|
175 |
-
import traceback
|
176 |
-
traceback.print_exc()
|
177 |
return None
|
178 |
|
179 |
def read_tiff_mask(mask_path):
|
@@ -182,17 +175,12 @@ def read_tiff_mask(mask_path):
|
|
182 |
This matches your training data loading approach
|
183 |
"""
|
184 |
try:
|
185 |
-
print(f"Reading TIFF mask from: {mask_path}")
|
186 |
# Read mask
|
187 |
with rasterio.open(mask_path) as src:
|
188 |
-
print(f"Mask TIFF opened successfully. Number of bands: {src.count}")
|
189 |
mask = src.read(1).astype(np.uint8)
|
190 |
-
print(f"Mask shape: {mask.shape}, min: {mask.min()}, max: {mask.max()}, unique values: {np.unique(mask)}")
|
191 |
return mask
|
192 |
except Exception as e:
|
193 |
print(f"Error reading mask file: {e}")
|
194 |
-
import traceback
|
195 |
-
traceback.print_exc()
|
196 |
return None
|
197 |
|
198 |
def preprocess_image(image, target_size=(128, 128)):
|
@@ -216,7 +204,6 @@ def preprocess_image(image, target_size=(128, 128)):
|
|
216 |
|
217 |
# Convert PIL image to numpy
|
218 |
elif isinstance(image, Image.Image):
|
219 |
-
print("Converting PIL image to numpy array")
|
220 |
image = np.array(image)
|
221 |
|
222 |
# Ensure RGB format
|
@@ -234,8 +221,6 @@ def preprocess_image(image, target_size=(128, 128)):
|
|
234 |
print(f"Unsupported image type: {type(image)}")
|
235 |
return None, None
|
236 |
|
237 |
-
print(f"Image shape: {image.shape}, min: {image.min()}, max: {image.max()}")
|
238 |
-
|
239 |
# Resize image to the target size
|
240 |
if albumentations_available:
|
241 |
# Use albumentations to match training preprocessing
|
@@ -255,59 +240,54 @@ def preprocess_image(image, target_size=(128, 128)):
|
|
255 |
|
256 |
return image_tensor, display_image
|
257 |
|
258 |
-
def
|
259 |
-
"""
|
260 |
-
Save uploaded file to a temporary file
|
261 |
-
"""
|
262 |
try:
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
if
|
267 |
-
#
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
#
|
272 |
-
content
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
# If it's a path
|
283 |
with open(file_obj, 'rb') as f:
|
284 |
-
|
285 |
-
print(f"Copied file from {file_obj} to {temp_path}")
|
286 |
else:
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
return temp_path
|
292 |
except Exception as e:
|
293 |
-
print(f"Error
|
294 |
-
import traceback
|
295 |
-
traceback.print_exc()
|
296 |
return None
|
297 |
|
298 |
def process_uploaded_tiff(file_obj):
|
299 |
-
"""
|
300 |
-
Process an uploaded TIFF file
|
301 |
-
"""
|
302 |
try:
|
303 |
-
|
|
|
|
|
|
|
|
|
304 |
|
305 |
# Save to a temporary file
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
|
310 |
-
# Read
|
311 |
image = read_tiff_image(temp_path)
|
312 |
|
313 |
# Clean up
|
@@ -343,16 +323,24 @@ def process_uploaded_tiff(file_obj):
|
|
343 |
return None, None
|
344 |
|
345 |
def process_uploaded_mask(file_obj):
|
346 |
-
"""
|
347 |
-
Process an uploaded mask file
|
348 |
-
"""
|
349 |
try:
|
350 |
-
|
|
|
|
|
|
|
351 |
|
352 |
# Save to a temporary file
|
353 |
-
|
354 |
-
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
# Check if it's a TIFF file
|
358 |
if temp_path.lower().endswith(('.tif', '.tiff')):
|
@@ -364,7 +352,6 @@ def process_uploaded_mask(file_obj):
|
|
364 |
mask = np.array(mask_img)
|
365 |
if len(mask.shape) == 3:
|
366 |
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
367 |
-
print(f"Opened mask as regular image, shape: {mask.shape}")
|
368 |
except Exception as e:
|
369 |
print(f"Error opening mask as regular image: {e}")
|
370 |
os.unlink(temp_path)
|
@@ -420,8 +407,6 @@ def predict_segmentation(image_tensor):
|
|
420 |
return pred
|
421 |
except Exception as e:
|
422 |
print(f"Error during prediction: {e}")
|
423 |
-
import traceback
|
424 |
-
traceback.print_exc()
|
425 |
return None
|
426 |
|
427 |
def calculate_metrics(pred_mask, gt_mask):
|
@@ -464,11 +449,6 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
|
464 |
Process input images and generate predictions
|
465 |
"""
|
466 |
try:
|
467 |
-
print("\n---- Starting new processing request ----")
|
468 |
-
print(f"Input image type: {type(input_image) if input_image is not None else None}")
|
469 |
-
print(f"Input TIFF type: {type(input_tiff) if input_tiff is not None else None}")
|
470 |
-
print(f"Ground truth mask type: {type(gt_mask_file) if gt_mask_file is not None else None}")
|
471 |
-
|
472 |
# Check if we have input
|
473 |
if input_image is None and input_tiff is None:
|
474 |
return None, "Please upload an image or TIFF file."
|
@@ -488,7 +468,6 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
|
488 |
return None, "No valid input provided."
|
489 |
|
490 |
# Get prediction
|
491 |
-
print("Running model prediction...")
|
492 |
pred_mask = predict_segmentation(image_tensor)
|
493 |
if pred_mask is None:
|
494 |
return None, "Failed to generate prediction."
|
@@ -498,19 +477,13 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
|
498 |
metrics_text = ""
|
499 |
|
500 |
if gt_mask_file is not None and gt_mask_file:
|
501 |
-
print("Processing ground truth mask...")
|
502 |
gt_mask_processed = process_uploaded_mask(gt_mask_file)
|
503 |
|
504 |
if gt_mask_processed is not None:
|
505 |
-
print("Calculating metrics...")
|
506 |
metrics = calculate_metrics(pred_mask, gt_mask_processed)
|
507 |
metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
508 |
-
print(f"Metrics calculated: {metrics}")
|
509 |
-
else:
|
510 |
-
print("Failed to process ground truth mask.")
|
511 |
|
512 |
# Create visualization
|
513 |
-
print("Creating visualization...")
|
514 |
fig = plt.figure(figsize=(12, 6))
|
515 |
|
516 |
if gt_mask_processed is not None:
|
@@ -549,15 +522,14 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
|
|
549 |
if metrics_text:
|
550 |
result_text += f"\n\nEvaluation Metrics:\n{metrics_text}"
|
551 |
|
552 |
-
# Convert figure to image
|
553 |
-
buf = BytesIO()
|
554 |
plt.tight_layout()
|
555 |
-
|
|
|
556 |
buf.seek(0)
|
557 |
result_image = Image.open(buf)
|
558 |
plt.close(fig)
|
559 |
|
560 |
-
print("Processing completed successfully.")
|
561 |
return result_image, result_text
|
562 |
|
563 |
except Exception as e:
|
@@ -581,7 +553,7 @@ with gr.Blocks(title="Wetlands Segmentation from Satellite Imagery") as demo:
|
|
581 |
with gr.Tab("Upload TIFF"):
|
582 |
input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
|
583 |
|
584 |
-
#
|
585 |
gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
|
586 |
|
587 |
process_btn = gr.Button("Analyze Image", variant="primary")
|
|
|
12 |
import urllib.request
|
13 |
import tempfile
|
14 |
import rasterio
|
|
|
15 |
import warnings
|
16 |
warnings.filterwarnings("ignore")
|
17 |
|
|
|
68 |
)
|
69 |
else:
|
70 |
# Fallback to a simple model that won't actually work but allows the UI to load
|
71 |
+
print("Warning: Using a placeholder model that won't produce valid predictions.")
|
72 |
from torch import nn
|
73 |
class PlaceholderModel(nn.Module):
|
74 |
def __init__(self):
|
|
|
132 |
This matches your training data loading approach
|
133 |
"""
|
134 |
try:
|
|
|
135 |
# Read the image using rasterio (get RGB channels)
|
136 |
with rasterio.open(tiff_path) as src:
|
|
|
137 |
# Check if we have enough bands
|
138 |
if src.count >= 3:
|
139 |
red = src.read(1)
|
|
|
142 |
|
143 |
# Stack to create RGB image
|
144 |
image = np.dstack((red, green, blue)).astype(np.float32)
|
|
|
145 |
|
146 |
# Normalize to [0, 1]
|
147 |
if image.max() > 0:
|
|
|
150 |
return image
|
151 |
else:
|
152 |
# If less than 3 bands, handle accordingly
|
|
|
153 |
bands = [src.read(i+1) for i in range(src.count)]
|
154 |
# If only one band, duplicate to create RGB
|
155 |
if len(bands) == 1:
|
|
|
167 |
return image
|
168 |
except Exception as e:
|
169 |
print(f"Error reading TIFF file: {e}")
|
|
|
|
|
170 |
return None
|
171 |
|
172 |
def read_tiff_mask(mask_path):
|
|
|
175 |
This matches your training data loading approach
|
176 |
"""
|
177 |
try:
|
|
|
178 |
# Read mask
|
179 |
with rasterio.open(mask_path) as src:
|
|
|
180 |
mask = src.read(1).astype(np.uint8)
|
|
|
181 |
return mask
|
182 |
except Exception as e:
|
183 |
print(f"Error reading mask file: {e}")
|
|
|
|
|
184 |
return None
|
185 |
|
186 |
def preprocess_image(image, target_size=(128, 128)):
|
|
|
204 |
|
205 |
# Convert PIL image to numpy
|
206 |
elif isinstance(image, Image.Image):
|
|
|
207 |
image = np.array(image)
|
208 |
|
209 |
# Ensure RGB format
|
|
|
221 |
print(f"Unsupported image type: {type(image)}")
|
222 |
return None, None
|
223 |
|
|
|
|
|
224 |
# Resize image to the target size
|
225 |
if albumentations_available:
|
226 |
# Use albumentations to match training preprocessing
|
|
|
240 |
|
241 |
return image_tensor, display_image
|
242 |
|
243 |
+
def extract_file_content(file_obj):
|
244 |
+
"""Extract content from the file object, handling different types"""
|
|
|
|
|
245 |
try:
|
246 |
+
if hasattr(file_obj, 'name') and isinstance(file_obj, str):
|
247 |
+
# Handle Gradio's NamedString
|
248 |
+
content = file_obj
|
249 |
+
if os.path.exists(content):
|
250 |
+
# It's a path
|
251 |
+
with open(content, 'rb') as f:
|
252 |
+
return f.read()
|
253 |
+
else:
|
254 |
+
# It's content
|
255 |
+
return content.encode('latin1')
|
256 |
+
elif hasattr(file_obj, 'read'):
|
257 |
+
# File-like object
|
258 |
+
return file_obj.read()
|
259 |
+
elif isinstance(file_obj, bytes):
|
260 |
+
# Already bytes
|
261 |
+
return file_obj
|
262 |
+
elif isinstance(file_obj, str):
|
263 |
+
# String path
|
264 |
+
if os.path.exists(file_obj):
|
|
|
265 |
with open(file_obj, 'rb') as f:
|
266 |
+
return f.read()
|
|
|
267 |
else:
|
268 |
+
return file_obj.encode('utf-8')
|
269 |
+
else:
|
270 |
+
print(f"Unsupported file object type: {type(file_obj)}")
|
271 |
+
return None
|
|
|
272 |
except Exception as e:
|
273 |
+
print(f"Error extracting file content: {e}")
|
|
|
|
|
274 |
return None
|
275 |
|
276 |
def process_uploaded_tiff(file_obj):
|
277 |
+
"""Process an uploaded TIFF file"""
|
|
|
|
|
278 |
try:
|
279 |
+
# Get file content
|
280 |
+
file_content = extract_file_content(file_obj)
|
281 |
+
if file_content is None:
|
282 |
+
print("Failed to extract file content")
|
283 |
+
return None, None
|
284 |
|
285 |
# Save to a temporary file
|
286 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as temp_file:
|
287 |
+
temp_path = temp_file.name
|
288 |
+
temp_file.write(file_content)
|
289 |
|
290 |
+
# Read as TIFF
|
291 |
image = read_tiff_image(temp_path)
|
292 |
|
293 |
# Clean up
|
|
|
323 |
return None, None
|
324 |
|
325 |
def process_uploaded_mask(file_obj):
|
326 |
+
"""Process an uploaded mask file"""
|
|
|
|
|
327 |
try:
|
328 |
+
# Get file content
|
329 |
+
file_content = extract_file_content(file_obj)
|
330 |
+
if file_content is None:
|
331 |
+
return None
|
332 |
|
333 |
# Save to a temporary file
|
334 |
+
# Determine suffix based on file name if available
|
335 |
+
suffix = '.tif'
|
336 |
+
if hasattr(file_obj, 'name'):
|
337 |
+
file_name = getattr(file_obj, 'name')
|
338 |
+
if isinstance(file_name, str) and '.' in file_name:
|
339 |
+
suffix = '.' + file_name.split('.')[-1].lower()
|
340 |
+
|
341 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
342 |
+
temp_path = temp_file.name
|
343 |
+
temp_file.write(file_content)
|
344 |
|
345 |
# Check if it's a TIFF file
|
346 |
if temp_path.lower().endswith(('.tif', '.tiff')):
|
|
|
352 |
mask = np.array(mask_img)
|
353 |
if len(mask.shape) == 3:
|
354 |
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
|
|
355 |
except Exception as e:
|
356 |
print(f"Error opening mask as regular image: {e}")
|
357 |
os.unlink(temp_path)
|
|
|
407 |
return pred
|
408 |
except Exception as e:
|
409 |
print(f"Error during prediction: {e}")
|
|
|
|
|
410 |
return None
|
411 |
|
412 |
def calculate_metrics(pred_mask, gt_mask):
|
|
|
449 |
Process input images and generate predictions
|
450 |
"""
|
451 |
try:
|
|
|
|
|
|
|
|
|
|
|
452 |
# Check if we have input
|
453 |
if input_image is None and input_tiff is None:
|
454 |
return None, "Please upload an image or TIFF file."
|
|
|
468 |
return None, "No valid input provided."
|
469 |
|
470 |
# Get prediction
|
|
|
471 |
pred_mask = predict_segmentation(image_tensor)
|
472 |
if pred_mask is None:
|
473 |
return None, "Failed to generate prediction."
|
|
|
477 |
metrics_text = ""
|
478 |
|
479 |
if gt_mask_file is not None and gt_mask_file:
|
|
|
480 |
gt_mask_processed = process_uploaded_mask(gt_mask_file)
|
481 |
|
482 |
if gt_mask_processed is not None:
|
|
|
483 |
metrics = calculate_metrics(pred_mask, gt_mask_processed)
|
484 |
metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
|
|
|
|
|
|
|
485 |
|
486 |
# Create visualization
|
|
|
487 |
fig = plt.figure(figsize=(12, 6))
|
488 |
|
489 |
if gt_mask_processed is not None:
|
|
|
522 |
if metrics_text:
|
523 |
result_text += f"\n\nEvaluation Metrics:\n{metrics_text}"
|
524 |
|
525 |
+
# Convert figure to image for display
|
|
|
526 |
plt.tight_layout()
|
527 |
+
buf = BytesIO()
|
528 |
+
plt.savefig(buf, format='png')
|
529 |
buf.seek(0)
|
530 |
result_image = Image.open(buf)
|
531 |
plt.close(fig)
|
532 |
|
|
|
533 |
return result_image, result_text
|
534 |
|
535 |
except Exception as e:
|
|
|
553 |
with gr.Tab("Upload TIFF"):
|
554 |
input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
|
555 |
|
556 |
+
# Ground truth mask as file upload
|
557 |
gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
|
558 |
|
559 |
process_btn = gr.Button("Analyze Image", variant="primary")
|