dcrey7 commited on
Commit
e45fd18
·
verified ·
1 Parent(s): 984c6f3

examples feature

Browse files
Files changed (1) hide show
  1. app.py +369 -248
app.py CHANGED
@@ -89,12 +89,12 @@ def download_model_weights():
89
  try:
90
  os.makedirs('weights', exist_ok=True)
91
  local_path = os.path.join('weights', SEGMENTATION_MODEL_FILENAME)
92
-
93
  # Check if weights are already downloaded
94
  if os.path.exists(local_path):
95
  print(f"Model weights already downloaded at {local_path}")
96
  return local_path
97
-
98
  # Download weights
99
  print(f"Downloading model weights from {SEGMENTATION_MODEL_REPO}...")
100
  url = f"https://huggingface.co/{SEGMENTATION_MODEL_REPO}/resolve/main/{SEGMENTATION_MODEL_FILENAME}"
@@ -159,14 +159,14 @@ def normalize(band):
159
  band_cleaned = band[np.isfinite(band)]
160
  if len(band_cleaned) == 0:
161
  return band
162
-
163
  # Use percentiles to avoid outliers
164
  band_min, band_max = np.percentile(band_cleaned, (2, 98))
165
-
166
  # Avoid division by zero
167
  if band_max == band_min:
168
  return np.zeros_like(band)
169
-
170
  band_normalized = (band - band_min) / (band_max - band_min)
171
  band_normalized = np.clip(band_normalized, 0, 1)
172
  return band_normalized
@@ -175,19 +175,19 @@ def calculate_cv(band):
175
  """Calculate coefficient of variation (CV) for a band"""
176
  # First normalize the band
177
  band_normalized = normalize(band)
178
-
179
  # Handle potential NaN or inf values
180
  band_cleaned = band_normalized[np.isfinite(band_normalized)]
181
  if len(band_cleaned) == 0:
182
  return 0
183
-
184
  # Get mean and std dev
185
  mean = np.mean(band_cleaned)
186
-
187
  # Guard against division by zero or very small means
188
  if abs(mean) < 1e-10:
189
  return 0
190
-
191
  std = np.std(band_cleaned)
192
  cv = (std / mean) # CV as ratio (not percentage)
193
  return cv
@@ -205,14 +205,14 @@ def read_tiff_image_for_segmentation(tiff_path):
205
  red = src.read(1)
206
  green = src.read(2)
207
  blue = src.read(3)
208
-
209
  # Stack to create RGB image
210
  image = np.dstack((red, green, blue)).astype(np.float32)
211
-
212
  # Normalize to [0, 1]
213
  if image.max() > 0:
214
  image = image / image.max()
215
-
216
  return image
217
  else:
218
  # If less than 3 bands, handle accordingly
@@ -225,11 +225,11 @@ def read_tiff_image_for_segmentation(tiff_path):
225
  while len(bands) < 3:
226
  bands.append(np.zeros_like(bands[0]))
227
  image = np.dstack(bands[:3]) # Use first 3 bands
228
-
229
  # Normalize
230
  if image.max() > 0:
231
  image = image / image.max()
232
-
233
  return image
234
  except Exception as e:
235
  print(f"Error reading TIFF file for segmentation: {e}")
@@ -243,22 +243,22 @@ def extract_cloud_features_from_tiff(tiff_path):
243
  try:
244
  with rasterio.open(tiff_path) as src:
245
  num_bands = min(src.count, 10) # Use up to 10 bands
246
-
247
  # Process each band
248
  features = {}
249
  for i in range(1, num_bands + 1):
250
  band = src.read(i)
251
-
252
  # Calculate coefficient of variation
253
  cv_value = calculate_cv(band)
254
-
255
  # Store feature with name matching the training data
256
  features[f'band{i}_cv'] = cv_value
257
-
258
  # If we have fewer than 10 bands, fill the missing ones with zeros
259
  for i in range(num_bands + 1, 11):
260
  features[f'band{i}_cv'] = 0.0
261
-
262
  return features
263
  except Exception as e:
264
  print(f"Error extracting cloud features from TIFF: {e}")
@@ -275,25 +275,26 @@ def extract_cloud_features_from_rgb(image):
275
  # Make sure image is in float format in range [0,1]
276
  if image.dtype != np.float32 and image.dtype != np.float64:
277
  image = image.astype(np.float32)
278
-
279
  if image.max() > 1.0:
280
  image = image / 255.0
281
-
282
  # Create a dictionary for band CV features
283
  features = {}
284
-
285
- # Process each channel/band
286
- for i in range(min(1, image.shape[2])):
 
287
  band = image[:, :, i]
288
  cv_value = calculate_cv(band)
289
  features[f'band{i+1}_cv'] = cv_value
290
-
291
  # Fill remaining bands with zeros to match the expected 10 features
292
- # for i in range(4, 11):
293
- # features[f'band{i}_cv'] = 0.0
294
-
295
  return features
296
-
297
  except Exception as e:
298
  print(f"Error extracting cloud features from RGB: {e}")
299
  import traceback
@@ -304,21 +305,21 @@ def predict_cloud(features_dict, model):
304
  """Predict if an image is cloudy based on extracted features"""
305
  if model is None:
306
  return {'prediction': 'Model unavailable', 'probability': 0.0}
307
-
308
  try:
309
  # Ensure all 10 features from band1_cv to band10_cv are present
310
  feature_dict = {}
311
  for i in range(1, 11):
312
  feature_name = f'band{i}_cv'
313
  feature_dict[feature_name] = features_dict.get(feature_name, 0.0)
314
-
315
  # Create a DataFrame with all required features
316
  feature_df = pd.DataFrame([feature_dict])
317
-
318
  # Enable shape check disabling for prediction
319
  if hasattr(model, 'set_params'):
320
  model.set_params(predict_disable_shape_check=True)
321
-
322
  # Make prediction
323
  if hasattr(model, 'predict_proba'):
324
  proba = model.predict_proba(feature_df)
@@ -330,10 +331,10 @@ def predict_cloud(features_dict, model):
330
  # If model doesn't have predict_proba, use predict and assume binary output
331
  pred = model.predict(feature_df)
332
  probability = float(pred[0])
333
-
334
  # Classification based on probability threshold
335
  prediction = 'Cloudy' if probability >= 0.5 else 'Non-Cloudy'
336
-
337
  return {
338
  'prediction': prediction,
339
  'probability': probability
@@ -369,38 +370,42 @@ def preprocess_image(image, target_size=(128, 128)):
369
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
370
  elif image.shape[2] == 4: # RGBA
371
  image = image[:, :, :3]
372
-
373
  # Make a copy for display
374
  display_image = image.copy()
375
-
376
  # Normalize to [0, 1] if needed
377
  if display_image.max() > 1.0:
378
  image = image.astype(np.float32) / 255.0
379
-
 
 
 
 
380
  # Convert PIL image to numpy
381
  elif isinstance(image, Image.Image):
382
  image = np.array(image)
383
-
384
  # Ensure RGB format
385
  if len(image.shape) == 2: # Grayscale
386
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
387
  elif image.shape[2] == 4: # RGBA
388
  image = image[:, :, :3]
389
-
390
  # Make a copy for display
391
- display_image = image.copy()
392
-
393
- # Normalize to [0, 1]
394
  image = image.astype(np.float32) / 255.0
395
  else:
396
  print(f"Unsupported image type: {type(image)}")
397
  return None, None
398
-
399
  # Resize image to the target size
400
  if albumentations_available:
401
  # Use albumentations to match training preprocessing
402
  aug = A.Compose([
403
- A.PadIfNeeded(min_height=target_size[0], min_width=target_size[1],
404
  border_mode=cv2.BORDER_CONSTANT, value=0),
405
  A.CenterCrop(height=target_size[0], width=target_size[1])
406
  ])
@@ -409,160 +414,196 @@ def preprocess_image(image, target_size=(128, 128)):
409
  else:
410
  # Fallback to OpenCV
411
  image_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
412
-
413
  # Convert to tensor [C, H, W]
414
  image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
415
-
416
  return image_tensor, display_image
417
 
418
  def extract_file_content(file_obj):
419
  """Extract content from the file object, handling different types"""
420
  try:
421
- if hasattr(file_obj, 'name') and isinstance(file_obj, str):
422
- # Handle Gradio's NamedString
423
- content = file_obj
424
- if os.path.exists(content):
425
- # It's a path
426
- with open(content, 'rb') as f:
427
- return f.read()
428
  else:
429
- # It's content
430
- return content.encode('latin1')
431
- elif hasattr(file_obj, 'read'):
432
- # File-like object
433
- return file_obj.read()
434
- elif isinstance(file_obj, bytes):
435
- # Already bytes
436
- return file_obj
437
  elif isinstance(file_obj, str):
438
- # String path
439
  if os.path.exists(file_obj):
440
- with open(file_obj, 'rb') as f:
441
- return f.read()
442
  else:
443
- return file_obj.encode('utf-8')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  else:
445
  print(f"Unsupported file object type: {type(file_obj)}")
446
- return None
447
  except Exception as e:
448
  print(f"Error extracting file content: {e}")
449
- return None
 
 
450
 
451
  def process_uploaded_tiff(file_obj):
452
  """Process an uploaded TIFF file for both segmentation and cloud detection"""
 
453
  try:
454
- # Get file content
455
- file_content = extract_file_content(file_obj)
456
- if file_content is None:
457
- print("Failed to extract file content")
 
 
458
  return None, None, None
459
-
460
- # Save to a temporary file
461
- with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as temp_file:
462
- temp_path = temp_file.name
463
- temp_file.write(file_content)
464
-
465
  # Read as TIFF for segmentation (only using first 3 bands)
466
- image_for_segmentation = read_tiff_image_for_segmentation(temp_path)
467
-
468
  # Extract cloud features from all available bands
469
- cloud_features = extract_cloud_features_from_tiff(temp_path)
470
-
471
- # Clean up
472
- os.unlink(temp_path)
473
-
474
  if image_for_segmentation is None:
 
475
  return None, None, None
476
-
477
- # Make a copy for display
478
- display_image = (image_for_segmentation * 255).astype(np.uint8) if image_for_segmentation.max() <= 1.0 else image_for_segmentation.copy()
479
-
 
 
 
 
 
 
 
 
480
  # Resize/preprocess for segmentation model
481
  if albumentations_available:
482
  aug = A.Compose([
483
- A.PadIfNeeded(min_height=128, min_width=128,
484
  border_mode=cv2.BORDER_CONSTANT, value=0),
485
  A.CenterCrop(height=128, width=128)
486
  ])
487
- augmented = aug(image=image_for_segmentation)
 
 
 
 
 
488
  image_resized = augmented['image']
489
  else:
490
- image_resized = cv2.resize(image_for_segmentation, (128, 128), interpolation=cv2.INTER_LINEAR)
491
-
 
 
 
492
  # Convert to tensor
493
  image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
494
-
495
  return image_tensor, display_image, cloud_features
496
-
497
  except Exception as e:
498
  print(f"Error processing uploaded TIFF: {e}")
499
  import traceback
500
  traceback.print_exc()
501
  return None, None, None
 
 
 
 
 
 
 
 
502
 
503
  def process_uploaded_mask(file_obj):
504
  """Process an uploaded mask file"""
 
505
  try:
506
- # Get file content
507
- file_content = extract_file_content(file_obj)
508
- if file_content is None:
509
- return None
510
-
511
- # Save to a temporary file
512
- # Determine suffix based on file name if available
513
- suffix = '.tif'
514
- if hasattr(file_obj, 'name'):
515
- file_name = getattr(file_obj, 'name')
516
- if isinstance(file_name, str) and '.' in file_name:
517
- suffix = '.' + file_name.split('.')[-1].lower()
518
-
519
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
520
- temp_path = temp_file.name
521
- temp_file.write(file_content)
522
-
523
  # Check if it's a TIFF file
524
- if temp_path.lower().endswith(('.tif', '.tiff')):
525
- mask = read_tiff_mask(temp_path)
526
  else:
527
  # Try to open as a regular image
528
  try:
529
- mask_img = Image.open(temp_path)
530
  mask = np.array(mask_img)
531
- if len(mask.shape) == 3:
532
- mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
533
  except Exception as e:
534
  print(f"Error opening mask as regular image: {e}")
535
- os.unlink(temp_path)
536
  return None
537
-
538
- # Clean up
539
- os.unlink(temp_path)
540
-
541
  if mask is None:
542
- return None
543
-
 
544
  # Resize mask to 128x128
545
  if albumentations_available:
546
  aug = A.Compose([
547
- A.PadIfNeeded(min_height=128, min_width=128,
548
  border_mode=cv2.BORDER_CONSTANT, value=0),
549
  A.CenterCrop(height=128, width=128)
550
  ])
551
- augmented = aug(image=mask)
552
  mask_resized = augmented['image']
553
  else:
554
  mask_resized = cv2.resize(mask, (128, 128), interpolation=cv2.INTER_NEAREST)
555
-
556
  # Binarize the mask (0: background, 1: wetland)
557
- mask_binary = (mask_resized > 0).astype(np.uint8)
558
-
 
 
 
 
559
  return mask_binary
560
-
561
  except Exception as e:
562
  print(f"Error processing uploaded mask: {e}")
563
  import traceback
564
  traceback.print_exc()
565
  return None
 
 
 
 
 
 
 
 
566
 
567
  def predict_segmentation(image_tensor):
568
  """
@@ -570,10 +611,10 @@ def predict_segmentation(image_tensor):
570
  """
571
  try:
572
  image_tensor = image_tensor.to(device)
573
-
574
  with torch.no_grad():
575
  output = model(image_tensor)
576
-
577
  # Handle different model output formats
578
  if isinstance(output, dict):
579
  output = output['out']
@@ -581,7 +622,7 @@ def predict_segmentation(image_tensor):
581
  pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
582
  else: # Binary output (from smp models)
583
  pred = (torch.sigmoid(output) > 0.5).squeeze().cpu().numpy().astype(np.uint8)
584
-
585
  return pred
586
  except Exception as e:
587
  print(f"Error during prediction: {e}")
@@ -591,208 +632,288 @@ def calculate_metrics(pred_mask, gt_mask):
591
  """
592
  Calculate evaluation metrics between prediction and ground truth
593
  """
 
 
 
 
 
 
 
 
 
 
594
  # Ensure binary masks
595
  pred_binary = (pred_mask > 0).astype(np.uint8)
596
  gt_binary = (gt_mask > 0).astype(np.uint8)
597
-
598
  # Calculate intersection and union
599
  intersection = np.logical_and(pred_binary, gt_binary).sum()
600
  union = np.logical_or(pred_binary, gt_binary).sum()
601
-
602
  # Calculate IoU
603
  iou = intersection / union if union > 0 else 0
604
-
605
  # Calculate precision and recall
606
  true_positive = intersection
607
- false_positive = pred_binary.sum() - true_positive
608
- false_negative = gt_binary.sum() - true_positive
609
-
610
- precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
611
- recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
612
-
 
 
 
613
  # Calculate F1 score
614
  f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
615
-
616
  metrics = {
617
  "Wetlands IoU": float(iou),
618
  "Precision": float(precision),
619
  "Recall": float(recall),
620
  "F1 Score": float(f1)
621
  }
622
-
623
  return metrics
624
 
625
  def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
626
  """
627
  Process input images, generate predictions for both wetland segmentation and cloud detection
628
  """
 
 
 
 
 
 
 
 
629
  try:
630
- # Check if we have input
631
- if input_image is None and input_tiff is None:
632
- return None, "Please upload an image or TIFF file."
633
-
634
- # Process the input and initialize cloud_features
635
- cloud_features = None
636
-
637
- if input_tiff is not None and input_tiff:
638
- # Process uploaded TIFF file for both segmentation and cloud detection
639
  image_tensor, display_image, cloud_features = process_uploaded_tiff(input_tiff)
640
  if image_tensor is None:
641
  return None, "Failed to process the input TIFF file."
 
642
  elif input_image is not None:
643
- # Process regular image
644
  image_tensor, display_image = preprocess_image(input_image)
645
  if image_tensor is None:
646
  return None, "Failed to process the input image."
647
-
648
- # For RGB images, we need to extract cloud features separately
649
- cloud_features = extract_cloud_features_from_rgb(display_image)
 
650
  else:
651
- return None, "No valid input provided."
652
-
 
653
  # Get wetland segmentation prediction
 
654
  pred_mask = predict_segmentation(image_tensor)
655
  if pred_mask is None:
656
  return None, "Failed to generate wetland segmentation prediction."
657
-
 
 
658
  # Get cloud prediction
659
  cloud_result = {'prediction': 'Unknown', 'probability': 0.0}
660
  if cloud_features and cloud_model:
 
661
  cloud_result = predict_cloud(cloud_features, cloud_model)
662
-
663
- # Process ground truth mask if provided
664
- gt_mask_processed = None
 
 
 
 
 
665
  metrics_text = ""
666
-
667
- if gt_mask_file is not None and gt_mask_file:
668
  gt_mask_processed = process_uploaded_mask(gt_mask_file)
669
-
670
  if gt_mask_processed is not None:
 
 
671
  metrics = calculate_metrics(pred_mask, gt_mask_processed)
672
  metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
673
-
674
- # Create visualization
675
- fig = plt.figure(figsize=(12, 6))
676
-
677
- if gt_mask_processed is not None:
678
- # Show original, ground truth, and prediction
679
- plt.subplot(1, 3, 1)
680
- plt.imshow(display_image)
681
- plt.title("Input Image")
682
- plt.axis('off')
683
-
684
- plt.subplot(1, 3, 2)
685
- plt.imshow(gt_mask_processed, cmap='binary')
686
- plt.title("Ground Truth")
687
- plt.axis('off')
688
-
689
- plt.subplot(1, 3, 3)
690
- plt.imshow(pred_mask, cmap='binary')
691
- plt.title("Prediction")
692
- plt.axis('off')
693
  else:
694
- # Show original and prediction
695
- plt.subplot(1, 2, 1)
696
- plt.imshow(display_image)
697
- plt.title("Input Image")
698
- plt.axis('off')
699
-
700
- plt.subplot(1, 2, 2)
701
- plt.imshow(pred_mask, cmap='binary')
702
- plt.title("Predicted Wetlands")
703
- plt.axis('off')
704
-
705
- # Calculate wetland percentage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  wetland_percentage = np.mean(pred_mask) * 100
707
-
708
- # Add results information
709
- result_text = f"Wetland Coverage: {wetland_percentage:.2f}%\n\n"
710
-
 
711
  # Add cloud detection results
712
  result_text += f"Cloud Detection: {cloud_result['prediction']} "
713
  result_text += f"({cloud_result['probability']*100:.2f}% Cloud probability)\n\n"
714
-
715
  # Add segmentation metrics if available
716
  if metrics_text:
717
- result_text += f"Evaluation Metrics:\n{metrics_text}"
718
-
 
 
 
 
 
 
719
  # Convert figure to image for display
720
- plt.tight_layout()
721
  buf = BytesIO()
722
- plt.savefig(buf, format='png')
723
  buf.seek(0)
724
  result_image = Image.open(buf)
725
  plt.close(fig)
726
-
 
727
  return result_image, result_text
728
-
729
  except Exception as e:
730
- print(f"Error in processing: {e}")
731
  import traceback
732
  traceback.print_exc()
733
- return None, f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
734
 
735
- # Create Gradio interface
736
  with gr.Blocks(title="Wetlands Segmentation & Cloud Detection") as demo:
737
  gr.Markdown("# Wetlands Segmentation & Cloud Detection from Satellite Imagery")
738
  gr.Markdown("Upload a satellite image or TIFF file to identify wetland areas and detect cloud cover. Optionally, you can also upload a ground truth mask for evaluation.")
739
-
740
  with gr.Row():
741
- with gr.Column():
742
- # Input options
743
  gr.Markdown("### Input")
744
- with gr.Tab("Upload Image"):
745
- input_image = gr.Image(label="Upload Satellite Image", type="numpy")
746
-
747
- with gr.Tab("Upload TIFF"):
748
- input_tiff = gr.File(label="Upload TIFF File", file_types=[".tif", ".tiff"])
749
-
750
  # Ground truth mask as file upload
751
- gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"])
752
-
 
 
 
 
 
 
 
 
 
 
 
753
  process_btn = gr.Button("Analyze Image", variant="primary")
754
-
755
- with gr.Column():
756
- # Output
757
  gr.Markdown("### Results")
758
- output_image = gr.Image(label="Segmentation Results", type="pil")
759
- output_text = gr.Textbox(label="Statistics", lines=8)
760
-
761
  # Information about the models
762
  gr.Markdown("### About these models")
763
  gr.Markdown("""
764
  This application uses two models:
765
-
766
  **1. Wetland Segmentation Model:**
767
  - Architecture: DeepLabv3+ with ResNet-34
768
- - Input: RGB satellite imagery
769
  - Output: Binary segmentation mask (Wetland vs Background)
770
- - Resolution: 128×128 pixels
771
-
772
  **2. Cloud Detection Model:**
773
  - Architecture: LightGBM Classifier
774
- - Input: CV features extracted from up to 10 image bands
775
  - Output: Binary classification (Cloudy vs Non-Cloudy) with probability
776
-
777
  **Tips for best results:**
778
- - For Cloudy image - train_11202327_p1, for Non cloudy image - train_02202325_p1
779
- - For Cloudy image - test_07202330_p1, for Non cloudy image - test_02202325_p1
780
- - The models work best with multi-band satellite imagery (TIFF files)
781
- - For optimal cloud detection results, use TIFF files with 10 bands
782
- - For optimal results, use images with similar characteristics to those used in training
783
- - The wetland model focuses on identifying wetland regions in natural landscapes
784
- - The cloud model detects cloud cover based on image band statistics
785
- - For ground truth masks, both TIFF and standard image formats are supported
786
-
787
- **Repository:** [dcrey7/wetlands_segmentation_deeplabsv3plus](https://huggingface.co/dcrey7/wetlands_segmentation_deeplabsv3plus)
788
  """)
789
-
790
  # Set up event handlers
791
  process_btn.click(
792
  fn=process_images,
793
  inputs=[input_image, input_tiff, gt_mask_file],
794
- outputs=[output_image, output_text]
 
795
  )
796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
797
  # Launch the app
798
- demo.launch()
 
 
89
  try:
90
  os.makedirs('weights', exist_ok=True)
91
  local_path = os.path.join('weights', SEGMENTATION_MODEL_FILENAME)
92
+
93
  # Check if weights are already downloaded
94
  if os.path.exists(local_path):
95
  print(f"Model weights already downloaded at {local_path}")
96
  return local_path
97
+
98
  # Download weights
99
  print(f"Downloading model weights from {SEGMENTATION_MODEL_REPO}...")
100
  url = f"https://huggingface.co/{SEGMENTATION_MODEL_REPO}/resolve/main/{SEGMENTATION_MODEL_FILENAME}"
 
159
  band_cleaned = band[np.isfinite(band)]
160
  if len(band_cleaned) == 0:
161
  return band
162
+
163
  # Use percentiles to avoid outliers
164
  band_min, band_max = np.percentile(band_cleaned, (2, 98))
165
+
166
  # Avoid division by zero
167
  if band_max == band_min:
168
  return np.zeros_like(band)
169
+
170
  band_normalized = (band - band_min) / (band_max - band_min)
171
  band_normalized = np.clip(band_normalized, 0, 1)
172
  return band_normalized
 
175
  """Calculate coefficient of variation (CV) for a band"""
176
  # First normalize the band
177
  band_normalized = normalize(band)
178
+
179
  # Handle potential NaN or inf values
180
  band_cleaned = band_normalized[np.isfinite(band_normalized)]
181
  if len(band_cleaned) == 0:
182
  return 0
183
+
184
  # Get mean and std dev
185
  mean = np.mean(band_cleaned)
186
+
187
  # Guard against division by zero or very small means
188
  if abs(mean) < 1e-10:
189
  return 0
190
+
191
  std = np.std(band_cleaned)
192
  cv = (std / mean) # CV as ratio (not percentage)
193
  return cv
 
205
  red = src.read(1)
206
  green = src.read(2)
207
  blue = src.read(3)
208
+
209
  # Stack to create RGB image
210
  image = np.dstack((red, green, blue)).astype(np.float32)
211
+
212
  # Normalize to [0, 1]
213
  if image.max() > 0:
214
  image = image / image.max()
215
+
216
  return image
217
  else:
218
  # If less than 3 bands, handle accordingly
 
225
  while len(bands) < 3:
226
  bands.append(np.zeros_like(bands[0]))
227
  image = np.dstack(bands[:3]) # Use first 3 bands
228
+
229
  # Normalize
230
  if image.max() > 0:
231
  image = image / image.max()
232
+
233
  return image
234
  except Exception as e:
235
  print(f"Error reading TIFF file for segmentation: {e}")
 
243
  try:
244
  with rasterio.open(tiff_path) as src:
245
  num_bands = min(src.count, 10) # Use up to 10 bands
246
+
247
  # Process each band
248
  features = {}
249
  for i in range(1, num_bands + 1):
250
  band = src.read(i)
251
+
252
  # Calculate coefficient of variation
253
  cv_value = calculate_cv(band)
254
+
255
  # Store feature with name matching the training data
256
  features[f'band{i}_cv'] = cv_value
257
+
258
  # If we have fewer than 10 bands, fill the missing ones with zeros
259
  for i in range(num_bands + 1, 11):
260
  features[f'band{i}_cv'] = 0.0
261
+
262
  return features
263
  except Exception as e:
264
  print(f"Error extracting cloud features from TIFF: {e}")
 
275
  # Make sure image is in float format in range [0,1]
276
  if image.dtype != np.float32 and image.dtype != np.float64:
277
  image = image.astype(np.float32)
278
+
279
  if image.max() > 1.0:
280
  image = image / 255.0
281
+
282
  # Create a dictionary for band CV features
283
  features = {}
284
+
285
+ # Process each channel/band (assuming image is H, W, C)
286
+ num_bands = min(3, image.shape[2])
287
+ for i in range(num_bands):
288
  band = image[:, :, i]
289
  cv_value = calculate_cv(band)
290
  features[f'band{i+1}_cv'] = cv_value
291
+
292
  # Fill remaining bands with zeros to match the expected 10 features
293
+ for i in range(num_bands + 1, 11):
294
+ features[f'band{i}_cv'] = 0.0
295
+
296
  return features
297
+
298
  except Exception as e:
299
  print(f"Error extracting cloud features from RGB: {e}")
300
  import traceback
 
305
  """Predict if an image is cloudy based on extracted features"""
306
  if model is None:
307
  return {'prediction': 'Model unavailable', 'probability': 0.0}
308
+
309
  try:
310
  # Ensure all 10 features from band1_cv to band10_cv are present
311
  feature_dict = {}
312
  for i in range(1, 11):
313
  feature_name = f'band{i}_cv'
314
  feature_dict[feature_name] = features_dict.get(feature_name, 0.0)
315
+
316
  # Create a DataFrame with all required features
317
  feature_df = pd.DataFrame([feature_dict])
318
+
319
  # Enable shape check disabling for prediction
320
  if hasattr(model, 'set_params'):
321
  model.set_params(predict_disable_shape_check=True)
322
+
323
  # Make prediction
324
  if hasattr(model, 'predict_proba'):
325
  proba = model.predict_proba(feature_df)
 
331
  # If model doesn't have predict_proba, use predict and assume binary output
332
  pred = model.predict(feature_df)
333
  probability = float(pred[0])
334
+
335
  # Classification based on probability threshold
336
  prediction = 'Cloudy' if probability >= 0.5 else 'Non-Cloudy'
337
+
338
  return {
339
  'prediction': prediction,
340
  'probability': probability
 
370
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
371
  elif image.shape[2] == 4: # RGBA
372
  image = image[:, :, :3]
373
+
374
  # Make a copy for display
375
  display_image = image.copy()
376
+
377
  # Normalize to [0, 1] if needed
378
  if display_image.max() > 1.0:
379
  image = image.astype(np.float32) / 255.0
380
+ display_image = display_image.astype(np.uint8) # Keep display image as uint8
381
+ else:
382
+ # If already normalized, scale up for display
383
+ display_image = (display_image * 255).astype(np.uint8)
384
+
385
  # Convert PIL image to numpy
386
  elif isinstance(image, Image.Image):
387
  image = np.array(image)
388
+
389
  # Ensure RGB format
390
  if len(image.shape) == 2: # Grayscale
391
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
392
  elif image.shape[2] == 4: # RGBA
393
  image = image[:, :, :3]
394
+
395
  # Make a copy for display
396
+ display_image = image.copy() # display_image is uint8 here
397
+
398
+ # Normalize to [0, 1] for model
399
  image = image.astype(np.float32) / 255.0
400
  else:
401
  print(f"Unsupported image type: {type(image)}")
402
  return None, None
403
+
404
  # Resize image to the target size
405
  if albumentations_available:
406
  # Use albumentations to match training preprocessing
407
  aug = A.Compose([
408
+ A.PadIfNeeded(min_height=target_size[0], min_width=target_size[1],
409
  border_mode=cv2.BORDER_CONSTANT, value=0),
410
  A.CenterCrop(height=target_size[0], width=target_size[1])
411
  ])
 
414
  else:
415
  # Fallback to OpenCV
416
  image_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
417
+
418
  # Convert to tensor [C, H, W]
419
  image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
420
+
421
  return image_tensor, display_image
422
 
423
  def extract_file_content(file_obj):
424
  """Extract content from the file object, handling different types"""
425
  try:
426
+ # Handle Gradio File object (which has 'name' attribute pointing to temp path)
427
+ if hasattr(file_obj, 'name') and isinstance(file_obj.name, str):
428
+ file_path = file_obj.name
429
+ if os.path.exists(file_path):
430
+ with open(file_path, 'rb') as f:
431
+ return f.read(), file_path # Return path for TIFF reading
 
432
  else:
433
+ print(f"Temp file path does not exist: {file_path}")
434
+ return None, None
435
+
436
+ # Handle string path (from gr.Examples)
 
 
 
 
437
  elif isinstance(file_obj, str):
 
438
  if os.path.exists(file_obj):
439
+ with open(file_obj, 'rb') as f:
440
+ return f.read(), file_obj # Return path for TIFF reading
441
  else:
442
+ print(f"Provided file path does not exist: {file_obj}")
443
+ return None, None
444
+ # Handle BytesIO or other file-like objects (less likely with gr.File/gr.Examples)
445
+ elif hasattr(file_obj, 'read'):
446
+ content = file_obj.read()
447
+ # Need to save to temp file to get a path for rasterio
448
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_f:
449
+ temp_path = temp_f.name
450
+ temp_f.write(content)
451
+ return content, temp_path # Return path
452
+ elif isinstance(file_obj, bytes):
453
+ # Need to save to temp file to get a path for rasterio
454
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_f:
455
+ temp_path = temp_f.name
456
+ temp_f.write(file_obj)
457
+ return file_obj, temp_path # Return path
458
  else:
459
  print(f"Unsupported file object type: {type(file_obj)}")
460
+ return None, None
461
  except Exception as e:
462
  print(f"Error extracting file content: {e}")
463
+ import traceback
464
+ traceback.print_exc()
465
+ return None, None
466
 
467
  def process_uploaded_tiff(file_obj):
468
  """Process an uploaded TIFF file for both segmentation and cloud detection"""
469
+ temp_file_to_delete = None
470
  try:
471
+ # Get file content and path
472
+ # We primarily need the path for rasterio
473
+ _, file_path = extract_file_content(file_obj)
474
+
475
+ if file_path is None:
476
+ print("Failed to get file path for TIFF processing")
477
  return None, None, None
478
+
479
+ # Check if extract_file_content created a temp file we need to manage
480
+ if file_path.endswith('.tmp'):
481
+ temp_file_to_delete = file_path
482
+
 
483
  # Read as TIFF for segmentation (only using first 3 bands)
484
+ image_for_segmentation = read_tiff_image_for_segmentation(file_path)
485
+
486
  # Extract cloud features from all available bands
487
+ cloud_features = extract_cloud_features_from_tiff(file_path)
488
+
 
 
 
489
  if image_for_segmentation is None:
490
+ print("Failed to read TIFF for segmentation.")
491
  return None, None, None
492
+
493
+ # Make a copy for display, ensuring uint8 for display
494
+ if image_for_segmentation.max() <= 1.0 and image_for_segmentation.min() >= 0.0:
495
+ display_image = (image_for_segmentation * 255).astype(np.uint8)
496
+ elif image_for_segmentation.dtype == np.uint8:
497
+ display_image = image_for_segmentation.copy()
498
+ else:
499
+ # Attempt normalization for display if out of range
500
+ display_image = cv2.normalize(image_for_segmentation, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
501
+ print("Warning: TIFF image values outside [0,1], normalized for display.")
502
+
503
+
504
  # Resize/preprocess for segmentation model
505
  if albumentations_available:
506
  aug = A.Compose([
507
+ A.PadIfNeeded(min_height=128, min_width=128,
508
  border_mode=cv2.BORDER_CONSTANT, value=0),
509
  A.CenterCrop(height=128, width=128)
510
  ])
511
+ # Need float image for albumentations
512
+ image_float = image_for_segmentation.astype(np.float32)
513
+ if image_float.max() > 1.0: # Normalize if not already done
514
+ image_float = image_float / (image_float.max() + 1e-6)
515
+
516
+ augmented = aug(image=image_float)
517
  image_resized = augmented['image']
518
  else:
519
+ image_float = image_for_segmentation.astype(np.float32)
520
+ if image_float.max() > 1.0: # Normalize if not already done
521
+ image_float = image_float / (image_float.max() + 1e-6)
522
+ image_resized = cv2.resize(image_float, (128, 128), interpolation=cv2.INTER_LINEAR)
523
+
524
  # Convert to tensor
525
  image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0)
526
+
527
  return image_tensor, display_image, cloud_features
528
+
529
  except Exception as e:
530
  print(f"Error processing uploaded TIFF: {e}")
531
  import traceback
532
  traceback.print_exc()
533
  return None, None, None
534
+ finally:
535
+ # Clean up temporary file if created by extract_file_content
536
+ if temp_file_to_delete and os.path.exists(temp_file_to_delete):
537
+ try:
538
+ os.unlink(temp_file_to_delete)
539
+ except Exception as e_del:
540
+ print(f"Error deleting temp file {temp_file_to_delete}: {e_del}")
541
+
542
 
543
  def process_uploaded_mask(file_obj):
544
  """Process an uploaded mask file"""
545
+ temp_file_to_delete = None
546
  try:
547
+ # Get file content and path
548
+ _, file_path = extract_file_content(file_obj)
549
+ if file_path is None:
550
+ print("Failed to get file path for Mask processing")
551
+ return None
552
+
553
+ # Check if extract_file_content created a temp file we need to manage
554
+ if file_path.endswith('.tmp'):
555
+ temp_file_to_delete = file_path
556
+
 
 
 
 
 
 
 
557
  # Check if it's a TIFF file
558
+ if file_path.lower().endswith(('.tif', '.tiff')):
559
+ mask = read_tiff_mask(file_path)
560
  else:
561
  # Try to open as a regular image
562
  try:
563
+ mask_img = Image.open(file_path).convert('L') # Convert to grayscale
564
  mask = np.array(mask_img)
 
 
565
  except Exception as e:
566
  print(f"Error opening mask as regular image: {e}")
 
567
  return None
568
+
 
 
 
569
  if mask is None:
570
+ print("Failed to read mask data.")
571
+ return None
572
+
573
  # Resize mask to 128x128
574
  if albumentations_available:
575
  aug = A.Compose([
576
+ A.PadIfNeeded(min_height=128, min_width=128,
577
  border_mode=cv2.BORDER_CONSTANT, value=0),
578
  A.CenterCrop(height=128, width=128)
579
  ])
580
+ augmented = aug(image=mask) # Use 'image' key even for mask
581
  mask_resized = augmented['image']
582
  else:
583
  mask_resized = cv2.resize(mask, (128, 128), interpolation=cv2.INTER_NEAREST)
584
+
585
  # Binarize the mask (0: background, 1: wetland)
586
+ # Use a threshold (e.g., 127 for typical grayscale) or >0 if it's already somewhat binary
587
+ threshold = np.median(mask_resized[mask_resized > 0]) if np.any(mask_resized > 0) else 1
588
+ mask_binary = (mask_resized >= threshold).astype(np.uint8)
589
+ print(f"Mask binarized. Original min/max: {mask.min()}/{mask.max()}, Resized min/max: {mask_resized.min()}/{mask_resized.max()}, Binary sum: {mask_binary.sum()}")
590
+
591
+
592
  return mask_binary
593
+
594
  except Exception as e:
595
  print(f"Error processing uploaded mask: {e}")
596
  import traceback
597
  traceback.print_exc()
598
  return None
599
+ finally:
600
+ # Clean up temporary file if created by extract_file_content
601
+ if temp_file_to_delete and os.path.exists(temp_file_to_delete):
602
+ try:
603
+ os.unlink(temp_file_to_delete)
604
+ except Exception as e_del:
605
+ print(f"Error deleting temp mask file {temp_file_to_delete}: {e_del}")
606
+
607
 
608
  def predict_segmentation(image_tensor):
609
  """
 
611
  """
612
  try:
613
  image_tensor = image_tensor.to(device)
614
+
615
  with torch.no_grad():
616
  output = model(image_tensor)
617
+
618
  # Handle different model output formats
619
  if isinstance(output, dict):
620
  output = output['out']
 
622
  pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
623
  else: # Binary output (from smp models)
624
  pred = (torch.sigmoid(output) > 0.5).squeeze().cpu().numpy().astype(np.uint8)
625
+
626
  return pred
627
  except Exception as e:
628
  print(f"Error during prediction: {e}")
 
632
  """
633
  Calculate evaluation metrics between prediction and ground truth
634
  """
635
+ if pred_mask is None or gt_mask is None:
636
+ print("Cannot calculate metrics: Invalid masks provided.")
637
+ return {}
638
+ if pred_mask.shape != gt_mask.shape:
639
+ print(f"Cannot calculate metrics: Shape mismatch - Pred {pred_mask.shape}, GT {gt_mask.shape}")
640
+ # Optionally resize one to match the other, e.g., resize GT to Pred shape
641
+ gt_mask = cv2.resize(gt_mask, (pred_mask.shape[1], pred_mask.shape[0]), interpolation=cv2.INTER_NEAREST)
642
+ print(f"Resized GT mask to {gt_mask.shape}")
643
+
644
+
645
  # Ensure binary masks
646
  pred_binary = (pred_mask > 0).astype(np.uint8)
647
  gt_binary = (gt_mask > 0).astype(np.uint8)
648
+
649
  # Calculate intersection and union
650
  intersection = np.logical_and(pred_binary, gt_binary).sum()
651
  union = np.logical_or(pred_binary, gt_binary).sum()
652
+
653
  # Calculate IoU
654
  iou = intersection / union if union > 0 else 0
655
+
656
  # Calculate precision and recall
657
  true_positive = intersection
658
+ pred_positive_count = pred_binary.sum()
659
+ gt_positive_count = gt_binary.sum()
660
+
661
+ false_positive = pred_positive_count - true_positive
662
+ false_negative = gt_positive_count - true_positive
663
+
664
+ precision = true_positive / pred_positive_count if pred_positive_count > 0 else 0
665
+ recall = true_positive / gt_positive_count if gt_positive_count > 0 else 0
666
+
667
  # Calculate F1 score
668
  f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
669
+
670
  metrics = {
671
  "Wetlands IoU": float(iou),
672
  "Precision": float(precision),
673
  "Recall": float(recall),
674
  "F1 Score": float(f1)
675
  }
676
+
677
  return metrics
678
 
679
  def process_images(input_image=None, input_tiff=None, gt_mask_file=None):
680
  """
681
  Process input images, generate predictions for both wetland segmentation and cloud detection
682
  """
683
+ image_tensor = None
684
+ display_image = None
685
+ cloud_features = None
686
+ pred_mask = None
687
+ gt_mask_processed = None
688
+ result_image = None
689
+ result_text = "Processing..." # Initial message
690
+
691
  try:
692
+ # Determine input type and process
693
+ if input_tiff is not None:
694
+ print("Processing TIFF input...")
 
 
 
 
 
 
695
  image_tensor, display_image, cloud_features = process_uploaded_tiff(input_tiff)
696
  if image_tensor is None:
697
  return None, "Failed to process the input TIFF file."
698
+ print("TIFF processing complete.")
699
  elif input_image is not None:
700
+ print("Processing Image input...")
701
  image_tensor, display_image = preprocess_image(input_image)
702
  if image_tensor is None:
703
  return None, "Failed to process the input image."
704
+ # For RGB images, extract cloud features separately
705
+ print("Extracting cloud features from RGB...")
706
+ cloud_features = extract_cloud_features_from_rgb(display_image) # Use display_image (uint8)
707
+ print("Image processing complete.")
708
  else:
709
+ return None, "Please upload an image or TIFF file."
710
+
711
+ # --- Perform Predictions ---
712
  # Get wetland segmentation prediction
713
+ print("Performing wetland segmentation...")
714
  pred_mask = predict_segmentation(image_tensor)
715
  if pred_mask is None:
716
  return None, "Failed to generate wetland segmentation prediction."
717
+ print(f"Segmentation prediction generated, shape: {pred_mask.shape}, type: {pred_mask.dtype}")
718
+
719
+
720
  # Get cloud prediction
721
  cloud_result = {'prediction': 'Unknown', 'probability': 0.0}
722
  if cloud_features and cloud_model:
723
+ print("Performing cloud detection...")
724
  cloud_result = predict_cloud(cloud_features, cloud_model)
725
+ print(f"Cloud detection result: {cloud_result}")
726
+ elif cloud_model is None:
727
+ print("Cloud detection model not available.")
728
+ else:
729
+ print("Cloud features not extracted, skipping cloud detection.")
730
+
731
+
732
+ # --- Process Ground Truth and Metrics (if provided) ---
733
  metrics_text = ""
734
+ if gt_mask_file is not None:
735
+ print("Processing ground truth mask...")
736
  gt_mask_processed = process_uploaded_mask(gt_mask_file)
737
+
738
  if gt_mask_processed is not None:
739
+ print(f"Ground truth mask processed, shape: {gt_mask_processed.shape}, type: {gt_mask_processed.dtype}")
740
+ print("Calculating metrics...")
741
  metrics = calculate_metrics(pred_mask, gt_mask_processed)
742
  metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
743
+ print(f"Metrics calculated: {metrics_text}")
744
+ else:
745
+ print("Failed to process ground truth mask.")
746
+ metrics_text = "Ground truth mask provided but could not be processed."
747
+
748
+
749
+ # --- Create Visualization ---
750
+ print("Creating result visualization...")
751
+ fig, axes = plt.subplots(1, 3 if gt_mask_processed is not None else 2, figsize=(12, 5))
752
+ fig.suptitle("Analysis Results", fontsize=16)
753
+
754
+ # Ensure display_image is in correct format for imshow
755
+ if display_image.dtype != np.uint8:
756
+ display_image_for_plot = (display_image * 255).astype(np.uint8) if display_image.max() <=1 else display_image.astype(np.uint8)
 
 
 
 
 
 
757
  else:
758
+ display_image_for_plot = display_image
759
+
760
+ # Resize display image to match mask size for consistent display if needed, or keep original?
761
+ # Let's keep original input size for clarity, masks are 128x128
762
+ display_image_resized = cv2.resize(display_image_for_plot, (512, 512), interpolation=cv2.INTER_LINEAR)
763
+
764
+
765
+ ax_idx = 0
766
+ axes[ax_idx].imshow(display_image_resized)
767
+ axes[ax_idx].set_title("Input Image (Resized for Display)")
768
+ axes[ax_idx].axis('off')
769
+ ax_idx += 1
770
+
771
+ if gt_mask_processed is not None:
772
+ axes[ax_idx].imshow(gt_mask_processed, cmap='viridis') # Use viridis for GT
773
+ axes[ax_idx].set_title("Ground Truth (128x128)")
774
+ axes[ax_idx].axis('off')
775
+ ax_idx += 1
776
+
777
+ # Ensure pred_mask is suitable for imshow (e.g., scale if needed, but should be 0/1)
778
+ axes[ax_idx].imshow(pred_mask, cmap='plasma') # Use plasma for Prediction
779
+ axes[ax_idx].set_title("Predicted Wetlands (128x128)")
780
+ axes[ax_idx].axis('off')
781
+
782
+
783
+ # Calculate wetland percentage from prediction
784
  wetland_percentage = np.mean(pred_mask) * 100
785
+
786
+ # --- Format Output Text ---
787
+ result_text = f"--- Analysis Summary ---\n"
788
+ result_text += f"Wetland Coverage (Predicted): {wetland_percentage:.2f}%\n\n"
789
+
790
  # Add cloud detection results
791
  result_text += f"Cloud Detection: {cloud_result['prediction']} "
792
  result_text += f"({cloud_result['probability']*100:.2f}% Cloud probability)\n\n"
793
+
794
  # Add segmentation metrics if available
795
  if metrics_text:
796
+ if "could not be processed" in metrics_text:
797
+ result_text += metrics_text # Add the error message
798
+ else:
799
+ result_text += f"--- Evaluation Metrics (vs Ground Truth) ---\n{metrics_text}"
800
+ elif gt_mask_file is not None:
801
+ result_text += "--- Evaluation Metrics ---\nGround truth provided but metrics could not be calculated."
802
+
803
+
804
  # Convert figure to image for display
805
+ plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
806
  buf = BytesIO()
807
+ plt.savefig(buf, format='png', bbox_inches='tight')
808
  buf.seek(0)
809
  result_image = Image.open(buf)
810
  plt.close(fig)
811
+ print("Visualization complete.")
812
+
813
  return result_image, result_text
814
+
815
  except Exception as e:
816
+ print(f"Error in process_images: {e}")
817
  import traceback
818
  traceback.print_exc()
819
+ # Return current state if partial results exist, otherwise error
820
+ error_message = f"An error occurred during processing: {str(e)}"
821
+ return result_image if result_image else None, result_text + f"\n\nERROR: {error_message}"
822
+
823
+
824
+ # --- Define Example Files ---
825
+ # Ensure these paths are correct relative to where app.py is run
826
+ # These should be at the root of your Gradio Space repository
827
+ example_list = [
828
+ ["test_p1_cloudy_input.tif", "test_p1_cloudy_output.tif"],
829
+ ["test_p1_noncloudy_input.tif", "test_p1_noncloudy_output.tif"],
830
+ ["test_p1_cloudy_input.tif", None], # Example without ground truth
831
+ ["test_p1_noncloudy_input.tif", None], # Example without ground truth
832
+ ]
833
 
834
+ # --- Create Gradio Interface ---
835
  with gr.Blocks(title="Wetlands Segmentation & Cloud Detection") as demo:
836
  gr.Markdown("# Wetlands Segmentation & Cloud Detection from Satellite Imagery")
837
  gr.Markdown("Upload a satellite image or TIFF file to identify wetland areas and detect cloud cover. Optionally, you can also upload a ground truth mask for evaluation.")
838
+
839
  with gr.Row():
840
+ with gr.Column(scale=1): # Input column
 
841
  gr.Markdown("### Input")
842
+ with gr.Tabs():
843
+ with gr.TabItem("Upload Image"):
844
+ input_image = gr.Image(label="Upload Satellite Image (JPG, PNG etc.)", type="numpy")
845
+ with gr.TabItem("Upload TIFF"):
846
+ input_tiff = gr.File(label="Upload Multi-Band TIFF File", file_types=[".tif", ".tiff"], type="filepath") # Use filepath for easier handling
847
+
848
  # Ground truth mask as file upload
849
+ gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"], type="filepath")
850
+
851
+ # --- Add Examples Section ---
852
+ gr.Markdown("### Load Examples")
853
+ gr.Examples(
854
+ examples=example_list,
855
+ inputs=[input_tiff, gt_mask_file], # Corresponds to TIFF input and GT Mask input
856
+ label="Click an example below to load files:",
857
+ # Outputs are not needed here, examples just populate inputs
858
+ # examples_per_page=4 # Optional: control pagination
859
+ )
860
+ # --- End Examples Section ---
861
+
862
  process_btn = gr.Button("Analyze Image", variant="primary")
863
+
864
+ with gr.Column(scale=2): # Output column (make it wider)
 
865
  gr.Markdown("### Results")
866
+ output_image = gr.Image(label="Segmentation Results", type="pil", height=450) # Adjust height as needed
867
+ output_text = gr.Textbox(label="Statistics & Metrics", lines=10, scale=1) # Adjust lines and scale
868
+
869
  # Information about the models
870
  gr.Markdown("### About these models")
871
  gr.Markdown("""
872
  This application uses two models:
873
+
874
  **1. Wetland Segmentation Model:**
875
  - Architecture: DeepLabv3+ with ResNet-34
876
+ - Input: RGB satellite imagery (extracted from first 3 bands of TIFF if provided)
877
  - Output: Binary segmentation mask (Wetland vs Background)
878
+ - Resolution: Processed at 128×128 pixels
879
+
880
  **2. Cloud Detection Model:**
881
  - Architecture: LightGBM Classifier
882
+ - Input: Coefficient of Variation (CV) features extracted from up to 10 image bands (from TIFF)
883
  - Output: Binary classification (Cloudy vs Non-Cloudy) with probability
884
+
885
  **Tips for best results:**
886
+ - Use the 'Upload TIFF' tab for multi-band satellite data to enable accurate cloud detection.
887
+ - The cloud detection model expects up to 10 bands. Performance may vary with fewer bands.
888
+ - The example files demonstrate cloudy/non-cloudy scenarios with corresponding ground truth.
889
+ - The models work best with images similar in characteristics to those used in training.
890
+ - For ground truth masks, both TIFF and standard image formats (PNG, JPG) are supported. Ensure the mask clearly delineates the target class.
891
+
892
+ **Repository:** [dcrey7/wetland_segmentation_deeplabsv3plus](https://huggingface.co/spaces/dcrey7/wetland_segmentation_deeplabsv3plus)
 
 
 
893
  """)
894
+
895
  # Set up event handlers
896
  process_btn.click(
897
  fn=process_images,
898
  inputs=[input_image, input_tiff, gt_mask_file],
899
+ outputs=[output_image, output_text],
900
+ api_name="analyze" # Optional: name for API endpoint
901
  )
902
 
903
+ # Clear inputs when switching tabs (optional but good UX)
904
+ def clear_other_input(selected_tab):
905
+ if selected_tab == 0: # "Upload Image" tab selected
906
+ return gr.update(value=None), gr.update() # Clear TIFF input
907
+ elif selected_tab == 1: # "Upload TIFF" tab selected
908
+ return gr.update(), gr.update(value=None) # Clear Image input
909
+ return gr.update(), gr.update() # Default case
910
+
911
+ # Assuming the Tabs component itself can be accessed or named,
912
+ # otherwise, this part might need adjustment based on Gradio's specific API for Tabs.
913
+ # If Tabs don't directly support change events easily, this part can be omitted.
914
+ # For now, let's assume tab switching doesn't automatically clear, user needs to manage inputs.
915
+
916
+
917
  # Launch the app
918
+ if __name__ == "__main__":
919
+ demo.launch(debug=True) # Enable debug for more detailed logs during development