ViktorDo commited on
Commit
cc7700a
·
verified ·
1 Parent(s): 595d52d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +95 -14
README.md CHANGED
@@ -12,12 +12,12 @@ SegFormer model with a MiT-b2 backbone fine-tuned on Coralscapes at resolution 1
12
 
13
  ### Model Description
14
 
15
- Training is conducted following \cite{xie2021segformer}, using a batch size of 8 for 265 epochs,
16
- using the AdamW optimizer with an initial learning rate of 6e-5 (multiplied by 10 when using LoRA),
17
- weight decay of 1e-2 and polynomial learning rate scheduler with a power of 1.
18
  During training, images are randomly scaled within a range of 1 and 2, flipped horizontally with a 0.5 probability and randomly cropped to 1024×1024 pixels.
19
  Input images are normalized using the ImageNet mean and standard deviation. For evaluation, a non-overlapping sliding window strategy is employed,
20
- using a window size of 1024x1024 and a stride of 1024.
 
21
 
22
  - **Developed by:** [More Information Needed]
23
  - **Funded by [optional]:** [More Information Needed]
@@ -30,9 +30,9 @@ using a window size of 1024x1024 and a stride of 1024.
30
 
31
  <!-- Provide the basic links for the model. -->
32
 
33
- - **Repository:** [coralscapesScripts](https://github.com/ViktorDomazetoski/coralscapesScripts/)
34
  - **Paper [optional]:** [More Information Needed]
35
- - **Demo [optional]:** [More Information Needed]
36
 
37
  ## Uses
38
 
@@ -70,32 +70,113 @@ Users (both direct and downstream) should be made aware of the risks, biases and
70
 
71
  ## How to Get Started with the Model
72
 
73
- Here is how to use this model to segment an image of the Coralscapes dataset
74
 
75
  ```python
76
  from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
77
  from PIL import Image
78
  from datasets import load_dataset
79
- import numpy as np
80
 
81
  # Load an image from the coralscapes dataset or load your own image
82
  dataset = load_dataset("EPFL-ECEO/coralscapes")
83
  image = dataset["test"][42]["image"]
84
- label = dataset["test"][42]["label"]
85
-
86
- id2label = {"0": "unlabeled", "1": "seagrass", "2": "trash", "3": "other coral dead", "4": "other coral bleached", "5": "sand", "6": "other coral alive", "7": "human", "8": "transect tools", "9": "fish", "10": "algae covered substrate", "11": "other animal", "12": "unknown hard substrate", "13": "background", "14": "dark", "15": "transect line", "16": "massive/meandering bleached", "17": "massive/meandering alive", "18": "rubble", "19": "branching bleached", "20": "branching dead", "21": "millepora", "22": "branching alive", "23": "massive/meandering dead", "24": "clam", "25": "acropora alive", "26": "sea cucumber", "27": "turbinaria", "28": "table acropora alive", "29": "sponge", "30": "anemone", "31": "pocillopora alive", "32": "table acropora dead", "33": "meandering bleached", "34": "stylophora alive", "35": "sea urchin", "36": "meandering alive", "37": "meandering dead", "38": "crown of thorn", "39": "dead clam"}
87
- label2color = {"unlabeled":[255, 255, 255], "human": [255, 0, 0], "background": [29, 162, 216], "fish": [255, 255, 0], "sand": [194, 178, 128], "rubble": [161, 153, 128], "unknown hard substrate": [125, 125, 125], "algae covered substrate": [125, 163, 125], "dark": [31, 31, 31], "branching bleached": [252, 231, 240], "branching dead": [123, 50, 86], "branching alive": [226, 91, 157], "stylophora alive": [255, 111, 194], "pocillopora alive": [255, 146, 150], "acropora alive": [236, 128, 255], "table acropora alive": [189, 119, 255], "table acropora dead": [85, 53, 116], "millepora": [244, 150, 115], "turbinaria": [228, 255, 119], "other coral bleached": [250, 224, 225], "other coral dead": [114, 60, 61], "other coral alive": [224, 118, 119], "massive/meandering alive": [236, 150, 21], "massive/meandering dead": [134, 86, 18], "massive/meandering bleached": [255, 248, 228], "meandering alive": [230, 193, 0], "meandering dead": [119, 100, 14], "meandering bleached": [251, 243, 216], "transect line": [0, 255, 0], "transect tools": [8, 205, 12], "sea urchin": [0, 142, 255], "sea cucumber": [0, 231, 255], "anemone": [0, 255, 189], "sponge": [240, 80, 80], "clam": [189, 255, 234], "other animal": [0, 255, 255], "trash": [255, 0, 134], "seagrass": [125, 222, 125], "crown of thorn": [179, 245, 234], "dead clam": [89, 155, 134]}
88
- id2color = {int(id): label2color[label] for id, label in id2label.items()}
89
 
90
  preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
91
  model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
92
 
93
- inputs = preprocessor(image.resize((1024, 512)), return_tensors = "pt")
94
  outputs = model(**inputs)
95
  outputs = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])
96
  label_pred = outputs[0].cpu().numpy()
97
  ```
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  ## Training Details
100
 
101
  ### Training Data
 
12
 
13
  ### Model Description
14
 
15
+ Training is conducted following the Segformer original [implementation](https://proceedings.neurips.cc/paper_files/paper/2021/file/64f1f27bf1b4ec22924fd0acb550c235-Paper.pdf), using a batch size of 8 for 265 epochs,
16
+ using the AdamW optimizer with an initial learning rate of 6e-5, weight decay of 1e-2 and polynomial learning rate scheduler with a power of 1.
 
17
  During training, images are randomly scaled within a range of 1 and 2, flipped horizontally with a 0.5 probability and randomly cropped to 1024×1024 pixels.
18
  Input images are normalized using the ImageNet mean and standard deviation. For evaluation, a non-overlapping sliding window strategy is employed,
19
+ using a window size of 1024x1024.
20
+ <!-- TODO - We used a stride of 1024 but in the demo it is variable. Should we move this entire section to training below? -->
21
 
22
  - **Developed by:** [More Information Needed]
23
  - **Funded by [optional]:** [More Information Needed]
 
30
 
31
  <!-- Provide the basic links for the model. -->
32
 
33
+ - **Repository:** [coralscapesScripts](https://github.com/eceo-epfl/coralscapesScripts/)
34
  - **Paper [optional]:** [More Information Needed]
35
+ - **Demo** [Hugging Face Spaces](https://huggingface.co/spaces/EPFL-ECEO/coralscapes_demo):
36
 
37
  ## Uses
38
 
 
70
 
71
  ## How to Get Started with the Model
72
 
73
+ The simplest way to use this model to segment an image of the Coralscapes dataset is as follows:
74
 
75
  ```python
76
  from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
77
  from PIL import Image
78
  from datasets import load_dataset
 
79
 
80
  # Load an image from the coralscapes dataset or load your own image
81
  dataset = load_dataset("EPFL-ECEO/coralscapes")
82
  image = dataset["test"][42]["image"]
 
 
 
 
 
83
 
84
  preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
85
  model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
86
 
87
+ inputs = preprocessor(image, return_tensors = "pt")
88
  outputs = model(**inputs)
89
  outputs = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=[(image.size[1], image.size[0])])
90
  label_pred = outputs[0].cpu().numpy()
91
  ```
92
 
93
+ While using the above approach should still work for images of different sizes and scales, for images that are not close to the training size of the model (1024x1024),
94
+ we recommend using the following approach using a sliding window to achieve better results:
95
+
96
+ ```python
97
+ import torch
98
+ import torch.nn.functional as F
99
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
100
+ from PIL import Image
101
+ from datasets import load_dataset
102
+ import numpy as np
103
+
104
+ def resize_image(image, target_size=1024):
105
+ """
106
+ Used to resize the image such that the smaller side equals 1024
107
+ """
108
+ h_img, w_img = image.size
109
+ if h_img < w_img:
110
+ new_h, new_w = target_size, int(w_img * (target_size / h_img))
111
+ else:
112
+ new_h, new_w = int(h_img * (target_size / w_img)), target_size
113
+ resized_img = image.resize((new_h, new_w))
114
+ return resized_img
115
+
116
+ def segment_image(image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40, transform=None):
117
+ """
118
+ Finds an optimal stride based on the image size and aspect ratio to create
119
+ overlapping sliding windows of size 1024x1024 which are then fed into the model.
120
+ """
121
+ h_crop, w_crop = crop_size
122
+
123
+ img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0)
124
+ batch_size, _, h_img, w_img = img.size()
125
+
126
+ if transform:
127
+ img = torch.Tensor(transform(image = img.numpy())["image"]).to(device)
128
+
129
+ h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
130
+ w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
131
+
132
+ h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
133
+ w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
134
+
135
+ preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
136
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
137
+
138
+ for h_idx in range(h_grids):
139
+ for w_idx in range(w_grids):
140
+ y1 = h_idx * h_stride
141
+ x1 = w_idx * w_stride
142
+ y2 = min(y1 + h_crop, h_img)
143
+ x2 = min(x1 + w_crop, w_img)
144
+ y1 = max(y2 - h_crop, 0)
145
+ x1 = max(x2 - w_crop, 0)
146
+ crop_img = img[:, :, y1:y2, x1:x2]
147
+ with torch.no_grad():
148
+ if(preprocessor):
149
+ inputs = preprocessor(crop_img, return_tensors = "pt")
150
+ inputs["pixel_values"] = inputs["pixel_values"].to(device)
151
+ else:
152
+ inputs = crop_img.to(device)
153
+ outputs = model(**inputs)
154
+
155
+ resized_logits = F.interpolate(
156
+ outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False
157
+ )
158
+ preds += F.pad(resized_logits,
159
+ (int(x1), int(preds.shape[3] - x2), int(y1),
160
+ int(preds.shape[2] - y2)))
161
+ count_mat[:, :, y1:y2, x1:x2] += 1
162
+
163
+ assert (count_mat == 0).sum() == 0
164
+ preds = preds / count_mat
165
+ preds = preds.argmax(dim=1)
166
+ preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
167
+ label_pred = preds.squeeze().cpu().numpy()
168
+ return label_pred
169
+
170
+ # Load an image from the coralscapes dataset or load your own image
171
+ dataset = load_dataset("EPFL-ECEO/coralscapes")
172
+ image = dataset["test"][42]["image"]
173
+
174
+ preprocessor = SegformerImageProcessor.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
175
+ model = SegformerForSemanticSegmentation.from_pretrained("EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024")
176
+
177
+ label_pred = segment_image(image, preprocessor, model)
178
+ ```
179
+
180
  ## Training Details
181
 
182
  ### Training Data