Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from huggingface_hub import snapshot_download | |
| import rasterio | |
| from rasterio.enums import Resampling | |
| from rasterio.plot import reshape_as_image | |
| import sys | |
| # Download the entire repository to a subdirectory | |
| repo_id = "truthdotphd/cloud-detection" | |
| repo_subdir = "." | |
| repo_dir = snapshot_download(repo_id=repo_id, local_dir=repo_subdir) | |
| # Add the repository directory to the Python path | |
| sys.path.append(repo_dir) | |
| # Import the necessary functions from the downloaded modules | |
| try: | |
| from omnicloudmask import predict_from_array | |
| except ImportError: | |
| omnicloudmask_dir = os.path.join(repo_dir, "omnicloudmask") | |
| if os.path.exists(omnicloudmask_dir): | |
| sys.path.append(omnicloudmask_dir) | |
| from omnicloudmask import predict_from_array | |
| else: | |
| raise ImportError("Could not find the omnicloudmask module in the downloaded repository") | |
| def visualize_rgb(red_file, green_file, blue_file, nir_file): | |
| """ | |
| Create and display an RGB visualization immediately after images are uploaded. | |
| """ | |
| if not all([red_file, green_file, blue_file, nir_file]): | |
| return None | |
| try: | |
| # Get dimensions from red band to use for resampling | |
| with rasterio.open(red_file) as src: | |
| target_height = src.height | |
| target_width = src.width | |
| # Load bands | |
| blue_data = load_band(blue_file) | |
| green_data = load_band(green_file) | |
| red_data = load_band(red_file) | |
| # Compute max values for each channel for dynamic normalization | |
| red_max = np.max(red_data) | |
| green_max = np.max(green_data) | |
| blue_max = np.max(blue_data) | |
| # Create RGB image for visualization with dynamic normalization | |
| rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32) | |
| # Normalize each channel individually | |
| epsilon = 1e-10 | |
| rgb_image[:, :, 0] = red_data / (red_max + epsilon) | |
| rgb_image[:, :, 1] = green_data / (green_max + epsilon) | |
| rgb_image[:, :, 2] = blue_data / (blue_max + epsilon) | |
| # Clip values to 0-1 range | |
| rgb_image = np.clip(rgb_image, 0, 1) | |
| # Apply contrast enhancement for better visualization | |
| p2 = np.percentile(rgb_image, 2) | |
| p98 = np.percentile(rgb_image, 98) | |
| rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1) | |
| # Convert to uint8 for display | |
| rgb_display = (rgb_image_enhanced * 255).astype(np.uint8) | |
| return rgb_display | |
| except Exception as e: | |
| print(f"Error generating RGB preview: {e}") | |
| return None | |
| def visualize_jp2(file_path): | |
| """ | |
| Visualize a single JP2 file. | |
| """ | |
| with rasterio.open(file_path) as src: | |
| # Read the data | |
| data = src.read(1) | |
| # Normalize the data for visualization | |
| data = (data - np.min(data)) / (np.max(data) - np.min(data)) | |
| # Apply a colormap for better visualization | |
| cmap = plt.get_cmap('viridis') | |
| colored_image = cmap(data) | |
| # Convert to 8-bit for display | |
| return (colored_image[:, :, :3] * 255).astype(np.uint8) | |
| def load_band(file_path, resample=False, target_height=None, target_width=None): | |
| """ | |
| Load a single band from a raster file with optional resampling. | |
| """ | |
| with rasterio.open(file_path) as src: | |
| if resample and target_height is not None and target_width is not None: | |
| band_data = src.read( | |
| out_shape=(src.count, target_height, target_width), | |
| resampling=Resampling.bilinear | |
| )[0].astype(np.float32) | |
| else: | |
| band_data = src.read()[0].astype(np.float32) | |
| return band_data | |
| def prepare_input_array(red_file, green_file, blue_file, nir_file): | |
| """ | |
| Prepare a stacked array of satellite bands for cloud mask prediction. | |
| """ | |
| # Get dimensions from red band to use for resampling | |
| with rasterio.open(red_file) as src: | |
| target_height = src.height | |
| target_width = src.width | |
| # Load bands (resample NIR band to match 10m resolution) | |
| blue_data = load_band(blue_file) | |
| green_data = load_band(green_file) | |
| red_data = load_band(red_file) | |
| nir_data = load_band( | |
| nir_file, | |
| resample=True, | |
| target_height=target_height, | |
| target_width=target_width | |
| ) | |
| # Print band shapes for debugging | |
| print(f"Band shapes - Blue: {blue_data.shape}, Green: {green_data.shape}, Red: {red_data.shape}, NIR: {nir_data.shape}") | |
| # Compute max values for each channel for dynamic normalization | |
| red_max = np.max(red_data) | |
| green_max = np.max(green_data) | |
| blue_max = np.max(blue_data) | |
| print(f"Max values - Red: {red_max}, Green: {green_max}, Blue: {blue_max}") | |
| # Create RGB image for visualization with dynamic normalization | |
| rgb_image = np.zeros((red_data.shape[0], red_data.shape[1], 3), dtype=np.float32) | |
| # Normalize each channel individually | |
| # Add a small epsilon to avoid division by zero | |
| epsilon = 1e-10 | |
| rgb_image[:, :, 0] = red_data / (red_max + epsilon) | |
| rgb_image[:, :, 1] = green_data / (green_max + epsilon) | |
| rgb_image[:, :, 2] = blue_data / (blue_max + epsilon) | |
| # Clip values to 0-1 range | |
| rgb_image = np.clip(rgb_image, 0, 1) | |
| # Optional: Apply contrast enhancement for better visualization | |
| p2 = np.percentile(rgb_image, 2) | |
| p98 = np.percentile(rgb_image, 98) | |
| rgb_image_enhanced = np.clip((rgb_image - p2) / (p98 - p2), 0, 1) | |
| # Stack bands in CHW format for cloud mask prediction (red, green, nir) | |
| prediction_array = np.stack([red_data, green_data, nir_data], axis=0) | |
| return prediction_array, rgb_image_enhanced | |
| def visualize_cloud_mask(rgb_image, pred_mask): | |
| """ | |
| Create a visualization of the cloud mask overlaid on the RGB image. | |
| """ | |
| # Ensure pred_mask has the right dimensions | |
| if pred_mask.ndim > 2: | |
| pred_mask = np.squeeze(pred_mask) | |
| print(f"RGB image shape: {rgb_image.shape}, Pred mask shape: {pred_mask.shape}") | |
| # Ensure mask has the same spatial dimensions as the image | |
| if pred_mask.shape != rgb_image.shape[:2]: | |
| pred_mask = cv2.resize( | |
| pred_mask.astype(np.float32), | |
| (rgb_image.shape[1], rgb_image.shape[0]), | |
| interpolation=cv2.INTER_NEAREST | |
| ).astype(np.uint8) | |
| print(f"Resized mask shape: {pred_mask.shape}") | |
| # Define colors for each class | |
| colors = { | |
| 0: [0, 255, 0], # Clear - Green | |
| 1: [255, 255, 255], # Thick Cloud - White | |
| 2: [200, 200, 200], # Thin Cloud - Light Gray | |
| 3: [100, 100, 100] # Cloud Shadow - Dark Gray | |
| } | |
| # Create a color-coded mask | |
| mask_vis = np.zeros((pred_mask.shape[0], pred_mask.shape[1], 3), dtype=np.uint8) | |
| for class_idx, color in colors.items(): | |
| mask_vis[pred_mask == class_idx] = color | |
| # Create a blended visualization | |
| alpha = 0.5 | |
| blended = cv2.addWeighted((rgb_image * 255).astype(np.uint8), 1-alpha, mask_vis, alpha, 0) | |
| # Get the width of the blended image for the legend | |
| image_width = blended.shape[1] | |
| # Create a legend with the same width as the image | |
| legend = np.ones((100, image_width, 3), dtype=np.uint8) * 255 | |
| legend_text = ["Clear", "Thick Cloud", "Thin Cloud", "Cloud Shadow"] | |
| legend_colors = [colors[i] for i in range(4)] | |
| for i, (text, color) in enumerate(zip(legend_text, legend_colors)): | |
| cv2.rectangle(legend, (10, 10 + i*20), (30, 30 + i*20), color, -1) | |
| cv2.putText(legend, text, (40, 25 + i*20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) | |
| # Combine image and legend | |
| final_output = np.vstack([blended, legend]) | |
| return final_output | |
| def process_satellite_images(red_file, green_file, blue_file, nir_file, batch_size, patch_size, patch_overlap): | |
| """ | |
| Process the satellite images and detect clouds. | |
| """ | |
| if not all([red_file, green_file, blue_file, nir_file]): | |
| return None, None, "Please upload all four channel files (Red, Green, Blue, NIR)" | |
| # Prepare input array and RGB image for visualization | |
| input_array, rgb_image = prepare_input_array(red_file, green_file, blue_file, nir_file) | |
| # Convert RGB image to format suitable for display | |
| rgb_display = (rgb_image * 255).astype(np.uint8) | |
| # Predict cloud mask using omnicloudmask | |
| pred_mask = predict_from_array( | |
| input_array, | |
| batch_size=batch_size, | |
| patch_size=patch_size, | |
| patch_overlap=patch_overlap | |
| ) | |
| # Calculate class distribution | |
| if pred_mask.ndim > 2: | |
| flat_mask = np.squeeze(pred_mask) | |
| else: | |
| flat_mask = pred_mask | |
| clear_pixels = np.sum(flat_mask == 0) | |
| thick_cloud_pixels = np.sum(flat_mask == 1) | |
| thin_cloud_pixels = np.sum(flat_mask == 2) | |
| cloud_shadow_pixels = np.sum(flat_mask == 3) | |
| total_pixels = flat_mask.size | |
| stats = f""" | |
| Cloud Mask Statistics: | |
| - Clear: {clear_pixels} pixels ({clear_pixels/total_pixels*100:.2f}%) | |
| - Thick Cloud: {thick_cloud_pixels} pixels ({thick_cloud_pixels/total_pixels*100:.2f}%) | |
| - Thin Cloud: {thin_cloud_pixels} pixels ({thin_cloud_pixels/total_pixels*100:.2f}%) | |
| - Cloud Shadow: {cloud_shadow_pixels} pixels ({cloud_shadow_pixels/total_pixels*100:.2f}%) | |
| - Total Cloud Cover: {(thick_cloud_pixels + thin_cloud_pixels)/total_pixels*100:.2f}% | |
| """ | |
| # Visualize the cloud mask on the original image | |
| visualization = visualize_cloud_mask(rgb_image, flat_mask) | |
| return rgb_display, visualization, stats | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=process_satellite_images, | |
| inputs=[ | |
| gr.Image(type="filepath", label="Red Channel (JP2)"), | |
| gr.Image(type="filepath", label="Green Channel (JP2)"), | |
| gr.Image(type="filepath", label="Blue Channel (JP2)"), | |
| gr.Image(type="filepath", label="NIR Channel (JP2)"), | |
| gr.Slider(minimum=1, maximum=32, value=1, step=1, label="Batch Size", info="Higher values use more memory but process faster"), | |
| gr.Slider(minimum=500, maximum=2000, value=1000, step=100, label="Patch Size", info="Size of image patches for processing"), | |
| gr.Slider(minimum=100, maximum=500, value=300, step=50, label="Patch Overlap", info="Overlap between patches to avoid edge artifacts") | |
| ], | |
| outputs=[ | |
| gr.Image(label="Original RGB Image"), | |
| gr.Image(label="Cloud Detection Visualization"), | |
| gr.Textbox(label="Statistics") | |
| ], | |
| title="Satellite Cloud Detection", | |
| description=""" | |
| Upload separate JP2 files for Red, Green, Blue, and NIR channels to detect clouds in satellite imagery. | |
| This application uses the OmniCloudMask model to classify each pixel as: | |
| - Clear (0) | |
| - Thick Cloud (1) | |
| - Thin Cloud (2) | |
| - Cloud Shadow (3) | |
| The model works best with imagery at 10-50m resolution. For higher resolution imagery, downsampling is recommended. | |
| """, | |
| examples=[ | |
| ["jp2s/B04.jp2", "jp2s/B03.jp2", "jp2s/B02.jp2", "jp2s/B8A.jp2", 1, 1000, 300] | |
| ] | |
| ) | |
| # Launch the app | |
| demo.launch(debug=True) | |