import streamlit as st import os import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.models import load_model import pathlib import natsort import datetime import shutil import rasterio import cv2 import tensorflow as tf import tempfile from rasterio import features from shapely.geometry import shape import geopandas as gpd import zipfile # Configuration HEIGHT = WIDTH = 256 SAR_SHAPE = (HEIGHT, WIDTH, 1) OPTIC_SHAPE = (HEIGHT, WIDTH, 3) MASK_SHAPE = (HEIGHT, WIDTH, 4) # One-hot encoded masks with 4 classes # Class colors: non-mining (green), illegal mining (red), beach (black) CLASS_COLORS = [ [115/255, 178/255, 115/255], [1, 0, 0], [0, 0, 0] ] @tf.keras.saving.register_keras_serializable() def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0): #determine binary or multiclass segmentation is_multiclass = y_true.shape[-1] > 1 if not is_multiclass: # Binary segmentation y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32) y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32) intersection = tf.reduce_sum(y_true_flat * y_pred_flat) score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth) return score else: # Multiclass segmentation num_classes = y_true.shape[-1] score_per_class = [] for i in range(num_classes): y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32) y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32) intersection = tf.reduce_sum(y_true_flat * y_pred_flat) score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth) score_per_class.append(score) return tf.reduce_mean(score_per_class) @tf.keras.saving.register_keras_serializable() def dice_loss(y_true, y_pred): dice = dice_score(y_true, y_pred) loss = 1. - dice return tf.cast(loss, dtype=tf.float32) @tf.keras.saving.register_keras_serializable() def cce_dice_loss(y_true, y_pred): cce = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred) dice = dice_loss(y_true, y_pred) return tf.cast(cce, dtype=tf.float32) + dice def convertColorToLabel(img): color_to_label = { (115, 178, 115): 0, # non_mining_land (green) (255, 0, 0): 1, # illegal_mining_land (red) (0, 0, 0): 2, # beach (black) } # Create empty label array label_img = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) # Map each RGB color to its corresponding label for color, label in color_to_label.items(): mask = np.all(img == color, axis=2) label_img[mask] = label # One-hot encode the label image num_classes = len(color_to_label) one_hot = np.zeros((img.shape[0], img.shape[1], num_classes), dtype=np.uint8) for c in range(num_classes): one_hot[:, :, c] = (label_img == c).astype(np.uint8) return one_hot def readImages(data, typeData, width, height): images = [] for img in data: if typeData == 's': # SAR image with rasterio.open(str(img)) as src: sar_bands = [src.read(i) for i in range(1, src.count + 1)] sar_image = np.stack(sar_bands, axis=-1) # Contrast stretching p2, p98 = np.percentile(sar_image, (2, 98)) sar_image = np.clip(sar_image, p2, p98) sar_image = ((sar_image - p2) / (p98 - p2) * 255).astype(np.uint8) # Resize sar_image = cv2.resize(sar_image, (width, height), interpolation=cv2.INTER_AREA) images.append(np.expand_dims(sar_image, axis=-1)) elif typeData == 'm': # Mask image img = cv2.imread(str(img), cv2.IMREAD_COLOR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (width, height), interpolation=cv2.INTER_NEAREST) images.append(convertColorToLabel(img)) elif typeData == 'o': # Optic image img = cv2.imread(str(img), cv2.IMREAD_COLOR) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) images.append(img) print(f"(INFO..) Read {len(images)} '{typeData}' image(s)") return np.array(images) def normalizeImages(images, typeData): normalized_images = [] for img in images: img = img.astype(np.uint8) if typeData in ['s', 'o']: img = img / 255. normalized_images.append(img) print("(INFO..) Normalization Image Done") return np.array(normalized_images) def save_uploaded_file(uploaded_file, suffix=".tif"): with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(uploaded_file.read()) return tmp.name # Streamlit App Title st.title("Satellite Mining Segmentation: SAR + Optic Image Inference") sar_file = st.file_uploader("Upload SAR Image", type=["tiff"]) optic_file = st.file_uploader("Upload Optical Image", type=["tiff"]) mask_file = st.file_uploader("Upload Mask Image", type=["tiff"]) wiup_file = st.file_uploader("Upload WIUP Boundary (Shapefile ZIP)", type=["zip"]) num_samples = 1 if st.button("Run Inference"): with st.spinner("Loading data and model..."): if sar_file is not None and optic_file is not None and mask_file is not None and wiup_file is not None: st.success("All files uploaded successfully!") # Save uploaded files sar_path = save_uploaded_file(sar_file, suffix=".tif") optic_path = save_uploaded_file(optic_file, suffix=".tif") mask_path = save_uploaded_file(mask_file, suffix=".tif") wiup_zip_path = save_uploaded_file(wiup_file, ".zip") extract_folder = wiup_zip_path.replace(".zip", "") with zipfile.ZipFile(wiup_zip_path, "r") as zip_ref: zip_ref.extractall(extract_folder) # Load WIUP shapefile wiup_gdf = gpd.read_file(extract_folder) # Create image lists sarImages = [sar_path] opticImages = [optic_path] masks = [mask_path] # Model path model_path = "Residual_UNET_Bilinear.keras" # Read and normalize images sar_images = readImages(sarImages, typeData='s', width=WIDTH, height=HEIGHT) optic_images = readImages(opticImages, typeData='o', width=WIDTH, height=HEIGHT) masks = readImages(masks, typeData='m', width=WIDTH, height=HEIGHT) sar_images = normalizeImages(sar_images, 's') optic_images = normalizeImages(optic_images, 'i') # Load model model = tf.keras.models.load_model( model_path, custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score} ) # Predict masks pred_masks = model.predict([optic_images, sar_images], verbose=0) is_multiclass = pred_masks.shape[-1] > 1 num_samples = min(num_samples, len(sar_images)) # Plotting results fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples)) for i in range(num_samples): ax = axes[i] if num_samples > 1 else axes # Plot SAR image ax[0].imshow(sar_images[i].squeeze(), cmap='gray') ax[0].set_title(f"SAR Image {i+1}") ax[0].axis('off') # Plot Optic image ax[1].imshow(optic_images[i]) ax[1].set_title(f"Optic Image {i+1}") ax[1].axis('off') # Plot Ground Truth Mask if is_multiclass: gt_color_mask = np.zeros((*masks[i].shape[:2], 3)) for j, color in enumerate(CLASS_COLORS): gt_color_mask += masks[i][:, :, j][:, :, np.newaxis] * np.array(color) ax[2].imshow(gt_color_mask) else: ax[2].imshow(masks[i], cmap='gray') ax[2].set_title(f"Ground Truth Mask {i+1}") ax[2].axis('off') # Plot Predicted Mask if is_multiclass: pred_color_mask = np.zeros((*pred_masks[i].shape[:2], 3)) for j, color in enumerate(CLASS_COLORS): pred_color_mask += pred_masks[i][:, :, j][:, :, np.newaxis] * np.array(color) ax[3].imshow(pred_color_mask) else: ax[3].imshow(pred_masks[i], cmap='gray') ax[3].set_title(f"Predicted Mask {i+1}") ax[3].axis('off') st.pyplot(fig) # Define color for class 1: illegal mining red_color = [255, 0, 0] # Convert optic_images to uint8 if needed if optic_images.dtype != np.uint8: optic_images = (optic_images * 255).astype(np.uint8) # Plot overlays fig, axes = plt.subplots(num_samples, 4, figsize=(21, 6 * num_samples)) for i in range(num_samples): ax = axes[i] if num_samples > 1 else axes # SAR image ax[0].imshow(sar_images[i].squeeze(), cmap='gray') ax[0].set_title(f"SAR Image {i+1}") ax[0].axis('off') # Optic image ax[1].imshow(optic_images[i]) ax[1].set_title(f"Optic Image {i+1}") ax[1].axis('off') # Ground truth overlay gt_overlay = optic_images[i].copy() if is_multiclass: gt_overlay[masks[i][:, :, 1] == 1] = red_color else: gt_overlay[masks[i].squeeze() == 1] = red_color ax[2].imshow(optic_images[i]) ax[2].imshow(gt_overlay, alpha=0.4) ax[2].set_title(f"Ground Truth Overlay {i+1}") ax[2].axis('off') # Predicted mask overlay pred_overlay = optic_images[i].copy() if is_multiclass: pred_overlay[pred_masks[i][:, :, 1] > 0.5] = red_color else: pred_overlay[pred_masks[i].squeeze() > 0.5] = red_color ax[3].imshow(optic_images[i]) ax[3].imshow(pred_overlay, alpha=0.4) ax[3].set_title(f"Predicted Overlay {i+1}") ax[3].axis('off') plt.tight_layout() st.pyplot(fig) wiup_mask = features.rasterize( [(geom, 1) for geom in wiup_gdf.geometry], out_shape=(HEIGHT, WIDTH), transform=transform, fill=0, dtype=np.uint8 ) # Binary mask of predicted illegal mining pred_illegal_mask = illegal_mask.astype(np.uint8) # Mining outside WIUP outside_mask = (pred_illegal_mask == 1) & (wiup_mask == 0) outside_percentage = 100 * np.sum(outside_mask) / np.sum(pred_illegal_mask) st.markdown(f"### 🚨 Illegal Mining Outside WIUP: `{outside_percentage:.2f}%`") fig2, ax2 = plt.subplots(figsize=(10, 6)) ax2.imshow(outside_mask, cmap='Reds') ax2.set_title("Illegal Mining Outside WIUP Boundary") ax2.axis('off') st.pyplot(fig2) else: st.warning("Please upload all three .tiff files to proceed.")