File size: 7,030 Bytes
78b2a04
 
7e61b48
 
 
771c2d5
7e61b48
 
771c2d5
 
41601eb
78b2a04
 
 
 
9e150c5
78b2a04
 
 
 
 
 
771c2d5
595d52d
78b2a04
771c2d5
78b2a04
cc7700a
 
78b2a04
 
 
cc7700a
78b2a04
595d52d
b600a74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595d52d
78b2a04
cc7700a
 
 
 
b600a74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc7700a
 
771c2d5
78b2a04
771c2d5
78b2a04
771c2d5
78b2a04
 
771c2d5
78b2a04
771c2d5
 
 
 
 
78b2a04
 
 
 
f92c856
 
78b2a04
e714c15
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
---
library_name: transformers
tags:
- vision
- image-segmentation
- ecology
datasets:
- coralscapes
metrics:
- mean_iou
license: apache-2.0
---

# Model Card for Model ID

SegFormer model with a MiT-B2 backbone fine-tuned on Coralscapes at resolution 1024x1024, as introduced in [The Coralscapes Dataset: Semantic Scene Understanding in Coral Reefs](https://arxiv.org/abs/2503.20000).


## Model Details

### Model Description

- **Model type:** SegFormer
- **Finetuned from model:** [SegFormer (b2-sized) encoder pre-trained-only (`nvidia/mit-b2`)](https://huggingface.co/nvidia/mit-b2)

### Model Sources 

- **Repository:** [coralscapesScripts](https://github.com/eceo-epfl/coralscapesScripts/)
- **Demo** [Hugging Face Spaces](https://huggingface.co/spaces/EPFL-ECEO/coralscapes_demo):

## How to Get Started with the Model

The simplest way to use this model to segment an image of the Coralscapes dataset is as follows:

```python
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
from datasets import load_dataset

# Load an image from the coralscapes dataset or load your own image 
dataset = load_dataset("EPFL-ECEO/coralscapes") 
image = dataset["test"][42]["image"]

preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")

inputs = preprocessor(image, return_tensors = "pt")
outputs = model(**inputs)
outputs = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])
label_pred = outputs[0].numpy()
```

While using the above approach should still work for images of different sizes and scales, for images that are not close to the training size of the model (1024x1024), 
we recommend using the following approach using a sliding window to achieve better results:

```python
import torch 
import torch.nn.functional as F
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import numpy as np
from datasets import load_dataset
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def resize_image(image, target_size=1024):
    """
    Used to resize the image such that the smaller side equals 1024
    """
    h_img, w_img = image.size
    if h_img < w_img:
        new_h, new_w = target_size, int(w_img * (target_size / h_img))
    else:
        new_h, new_w  = int(h_img * (target_size / w_img)), target_size
    resized_img = image.resize((new_h, new_w))
    return resized_img

def segment_image(image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40, transform=None):
    """
    Finds an optimal stride based on the image size and aspect ratio to create
    overlapping sliding windows of size 1024x1024 which are then fed into the model.  
    """ 
    h_crop, w_crop = crop_size
    
    img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0)
    batch_size, _, h_img, w_img = img.size()
    
    if transform:
        img = torch.Tensor(transform(image = img.numpy())["image"]).to(device)    
        
    h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
    w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
    
    h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
    w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
    
    preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
    count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
    
    for h_idx in range(h_grids):
        for w_idx in range(w_grids):
            y1 = h_idx * h_stride
            x1 = w_idx * w_stride
            y2 = min(y1 + h_crop, h_img)
            x2 = min(x1 + w_crop, w_img)
            y1 = max(y2 - h_crop, 0)
            x1 = max(x2 - w_crop, 0)
            crop_img = img[:, :, y1:y2, x1:x2]
            with torch.no_grad():
                if(preprocessor):
                    inputs = preprocessor(crop_img, return_tensors = "pt")
                    inputs["pixel_values"] = inputs["pixel_values"].to(device)
                else:
                    inputs = crop_img.to(device)
                outputs = model(**inputs)

            resized_logits = F.interpolate(
                outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False
            )
            preds += F.pad(resized_logits,
                            (int(x1), int(preds.shape[3] - x2), int(y1),
                            int(preds.shape[2] - y2))).cpu()
            count_mat[:, :, y1:y2, x1:x2] += 1
        
    assert (count_mat == 0).sum() == 0
    preds = preds / count_mat
    preds = preds.argmax(dim=1)
    preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
    label_pred = preds.squeeze().cpu().numpy()
    return label_pred

# Load an image from the coralscapes dataset or load your own image 
dataset = load_dataset("EPFL-ECEO/coralscapes") 
image = dataset["test"][42]["image"]

preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")

label_pred = segment_image(image, preprocessor, model)
```

## Training & Evaluation Details

### Data

The model is trained and evaluated on the [Coralscapes dataset](https://huggingface.co/datasets/EPFL-ECEO/coralscapes) which is a general-purpose dense semantic segmentation dataset for coral reefs. 


### Procedure

Training is conducted following the Segformer original [implementation](https://proceedings.neurips.cc/paper_files/paper/2021/file/64f1f27bf1b4ec22924fd0acb550c235-Paper.pdf), using a batch size of 8 for 265 epochs, 
using the AdamW optimizer with an initial learning rate of 6e-5, weight decay of 1e-2 and polynomial learning rate scheduler with a power of 1. 
During training, images are randomly scaled within a range of 1 and 2, flipped horizontally with a 0.5 probability and randomly cropped to 1024×1024 pixels. 
Input images are normalized using the ImageNet mean and standard deviation. For evaluation, a non-overlapping sliding window strategy is employed, 
using a window size of 1024x1024. 


### Results

- Test Accuracy:  80.904
- Test Mean IoU: 54.682

## Citation
If you find this project useful, please consider citing:
```bibtex
@misc{sauder2025coralscapesdatasetsemanticscene,
        title={The Coralscapes Dataset: Semantic Scene Understanding in Coral Reefs}, 
        author={Jonathan Sauder and Viktor Domazetoski and Guilhem Banc-Prandi and Gabriela Perna and Anders Meibom and Devis Tuia},
        year={2025},
        eprint={2503.20000},
        archivePrefix={arXiv},
        primaryClass={cs.CV},
        url={https://arxiv.org/abs/2503.20000}, 
  }
```