Update README.md
Browse files
README.md
CHANGED
@@ -41,109 +41,109 @@ using a window size of 1024x1024.
|
|
41 |
The simplest way to use this model to segment an image of the Coralscapes dataset is as follows:
|
42 |
|
43 |
```python
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
```
|
60 |
|
61 |
While using the above approach should still work for images of different sizes and scales, for images that are not close to the training size of the model (1024x1024),
|
62 |
we recommend using the following approach using a sliding window to achieve better results:
|
63 |
|
64 |
```python
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
```
|
148 |
|
149 |
## Training & Evaluation Details
|
|
|
41 |
The simplest way to use this model to segment an image of the Coralscapes dataset is as follows:
|
42 |
|
43 |
```python
|
44 |
+
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
|
45 |
+
from PIL import Image
|
46 |
+
from datasets import load_dataset
|
47 |
+
|
48 |
+
# Load an image from the coralscapes dataset or load your own image
|
49 |
+
dataset = load_dataset("EPFL-ECEO/coralscapes")
|
50 |
+
image = dataset["test"][42]["image"]
|
51 |
+
|
52 |
+
preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
|
53 |
+
model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
|
54 |
+
|
55 |
+
inputs = preprocessor(image, return_tensors = "pt")
|
56 |
+
outputs = model(**inputs)
|
57 |
+
outputs = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])
|
58 |
+
label_pred = outputs[0].numpy()
|
59 |
```
|
60 |
|
61 |
While using the above approach should still work for images of different sizes and scales, for images that are not close to the training size of the model (1024x1024),
|
62 |
we recommend using the following approach using a sliding window to achieve better results:
|
63 |
|
64 |
```python
|
65 |
+
import torch
|
66 |
+
import torch.nn.functional as F
|
67 |
+
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
|
68 |
+
from PIL import Image
|
69 |
+
import numpy as np
|
70 |
+
from datasets import load_dataset
|
71 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
72 |
+
|
73 |
+
def resize_image(image, target_size=1024):
|
74 |
+
"""
|
75 |
+
Used to resize the image such that the smaller side equals 1024
|
76 |
+
"""
|
77 |
+
h_img, w_img = image.size
|
78 |
+
if h_img < w_img:
|
79 |
+
new_h, new_w = target_size, int(w_img * (target_size / h_img))
|
80 |
+
else:
|
81 |
+
new_h, new_w = int(h_img * (target_size / w_img)), target_size
|
82 |
+
resized_img = image.resize((new_h, new_w))
|
83 |
+
return resized_img
|
84 |
+
|
85 |
+
def segment_image(image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40, transform=None):
|
86 |
+
"""
|
87 |
+
Finds an optimal stride based on the image size and aspect ratio to create
|
88 |
+
overlapping sliding windows of size 1024x1024 which are then fed into the model.
|
89 |
+
"""
|
90 |
+
h_crop, w_crop = crop_size
|
91 |
+
|
92 |
+
img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0)
|
93 |
+
batch_size, _, h_img, w_img = img.size()
|
94 |
+
|
95 |
+
if transform:
|
96 |
+
img = torch.Tensor(transform(image = img.numpy())["image"]).to(device)
|
97 |
+
|
98 |
+
h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
|
99 |
+
w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
|
100 |
+
|
101 |
+
h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
|
102 |
+
w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
|
103 |
+
|
104 |
+
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
|
105 |
+
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
106 |
+
|
107 |
+
for h_idx in range(h_grids):
|
108 |
+
for w_idx in range(w_grids):
|
109 |
+
y1 = h_idx * h_stride
|
110 |
+
x1 = w_idx * w_stride
|
111 |
+
y2 = min(y1 + h_crop, h_img)
|
112 |
+
x2 = min(x1 + w_crop, w_img)
|
113 |
+
y1 = max(y2 - h_crop, 0)
|
114 |
+
x1 = max(x2 - w_crop, 0)
|
115 |
+
crop_img = img[:, :, y1:y2, x1:x2]
|
116 |
+
with torch.no_grad():
|
117 |
+
if(preprocessor):
|
118 |
+
inputs = preprocessor(crop_img, return_tensors = "pt")
|
119 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(device)
|
120 |
+
else:
|
121 |
+
inputs = crop_img.to(device)
|
122 |
+
outputs = model(**inputs)
|
123 |
+
|
124 |
+
resized_logits = F.interpolate(
|
125 |
+
outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False
|
126 |
+
)
|
127 |
+
preds += F.pad(resized_logits,
|
128 |
+
(int(x1), int(preds.shape[3] - x2), int(y1),
|
129 |
+
int(preds.shape[2] - y2))).cpu()
|
130 |
+
count_mat[:, :, y1:y2, x1:x2] += 1
|
131 |
+
|
132 |
+
assert (count_mat == 0).sum() == 0
|
133 |
+
preds = preds / count_mat
|
134 |
+
preds = preds.argmax(dim=1)
|
135 |
+
preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
|
136 |
+
label_pred = preds.squeeze().cpu().numpy()
|
137 |
+
return label_pred
|
138 |
+
|
139 |
+
# Load an image from the coralscapes dataset or load your own image
|
140 |
+
dataset = load_dataset("EPFL-ECEO/coralscapes")
|
141 |
+
image = dataset["test"][42]["image"]
|
142 |
+
|
143 |
+
preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
|
144 |
+
model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
|
145 |
+
|
146 |
+
label_pred = segment_image(image, preprocessor, model)
|
147 |
```
|
148 |
|
149 |
## Training & Evaluation Details
|