import os import gradio as gr import torch import numpy as np import matplotlib.pyplot as plt from PIL import Image import requests import io import matplotlib.colors as mcolors import cv2 from io import BytesIO import urllib.request import tempfile import rasterio import warnings import pandas as pd import joblib warnings.filterwarnings("ignore") # Try to import segmentation_models_pytorch try: import segmentation_models_pytorch as smp smp_available = True print("Successfully imported segmentation_models_pytorch") except ImportError: smp_available = False print("Warning: segmentation_models_pytorch not available, will try to install it") import subprocess try: subprocess.check_call([ "pip", "install", "segmentation-models-pytorch" ]) import segmentation_models_pytorch as smp smp_available = True print("Successfully installed and imported segmentation_models_pytorch") except: print("Failed to install segmentation_models_pytorch") # Try to import albumentations if needed for preprocessing try: import albumentations as A albumentations_available = True print("Successfully imported albumentations") except ImportError: albumentations_available = False print("Warning: albumentations not available, will try to install it") import subprocess try: subprocess.check_call([ "pip", "install", "albumentations" ]) import albumentations as A albumentations_available = True print("Successfully installed and imported albumentations") except: print("Failed to install albumentations, will use OpenCV for transforms") # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Initialize the segmentation model if smp_available: # Define the DeepLabV3+ model using smp model = smp.DeepLabV3Plus( encoder_name="resnet34", # Using ResNet34 backbone as in your training encoder_weights=None, # We'll load your custom weights in_channels=3, # RGB input classes=1, # Binary segmentation ) else: # Fallback to a simple model that won't actually work but allows the UI to load print("Warning: Using a placeholder model that won't produce valid predictions.") from torch import nn class PlaceholderModel(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 1, 3, padding=1) def forward(self, x): return self.conv(x) model = PlaceholderModel() # Download segmentation model weights from HuggingFace SEGMENTATION_MODEL_REPO = "dcrey7/wetlands_segmentation_deeplabsv3plus" SEGMENTATION_MODEL_FILENAME = "DeepLabV3plus_best_model.pth" def download_model_weights(): """Download model weights from HuggingFace repository""" try: os.makedirs('weights', exist_ok=True) local_path = os.path.join('weights', SEGMENTATION_MODEL_FILENAME) # Check if weights are already downloaded if os.path.exists(local_path): print(f"Model weights already downloaded at {local_path}") return local_path # Download weights print(f"Downloading model weights from {SEGMENTATION_MODEL_REPO}...") url = f"https://huggingface.co/{SEGMENTATION_MODEL_REPO}/resolve/main/{SEGMENTATION_MODEL_FILENAME}" urllib.request.urlretrieve(url, local_path) print(f"Model weights downloaded to {local_path}") return local_path except Exception as e: print(f"Error downloading model weights: {e}") return None # Load the segmentation model weights weights_path = download_model_weights() if weights_path: try: # Try to load with strict=False to allow for some parameter mismatches state_dict = torch.load(weights_path, map_location=device) # Check if we need to modify the state dict keys if all(key.startswith('encoder.') or key.startswith('decoder.') for key in list(state_dict.keys())[:5]): print("Model weights use encoder/decoder format, loading directly") model.load_state_dict(state_dict, strict=False) else: print("Attempting to adapt state dict to match model architecture") # This is a placeholder for state dict adaptation if needed model.load_state_dict(state_dict, strict=False) print("Model weights loaded successfully") except Exception as e: print(f"Error loading model weights: {e}") else: print("No weights available. Model will not produce valid predictions.") model.to(device) model.eval() # Load the cloud detection model def load_cloud_detection_model(): """Load cloud detection model from the local file""" try: # Check if the model file exists model_path = "cloud_detection_lightgbm.joblib" if os.path.exists(model_path): # Load the model cloud_model = joblib.load(model_path) print(f"Cloud detection model loaded successfully from {model_path}") return cloud_model else: print(f"Cloud detection model file not found at {model_path}") return None except Exception as e: print(f"Error loading cloud detection model: {e}") return None # Load the cloud detection model cloud_model = load_cloud_detection_model() if cloud_model: print("Cloud detection model is ready for predictions") else: print("Warning: Cloud detection model could not be loaded") def normalize(band): """Normalize band values using 2-98 percentile range""" # Handle potential NaN or inf values band_cleaned = band[np.isfinite(band)] if len(band_cleaned) == 0: return band # Use percentiles to avoid outliers band_min, band_max = np.percentile(band_cleaned, (2, 98)) # Avoid division by zero if band_max == band_min: return np.zeros_like(band) band_normalized = (band - band_min) / (band_max - band_min) band_normalized = np.clip(band_normalized, 0, 1) return band_normalized def calculate_cv(band): """Calculate coefficient of variation (CV) for a band""" # First normalize the band band_normalized = normalize(band) # Handle potential NaN or inf values band_cleaned = band_normalized[np.isfinite(band_normalized)] if len(band_cleaned) == 0: return 0 # Get mean and std dev mean = np.mean(band_cleaned) # Guard against division by zero or very small means if abs(mean) < 1e-10: return 0 std = np.std(band_cleaned) cv = (std / mean) # CV as ratio (not percentage) return cv def read_tiff_image_for_segmentation(tiff_path): """ Read a TIFF image using rasterio, focusing on RGB bands (first 3 bands) for wetland segmentation """ try: # Read the image using rasterio (get RGB channels) with rasterio.open(tiff_path) as src: # Check if we have enough bands if src.count >= 3: red = src.read(1) green = src.read(2) blue = src.read(3) # Stack to create RGB image image = np.dstack((red, green, blue)).astype(np.float32) # Normalize to [0, 1] if image.max() > 0: image = image / image.max() return image else: # If less than 3 bands, handle accordingly bands = [src.read(i+1) for i in range(src.count)] # If only one band, duplicate to create RGB if len(bands) == 1: image = np.dstack((bands[0], bands[0], bands[0])) else: # Use available bands and pad with zeros if needed while len(bands) < 3: bands.append(np.zeros_like(bands[0])) image = np.dstack(bands[:3]) # Use first 3 bands # Normalize if image.max() > 0: image = image / image.max() return image except Exception as e: print(f"Error reading TIFF file for segmentation: {e}") return None def extract_cloud_features_from_tiff(tiff_path): """ Extract CV features from all bands in a TIFF file for cloud detection. Will try to use up to 10 bands. """ try: with rasterio.open(tiff_path) as src: num_bands = min(src.count, 10) # Use up to 10 bands # Process each band features = {} for i in range(1, num_bands + 1): band = src.read(i) # Calculate coefficient of variation cv_value = calculate_cv(band) # Store feature with name matching the training data features[f'band{i}_cv'] = cv_value # If we have fewer than 10 bands, fill the missing ones with zeros for i in range(num_bands + 1, 11): features[f'band{i}_cv'] = 0.0 return features except Exception as e: print(f"Error extracting cloud features from TIFF: {e}") import traceback traceback.print_exc() return None def extract_cloud_features_from_rgb(image): """ Extract CV features from RGB image for cloud detection. Will use 3 bands and fill the remaining 7 with zeros to match the expected 10 features. """ try: # Make sure image is in float format in range [0,1] if image.dtype != np.float32 and image.dtype != np.float64: image = image.astype(np.float32) if image.max() > 1.0: image = image / 255.0 # Create a dictionary for band CV features features = {} # Process each channel/band (assuming image is H, W, C) num_bands = min(3, image.shape[2]) for i in range(num_bands): band = image[:, :, i] cv_value = calculate_cv(band) features[f'band{i+1}_cv'] = cv_value # Fill remaining bands with zeros to match the expected 10 features for i in range(num_bands + 1, 11): features[f'band{i}_cv'] = 0.0 return features except Exception as e: print(f"Error extracting cloud features from RGB: {e}") import traceback traceback.print_exc() return None def predict_cloud(features_dict, model): """Predict if an image is cloudy based on extracted features""" if model is None: return {'prediction': 'Model unavailable', 'probability': 0.0} try: # Ensure all 10 features from band1_cv to band10_cv are present feature_dict = {} for i in range(1, 11): feature_name = f'band{i}_cv' feature_dict[feature_name] = features_dict.get(feature_name, 0.0) # Create a DataFrame with all required features feature_df = pd.DataFrame([feature_dict]) # Enable shape check disabling for prediction if hasattr(model, 'set_params'): model.set_params(predict_disable_shape_check=True) # Make prediction if hasattr(model, 'predict_proba'): proba = model.predict_proba(feature_df) if proba.shape[1] > 1: # Binary classification with probabilities for both classes probability = proba[0][1] # Probability of the positive class (cloudy) else: probability = proba[0][0] # Single probability output else: # If model doesn't have predict_proba, use predict and assume binary output pred = model.predict(feature_df) probability = float(pred[0]) # Classification based on probability threshold prediction = 'Cloudy' if probability >= 0.5 else 'Non-Cloudy' return { 'prediction': prediction, 'probability': probability } except Exception as e: print(f"Error predicting cloud: {e}") import traceback traceback.print_exc() return {'prediction': 'Error', 'probability': 0.0} def read_tiff_mask(mask_path): """ Read a TIFF mask using rasterio This matches your training data loading approach """ try: # Read mask with rasterio.open(mask_path) as src: mask = src.read(1).astype(np.uint8) return mask except Exception as e: print(f"Error reading mask file: {e}") return None def preprocess_image(image, target_size=(128, 128)): """ Preprocess an image for inference """ # If image is already a numpy array, use it directly if isinstance(image, np.ndarray): # Ensure RGB format if len(image.shape) == 2: # Grayscale image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # RGBA image = image[:, :, :3] # Make a copy for display display_image = image.copy() # Normalize to [0, 1] if needed if display_image.max() > 1.0: image = image.astype(np.float32) / 255.0 display_image = display_image.astype(np.uint8) # Keep display image as uint8 else: # If already normalized, scale up for display display_image = (display_image * 255).astype(np.uint8) # Convert PIL image to numpy elif isinstance(image, Image.Image): image = np.array(image) # Ensure RGB format if len(image.shape) == 2: # Grayscale image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) elif image.shape[2] == 4: # RGBA image = image[:, :, :3] # Make a copy for display display_image = image.copy() # display_image is uint8 here # Normalize to [0, 1] for model image = image.astype(np.float32) / 255.0 else: print(f"Unsupported image type: {type(image)}") return None, None # Resize image to the target size if albumentations_available: # Use albumentations to match training preprocessing aug = A.Compose([ A.PadIfNeeded(min_height=target_size[0], min_width=target_size[1], border_mode=cv2.BORDER_CONSTANT, value=0), A.CenterCrop(height=target_size[0], width=target_size[1]) ]) augmented = aug(image=image) image_resized = augmented['image'] else: # Fallback to OpenCV image_resized = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR) # Convert to tensor [C, H, W] image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0) return image_tensor, display_image def extract_file_content(file_obj): """Extract content from the file object, handling different types""" try: # Handle Gradio File object (which has 'name' attribute pointing to temp path) if hasattr(file_obj, 'name') and isinstance(file_obj.name, str): file_path = file_obj.name if os.path.exists(file_path): with open(file_path, 'rb') as f: return f.read(), file_path # Return path for TIFF reading else: print(f"Temp file path does not exist: {file_path}") return None, None # Handle string path (from gr.Examples) elif isinstance(file_obj, str): if os.path.exists(file_obj): with open(file_obj, 'rb') as f: return f.read(), file_obj # Return path for TIFF reading else: print(f"Provided file path does not exist: {file_obj}") return None, None # Handle BytesIO or other file-like objects (less likely with gr.File/gr.Examples) elif hasattr(file_obj, 'read'): content = file_obj.read() # Need to save to temp file to get a path for rasterio with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_f: temp_path = temp_f.name temp_f.write(content) return content, temp_path # Return path elif isinstance(file_obj, bytes): # Need to save to temp file to get a path for rasterio with tempfile.NamedTemporaryFile(delete=False, suffix='.tmp') as temp_f: temp_path = temp_f.name temp_f.write(file_obj) return file_obj, temp_path # Return path else: print(f"Unsupported file object type: {type(file_obj)}") return None, None except Exception as e: print(f"Error extracting file content: {e}") import traceback traceback.print_exc() return None, None def process_uploaded_tiff(file_obj): """Process an uploaded TIFF file for both segmentation and cloud detection""" temp_file_to_delete = None try: # Get file content and path # We primarily need the path for rasterio _, file_path = extract_file_content(file_obj) if file_path is None: print("Failed to get file path for TIFF processing") return None, None, None # Check if extract_file_content created a temp file we need to manage if file_path.endswith('.tmp'): temp_file_to_delete = file_path # Read as TIFF for segmentation (only using first 3 bands) image_for_segmentation = read_tiff_image_for_segmentation(file_path) # Extract cloud features from all available bands cloud_features = extract_cloud_features_from_tiff(file_path) if image_for_segmentation is None: print("Failed to read TIFF for segmentation.") return None, None, None # Make a copy for display, ensuring uint8 for display if image_for_segmentation.max() <= 1.0 and image_for_segmentation.min() >= 0.0: display_image = (image_for_segmentation * 255).astype(np.uint8) elif image_for_segmentation.dtype == np.uint8: display_image = image_for_segmentation.copy() else: # Attempt normalization for display if out of range display_image = cv2.normalize(image_for_segmentation, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) print("Warning: TIFF image values outside [0,1], normalized for display.") # Resize/preprocess for segmentation model if albumentations_available: aug = A.Compose([ A.PadIfNeeded(min_height=128, min_width=128, border_mode=cv2.BORDER_CONSTANT, value=0), A.CenterCrop(height=128, width=128) ]) # Need float image for albumentations image_float = image_for_segmentation.astype(np.float32) if image_float.max() > 1.0: # Normalize if not already done image_float = image_float / (image_float.max() + 1e-6) augmented = aug(image=image_float) image_resized = augmented['image'] else: image_float = image_for_segmentation.astype(np.float32) if image_float.max() > 1.0: # Normalize if not already done image_float = image_float / (image_float.max() + 1e-6) image_resized = cv2.resize(image_float, (128, 128), interpolation=cv2.INTER_LINEAR) # Convert to tensor image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float().unsqueeze(0) return image_tensor, display_image, cloud_features except Exception as e: print(f"Error processing uploaded TIFF: {e}") import traceback traceback.print_exc() return None, None, None finally: # Clean up temporary file if created by extract_file_content if temp_file_to_delete and os.path.exists(temp_file_to_delete): try: os.unlink(temp_file_to_delete) except Exception as e_del: print(f"Error deleting temp file {temp_file_to_delete}: {e_del}") def process_uploaded_mask(file_obj): """Process an uploaded mask file""" temp_file_to_delete = None try: # Get file content and path _, file_path = extract_file_content(file_obj) if file_path is None: print("Failed to get file path for Mask processing") return None # Check if extract_file_content created a temp file we need to manage if file_path.endswith('.tmp'): temp_file_to_delete = file_path # Check if it's a TIFF file if file_path.lower().endswith(('.tif', '.tiff')): mask = read_tiff_mask(file_path) else: # Try to open as a regular image try: mask_img = Image.open(file_path).convert('L') # Convert to grayscale mask = np.array(mask_img) except Exception as e: print(f"Error opening mask as regular image: {e}") return None if mask is None: print("Failed to read mask data.") return None # Resize mask to 128x128 if albumentations_available: aug = A.Compose([ A.PadIfNeeded(min_height=128, min_width=128, border_mode=cv2.BORDER_CONSTANT, value=0), A.CenterCrop(height=128, width=128) ]) augmented = aug(image=mask) # Use 'image' key even for mask mask_resized = augmented['image'] else: mask_resized = cv2.resize(mask, (128, 128), interpolation=cv2.INTER_NEAREST) # Binarize the mask (0: background, 1: wetland) # Use a threshold (e.g., 127 for typical grayscale) or >0 if it's already somewhat binary threshold = np.median(mask_resized[mask_resized > 0]) if np.any(mask_resized > 0) else 1 mask_binary = (mask_resized >= threshold).astype(np.uint8) print(f"Mask binarized. Original min/max: {mask.min()}/{mask.max()}, Resized min/max: {mask_resized.min()}/{mask_resized.max()}, Binary sum: {mask_binary.sum()}") return mask_binary except Exception as e: print(f"Error processing uploaded mask: {e}") import traceback traceback.print_exc() return None finally: # Clean up temporary file if created by extract_file_content if temp_file_to_delete and os.path.exists(temp_file_to_delete): try: os.unlink(temp_file_to_delete) except Exception as e_del: print(f"Error deleting temp mask file {temp_file_to_delete}: {e_del}") def predict_segmentation(image_tensor): """ Run inference on the model """ try: image_tensor = image_tensor.to(device) with torch.no_grad(): output = model(image_tensor) # Handle different model output formats if isinstance(output, dict): output = output['out'] if output.shape[1] > 1: # Multi-class output pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy() else: # Binary output (from smp models) pred = (torch.sigmoid(output) > 0.5).squeeze().cpu().numpy().astype(np.uint8) return pred except Exception as e: print(f"Error during prediction: {e}") return None def calculate_metrics(pred_mask, gt_mask): """ Calculate evaluation metrics between prediction and ground truth """ if pred_mask is None or gt_mask is None: print("Cannot calculate metrics: Invalid masks provided.") return {} if pred_mask.shape != gt_mask.shape: print(f"Cannot calculate metrics: Shape mismatch - Pred {pred_mask.shape}, GT {gt_mask.shape}") # Optionally resize one to match the other, e.g., resize GT to Pred shape gt_mask = cv2.resize(gt_mask, (pred_mask.shape[1], pred_mask.shape[0]), interpolation=cv2.INTER_NEAREST) print(f"Resized GT mask to {gt_mask.shape}") # Ensure binary masks pred_binary = (pred_mask > 0).astype(np.uint8) gt_binary = (gt_mask > 0).astype(np.uint8) # Calculate intersection and union intersection = np.logical_and(pred_binary, gt_binary).sum() union = np.logical_or(pred_binary, gt_binary).sum() # Calculate IoU iou = intersection / union if union > 0 else 0 # Calculate precision and recall true_positive = intersection pred_positive_count = pred_binary.sum() gt_positive_count = gt_binary.sum() false_positive = pred_positive_count - true_positive false_negative = gt_positive_count - true_positive precision = true_positive / pred_positive_count if pred_positive_count > 0 else 0 recall = true_positive / gt_positive_count if gt_positive_count > 0 else 0 # Calculate F1 score f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 metrics = { "Wetlands IoU": float(iou), "Precision": float(precision), "Recall": float(recall), "F1 Score": float(f1) } return metrics def process_images(input_image=None, input_tiff=None, gt_mask_file=None): """ Process input images, generate predictions for both wetland segmentation and cloud detection """ image_tensor = None display_image = None cloud_features = None pred_mask = None gt_mask_processed = None result_image = None result_text = "Processing..." # Initial message try: # Determine input type and process if input_tiff is not None: print("Processing TIFF input...") image_tensor, display_image, cloud_features = process_uploaded_tiff(input_tiff) if image_tensor is None: return None, "Failed to process the input TIFF file." print("TIFF processing complete.") elif input_image is not None: print("Processing Image input...") image_tensor, display_image = preprocess_image(input_image) if image_tensor is None: return None, "Failed to process the input image." # For RGB images, extract cloud features separately print("Extracting cloud features from RGB...") cloud_features = extract_cloud_features_from_rgb(display_image) # Use display_image (uint8) print("Image processing complete.") else: return None, "Please upload an image or TIFF file." # --- Perform Predictions --- # Get wetland segmentation prediction print("Performing wetland segmentation...") pred_mask = predict_segmentation(image_tensor) if pred_mask is None: return None, "Failed to generate wetland segmentation prediction." print(f"Segmentation prediction generated, shape: {pred_mask.shape}, type: {pred_mask.dtype}") # Get cloud prediction cloud_result = {'prediction': 'Unknown', 'probability': 0.0} if cloud_features and cloud_model: print("Performing cloud detection...") cloud_result = predict_cloud(cloud_features, cloud_model) print(f"Cloud detection result: {cloud_result}") elif cloud_model is None: print("Cloud detection model not available.") else: print("Cloud features not extracted, skipping cloud detection.") # --- Process Ground Truth and Metrics (if provided) --- metrics_text = "" if gt_mask_file is not None: print("Processing ground truth mask...") gt_mask_processed = process_uploaded_mask(gt_mask_file) if gt_mask_processed is not None: print(f"Ground truth mask processed, shape: {gt_mask_processed.shape}, type: {gt_mask_processed.dtype}") print("Calculating metrics...") metrics = calculate_metrics(pred_mask, gt_mask_processed) metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) print(f"Metrics calculated: {metrics_text}") else: print("Failed to process ground truth mask.") metrics_text = "Ground truth mask provided but could not be processed." # --- Create Visualization --- print("Creating result visualization...") fig, axes = plt.subplots(1, 3 if gt_mask_processed is not None else 2, figsize=(12, 5)) fig.suptitle("Analysis Results", fontsize=16) # Ensure display_image is in correct format for imshow if display_image.dtype != np.uint8: display_image_for_plot = (display_image * 255).astype(np.uint8) if display_image.max() <=1 else display_image.astype(np.uint8) else: display_image_for_plot = display_image # Resize display image to match mask size for consistent display if needed, or keep original? # Let's keep original input size for clarity, masks are 128x128 display_image_resized = cv2.resize(display_image_for_plot, (512, 512), interpolation=cv2.INTER_LINEAR) ax_idx = 0 axes[ax_idx].imshow(display_image_resized) axes[ax_idx].set_title("Input Image (Resized for Display)") axes[ax_idx].axis('off') ax_idx += 1 if gt_mask_processed is not None: axes[ax_idx].imshow(gt_mask_processed, cmap='viridis') # Use viridis for GT axes[ax_idx].set_title("Ground Truth (128x128)") axes[ax_idx].axis('off') ax_idx += 1 # Ensure pred_mask is suitable for imshow (e.g., scale if needed, but should be 0/1) axes[ax_idx].imshow(pred_mask, cmap='viridis') # Use plasma for Prediction axes[ax_idx].set_title("Predicted Wetlands (128x128)") axes[ax_idx].axis('off') # Calculate wetland percentage from prediction wetland_percentage = np.mean(pred_mask) * 100 # --- Format Output Text --- result_text = f"--- Analysis Summary ---\n" result_text += f"Wetland Coverage (Predicted): {wetland_percentage:.2f}%\n\n" # Add cloud detection results result_text += f"Cloud Detection: {cloud_result['prediction']} " result_text += f"({cloud_result['probability']*100:.2f}% Cloud probability)\n\n" # Add segmentation metrics if available if metrics_text: if "could not be processed" in metrics_text: result_text += metrics_text # Add the error message else: result_text += f"--- Evaluation Metrics (vs Ground Truth) ---\n{metrics_text}" elif gt_mask_file is not None: result_text += "--- Evaluation Metrics ---\nGround truth provided but metrics could not be calculated." # Convert figure to image for display plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap buf = BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) result_image = Image.open(buf) plt.close(fig) print("Visualization complete.") return result_image, result_text except Exception as e: print(f"Error in process_images: {e}") import traceback traceback.print_exc() # Return current state if partial results exist, otherwise error error_message = f"An error occurred during processing: {str(e)}" return result_image if result_image else None, result_text + f"\n\nERROR: {error_message}" # --- Define Example Files --- # Ensure these paths are correct relative to where app.py is run # These should be at the root of your Gradio Space repository example_list = [ ["test_p1_cloudy_input.tif", "test_p1_cloudy_output.tif"], ["test_p1_noncloudy_input.tif", "test_p1_noncloudy_output.tif"], ["test_p1_cloudy_input.tif", None], # Example without ground truth ["test_p1_noncloudy_input.tif", None], # Example without ground truth ] # --- Create Gradio Interface --- with gr.Blocks(title="Wetlands Segmentation & Cloud Detection") as demo: gr.Markdown("# Wetlands Segmentation & Cloud Detection from Satellite Imagery") gr.Markdown("Upload a satellite image or TIFF file to identify wetland areas and detect cloud cover. Optionally, you can also upload a ground truth mask for evaluation.") with gr.Row(): with gr.Column(scale=1): # Input column gr.Markdown("### Input") with gr.Tabs(): with gr.TabItem("Upload Image"): input_image = gr.Image(label="Upload Satellite Image (JPG, PNG etc.)", type="numpy") with gr.TabItem("Upload TIFF"): input_tiff = gr.File(label="Upload Multi-Band TIFF File", file_types=[".tif", ".tiff"], type="filepath") # Use filepath for easier handling # Ground truth mask as file upload gt_mask_file = gr.File(label="Ground Truth Mask (Optional)", file_types=[".tif", ".tiff", ".png", ".jpg", ".jpeg"], type="filepath") # --- Add Examples Section --- gr.Markdown("### Load Examples") gr.Examples( examples=example_list, inputs=[input_tiff, gt_mask_file], # Corresponds to TIFF input and GT Mask input label="Click an example below to load files:", # Outputs are not needed here, examples just populate inputs # examples_per_page=4 # Optional: control pagination ) # --- End Examples Section --- process_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(scale=2): # Output column (make it wider) gr.Markdown("### Results") output_image = gr.Image(label="Segmentation Results", type="pil", height=450) # Adjust height as needed output_text = gr.Textbox(label="Statistics & Metrics", lines=10, scale=1) # Adjust lines and scale # Information about the models gr.Markdown("### About these models") gr.Markdown(""" This application uses two models: **1. Wetland Segmentation Model:** - Architecture: DeepLabv3+ with ResNet-34 - Input: RGB satellite imagery (extracted from first 3 bands of TIFF if provided) - Output: Binary segmentation mask (Wetland vs Background) - Resolution: Processed at 128×128 pixels **2. Cloud Detection Model:** - Architecture: LightGBM Classifier - Input: Coefficient of Variation (CV) features extracted from up to 10 image bands (from TIFF) - Output: Binary classification (Cloudy vs Non-Cloudy) with probability **Tips for best results:** - Use the 'Upload TIFF' tab for multi-band satellite data to enable accurate cloud detection. - The cloud detection model expects up to 10 bands. Performance may vary with fewer bands. - The example files demonstrate cloudy/non-cloudy scenarios with corresponding ground truth. - The models work best with images similar in characteristics to those used in training. - For ground truth masks, both TIFF and standard image formats (PNG, JPG) are supported. Ensure the mask clearly delineates the target class. **Repository:** [dcrey7/wetland_segmentation_deeplabsv3plus](https://huggingface.co/spaces/dcrey7/wetland_segmentation_deeplabsv3plus) """) # Set up event handlers process_btn.click( fn=process_images, inputs=[input_image, input_tiff, gt_mask_file], outputs=[output_image, output_text], api_name="analyze" # Optional: name for API endpoint ) # Clear inputs when switching tabs (optional but good UX) def clear_other_input(selected_tab): if selected_tab == 0: # "Upload Image" tab selected return gr.update(value=None), gr.update() # Clear TIFF input elif selected_tab == 1: # "Upload TIFF" tab selected return gr.update(), gr.update(value=None) # Clear Image input return gr.update(), gr.update() # Default case # Assuming the Tabs component itself can be accessed or named, # otherwise, this part might need adjustment based on Gradio's specific API for Tabs. # If Tabs don't directly support change events easily, this part can be omitted. # For now, let's assume tab switching doesn't automatically clear, user needs to manage inputs. # Launch the app if __name__ == "__main__": demo.launch(debug=True) # Enable debug for more detailed logs during development