dcrey7's picture
Update README.md
adb5d34 verified
metadata
language: en
license: apache-2.0
library_name: pytorch
tags:
  - image-segmentation
  - wetlands
  - satellite-imagery
  - environmental
  - deeplabv3plus
datasets:
  - custom
metrics:
  - iou
  - precision
  - recall
  - f1

Wetlands Segmentation Model (DeepLabv3+)

Model Description

This repository contains a DeepLabv3+ model trained for wetlands segmentation from satellite imagery. The model is designed to identify wetland areas in multi-band satellite images, which is crucial for environmental monitoring, conservation planning, and climate change studies.

Model Architecture

  • Base Architecture: DeepLabv3+ with ResNet-50 backbone
  • Input: Multi-band satellite imagery (focusing on RGB bands)
  • Output: Binary segmentation mask (Wetland vs Background)
  • Resolution: 128Γ—128 pixels

Use Cases

  • Environmental monitoring of wetland regions
  • Land use and land cover change analysis
  • Conservation planning and management
  • Climate change impact assessment
  • Hydrological modeling

Training Data

The model was trained on a dataset of satellite imagery patches containing wetland regions. Each patch is 128Γ—128 pixels and includes multiple spectral bands.

Dataset Structure

patches_data_allbands/
β”œβ”€β”€ train/
β”‚   β”œβ”€β”€ input/  # Satellite image patches (.tif)
β”‚   └── output/ # Segmentation masks (.tif)
β”œβ”€β”€ val/
β”‚   β”œβ”€β”€ input/
β”‚   └── output/
└── test/
    β”œβ”€β”€ input/
    └── output/

Data Preprocessing

  • Each TIF image contains multiple spectral bands
  • For this model, RGB bands (bands 1, 2, 3) were extracted
  • Images were normalized to the range [0, 1]
  • Masks were converted to binary format (0 = background, 1 = wetland)

Data Augmentation

The following augmentations were applied during training:

  • Random horizontal flips (p=0.5)
  • Random vertical flips (p=0.5)
  • Random 90-degree rotations (p=0.5)
  • Padding to ensure 128Γ—128 dimensions
  • Random cropping to maintain consistent size

Performance Metrics

The model was evaluated using the following metrics:

Training Set

  • Average IoU: 0.2472
  • Background IoU: 0.9379
  • Wetland IoU: 0.2450
  • Mean IoU: 0.5915
  • Precision: 0.2620
  • Recall: 0.7908
  • F1 Score: 0.3936

Validation Set

  • Average IoU: 0.0489
  • Background IoU: 0.9515
  • Wetland IoU: 0.0481
  • Mean IoU: 0.4998
  • Precision: 0.0533
  • Recall: 0.3313
  • F1 Score: 0.0918

Test Set

  • Average IoU: 0.1550
  • Background IoU: 0.8977
  • Wetland IoU: 0.1558
  • Mean IoU: 0.5267
  • Precision: 0.1720
  • Recall: 0.6229
  • F1 Score: 0.2695

Known Limitations

  • The model shows signs of overfitting, with significantly better performance on the training set compared to validation and test sets
  • Limited to RGB bands analysis (future work could incorporate more spectral bands)
  • Performance varies based on the quality and resolution of input imagery
  • Binary segmentation only (wetland vs. non-wetland)

Usage

Here's how to use the model for inference:

import torch
from torchvision import transforms
import rasterio
import numpy as np
from model import DeepLabv3Plus  # Import your model architecture

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepLabv3Plus(num_classes=2)  # Adjust based on your model configuration
model.load_state_dict(torch.load("model_weights.pth", map_location=device))
model.to(device)
model.eval()

# Function to preprocess and predict
def predict_wetland(image_path):
    # Read image using rasterio (get RGB bands)
    with rasterio.open(image_path) as src:
        red = src.read(1)
        green = src.read(2)
        blue = src.read(3)
    
    # Stack to create RGB image
    image = np.dstack((red, green, blue)).astype(np.float32)
    
    # Normalize
    if image.max() > 0:
        image = image / image.max()
    
    # Convert to tensor and add batch dimension
    image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
    image = image.to(device)
    
    # Get prediction
    with torch.no_grad():
        output = model(image)
        prediction = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
    
    return prediction

Citation

If you use this model in your research, please cite:

@software{wetlands_segmentation_deeplabsv3plus,
  author = {dcrey7},
  title = {Wetlands Segmentation using DeepLabv3+},
  url = {https://huggingface.co/dcrey7/wetlands_segmentation_deeplabsv3plus},
  year = {2025},
}

License

This model is available under the Apache 2.0 license.

Contact

For questions or feedback, please open an issue on this repository or contact the repository owner via HuggingFace.