dcrey7 commited on
Commit
476cf4b
·
verified ·
1 Parent(s): 66f27f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -85
app.py CHANGED
@@ -12,7 +12,6 @@ from io import BytesIO
12
  import urllib.request
13
  import tempfile
14
  import rasterio
15
- from rasterio.plot import reshape_as_image
16
  import warnings
17
  warnings.filterwarnings("ignore")
18
 
@@ -69,7 +68,7 @@ if smp_available:
69
  )
70
  else:
71
  # Fallback to a simple model that won't actually work but allows the UI to load
72
- print("Warning: Using a placeholder model that won't produce correct predictions.")
73
  from torch import nn
74
  class PlaceholderModel(nn.Module):
75
  def __init__(self):
@@ -133,10 +132,8 @@ def read_tiff_image(tiff_path):
133
  This matches your training data loading approach
134
  """
135
  try:
136
- print(f"Reading TIFF image from: {tiff_path}")
137
  # Read the image using rasterio (get RGB channels)
138
  with rasterio.open(tiff_path) as src:
139
- print(f"TIFF opened successfully. Number of bands: {src.count}")
140
  # Check if we have enough bands
141
  if src.count >= 3:
142
  red = src.read(1)
@@ -145,7 +142,6 @@ def read_tiff_image(tiff_path):
145
 
146
  # Stack to create RGB image
147
  image = np.dstack((red, green, blue)).astype(np.float32)
148
- print(f"RGB image created, shape: {image.shape}, min: {image.min()}, max: {image.max()}")
149
 
150
  # Normalize to [0, 1]
151
  if image.max() > 0:
@@ -154,7 +150,6 @@ def read_tiff_image(tiff_path):
154
  return image
155
  else:
156
  # If less than 3 bands, handle accordingly
157
- print(f"Warning: TIFF file has only {src.count} bands, RGB expected")
158
  bands = [src.read(i+1) for i in range(src.count)]
159
  # If only one band, duplicate to create RGB
160
  if len(bands) == 1:
@@ -172,8 +167,6 @@ def read_tiff_image(tiff_path):
172
  return image
173
  except Exception as e:
174
  print(f"Error reading TIFF file: {e}")
175
- import traceback
176
- traceback.print_exc()
177
  return None
178
 
179
  def read_tiff_mask(mask_path):
@@ -182,17 +175,12 @@ def read_tiff_mask(mask_path):
182
  This matches your training data loading approach
183
  """
184
  try:
185
- print(f"Reading TIFF mask from: {mask_path}")
186
  # Read mask
187
  with rasterio.open(mask_path) as src:
188
- print(f"Mask TIFF opened successfully. Number of bands: {src.count}")
189
  mask = src.read(1).astype(np.uint8)
190
- print(f"Mask shape: {mask.shape}, min: {mask.min()}, max: {mask.max()}, unique values: {np.unique(mask)}")
191
  return mask
192
  except Exception as e:
193
  print(f"Error reading mask file: {e}")
194
- import traceback
195
- traceback.print_exc()
196
  return None
197
 
198
  def preprocess_image(image, target_size=(128, 128)):
@@ -216,7 +204,6 @@ def preprocess_image(image, target_size=(128, 128)):
216
 
217
  # Convert PIL image to numpy
218
  elif isinstance(image, Image.Image):
219
- print("Converting PIL image to numpy array")
220
  image = np.array(image)
221
 
222
  # Ensure RGB format
@@ -234,8 +221,6 @@ def preprocess_image(image, target_size=(128, 128)):
234
  print(f"Unsupported image type: {type(image)}")
235
  return None, None
236
 
237
- print(f"Image shape: {image.shape}, min: {image.min()}, max: {image.max()}")
238
-
239
  # Resize image to the target size
240
  if albumentations_available:
241
  # Use albumentations to match training preprocessing
@@ -255,59 +240,54 @@ def preprocess_image(image, target_size=(128, 128)):
255
 
256
  return image_tensor, display_image
257
 
258
- def save_temp_file(file_obj, suffix='.tif'):
259
- """
260
- Save uploaded file to a temporary file
261
- """
262
  try:
263
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
264
- temp_path = temp_file.name
265
-
266
- if hasattr(file_obj, 'name'):
267
- # If it has a name attribute, it's likely a FileStorage object
268
- file_obj.save(temp_path)
269
- print(f"Saved file from FileStorage object to {temp_path}")
270
- elif hasattr(file_obj, 'read'):
271
- # If it's a file-like object
272
- content = file_obj.read()
273
- if isinstance(content, str): # It's text, not binary
274
- content = content.encode('utf-8')
275
- temp_file.write(content)
276
- print(f"Wrote {len(content)} bytes to {temp_path}")
277
- elif isinstance(file_obj, bytes):
278
- # If it's already bytes
279
- temp_file.write(file_obj)
280
- print(f"Wrote {len(file_obj)} bytes to {temp_path}")
281
- elif isinstance(file_obj, str):
282
- # If it's a path
283
  with open(file_obj, 'rb') as f:
284
- temp_file.write(f.read())
285
- print(f"Copied file from {file_obj} to {temp_path}")
286
  else:
287
- print(f"Unsupported file object type: {type(file_obj)}")
288
- os.unlink(temp_path)
289
- return None
290
-
291
- return temp_path
292
  except Exception as e:
293
- print(f"Error saving temporary file: {e}")
294
- import traceback
295
- traceback.print_exc()
296
  return None
297
 
298
  def process_uploaded_tiff(file_obj):
299
- """
300
- Process an uploaded TIFF file
301
- """
302
  try:
303
- print(f"Processing uploaded TIFF file: {type(file_obj)}")
 
 
 
 
304
 
305
  # Save to a temporary file
306
- temp_path = save_temp_file(file_obj)
307
- if not temp_path:
308
- return None, None
309
 
310
- # Read the TIFF file
311
  image = read_tiff_image(temp_path)
312
 
313
  # Clean up
@@ -343,16 +323,24 @@ def process_uploaded_tiff(file_obj):
343
  return None, None
344
 
345
  def process_uploaded_mask(file_obj):
346
- """
347
- Process an uploaded mask file
348
- """
349
  try:
350
- print(f"Processing uploaded mask file: {type(file_obj)}")
 
 
 
351
 
352
  # Save to a temporary file
353
- temp_path = save_temp_file(file_obj)
354
- if not temp_path:
355
- return None
 
 
 
 
 
 
 
356
 
357
  # Check if it's a TIFF file
358
  if temp_path.lower().endswith(('.tif', '.tiff')):
@@ -364,7 +352,6 @@ def process_uploaded_mask(file_obj):
364
  mask = np.array(mask_img)
365
  if len(mask.shape) == 3:
366
  mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
367
- print(f"Opened mask as regular image, shape: {mask.shape}")
368
  except Exception as e:
369
  print(f"Error opening mask as regular image: {e}")
370
  os.unlink(temp_path)
@@ -420,8 +407,6 @@ def predict_segmentation(image_tensor):
420
  return pred
421
  except Exception as e:
422
  print(f"Error during prediction: {e}")
423
- import traceback
424
- traceback.print_exc()
425
  return None
426
 
427
  def calculate_metrics(pred_mask, gt_mask):
@@ -464,11 +449,6 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
464
  Process input images and generate predictions
465
  """
466
  try:
467
- print("\n---- Starting new processing request ----")
468
- print(f"Input image type: {type(input_image) if input_image is not None else None}")
469
- print(f"Input TIFF type: {type(input_tiff) if input_tiff is not None else None}")
470
- print(f"Ground truth mask type: {type(gt_mask_file) if gt_mask_file is not None else None}")
471
-
472
  # Check if we have input
473
  if input_image is None and input_tiff is None:
474
  return None, "Please upload an image or TIFF file."
@@ -488,7 +468,6 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
488
  return None, "No valid input provided."
489
 
490
  # Get prediction
491
- print("Running model prediction...")
492
  pred_mask = predict_segmentation(image_tensor)
493
  if pred_mask is None:
494
  return None, "Failed to generate prediction."
@@ -498,19 +477,13 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
498
  metrics_text = ""
499
 
500
  if gt_mask_file is not None and gt_mask_file:
501
- print("Processing ground truth mask...")
502
  gt_mask_processed = process_uploaded_mask(gt_mask_file)
503
 
504
  if gt_mask_processed is not None:
505
- print("Calculating metrics...")
506
  metrics = calculate_metrics(pred_mask, gt_mask_processed)
507
  metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
508
- print(f"Metrics calculated: {metrics}")
509
- else:
510
- print("Failed to process ground truth mask.")
511
 
512
  # Create visualization
513
- print("Creating visualization...")
514
  fig = plt.figure(figsize=(12, 6))
515
 
516
  if gt_mask_processed is not None:
@@ -549,15 +522,14 @@ def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
549
  if metrics_text:
550
  result_text += f"\n\nEvaluation Metrics:\n{metrics_text}"
551
 
552
- # Convert figure to image
553
- buf = BytesIO()
554
  plt.tight_layout()
555
- plt.savefig(buf, format='png', dpi=150)
 
556
  buf.seek(0)
557
  result_image = Image.open(buf)
558
  plt.close(fig)
559
 
560
- print("Processing completed successfully.")
561
  return result_image, result_text
562
 
563
  except Exception as e:
@@ -581,7 +553,7 @@ with gr.Blocks(title="Wetlands Segmentation from Satellite Imagery") as demo:
581
  with gr.Tab("Upload TIFF"):
582
  input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
583
 
584
- # IMPORTANT CHANGE: Changed ground truth from Image to File for TIFF support
585
  gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
586
 
587
  process_btn = gr.Button("Analyze Image", variant="primary")
 
12
  import urllib.request
13
  import tempfile
14
  import rasterio
 
15
  import warnings
16
  warnings.filterwarnings("ignore")
17
 
 
68
  )
69
  else:
70
  # Fallback to a simple model that won't actually work but allows the UI to load
71
+ print("Warning: Using a placeholder model that won't produce valid predictions.")
72
  from torch import nn
73
  class PlaceholderModel(nn.Module):
74
  def __init__(self):
 
132
  This matches your training data loading approach
133
  """
134
  try:
 
135
  # Read the image using rasterio (get RGB channels)
136
  with rasterio.open(tiff_path) as src:
 
137
  # Check if we have enough bands
138
  if src.count >= 3:
139
  red = src.read(1)
 
142
 
143
  # Stack to create RGB image
144
  image = np.dstack((red, green, blue)).astype(np.float32)
 
145
 
146
  # Normalize to [0, 1]
147
  if image.max() > 0:
 
150
  return image
151
  else:
152
  # If less than 3 bands, handle accordingly
 
153
  bands = [src.read(i+1) for i in range(src.count)]
154
  # If only one band, duplicate to create RGB
155
  if len(bands) == 1:
 
167
  return image
168
  except Exception as e:
169
  print(f"Error reading TIFF file: {e}")
 
 
170
  return None
171
 
172
  def read_tiff_mask(mask_path):
 
175
  This matches your training data loading approach
176
  """
177
  try:
 
178
  # Read mask
179
  with rasterio.open(mask_path) as src:
 
180
  mask = src.read(1).astype(np.uint8)
 
181
  return mask
182
  except Exception as e:
183
  print(f"Error reading mask file: {e}")
 
 
184
  return None
185
 
186
  def preprocess_image(image, target_size=(128, 128)):
 
204
 
205
  # Convert PIL image to numpy
206
  elif isinstance(image, Image.Image):
 
207
  image = np.array(image)
208
 
209
  # Ensure RGB format
 
221
  print(f"Unsupported image type: {type(image)}")
222
  return None, None
223
 
 
 
224
  # Resize image to the target size
225
  if albumentations_available:
226
  # Use albumentations to match training preprocessing
 
240
 
241
  return image_tensor, display_image
242
 
243
+ def extract_file_content(file_obj):
244
+ """Extract content from the file object, handling different types"""
 
 
245
  try:
246
+ if hasattr(file_obj, 'name') and isinstance(file_obj, str):
247
+ # Handle Gradio's NamedString
248
+ content = file_obj
249
+ if os.path.exists(content):
250
+ # It's a path
251
+ with open(content, 'rb') as f:
252
+ return f.read()
253
+ else:
254
+ # It's content
255
+ return content.encode('latin1')
256
+ elif hasattr(file_obj, 'read'):
257
+ # File-like object
258
+ return file_obj.read()
259
+ elif isinstance(file_obj, bytes):
260
+ # Already bytes
261
+ return file_obj
262
+ elif isinstance(file_obj, str):
263
+ # String path
264
+ if os.path.exists(file_obj):
 
265
  with open(file_obj, 'rb') as f:
266
+ return f.read()
 
267
  else:
268
+ return file_obj.encode('utf-8')
269
+ else:
270
+ print(f"Unsupported file object type: {type(file_obj)}")
271
+ return None
 
272
  except Exception as e:
273
+ print(f"Error extracting file content: {e}")
 
 
274
  return None
275
 
276
  def process_uploaded_tiff(file_obj):
277
+ """Process an uploaded TIFF file"""
 
 
278
  try:
279
+ # Get file content
280
+ file_content = extract_file_content(file_obj)
281
+ if file_content is None:
282
+ print("Failed to extract file content")
283
+ return None, None
284
 
285
  # Save to a temporary file
286
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as temp_file:
287
+ temp_path = temp_file.name
288
+ temp_file.write(file_content)
289
 
290
+ # Read as TIFF
291
  image = read_tiff_image(temp_path)
292
 
293
  # Clean up
 
323
  return None, None
324
 
325
  def process_uploaded_mask(file_obj):
326
+ """Process an uploaded mask file"""
 
 
327
  try:
328
+ # Get file content
329
+ file_content = extract_file_content(file_obj)
330
+ if file_content is None:
331
+ return None
332
 
333
  # Save to a temporary file
334
+ # Determine suffix based on file name if available
335
+ suffix = '.tif'
336
+ if hasattr(file_obj, 'name'):
337
+ file_name = getattr(file_obj, 'name')
338
+ if isinstance(file_name, str) and '.' in file_name:
339
+ suffix = '.' + file_name.split('.')[-1].lower()
340
+
341
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
342
+ temp_path = temp_file.name
343
+ temp_file.write(file_content)
344
 
345
  # Check if it's a TIFF file
346
  if temp_path.lower().endswith(('.tif', '.tiff')):
 
352
  mask = np.array(mask_img)
353
  if len(mask.shape) == 3:
354
  mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
 
355
  except Exception as e:
356
  print(f"Error opening mask as regular image: {e}")
357
  os.unlink(temp_path)
 
407
  return pred
408
  except Exception as e:
409
  print(f"Error during prediction: {e}")
 
 
410
  return None
411
 
412
  def calculate_metrics(pred_mask, gt_mask):
 
449
  Process input images and generate predictions
450
  """
451
  try:
 
 
 
 
 
452
  # Check if we have input
453
  if input_image is None and input_tiff is None:
454
  return None, "Please upload an image or TIFF file."
 
468
  return None, "No valid input provided."
469
 
470
  # Get prediction
 
471
  pred_mask = predict_segmentation(image_tensor)
472
  if pred_mask is None:
473
  return None, "Failed to generate prediction."
 
477
  metrics_text = ""
478
 
479
  if gt_mask_file is not None and gt_mask_file:
 
480
  gt_mask_processed = process_uploaded_mask(gt_mask_file)
481
 
482
  if gt_mask_processed is not None:
 
483
  metrics = calculate_metrics(pred_mask, gt_mask_processed)
484
  metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
 
 
 
485
 
486
  # Create visualization
 
487
  fig = plt.figure(figsize=(12, 6))
488
 
489
  if gt_mask_processed is not None:
 
522
  if metrics_text:
523
  result_text += f"\n\nEvaluation Metrics:\n{metrics_text}"
524
 
525
+ # Convert figure to image for display
 
526
  plt.tight_layout()
527
+ buf = BytesIO()
528
+ plt.savefig(buf, format='png')
529
  buf.seek(0)
530
  result_image = Image.open(buf)
531
  plt.close(fig)
532
 
 
533
  return result_image, result_text
534
 
535
  except Exception as e:
 
553
  with gr.Tab("Upload TIFF"):
554
  input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
555
 
556
+ # Ground truth mask as file upload
557
  gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
558
 
559
  process_btn = gr.Button("Analyze Image", variant="primary")