Upload 42 files
Browse filesPytorch Segmentation pipeline. use to train/predict binary/multi/raster images
- README.md +52 -0
- examples/.ipynb_checkpoints/predict-checkpoint.ipynb +255 -0
- examples/.ipynb_checkpoints/train-checkpoint.ipynb +113 -0
- examples/predict.ipynb +255 -0
- examples/train.ipynb +113 -0
- requirements.txt +15 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/custom_losses-checkpoint.py +97 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/data_loader-checkpoint.py +129 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/encoder_management-checkpoint.py +136 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/evaluation_utils-checkpoint.py +108 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/image_preprocessing-checkpoint.py +81 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/metrics-checkpoint.py +94 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/model_core-checkpoint.py +129 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/prediction-checkpoint.py +336 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/training-checkpoint.py +313 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/utilities-checkpoint.py +119 -0
- semantic-segmentation/SemanticModel/.ipynb_checkpoints/visualization-checkpoint.py +115 -0
- semantic-segmentation/SemanticModel/__init__.py +0 -0
- semantic-segmentation/SemanticModel/__pycache__/__init__.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/custom_losses.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/data_loader.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/encoder_management.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/evaluation_utils.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/image_preprocessing.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/metrics.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/model_core.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/prediction.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/training.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/utilities.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/__pycache__/visualization.cpython-38.pyc +0 -0
- semantic-segmentation/SemanticModel/custom_losses.py +97 -0
- semantic-segmentation/SemanticModel/data_loader.py +129 -0
- semantic-segmentation/SemanticModel/encoder_management.py +136 -0
- semantic-segmentation/SemanticModel/evaluation_utils.py +108 -0
- semantic-segmentation/SemanticModel/image_preprocessing.py +81 -0
- semantic-segmentation/SemanticModel/metrics.py +94 -0
- semantic-segmentation/SemanticModel/model_core.py +129 -0
- semantic-segmentation/SemanticModel/prediction.py +336 -0
- semantic-segmentation/SemanticModel/training.py +313 -0
- semantic-segmentation/SemanticModel/utilities.py +119 -0
- semantic-segmentation/SemanticModel/visualization.py +115 -0
- setup.py +34 -0
README.md
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SemanticModel
|
2 |
+
|
3 |
+
Deep learning framework for semantic segmentation using PyTorch.
|
4 |
+
|
5 |
+
## Install
|
6 |
+
```bash
|
7 |
+
pip install -r requirements.txt
|
8 |
+
python setup.py install
|
9 |
+
```
|
10 |
+
|
11 |
+
## Usage
|
12 |
+
```python
|
13 |
+
from SemanticModel.model_core import SegmentationModel
|
14 |
+
from SemanticModel.prediction import PredictionPipeline
|
15 |
+
|
16 |
+
# Train
|
17 |
+
model = SegmentationModel(
|
18 |
+
classes=['background', 'object'],
|
19 |
+
architecture='unet',
|
20 |
+
encoder='timm-regnety_120'
|
21 |
+
)
|
22 |
+
|
23 |
+
trainer = ModelTrainer(
|
24 |
+
model_config=model,
|
25 |
+
root_dir='path/to/dataset',
|
26 |
+
epochs=40
|
27 |
+
)
|
28 |
+
model, metrics = trainer.train()
|
29 |
+
|
30 |
+
# Predict
|
31 |
+
predictor = PredictionPipeline(model)
|
32 |
+
predictor.predict_single_image('image.jpg')
|
33 |
+
predictor.predict_directory('image_dir/')
|
34 |
+
predictor.predict_raster('raster.tif')
|
35 |
+
|
36 |
+
# Load pretrained
|
37 |
+
model = SegmentationModel(
|
38 |
+
classes=['background', 'object'],
|
39 |
+
weights='path/to/best_model.pth'
|
40 |
+
)
|
41 |
+
```
|
42 |
+
|
43 |
+
## Data Structure
|
44 |
+
```
|
45 |
+
dataset/
|
46 |
+
├── train/
|
47 |
+
│ ├── Images/
|
48 |
+
│ └── Masks/
|
49 |
+
└── val/
|
50 |
+
├── Images/
|
51 |
+
└── Masks/
|
52 |
+
```
|
examples/.ipynb_checkpoints/predict-checkpoint.ipynb
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "42e4027f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"/home/jovyan/shared/Chima/ml_project/repos/MYSMP\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"%cd \"../../MYSMP\""
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"id": "ef6ea33e",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"# %pip install -r requirements.txt"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 3,
|
34 |
+
"id": "30812aed",
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [
|
37 |
+
{
|
38 |
+
"name": "stdout",
|
39 |
+
"output_type": "stream",
|
40 |
+
"text": [
|
41 |
+
"/home/jovyan/shared/Chima/ml_project/repos/MYSMP/semantic-segmentation\n"
|
42 |
+
]
|
43 |
+
}
|
44 |
+
],
|
45 |
+
"source": [
|
46 |
+
"%cd \"../MYSMP/semantic-segmentation\""
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 5,
|
52 |
+
"id": "aaa4a036",
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [
|
55 |
+
{
|
56 |
+
"name": "stdout",
|
57 |
+
"output_type": "stream",
|
58 |
+
"text": [
|
59 |
+
"/home/jovyan/shared/Chima/ml_project/repos/MYSMP/semantic-segmentation\n"
|
60 |
+
]
|
61 |
+
}
|
62 |
+
],
|
63 |
+
"source": [
|
64 |
+
"!pwd"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": 6,
|
70 |
+
"id": "099239c4",
|
71 |
+
"metadata": {},
|
72 |
+
"outputs": [
|
73 |
+
{
|
74 |
+
"name": "stdout",
|
75 |
+
"output_type": "stream",
|
76 |
+
"text": [
|
77 |
+
"Loading pretrained model...\n"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"name": "stderr",
|
82 |
+
"output_type": "stream",
|
83 |
+
"text": [
|
84 |
+
"/srv/conda/envs/notebook/lib/python3.8/site-packages/segmentation_models_pytorch/base/modules.py:116: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
85 |
+
" return self.activation(x)\n"
|
86 |
+
]
|
87 |
+
}
|
88 |
+
],
|
89 |
+
"source": [
|
90 |
+
"# Initialize the model\n",
|
91 |
+
"from SemanticModel.model_core import SegmentationModel\n",
|
92 |
+
"\n",
|
93 |
+
"model = SegmentationModel(\n",
|
94 |
+
" classes=['bg', 'cacao', 'matarraton', 'abarco'],\n",
|
95 |
+
" architecture='unet',\n",
|
96 |
+
" encoder='timm-regnety_120',\n",
|
97 |
+
" weights='../data/model_outputs-unet[timm-regnety_120]-01-23-2025_075803/best_model.pth'\n",
|
98 |
+
")\n",
|
99 |
+
"\n",
|
100 |
+
"# Initialize prediction pipeline\n",
|
101 |
+
"from SemanticModel.prediction import PredictionPipeline\n",
|
102 |
+
"\n",
|
103 |
+
"predictor = PredictionPipeline(model)\n",
|
104 |
+
"output_dir='../predictions'\n",
|
105 |
+
"image_path= '../data/Images/2019-Mission2-odm_1_42.jpg'\n",
|
106 |
+
"\n",
|
107 |
+
"# Make prediction\n",
|
108 |
+
"prediction = predictor.predict_single_image(image_path=image_path,output_dir=output_dir)"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": 7,
|
114 |
+
"id": "e0ae4c21",
|
115 |
+
"metadata": {},
|
116 |
+
"outputs": [
|
117 |
+
{
|
118 |
+
"name": "stdout",
|
119 |
+
"output_type": "stream",
|
120 |
+
"text": [
|
121 |
+
"\n",
|
122 |
+
"Predictions saved to: path/to/folderofImages/predictions\n"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"data": {
|
127 |
+
"text/plain": [
|
128 |
+
"'path/to/folderofImages/predictions'"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
"execution_count": 7,
|
132 |
+
"metadata": {},
|
133 |
+
"output_type": "execute_result"
|
134 |
+
}
|
135 |
+
],
|
136 |
+
"source": [
|
137 |
+
"# Directory of images\n",
|
138 |
+
"image_path= 'path/to/folderofImages'\n",
|
139 |
+
"predictor.predict_directory(image_path)"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 8,
|
145 |
+
"id": "fc4f2a48",
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [
|
148 |
+
{
|
149 |
+
"name": "stdout",
|
150 |
+
"output_type": "stream",
|
151 |
+
"text": [
|
152 |
+
"Loading raster...\n",
|
153 |
+
"Processed 6/6 tiles\n",
|
154 |
+
"Prediction saved to: ../predictions/prediction.tif\n"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"data": {
|
159 |
+
"text/plain": [
|
160 |
+
"(array([[[0, 0, 0],\n",
|
161 |
+
" [0, 0, 0],\n",
|
162 |
+
" [0, 0, 0],\n",
|
163 |
+
" ...,\n",
|
164 |
+
" [0, 0, 0],\n",
|
165 |
+
" [0, 0, 0],\n",
|
166 |
+
" [0, 0, 0]],\n",
|
167 |
+
" \n",
|
168 |
+
" [[0, 0, 0],\n",
|
169 |
+
" [0, 0, 0],\n",
|
170 |
+
" [0, 0, 0],\n",
|
171 |
+
" ...,\n",
|
172 |
+
" [0, 0, 0],\n",
|
173 |
+
" [0, 0, 0],\n",
|
174 |
+
" [0, 0, 0]],\n",
|
175 |
+
" \n",
|
176 |
+
" [[0, 0, 0],\n",
|
177 |
+
" [0, 0, 0],\n",
|
178 |
+
" [0, 0, 0],\n",
|
179 |
+
" ...,\n",
|
180 |
+
" [0, 0, 0],\n",
|
181 |
+
" [0, 0, 0],\n",
|
182 |
+
" [0, 0, 0]],\n",
|
183 |
+
" \n",
|
184 |
+
" ...,\n",
|
185 |
+
" \n",
|
186 |
+
" [[0, 0, 0],\n",
|
187 |
+
" [0, 0, 0],\n",
|
188 |
+
" [0, 0, 0],\n",
|
189 |
+
" ...,\n",
|
190 |
+
" [0, 0, 0],\n",
|
191 |
+
" [0, 0, 0],\n",
|
192 |
+
" [0, 0, 0]],\n",
|
193 |
+
" \n",
|
194 |
+
" [[0, 0, 0],\n",
|
195 |
+
" [0, 0, 0],\n",
|
196 |
+
" [0, 0, 0],\n",
|
197 |
+
" ...,\n",
|
198 |
+
" [0, 0, 0],\n",
|
199 |
+
" [0, 0, 0],\n",
|
200 |
+
" [0, 0, 0]],\n",
|
201 |
+
" \n",
|
202 |
+
" [[0, 0, 0],\n",
|
203 |
+
" [0, 0, 0],\n",
|
204 |
+
" [0, 0, 0],\n",
|
205 |
+
" ...,\n",
|
206 |
+
" [0, 0, 0],\n",
|
207 |
+
" [0, 0, 0],\n",
|
208 |
+
" [0, 0, 0]]], dtype=uint8),\n",
|
209 |
+
" {'driver': 'GTiff', 'dtype': 'uint8', 'nodata': None, 'width': 2365, 'height': 1797, 'count': 3, 'crs': CRS.from_epsg(32618), 'transform': Affine(0.03564594364277113, 0.0, 740295.5186183113,\n",
|
210 |
+
" 0.0, -0.03564594364276106, 485117.0212715292), 'tiled': False, 'interleave': 'pixel'})"
|
211 |
+
]
|
212 |
+
},
|
213 |
+
"execution_count": 8,
|
214 |
+
"metadata": {},
|
215 |
+
"output_type": "execute_result"
|
216 |
+
}
|
217 |
+
],
|
218 |
+
"source": [
|
219 |
+
"# Large raster\n",
|
220 |
+
"output_path='../predictions/prediction.tif'\n",
|
221 |
+
"raster_path = '../data/2021-Mission7_clipped_2.tif'\n",
|
222 |
+
"predictor.predict_raster(raster_path, tile_size=1024,output_path=output_path,format='color')"
|
223 |
+
]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"cell_type": "code",
|
227 |
+
"execution_count": null,
|
228 |
+
"id": "8e4ab5a5",
|
229 |
+
"metadata": {},
|
230 |
+
"outputs": [],
|
231 |
+
"source": []
|
232 |
+
}
|
233 |
+
],
|
234 |
+
"metadata": {
|
235 |
+
"kernelspec": {
|
236 |
+
"display_name": "AgLab - Python 3",
|
237 |
+
"language": "python",
|
238 |
+
"name": "python3"
|
239 |
+
},
|
240 |
+
"language_info": {
|
241 |
+
"codemirror_mode": {
|
242 |
+
"name": "ipython",
|
243 |
+
"version": 3
|
244 |
+
},
|
245 |
+
"file_extension": ".py",
|
246 |
+
"mimetype": "text/x-python",
|
247 |
+
"name": "python",
|
248 |
+
"nbconvert_exporter": "python",
|
249 |
+
"pygments_lexer": "ipython3",
|
250 |
+
"version": "3.8.13"
|
251 |
+
}
|
252 |
+
},
|
253 |
+
"nbformat": 4,
|
254 |
+
"nbformat_minor": 5
|
255 |
+
}
|
examples/.ipynb_checkpoints/train-checkpoint.ipynb
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "333ede5f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"%cd \"../../MYSMP\""
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"id": "84d4c945",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"# %pip install -r requirements.txt"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": null,
|
26 |
+
"id": "7b088ef9",
|
27 |
+
"metadata": {},
|
28 |
+
"outputs": [],
|
29 |
+
"source": [
|
30 |
+
"%cd \"../MYSMP/semantic-segmentation\""
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": null,
|
36 |
+
"id": "99937292",
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"from SemanticModel.model_core import SegmentationModel\n",
|
41 |
+
"from SemanticModel.training import ModelTrainer"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": null,
|
47 |
+
"id": "ab69a291",
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"# initialization loss function\n",
|
52 |
+
"model = SegmentationModel(\n",
|
53 |
+
" classes=['bg', 'cacao', 'matarraton', 'abarco'],\n",
|
54 |
+
" architecture='unet',\n",
|
55 |
+
" encoder='timm-regnety_120',\n",
|
56 |
+
" weights='imagenet',\n",
|
57 |
+
" loss='dice' # Try 'dice' or 'tversky' instead of default\n",
|
58 |
+
")\n",
|
59 |
+
"\n",
|
60 |
+
"# training parameters\n",
|
61 |
+
"trainer = ModelTrainer(\n",
|
62 |
+
" model_config=model,\n",
|
63 |
+
" root_dir='../data',\n",
|
64 |
+
" epochs=100,\n",
|
65 |
+
" train_size=1024,\n",
|
66 |
+
" batch_size=4,\n",
|
67 |
+
" learning_rate=1e-3, # Increased learning rate\n",
|
68 |
+
" step_count=3, # More learning rate adjustments\n",
|
69 |
+
" decay_factor=0.5 # Stronger decay\n",
|
70 |
+
")"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": null,
|
76 |
+
"id": "38fc7c6f",
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"trained_model, metrics = trainer.train()"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": null,
|
86 |
+
"id": "a053c2ae",
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [],
|
89 |
+
"source": []
|
90 |
+
}
|
91 |
+
],
|
92 |
+
"metadata": {
|
93 |
+
"kernelspec": {
|
94 |
+
"display_name": "AgLab - Python 3",
|
95 |
+
"language": "python",
|
96 |
+
"name": "python3"
|
97 |
+
},
|
98 |
+
"language_info": {
|
99 |
+
"codemirror_mode": {
|
100 |
+
"name": "ipython",
|
101 |
+
"version": 3
|
102 |
+
},
|
103 |
+
"file_extension": ".py",
|
104 |
+
"mimetype": "text/x-python",
|
105 |
+
"name": "python",
|
106 |
+
"nbconvert_exporter": "python",
|
107 |
+
"pygments_lexer": "ipython3",
|
108 |
+
"version": "3.8.13"
|
109 |
+
}
|
110 |
+
},
|
111 |
+
"nbformat": 4,
|
112 |
+
"nbformat_minor": 5
|
113 |
+
}
|
examples/predict.ipynb
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "42e4027f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"/home/jovyan/shared/Chima/ml_project/repos/MYSMP\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"%cd \"../../MYSMP\""
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"id": "ef6ea33e",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"# %pip install -r requirements.txt"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 3,
|
34 |
+
"id": "30812aed",
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [
|
37 |
+
{
|
38 |
+
"name": "stdout",
|
39 |
+
"output_type": "stream",
|
40 |
+
"text": [
|
41 |
+
"/home/jovyan/shared/Chima/ml_project/repos/MYSMP/semantic-segmentation\n"
|
42 |
+
]
|
43 |
+
}
|
44 |
+
],
|
45 |
+
"source": [
|
46 |
+
"%cd \"../MYSMP/semantic-segmentation\""
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 5,
|
52 |
+
"id": "aaa4a036",
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [
|
55 |
+
{
|
56 |
+
"name": "stdout",
|
57 |
+
"output_type": "stream",
|
58 |
+
"text": [
|
59 |
+
"/home/jovyan/shared/Chima/ml_project/repos/MYSMP/semantic-segmentation\n"
|
60 |
+
]
|
61 |
+
}
|
62 |
+
],
|
63 |
+
"source": [
|
64 |
+
"!pwd"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": 6,
|
70 |
+
"id": "099239c4",
|
71 |
+
"metadata": {},
|
72 |
+
"outputs": [
|
73 |
+
{
|
74 |
+
"name": "stdout",
|
75 |
+
"output_type": "stream",
|
76 |
+
"text": [
|
77 |
+
"Loading pretrained model...\n"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"name": "stderr",
|
82 |
+
"output_type": "stream",
|
83 |
+
"text": [
|
84 |
+
"/srv/conda/envs/notebook/lib/python3.8/site-packages/segmentation_models_pytorch/base/modules.py:116: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
85 |
+
" return self.activation(x)\n"
|
86 |
+
]
|
87 |
+
}
|
88 |
+
],
|
89 |
+
"source": [
|
90 |
+
"# Initialize the model\n",
|
91 |
+
"from SemanticModel.model_core import SegmentationModel\n",
|
92 |
+
"\n",
|
93 |
+
"model = SegmentationModel(\n",
|
94 |
+
" classes=['bg', 'cacao', 'matarraton', 'abarco'],\n",
|
95 |
+
" architecture='unet',\n",
|
96 |
+
" encoder='timm-regnety_120',\n",
|
97 |
+
" weights='../data/model_outputs-unet[timm-regnety_120]-01-23-2025_075803/best_model.pth'\n",
|
98 |
+
")\n",
|
99 |
+
"\n",
|
100 |
+
"# Initialize prediction pipeline\n",
|
101 |
+
"from SemanticModel.prediction import PredictionPipeline\n",
|
102 |
+
"\n",
|
103 |
+
"predictor = PredictionPipeline(model)\n",
|
104 |
+
"output_dir='../predictions'\n",
|
105 |
+
"image_path= '../data/Images/2019-Mission2-odm_1_42.jpg'\n",
|
106 |
+
"\n",
|
107 |
+
"# Make prediction\n",
|
108 |
+
"prediction = predictor.predict_single_image(image_path=image_path,output_dir=output_dir)"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": 7,
|
114 |
+
"id": "e0ae4c21",
|
115 |
+
"metadata": {},
|
116 |
+
"outputs": [
|
117 |
+
{
|
118 |
+
"name": "stdout",
|
119 |
+
"output_type": "stream",
|
120 |
+
"text": [
|
121 |
+
"\n",
|
122 |
+
"Predictions saved to: path/to/folderofImages/predictions\n"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"data": {
|
127 |
+
"text/plain": [
|
128 |
+
"'path/to/folderofImages/predictions'"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
"execution_count": 7,
|
132 |
+
"metadata": {},
|
133 |
+
"output_type": "execute_result"
|
134 |
+
}
|
135 |
+
],
|
136 |
+
"source": [
|
137 |
+
"# Directory of images\n",
|
138 |
+
"image_path= 'path/to/folderofImages'\n",
|
139 |
+
"predictor.predict_directory(image_path)"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 8,
|
145 |
+
"id": "fc4f2a48",
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [
|
148 |
+
{
|
149 |
+
"name": "stdout",
|
150 |
+
"output_type": "stream",
|
151 |
+
"text": [
|
152 |
+
"Loading raster...\n",
|
153 |
+
"Processed 6/6 tiles\n",
|
154 |
+
"Prediction saved to: ../predictions/prediction.tif\n"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"data": {
|
159 |
+
"text/plain": [
|
160 |
+
"(array([[[0, 0, 0],\n",
|
161 |
+
" [0, 0, 0],\n",
|
162 |
+
" [0, 0, 0],\n",
|
163 |
+
" ...,\n",
|
164 |
+
" [0, 0, 0],\n",
|
165 |
+
" [0, 0, 0],\n",
|
166 |
+
" [0, 0, 0]],\n",
|
167 |
+
" \n",
|
168 |
+
" [[0, 0, 0],\n",
|
169 |
+
" [0, 0, 0],\n",
|
170 |
+
" [0, 0, 0],\n",
|
171 |
+
" ...,\n",
|
172 |
+
" [0, 0, 0],\n",
|
173 |
+
" [0, 0, 0],\n",
|
174 |
+
" [0, 0, 0]],\n",
|
175 |
+
" \n",
|
176 |
+
" [[0, 0, 0],\n",
|
177 |
+
" [0, 0, 0],\n",
|
178 |
+
" [0, 0, 0],\n",
|
179 |
+
" ...,\n",
|
180 |
+
" [0, 0, 0],\n",
|
181 |
+
" [0, 0, 0],\n",
|
182 |
+
" [0, 0, 0]],\n",
|
183 |
+
" \n",
|
184 |
+
" ...,\n",
|
185 |
+
" \n",
|
186 |
+
" [[0, 0, 0],\n",
|
187 |
+
" [0, 0, 0],\n",
|
188 |
+
" [0, 0, 0],\n",
|
189 |
+
" ...,\n",
|
190 |
+
" [0, 0, 0],\n",
|
191 |
+
" [0, 0, 0],\n",
|
192 |
+
" [0, 0, 0]],\n",
|
193 |
+
" \n",
|
194 |
+
" [[0, 0, 0],\n",
|
195 |
+
" [0, 0, 0],\n",
|
196 |
+
" [0, 0, 0],\n",
|
197 |
+
" ...,\n",
|
198 |
+
" [0, 0, 0],\n",
|
199 |
+
" [0, 0, 0],\n",
|
200 |
+
" [0, 0, 0]],\n",
|
201 |
+
" \n",
|
202 |
+
" [[0, 0, 0],\n",
|
203 |
+
" [0, 0, 0],\n",
|
204 |
+
" [0, 0, 0],\n",
|
205 |
+
" ...,\n",
|
206 |
+
" [0, 0, 0],\n",
|
207 |
+
" [0, 0, 0],\n",
|
208 |
+
" [0, 0, 0]]], dtype=uint8),\n",
|
209 |
+
" {'driver': 'GTiff', 'dtype': 'uint8', 'nodata': None, 'width': 2365, 'height': 1797, 'count': 3, 'crs': CRS.from_epsg(32618), 'transform': Affine(0.03564594364277113, 0.0, 740295.5186183113,\n",
|
210 |
+
" 0.0, -0.03564594364276106, 485117.0212715292), 'tiled': False, 'interleave': 'pixel'})"
|
211 |
+
]
|
212 |
+
},
|
213 |
+
"execution_count": 8,
|
214 |
+
"metadata": {},
|
215 |
+
"output_type": "execute_result"
|
216 |
+
}
|
217 |
+
],
|
218 |
+
"source": [
|
219 |
+
"# Large raster\n",
|
220 |
+
"output_path='../predictions/prediction.tif'\n",
|
221 |
+
"raster_path = '../data/2021-Mission7_clipped_2.tif'\n",
|
222 |
+
"predictor.predict_raster(raster_path, tile_size=1024,output_path=output_path,format='color')"
|
223 |
+
]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"cell_type": "code",
|
227 |
+
"execution_count": null,
|
228 |
+
"id": "8e4ab5a5",
|
229 |
+
"metadata": {},
|
230 |
+
"outputs": [],
|
231 |
+
"source": []
|
232 |
+
}
|
233 |
+
],
|
234 |
+
"metadata": {
|
235 |
+
"kernelspec": {
|
236 |
+
"display_name": "AgLab - Python 3",
|
237 |
+
"language": "python",
|
238 |
+
"name": "python3"
|
239 |
+
},
|
240 |
+
"language_info": {
|
241 |
+
"codemirror_mode": {
|
242 |
+
"name": "ipython",
|
243 |
+
"version": 3
|
244 |
+
},
|
245 |
+
"file_extension": ".py",
|
246 |
+
"mimetype": "text/x-python",
|
247 |
+
"name": "python",
|
248 |
+
"nbconvert_exporter": "python",
|
249 |
+
"pygments_lexer": "ipython3",
|
250 |
+
"version": "3.8.13"
|
251 |
+
}
|
252 |
+
},
|
253 |
+
"nbformat": 4,
|
254 |
+
"nbformat_minor": 5
|
255 |
+
}
|
examples/train.ipynb
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "333ede5f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"%cd \"../../MYSMP\""
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"id": "84d4c945",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"# %pip install -r requirements.txt"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": null,
|
26 |
+
"id": "7b088ef9",
|
27 |
+
"metadata": {},
|
28 |
+
"outputs": [],
|
29 |
+
"source": [
|
30 |
+
"%cd \"../MYSMP/semantic-segmentation\""
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": null,
|
36 |
+
"id": "99937292",
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"from SemanticModel.model_core import SegmentationModel\n",
|
41 |
+
"from SemanticModel.training import ModelTrainer"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": null,
|
47 |
+
"id": "ab69a291",
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"# initialization loss function\n",
|
52 |
+
"model = SegmentationModel(\n",
|
53 |
+
" classes=['bg', 'cacao', 'matarraton', 'abarco'],\n",
|
54 |
+
" architecture='unet',\n",
|
55 |
+
" encoder='timm-regnety_120',\n",
|
56 |
+
" weights='imagenet',\n",
|
57 |
+
" loss='dice' # Try 'dice' or 'tversky' instead of default\n",
|
58 |
+
")\n",
|
59 |
+
"\n",
|
60 |
+
"# training parameters\n",
|
61 |
+
"trainer = ModelTrainer(\n",
|
62 |
+
" model_config=model,\n",
|
63 |
+
" root_dir='../data',\n",
|
64 |
+
" epochs=100,\n",
|
65 |
+
" train_size=1024,\n",
|
66 |
+
" batch_size=4,\n",
|
67 |
+
" learning_rate=1e-3, # Increased learning rate\n",
|
68 |
+
" step_count=3, # More learning rate adjustments\n",
|
69 |
+
" decay_factor=0.5 # Stronger decay\n",
|
70 |
+
")"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": null,
|
76 |
+
"id": "38fc7c6f",
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [],
|
79 |
+
"source": [
|
80 |
+
"trained_model, metrics = trainer.train()"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": null,
|
86 |
+
"id": "a053c2ae",
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [],
|
89 |
+
"source": []
|
90 |
+
}
|
91 |
+
],
|
92 |
+
"metadata": {
|
93 |
+
"kernelspec": {
|
94 |
+
"display_name": "AgLab - Python 3",
|
95 |
+
"language": "python",
|
96 |
+
"name": "python3"
|
97 |
+
},
|
98 |
+
"language_info": {
|
99 |
+
"codemirror_mode": {
|
100 |
+
"name": "ipython",
|
101 |
+
"version": 3
|
102 |
+
},
|
103 |
+
"file_extension": ".py",
|
104 |
+
"mimetype": "text/x-python",
|
105 |
+
"name": "python",
|
106 |
+
"nbconvert_exporter": "python",
|
107 |
+
"pygments_lexer": "ipython3",
|
108 |
+
"version": "3.8.13"
|
109 |
+
}
|
110 |
+
},
|
111 |
+
"nbformat": 4,
|
112 |
+
"nbformat_minor": 5
|
113 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
tensorboard
|
4 |
+
pyproj
|
5 |
+
fiona==1.8.20
|
6 |
+
rtree
|
7 |
+
geopandas
|
8 |
+
rasterio
|
9 |
+
slidingwindow
|
10 |
+
opencv-python
|
11 |
+
wandb
|
12 |
+
tifffile
|
13 |
+
imagecodecs
|
14 |
+
albumentations
|
15 |
+
segmentation-models-pytorch>=0.3.3
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/custom_losses-checkpoint.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from segmentation_models_pytorch.utils import base
|
4 |
+
from segmentation_models_pytorch.base.modules import Activation
|
5 |
+
|
6 |
+
class FocalLossFunction(base.Loss):
|
7 |
+
def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs):
|
8 |
+
super().__init__(**kwargs)
|
9 |
+
self.activation = Activation(activation)
|
10 |
+
self.alpha = alpha
|
11 |
+
self.gamma = gamma
|
12 |
+
self.reduction = reduction
|
13 |
+
|
14 |
+
def forward(self, inputs, targets):
|
15 |
+
if inputs.shape[1] == 1: # Binary case
|
16 |
+
inputs = torch.cat((inputs, 1 - inputs), dim=1)
|
17 |
+
targets = torch.cat((targets, 1 - targets), dim=1)
|
18 |
+
|
19 |
+
targets = torch.argmax(targets, dim=1)
|
20 |
+
cross_entropy = F.cross_entropy(inputs, targets, reduction='none')
|
21 |
+
probability = torch.exp(-cross_entropy)
|
22 |
+
alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where(
|
23 |
+
targets == 1, 1-self.alpha, self.alpha)
|
24 |
+
|
25 |
+
focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy
|
26 |
+
|
27 |
+
if self.reduction == 'mean':
|
28 |
+
return focal_weight.mean()
|
29 |
+
elif self.reduction == 'sum':
|
30 |
+
return focal_weight.sum()
|
31 |
+
return focal_weight
|
32 |
+
|
33 |
+
class TverskyLossFunction(base.Loss):
|
34 |
+
def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None,
|
35 |
+
reduction='mean', **kwargs):
|
36 |
+
super().__init__(**kwargs)
|
37 |
+
self.activation = Activation(activation)
|
38 |
+
self.alpha = alpha
|
39 |
+
self.beta = beta
|
40 |
+
self.ignore_channels = ignore_channels
|
41 |
+
self.reduction = reduction
|
42 |
+
|
43 |
+
def forward(self, inputs, targets):
|
44 |
+
if self.ignore_channels is not None:
|
45 |
+
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
|
46 |
+
mask[self.ignore_channels] = False
|
47 |
+
inputs = inputs[:, mask, ...]
|
48 |
+
|
49 |
+
num_classes = inputs.shape[1]
|
50 |
+
inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1
|
51 |
+
else F.softmax(inputs, dim=1))
|
52 |
+
|
53 |
+
if num_classes == 1:
|
54 |
+
inputs_softmax = inputs_softmax.squeeze(1)
|
55 |
+
targets = targets.squeeze(1)
|
56 |
+
|
57 |
+
tversky_loss = 0
|
58 |
+
for class_idx in range(num_classes):
|
59 |
+
if num_classes == 1:
|
60 |
+
flat_inputs = inputs_softmax.reshape(-1)
|
61 |
+
flat_targets = targets.reshape(-1)
|
62 |
+
else:
|
63 |
+
flat_inputs = inputs_softmax[:, class_idx].reshape(-1)
|
64 |
+
flat_targets = targets[:, class_idx].reshape(-1)
|
65 |
+
|
66 |
+
intersection = (flat_inputs * flat_targets).sum()
|
67 |
+
fps = ((1 - flat_targets) * flat_inputs).sum()
|
68 |
+
fns = (flat_targets * (1 - flat_inputs)).sum()
|
69 |
+
|
70 |
+
tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10
|
71 |
+
tversky_loss += 1 - intersection / tversky_index
|
72 |
+
|
73 |
+
if self.reduction == 'mean':
|
74 |
+
return tversky_loss / (1 if num_classes == 1 else num_classes)
|
75 |
+
elif self.reduction == 'sum':
|
76 |
+
return tversky_loss
|
77 |
+
return tversky_loss / inputs.shape[0]
|
78 |
+
|
79 |
+
class EnhancedCrossEntropy(base.Loss):
|
80 |
+
def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs):
|
81 |
+
super().__init__(**kwargs)
|
82 |
+
self.activation = Activation(activation)
|
83 |
+
self.ignore_channels = ignore_channels
|
84 |
+
self.reduction = reduction
|
85 |
+
|
86 |
+
def forward(self, inputs, targets):
|
87 |
+
inputs = self.activation(inputs)
|
88 |
+
|
89 |
+
if self.ignore_channels is not None:
|
90 |
+
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
|
91 |
+
mask[self.ignore_channels] = False
|
92 |
+
inputs = inputs[:, mask, ...]
|
93 |
+
|
94 |
+
if targets.dim() == 4: # Convert one-hot to class indices
|
95 |
+
targets = torch.argmax(targets, dim=1)
|
96 |
+
|
97 |
+
return F.cross_entropy(inputs, targets, reduction=self.reduction)
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/data_loader-checkpoint.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import Dataset as BaseDataset
|
5 |
+
|
6 |
+
class SegmentationDataset(BaseDataset):
|
7 |
+
"""Dataset class for semantic segmentation task."""
|
8 |
+
|
9 |
+
def __init__(self, data_dir, classes=['background', 'object'],
|
10 |
+
augmentation=None, preprocessing=None):
|
11 |
+
|
12 |
+
self.image_dir = os.path.join(data_dir, 'Images')
|
13 |
+
self.mask_dir = os.path.join(data_dir, 'Masks')
|
14 |
+
|
15 |
+
for dir_path in [self.image_dir, self.mask_dir]:
|
16 |
+
if not os.path.exists(dir_path):
|
17 |
+
raise FileNotFoundError(f"Directory not found: {dir_path}")
|
18 |
+
|
19 |
+
self.filenames = self._get_filenames()
|
20 |
+
self.image_paths = [os.path.join(self.image_dir, fname) for fname in self.filenames]
|
21 |
+
self.mask_paths = self._get_mask_paths()
|
22 |
+
|
23 |
+
self.target_classes = [cls for cls in classes if cls.lower() != 'background']
|
24 |
+
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
|
25 |
+
|
26 |
+
self.augmentation = augmentation
|
27 |
+
self.preprocessing = preprocessing
|
28 |
+
|
29 |
+
def __getitem__(self, index):
|
30 |
+
image = self._load_image(self.image_paths[index])
|
31 |
+
mask = self._load_mask(self.mask_paths[index])
|
32 |
+
|
33 |
+
if self.augmentation:
|
34 |
+
processed = self.augmentation(image=image, mask=mask)
|
35 |
+
image, mask = processed['image'], processed['mask']
|
36 |
+
|
37 |
+
if self.preprocessing:
|
38 |
+
processed = self.preprocessing(image=image, mask=mask)
|
39 |
+
image, mask = processed['image'], processed['mask']
|
40 |
+
|
41 |
+
return image, mask
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.filenames)
|
45 |
+
|
46 |
+
def _get_filenames(self):
|
47 |
+
"""Returns sorted list of filenames, excluding directories."""
|
48 |
+
files = sorted(os.listdir(self.image_dir))
|
49 |
+
return [f for f in files if not os.path.isdir(os.path.join(self.image_dir, f))]
|
50 |
+
|
51 |
+
def _get_mask_paths(self):
|
52 |
+
"""Generates corresponding mask paths for each image."""
|
53 |
+
mask_paths = []
|
54 |
+
for image_file in self.filenames:
|
55 |
+
name, _ = os.path.splitext(image_file)
|
56 |
+
mask_paths.append(os.path.join(self.mask_dir, f"{name}.png"))
|
57 |
+
return mask_paths
|
58 |
+
|
59 |
+
def _load_image(self, path):
|
60 |
+
"""Loads and converts image to RGB."""
|
61 |
+
image = cv2.imread(path)
|
62 |
+
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
63 |
+
|
64 |
+
def _load_mask(self, path):
|
65 |
+
"""Loads and processes segmentation mask."""
|
66 |
+
mask = cv2.imread(path, 0)
|
67 |
+
masks = [(mask == value) for value in self.class_values]
|
68 |
+
mask = np.stack(masks, axis=-1).astype('float')
|
69 |
+
return mask
|
70 |
+
|
71 |
+
class InferenceDataset(BaseDataset):
|
72 |
+
"""Dataset class for inference without ground truth masks."""
|
73 |
+
|
74 |
+
def __init__(self, data_dir, classes=['background', 'object'],
|
75 |
+
augmentation=None, preprocessing=None):
|
76 |
+
self.filenames = sorted([
|
77 |
+
f for f in os.listdir(data_dir)
|
78 |
+
if not os.path.isdir(os.path.join(data_dir, f))
|
79 |
+
])
|
80 |
+
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
|
81 |
+
|
82 |
+
self.target_classes = [cls for cls in classes if cls.lower() != 'background']
|
83 |
+
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
|
84 |
+
|
85 |
+
self.augmentation = augmentation
|
86 |
+
self.preprocessing = preprocessing
|
87 |
+
|
88 |
+
def __getitem__(self, index):
|
89 |
+
image = cv2.imread(self.image_paths[index])
|
90 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
91 |
+
original_height, original_width = image.shape[:2]
|
92 |
+
|
93 |
+
if self.augmentation:
|
94 |
+
image = self.augmentation(image=image)['image']
|
95 |
+
|
96 |
+
if self.preprocessing:
|
97 |
+
image = self.preprocessing(image=image)['image']
|
98 |
+
|
99 |
+
return image, original_height, original_width
|
100 |
+
|
101 |
+
def __len__(self):
|
102 |
+
return len(self.filenames)
|
103 |
+
|
104 |
+
class StreamingDataset(BaseDataset):
|
105 |
+
"""Dataset class optimized for video frame processing."""
|
106 |
+
|
107 |
+
def __init__(self, data_dir, classes=['background', 'object'],
|
108 |
+
augmentation=None, preprocessing=None):
|
109 |
+
self.filenames = self._get_frame_filenames(data_dir)
|
110 |
+
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
|
111 |
+
|
112 |
+
self.target_classes = [cls for cls in classes if cls.lower() != 'background']
|
113 |
+
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
|
114 |
+
|
115 |
+
self.augmentation = augmentation
|
116 |
+
self.preprocessing = preprocessing
|
117 |
+
|
118 |
+
def _get_frame_filenames(self, directory):
|
119 |
+
"""Returns sorted list of frame filenames."""
|
120 |
+
files = sorted(os.listdir(directory))
|
121 |
+
return [f for f in files if (('frame' in f or 'Image' in f) and
|
122 |
+
f.lower().endswith('jpg') and
|
123 |
+
not os.path.isdir(os.path.join(directory, f)))]
|
124 |
+
|
125 |
+
def __getitem__(self, index):
|
126 |
+
return InferenceDataset.__getitem__(self, index)
|
127 |
+
|
128 |
+
def __len__(self):
|
129 |
+
return len(self.filenames)
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/encoder_management-checkpoint.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ssl
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
import hashlib
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.hub import get_dir
|
8 |
+
from urllib.request import urlopen, Request
|
9 |
+
|
10 |
+
from segmentation_models_pytorch.encoders import (
|
11 |
+
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
|
12 |
+
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
|
13 |
+
efficient_net_encoders, mobilenet_encoders, xception_encoders,
|
14 |
+
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
|
15 |
+
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
|
16 |
+
timm_gernet_encoders
|
17 |
+
)
|
18 |
+
|
19 |
+
from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder
|
20 |
+
|
21 |
+
def initialize_encoders():
|
22 |
+
"""Initialize dictionary of available encoders."""
|
23 |
+
available_encoders = {}
|
24 |
+
encoder_modules = [
|
25 |
+
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
|
26 |
+
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
|
27 |
+
efficient_net_encoders, mobilenet_encoders, xception_encoders,
|
28 |
+
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
|
29 |
+
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
|
30 |
+
timm_gernet_encoders
|
31 |
+
]
|
32 |
+
|
33 |
+
for module in encoder_modules:
|
34 |
+
available_encoders.update(module)
|
35 |
+
|
36 |
+
try:
|
37 |
+
import segmentation_models_pytorch
|
38 |
+
from packaging import version
|
39 |
+
if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"):
|
40 |
+
from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders
|
41 |
+
from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders
|
42 |
+
available_encoders.update(mix_transformer_encoders)
|
43 |
+
available_encoders.update(mobileone_encoders)
|
44 |
+
except ImportError:
|
45 |
+
pass
|
46 |
+
|
47 |
+
return available_encoders
|
48 |
+
|
49 |
+
def download_weights(url, destination, hash_prefix=None, show_progress=True):
|
50 |
+
"""Downloads model weights with progress tracking and verification."""
|
51 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
52 |
+
|
53 |
+
req = Request(url, headers={"User-Agent": "torch.hub"})
|
54 |
+
response = urlopen(req)
|
55 |
+
content_length = response.headers.get("Content-Length")
|
56 |
+
file_size = int(content_length[0]) if content_length else None
|
57 |
+
|
58 |
+
destination = os.path.expanduser(destination)
|
59 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination))
|
60 |
+
|
61 |
+
try:
|
62 |
+
hasher = hashlib.sha256() if hash_prefix else None
|
63 |
+
|
64 |
+
with tqdm(total=file_size, disable=not show_progress,
|
65 |
+
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
66 |
+
while True:
|
67 |
+
buffer = response.read(8192)
|
68 |
+
if not buffer:
|
69 |
+
break
|
70 |
+
|
71 |
+
temp_file.write(buffer)
|
72 |
+
if hasher:
|
73 |
+
hasher.update(buffer)
|
74 |
+
pbar.update(len(buffer))
|
75 |
+
|
76 |
+
temp_file.close()
|
77 |
+
|
78 |
+
if hasher and hash_prefix:
|
79 |
+
digest = hasher.hexdigest()
|
80 |
+
if digest[:len(hash_prefix)] != hash_prefix:
|
81 |
+
raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")')
|
82 |
+
|
83 |
+
shutil.move(temp_file.name, destination)
|
84 |
+
|
85 |
+
finally:
|
86 |
+
temp_file.close()
|
87 |
+
if os.path.exists(temp_file.name):
|
88 |
+
os.remove(temp_file.name)
|
89 |
+
|
90 |
+
def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
|
91 |
+
"""Initializes and returns configured encoder."""
|
92 |
+
encoders = initialize_encoders()
|
93 |
+
|
94 |
+
if name.startswith("tu-"):
|
95 |
+
name = name[3:]
|
96 |
+
return TimmUniversalEncoder(
|
97 |
+
name=name,
|
98 |
+
in_channels=in_channels,
|
99 |
+
depth=depth,
|
100 |
+
output_stride=output_stride,
|
101 |
+
pretrained=weights is not None,
|
102 |
+
**kwargs
|
103 |
+
)
|
104 |
+
|
105 |
+
try:
|
106 |
+
encoder_config = encoders[name]
|
107 |
+
except KeyError:
|
108 |
+
raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}")
|
109 |
+
|
110 |
+
encoder_class = encoder_config["encoder"]
|
111 |
+
encoder_params = encoder_config["params"]
|
112 |
+
encoder_params.update(depth=depth)
|
113 |
+
|
114 |
+
if weights:
|
115 |
+
try:
|
116 |
+
weights_config = encoder_config["pretrained_settings"][weights]
|
117 |
+
except KeyError:
|
118 |
+
raise KeyError(
|
119 |
+
f"Invalid weights '{weights}' for encoder '{name}'. "
|
120 |
+
f"Available options: {list(encoder_config['pretrained_settings'].keys())}"
|
121 |
+
)
|
122 |
+
|
123 |
+
cache_dir = os.path.join(get_dir(), 'checkpoints')
|
124 |
+
os.makedirs(cache_dir, exist_ok=True)
|
125 |
+
|
126 |
+
weights_file = os.path.basename(weights_config["url"])
|
127 |
+
weights_path = os.path.join(cache_dir, weights_file)
|
128 |
+
|
129 |
+
if not os.path.exists(weights_path):
|
130 |
+
print(f'Downloading {weights_file}...')
|
131 |
+
download_weights(
|
132 |
+
weights_config["url"].replace("https", "http"),
|
133 |
+
weights_path
|
134 |
+
)
|
135 |
+
|
136 |
+
return encoder_class(**encoder_params)
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/evaluation_utils-checkpoint.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from segmentation_models_pytorch.base.modules import Activation
|
7 |
+
|
8 |
+
from SemanticModel.data_loader import SegmentationDataset
|
9 |
+
from SemanticModel.metrics import compute_mean_iou
|
10 |
+
from SemanticModel.image_preprocessing import get_validation_augmentations
|
11 |
+
|
12 |
+
def evaluate_model(model_config, data_path, image_size=None):
|
13 |
+
"""Evaluates model performance on a dataset."""
|
14 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
+
|
16 |
+
classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
|
17 |
+
|
18 |
+
data_path = os.path.realpath(data_path)
|
19 |
+
image_subdir = os.path.join(data_path, 'Images')
|
20 |
+
mask_subdir = os.path.join(data_path, 'Masks')
|
21 |
+
|
22 |
+
if not all(os.path.exists(d) for d in [image_subdir, mask_subdir]):
|
23 |
+
raise Exception("Missing required subdirectories: 'Images' and 'Masks'")
|
24 |
+
|
25 |
+
if not image_size:
|
26 |
+
sample_image = cv2.imread(os.path.join(image_subdir, os.listdir(image_subdir)[0]))
|
27 |
+
height, width = sample_image.shape[:2]
|
28 |
+
image_size = max(height, width)
|
29 |
+
|
30 |
+
evaluation_dataset = SegmentationDataset(
|
31 |
+
data_path,
|
32 |
+
classes=classes,
|
33 |
+
augmentation=get_validation_augmentations(
|
34 |
+
im_width=image_size,
|
35 |
+
im_height=image_size,
|
36 |
+
fixed_size=False
|
37 |
+
),
|
38 |
+
preprocessing=model_config.preprocessing
|
39 |
+
)
|
40 |
+
|
41 |
+
evaluation_loader = DataLoader(
|
42 |
+
evaluation_dataset,
|
43 |
+
batch_size=1,
|
44 |
+
shuffle=False,
|
45 |
+
num_workers=2
|
46 |
+
)
|
47 |
+
|
48 |
+
model = model_config.model.to(device)
|
49 |
+
model.eval()
|
50 |
+
|
51 |
+
requires_sigmoid = False
|
52 |
+
if model_config.n_classes == 1:
|
53 |
+
current_activation = _check_activation_function(model)
|
54 |
+
if current_activation != 'Sigmoid':
|
55 |
+
requires_sigmoid = True
|
56 |
+
|
57 |
+
predictions = []
|
58 |
+
ground_truth = []
|
59 |
+
|
60 |
+
print("Evaluating model performance...")
|
61 |
+
with torch.no_grad():
|
62 |
+
for images, masks in tqdm(evaluation_loader):
|
63 |
+
images = images.to(device)
|
64 |
+
masks = masks.to(device)
|
65 |
+
|
66 |
+
outputs = model.forward(images)
|
67 |
+
|
68 |
+
if model_config.n_classes > 1:
|
69 |
+
predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
|
70 |
+
ground_truth.extend([gt.cpu().argmax(dim=0) for gt in masks])
|
71 |
+
else:
|
72 |
+
if requires_sigmoid:
|
73 |
+
predictions.extend([
|
74 |
+
(torch.sigmoid(p) > 0.5).float().squeeze().cpu()
|
75 |
+
for p in outputs
|
76 |
+
])
|
77 |
+
else:
|
78 |
+
predictions.extend([
|
79 |
+
(p > 0.5).float().squeeze().cpu()
|
80 |
+
for p in outputs
|
81 |
+
])
|
82 |
+
ground_truth.extend([gt.cpu().squeeze() for gt in masks])
|
83 |
+
|
84 |
+
metrics = compute_mean_iou(
|
85 |
+
predictions,
|
86 |
+
ground_truth,
|
87 |
+
num_labels=len(classes),
|
88 |
+
ignore_index=255
|
89 |
+
)
|
90 |
+
|
91 |
+
print("\nEvaluation Results:")
|
92 |
+
print(f"Mean IoU: {metrics['mean_iou']:.3f}")
|
93 |
+
print("\nPer-class IoU:")
|
94 |
+
for idx, iou in enumerate(metrics['per_category_iou']):
|
95 |
+
print(f"{classes[idx]}: {iou:.3f}")
|
96 |
+
|
97 |
+
return metrics
|
98 |
+
|
99 |
+
def _check_activation_function(model):
|
100 |
+
"""Checks the activation function used in model's segmentation head."""
|
101 |
+
from segmentation_models_pytorch.base.modules import Activation
|
102 |
+
|
103 |
+
activation_functions = []
|
104 |
+
for _, module in model.segmentation_head.named_children():
|
105 |
+
if isinstance(module, Activation):
|
106 |
+
activation_functions.append(type(module.activation).__name__)
|
107 |
+
|
108 |
+
return activation_functions[-1] if activation_functions else None
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/image_preprocessing-checkpoint.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import albumentations as albu
|
4 |
+
from albumentations.augmentations.geometric.resize import LongestMaxSize
|
5 |
+
|
6 |
+
def round_pixel_dim(dimension: float) -> int:
|
7 |
+
"""Rounds pixel dimensions consistently."""
|
8 |
+
if abs(round(dimension) - dimension) == 0.5:
|
9 |
+
return int(2.0 * round(dimension / 2.0))
|
10 |
+
return int(round(dimension))
|
11 |
+
|
12 |
+
def resize_with_padding(image, target_size, stride=32, interpolation=cv2.INTER_LINEAR):
|
13 |
+
"""Resizes image maintaining aspect ratio and ensures dimensions are stride-compatible."""
|
14 |
+
height, width = image.shape[:2]
|
15 |
+
max_dimension = max(height, width)
|
16 |
+
|
17 |
+
if ((height % stride == 0) and (width % stride == 0) and
|
18 |
+
(max_dimension <= target_size)):
|
19 |
+
return image
|
20 |
+
|
21 |
+
scale = target_size / float(max(width, height))
|
22 |
+
new_dims = tuple(round_pixel_dim(dim * scale) for dim in (height, width))
|
23 |
+
new_height, new_width = new_dims
|
24 |
+
|
25 |
+
new_height = ((new_height // stride + 1) * stride
|
26 |
+
if new_height % stride != 0 else new_height)
|
27 |
+
new_width = ((new_width // stride + 1) * stride
|
28 |
+
if new_width % stride != 0 else new_width)
|
29 |
+
|
30 |
+
return cv2.resize(image, (new_width, new_height), interpolation=interpolation)
|
31 |
+
|
32 |
+
class PaddedResize(LongestMaxSize):
|
33 |
+
def apply(self, img: np.ndarray, target_size: int = 1024,
|
34 |
+
interpolation: int = cv2.INTER_LINEAR, **params) -> np.ndarray:
|
35 |
+
return resize_with_padding(img, target_size=target_size, interpolation=interpolation)
|
36 |
+
|
37 |
+
def get_training_augmentations(width=768, height=576):
|
38 |
+
"""Configures training-time augmentations."""
|
39 |
+
target_size = max([width, height])
|
40 |
+
transforms = [
|
41 |
+
albu.HorizontalFlip(p=0.5),
|
42 |
+
albu.ShiftScaleRotate(
|
43 |
+
scale_limit=0.5, rotate_limit=90, shift_limit=0.1, p=0.5, border_mode=0),
|
44 |
+
albu.PadIfNeeded(min_height=target_size, min_width=target_size, always_apply=True),
|
45 |
+
albu.RandomCrop(height=target_size, width=target_size, always_apply=True),
|
46 |
+
albu.GaussNoise(p=0.2),
|
47 |
+
albu.Perspective(p=0.2),
|
48 |
+
albu.OneOf([albu.CLAHE(p=1), albu.RandomGamma(p=1)], p=0.33),
|
49 |
+
albu.OneOf([
|
50 |
+
albu.Sharpen(p=1),
|
51 |
+
albu.Blur(blur_limit=3, p=1),
|
52 |
+
albu.MotionBlur(blur_limit=3, p=1)], p=0.33),
|
53 |
+
albu.OneOf([
|
54 |
+
albu.RandomBrightnessContrast(p=1),
|
55 |
+
albu.HueSaturationValue(p=1)], p=0.33),
|
56 |
+
]
|
57 |
+
return albu.Compose(transforms)
|
58 |
+
|
59 |
+
def get_validation_augmentations(width=1920, height=1440, fixed_size=True):
|
60 |
+
"""Configures validation/inference-time augmentations."""
|
61 |
+
if fixed_size:
|
62 |
+
transforms = [albu.Resize(height=height, width=width, always_apply=True)]
|
63 |
+
return albu.Compose(transforms)
|
64 |
+
|
65 |
+
target_size = max(width, height)
|
66 |
+
transforms = [PaddedResize(max_size=target_size, always_apply=True)]
|
67 |
+
return albu.Compose(transforms)
|
68 |
+
|
69 |
+
def convert_to_tensor(x, **kwargs):
|
70 |
+
"""Converts image array to PyTorch tensor format."""
|
71 |
+
if x.ndim == 2:
|
72 |
+
x = np.expand_dims(x, axis=-1)
|
73 |
+
return x.transpose(2, 0, 1).astype('float32')
|
74 |
+
|
75 |
+
def get_preprocessing_pipeline(preprocessing_fn):
|
76 |
+
"""Builds preprocessing pipeline including normalization and tensor conversion."""
|
77 |
+
transforms = [
|
78 |
+
albu.Lambda(image=preprocessing_fn),
|
79 |
+
albu.Lambda(image=convert_to_tensor, mask=convert_to_tensor),
|
80 |
+
]
|
81 |
+
return albu.Compose(transforms)
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/metrics-checkpoint.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def compute_intersection_union(prediction, ground_truth, num_classes, ignore_index: bool,
|
5 |
+
label_mapping: Optional[Dict[int, int]] = None,
|
6 |
+
reduce_labels: bool = False):
|
7 |
+
"""Computes intersection and union for IoU calculation."""
|
8 |
+
|
9 |
+
if label_mapping:
|
10 |
+
for old_id, new_id in label_mapping.items():
|
11 |
+
ground_truth[ground_truth == old_id] = new_id
|
12 |
+
|
13 |
+
prediction = np.array(prediction)
|
14 |
+
ground_truth = np.array(ground_truth)
|
15 |
+
|
16 |
+
if reduce_labels:
|
17 |
+
ground_truth[ground_truth == 0] = 255
|
18 |
+
ground_truth = ground_truth - 1
|
19 |
+
ground_truth[ground_truth == 254] = 255
|
20 |
+
|
21 |
+
valid_mask = np.not_equal(ground_truth, ignore_index)
|
22 |
+
prediction = prediction[valid_mask]
|
23 |
+
ground_truth = ground_truth[valid_mask]
|
24 |
+
|
25 |
+
intersection_mask = prediction == ground_truth
|
26 |
+
intersection = prediction[intersection_mask]
|
27 |
+
|
28 |
+
area_intersection = np.histogram(intersection, bins=num_classes,
|
29 |
+
range=(0, num_classes - 1))[0]
|
30 |
+
area_prediction = np.histogram(prediction, bins=num_classes,
|
31 |
+
range=(0, num_classes - 1))[0]
|
32 |
+
area_ground_truth = np.histogram(ground_truth, bins=num_classes,
|
33 |
+
range=(0, num_classes - 1))[0]
|
34 |
+
area_union = area_prediction + area_ground_truth - area_intersection
|
35 |
+
|
36 |
+
return area_intersection, area_union, area_prediction, area_ground_truth
|
37 |
+
|
38 |
+
def compute_total_intersection_union(predictions, ground_truths, num_classes, ignore_index: bool,
|
39 |
+
label_mapping: Optional[Dict[int, int]] = None,
|
40 |
+
reduce_labels: bool = False):
|
41 |
+
"""Computes total intersection and union across all samples."""
|
42 |
+
|
43 |
+
totals = {
|
44 |
+
'intersection': np.zeros((num_classes,), dtype=np.float64),
|
45 |
+
'union': np.zeros((num_classes,), dtype=np.float64),
|
46 |
+
'prediction': np.zeros((num_classes,), dtype=np.float64),
|
47 |
+
'ground_truth': np.zeros((num_classes,), dtype=np.float64)
|
48 |
+
}
|
49 |
+
|
50 |
+
for pred, gt in zip(predictions, ground_truths):
|
51 |
+
intersection, union, pred_area, gt_area = compute_intersection_union(
|
52 |
+
pred, gt, num_classes, ignore_index, label_mapping, reduce_labels
|
53 |
+
)
|
54 |
+
totals['intersection'] += intersection
|
55 |
+
totals['union'] += union
|
56 |
+
totals['prediction'] += pred_area
|
57 |
+
totals['ground_truth'] += gt_area
|
58 |
+
|
59 |
+
return tuple(totals.values())
|
60 |
+
|
61 |
+
def compute_mean_iou(predictions, ground_truths, num_classes, ignore_index: bool,
|
62 |
+
nan_to_num: Optional[int] = None,
|
63 |
+
label_mapping: Optional[Dict[int, int]] = None,
|
64 |
+
reduce_labels: bool = False):
|
65 |
+
"""Computes mean IoU and related metrics."""
|
66 |
+
|
67 |
+
intersection, union, prediction_area, ground_truth_area = compute_total_intersection_union(
|
68 |
+
predictions, ground_truths, num_classes, ignore_index, label_mapping, reduce_labels
|
69 |
+
)
|
70 |
+
|
71 |
+
metrics = {}
|
72 |
+
|
73 |
+
# Compute overall accuracy
|
74 |
+
total_accuracy = intersection.sum() / ground_truth_area.sum()
|
75 |
+
|
76 |
+
# Compute IoU per class
|
77 |
+
iou_per_class = intersection / union
|
78 |
+
accuracy_per_class = intersection / ground_truth_area
|
79 |
+
|
80 |
+
metrics.update({
|
81 |
+
"mean_iou": np.nanmean(iou_per_class),
|
82 |
+
"mean_accuracy": np.nanmean(accuracy_per_class),
|
83 |
+
"overall_accuracy": total_accuracy,
|
84 |
+
"per_category_iou": iou_per_class,
|
85 |
+
"per_category_accuracy": accuracy_per_class
|
86 |
+
})
|
87 |
+
|
88 |
+
if nan_to_num is not None:
|
89 |
+
metrics = {
|
90 |
+
metric: np.nan_to_num(value, nan=nan_to_num)
|
91 |
+
for metric, value in metrics.items()
|
92 |
+
}
|
93 |
+
|
94 |
+
return metrics
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/model_core-checkpoint.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import segmentation_models_pytorch as smp
|
4 |
+
from segmentation_models_pytorch import utils
|
5 |
+
|
6 |
+
from SemanticModel.encoder_management import initialize_encoder
|
7 |
+
from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy
|
8 |
+
from SemanticModel.image_preprocessing import get_preprocessing_pipeline
|
9 |
+
|
10 |
+
class SegmentationModel:
|
11 |
+
def __init__(self, classes=['background', 'foreground'], architecture='unet',
|
12 |
+
encoder='timm-regnety_120', weights='imagenet', loss=None):
|
13 |
+
self._initialize_classes(classes)
|
14 |
+
self.architecture = architecture
|
15 |
+
self.encoder = encoder
|
16 |
+
self.weights = weights
|
17 |
+
self._setup_loss_function(loss)
|
18 |
+
self._initialize_model()
|
19 |
+
|
20 |
+
def _initialize_classes(self, classes):
|
21 |
+
"""Sets up class configuration."""
|
22 |
+
if len(classes) <= 2:
|
23 |
+
self.classes = [c for c in classes if c.lower() != 'background']
|
24 |
+
self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background']
|
25 |
+
self.background_flag = 'background' in classes
|
26 |
+
else:
|
27 |
+
self.classes = classes
|
28 |
+
self.class_values = list(range(len(classes)))
|
29 |
+
self.background_flag = False
|
30 |
+
self.n_classes = len(self.classes)
|
31 |
+
|
32 |
+
def _setup_loss_function(self, loss):
|
33 |
+
"""Configures model's loss function."""
|
34 |
+
if not loss:
|
35 |
+
loss = 'bce_with_logits' if self.n_classes > 1 else 'dice'
|
36 |
+
|
37 |
+
if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']:
|
38 |
+
print(f'Invalid loss: {loss}, defaulting to dice')
|
39 |
+
loss = 'dice'
|
40 |
+
|
41 |
+
loss_configs = {
|
42 |
+
'bce_with_logits': {
|
43 |
+
'activation': None,
|
44 |
+
'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss()
|
45 |
+
},
|
46 |
+
'dice': {
|
47 |
+
'activation': 'softmax' if self.n_classes > 1 else 'sigmoid',
|
48 |
+
'loss': utils.losses.DiceLoss()
|
49 |
+
},
|
50 |
+
'focal': {
|
51 |
+
'activation': None,
|
52 |
+
'loss': FocalLossFunction()
|
53 |
+
},
|
54 |
+
'tversky': {
|
55 |
+
'activation': None,
|
56 |
+
'loss': TverskyLossFunction()
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
+
config = loss_configs[loss.lower()]
|
61 |
+
self.activation = config['activation']
|
62 |
+
self.loss = config['loss']
|
63 |
+
self.loss_name = loss
|
64 |
+
|
65 |
+
def _initialize_model(self):
|
66 |
+
"""Initializes the segmentation model architecture."""
|
67 |
+
if self.weights.endswith('pth'):
|
68 |
+
self._load_pretrained_model()
|
69 |
+
else:
|
70 |
+
self._create_new_model()
|
71 |
+
|
72 |
+
def _load_pretrained_model(self):
|
73 |
+
"""Loads model from pretrained weights."""
|
74 |
+
print('Loading pretrained model...')
|
75 |
+
self.model = torch.load(self.weights)
|
76 |
+
if isinstance(self.model, torch.nn.DataParallel):
|
77 |
+
self.model = self.model.module
|
78 |
+
|
79 |
+
try:
|
80 |
+
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
|
81 |
+
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
|
82 |
+
except:
|
83 |
+
print('Failed to configure preprocessing. Setting to None.')
|
84 |
+
self.preprocessing = None
|
85 |
+
|
86 |
+
def _create_new_model(self):
|
87 |
+
"""Creates new model with specified architecture."""
|
88 |
+
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
|
89 |
+
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
|
90 |
+
initialize_encoder(name=self.encoder, weights=self.weights)
|
91 |
+
|
92 |
+
architectures = {
|
93 |
+
'unet': smp.Unet,
|
94 |
+
'unet++': smp.UnetPlusPlus,
|
95 |
+
'deeplabv3': smp.DeepLabV3,
|
96 |
+
'deeplabv3+': smp.DeepLabV3Plus,
|
97 |
+
'fpn': smp.FPN,
|
98 |
+
'linknet': smp.Linknet,
|
99 |
+
'manet': smp.MAnet,
|
100 |
+
'pan': smp.PAN,
|
101 |
+
'pspnet': smp.PSPNet
|
102 |
+
}
|
103 |
+
|
104 |
+
if self.architecture not in architectures:
|
105 |
+
raise ValueError(f'Unsupported architecture: {self.architecture}')
|
106 |
+
|
107 |
+
self.model = architectures[self.architecture](
|
108 |
+
encoder_name=self.encoder,
|
109 |
+
encoder_weights=self.weights,
|
110 |
+
classes=self.n_classes,
|
111 |
+
activation=self.activation
|
112 |
+
)
|
113 |
+
|
114 |
+
@property
|
115 |
+
def config_data(self):
|
116 |
+
"""Returns model configuration data."""
|
117 |
+
return {
|
118 |
+
'architecture': self.architecture,
|
119 |
+
'encoder': self.encoder,
|
120 |
+
'weights': self.weights,
|
121 |
+
'activation': self.activation,
|
122 |
+
'loss': self.loss_name,
|
123 |
+
'classes': ['background'] + self.classes if self.background_flag else self.classes
|
124 |
+
}
|
125 |
+
|
126 |
+
def list_architectures():
|
127 |
+
"""Returns available architecture options."""
|
128 |
+
return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn',
|
129 |
+
'linknet', 'manet', 'pan', 'pspnet']
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/prediction-checkpoint.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import imageio
|
6 |
+
import tifffile
|
7 |
+
import numpy as np
|
8 |
+
import slidingwindow
|
9 |
+
import rasterio as rio
|
10 |
+
import geopandas as gpd
|
11 |
+
from shapely.geometry import Polygon
|
12 |
+
from rasterio import mask as riomask
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from SemanticModel.visualization import generate_color_mapping
|
15 |
+
from SemanticModel.image_preprocessing import get_validation_augmentations
|
16 |
+
from SemanticModel.data_loader import InferenceDataset, StreamingDataset
|
17 |
+
from SemanticModel.utilities import calc_image_size, convert_coordinates
|
18 |
+
|
19 |
+
class PredictionPipeline:
|
20 |
+
def __init__(self, model_config, device=None):
|
21 |
+
self.config = model_config
|
22 |
+
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
+
self.classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
|
24 |
+
self.colors = generate_color_mapping(len(self.classes))
|
25 |
+
self.model = model_config.model.to(self.device)
|
26 |
+
self.model.eval()
|
27 |
+
|
28 |
+
def _preprocess_image(self, image_path, target_size=None):
|
29 |
+
"""Preprocesses single image for prediction."""
|
30 |
+
image = cv2.imread(image_path)
|
31 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
32 |
+
height, width = image.shape[:2]
|
33 |
+
|
34 |
+
target_size = target_size or max(height, width)
|
35 |
+
test_height, test_width = calc_image_size(image, target_size)
|
36 |
+
|
37 |
+
augmentation = get_validation_augmentations(test_width, test_height)
|
38 |
+
image = augmentation(image=image)['image']
|
39 |
+
image = self.config.preprocessing(image=image)['image']
|
40 |
+
|
41 |
+
return image, (height, width)
|
42 |
+
|
43 |
+
def predict_single_image(self, image_path, target_size=None, output_dir=None,
|
44 |
+
format='integer', save_output=True):
|
45 |
+
"""Generates prediction for a single image."""
|
46 |
+
image, original_dims = self._preprocess_image(image_path, target_size)
|
47 |
+
x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
prediction = self.model.predict(x_tensor)
|
51 |
+
|
52 |
+
if self.config.n_classes > 1:
|
53 |
+
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
|
54 |
+
else:
|
55 |
+
prediction = prediction.squeeze().cpu().numpy().round()
|
56 |
+
|
57 |
+
# Resize to original dimensions if needed
|
58 |
+
if prediction.shape[:2] != original_dims:
|
59 |
+
prediction = cv2.resize(prediction, original_dims[::-1],
|
60 |
+
interpolation=cv2.INTER_NEAREST)
|
61 |
+
|
62 |
+
prediction = self._format_prediction(prediction, format)
|
63 |
+
|
64 |
+
if save_output:
|
65 |
+
self._save_prediction(prediction, image_path, output_dir, format)
|
66 |
+
|
67 |
+
return prediction
|
68 |
+
|
69 |
+
def predict_directory(self, input_dir, target_size=None, output_dir=None,
|
70 |
+
fixed_size=True, format='integer'):
|
71 |
+
"""Generates predictions for all images in directory."""
|
72 |
+
output_dir = output_dir or os.path.join(input_dir, 'predictions')
|
73 |
+
os.makedirs(output_dir, exist_ok=True)
|
74 |
+
|
75 |
+
dataset = InferenceDataset(
|
76 |
+
input_dir,
|
77 |
+
classes=self.classes,
|
78 |
+
augmentation=get_validation_augmentations(
|
79 |
+
target_size, target_size, fixed_size=fixed_size
|
80 |
+
) if target_size else None,
|
81 |
+
preprocessing=self.config.preprocessing
|
82 |
+
)
|
83 |
+
|
84 |
+
total_images = len(dataset)
|
85 |
+
start_time = time.time()
|
86 |
+
|
87 |
+
for idx in range(total_images):
|
88 |
+
if (idx + 1) % 10 == 0 or idx == total_images - 1:
|
89 |
+
elapsed = time.time() - start_time
|
90 |
+
print(f'\rProcessed {idx+1}/{total_images} images in {elapsed:.1f}s',
|
91 |
+
end='')
|
92 |
+
|
93 |
+
image, height, width = dataset[idx]
|
94 |
+
filename = dataset.filenames[idx]
|
95 |
+
|
96 |
+
x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)
|
97 |
+
with torch.no_grad():
|
98 |
+
prediction = self.model.predict(x_tensor)
|
99 |
+
|
100 |
+
if self.config.n_classes > 1:
|
101 |
+
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
|
102 |
+
else:
|
103 |
+
prediction = prediction.squeeze().cpu().numpy().round()
|
104 |
+
|
105 |
+
if prediction.shape != (height, width):
|
106 |
+
prediction = cv2.resize(prediction, (width, height),
|
107 |
+
interpolation=cv2.INTER_NEAREST)
|
108 |
+
|
109 |
+
prediction = self._format_prediction(prediction, format)
|
110 |
+
self._save_prediction(prediction, filename, output_dir, format)
|
111 |
+
|
112 |
+
print(f'\nPredictions saved to: {output_dir}')
|
113 |
+
return output_dir
|
114 |
+
|
115 |
+
def predict_raster(self, raster_path, tile_size=1024, overlap=0.175,
|
116 |
+
boundary_path=None, output_path=None, format='integer'):
|
117 |
+
"""Processes large raster images using tiling approach."""
|
118 |
+
print('Loading raster...')
|
119 |
+
with rio.open(raster_path) as src:
|
120 |
+
raster = src.read()
|
121 |
+
raster = np.moveaxis(raster, 0, 2)[:,:,:3]
|
122 |
+
profile = src.profile
|
123 |
+
transform = src.transform
|
124 |
+
|
125 |
+
if boundary_path:
|
126 |
+
boundary = gpd.read_file(boundary_path)
|
127 |
+
boundary = boundary.to_crs(profile['crs'])
|
128 |
+
boundary_geom = boundary.iloc[0].geometry
|
129 |
+
|
130 |
+
tiles = slidingwindow.generate(
|
131 |
+
raster,
|
132 |
+
slidingwindow.DimOrder.HeightWidthChannel,
|
133 |
+
tile_size,
|
134 |
+
overlap
|
135 |
+
)
|
136 |
+
|
137 |
+
pred_raster = np.zeros_like(raster[:,:,0], dtype='uint8')
|
138 |
+
confidence = np.zeros_like(pred_raster, dtype=np.float32)
|
139 |
+
|
140 |
+
aug = get_validation_augmentations(tile_size, tile_size, fixed_size=False)
|
141 |
+
|
142 |
+
for idx, tile in enumerate(tiles):
|
143 |
+
if (idx + 1) % 10 == 0 or idx == len(tiles) - 1:
|
144 |
+
print(f'\rProcessed {idx+1}/{len(tiles)} tiles', end='')
|
145 |
+
|
146 |
+
bounds = tile.indices()
|
147 |
+
|
148 |
+
tile_image = raster[bounds[0], bounds[1]]
|
149 |
+
|
150 |
+
if boundary_path:
|
151 |
+
corners = [
|
152 |
+
convert_coordinates(transform, bounds[1].start, bounds[0].start),
|
153 |
+
convert_coordinates(transform, bounds[1].stop, bounds[0].start),
|
154 |
+
convert_coordinates(transform, bounds[1].stop, bounds[0].stop),
|
155 |
+
convert_coordinates(transform, bounds[1].start, bounds[0].stop)
|
156 |
+
]
|
157 |
+
if not Polygon(corners).intersects(boundary_geom):
|
158 |
+
continue
|
159 |
+
|
160 |
+
processed = aug(image=tile_image)['image']
|
161 |
+
processed = self.config.preprocessing(image=processed)['image']
|
162 |
+
|
163 |
+
x_tensor = torch.from_numpy(processed).to(self.device).unsqueeze(0)
|
164 |
+
with torch.no_grad():
|
165 |
+
prediction = self.model.predict(x_tensor)
|
166 |
+
prediction = prediction.squeeze().cpu().numpy()
|
167 |
+
|
168 |
+
if self.config.n_classes > 1:
|
169 |
+
tile_pred = np.argmax(prediction, axis=0)
|
170 |
+
tile_conf = np.max(prediction, axis=0)
|
171 |
+
else:
|
172 |
+
tile_conf = np.abs(prediction - 0.5)
|
173 |
+
tile_pred = prediction.round()
|
174 |
+
|
175 |
+
if tile_pred.shape != tile_image.shape[:2]:
|
176 |
+
tile_pred = cv2.resize(tile_pred, tile_image.shape[:2][::-1],
|
177 |
+
interpolation=cv2.INTER_NEAREST)
|
178 |
+
tile_conf = cv2.resize(tile_conf, tile_image.shape[:2][::-1],
|
179 |
+
interpolation=cv2.INTER_LINEAR)
|
180 |
+
|
181 |
+
# Update prediction and confidence maps
|
182 |
+
existing_conf = confidence[bounds[0], bounds[1]]
|
183 |
+
existing_pred = pred_raster[bounds[0], bounds[1]]
|
184 |
+
|
185 |
+
mask = existing_conf < tile_conf
|
186 |
+
existing_pred[mask] = tile_pred[mask]
|
187 |
+
existing_conf[mask] = tile_conf[mask]
|
188 |
+
|
189 |
+
pred_raster[bounds[0], bounds[1]] = existing_pred
|
190 |
+
confidence[bounds[0], bounds[1]] = existing_conf
|
191 |
+
|
192 |
+
pred_raster = self._format_prediction(pred_raster, format)
|
193 |
+
|
194 |
+
if output_path or boundary_path:
|
195 |
+
self._save_raster_prediction(
|
196 |
+
pred_raster, raster_path, output_path,
|
197 |
+
profile, boundary_geom if boundary_path else None
|
198 |
+
)
|
199 |
+
|
200 |
+
return pred_raster, profile
|
201 |
+
|
202 |
+
def _format_prediction(self, prediction, format):
|
203 |
+
"""Formats prediction according to specified output type."""
|
204 |
+
if format == 'integer':
|
205 |
+
return prediction.astype('uint8')
|
206 |
+
elif format == 'color':
|
207 |
+
return self._apply_color_mapping(prediction)
|
208 |
+
else:
|
209 |
+
raise ValueError(f"Unsupported format: {format}")
|
210 |
+
|
211 |
+
def _save_prediction(self, prediction, source_path, output_dir, format):
|
212 |
+
"""Saves prediction to disk."""
|
213 |
+
filename = os.path.splitext(os.path.basename(source_path))[0]
|
214 |
+
output_path = os.path.join(output_dir, f"{filename}_pred.png")
|
215 |
+
cv2.imwrite(output_path, prediction)
|
216 |
+
|
217 |
+
|
218 |
+
def _save_raster_prediction(self, prediction, source_path, output_path,
|
219 |
+
profile, boundary=None):
|
220 |
+
"""Saves raster prediction with geospatial information."""
|
221 |
+
output_path = output_path or source_path.replace(
|
222 |
+
os.path.splitext(source_path)[1], '_predicted.tif'
|
223 |
+
)
|
224 |
+
|
225 |
+
profile.update(
|
226 |
+
dtype='uint8',
|
227 |
+
count=3 if prediction.ndim == 3 else 1
|
228 |
+
)
|
229 |
+
|
230 |
+
with rio.open(output_path, 'w', **profile) as dst:
|
231 |
+
if prediction.ndim == 3:
|
232 |
+
for i in range(3):
|
233 |
+
dst.write(prediction[:,:,i], i+1)
|
234 |
+
else:
|
235 |
+
dst.write(prediction, 1)
|
236 |
+
|
237 |
+
if boundary:
|
238 |
+
with rio.open(output_path) as src:
|
239 |
+
cropped, transform = riomask.mask(src, [boundary], crop=True)
|
240 |
+
profile.update(
|
241 |
+
height=cropped.shape[1],
|
242 |
+
width=cropped.shape[2],
|
243 |
+
transform=transform
|
244 |
+
)
|
245 |
+
|
246 |
+
os.remove(output_path)
|
247 |
+
with rio.open(output_path, 'w', **profile) as dst:
|
248 |
+
dst.write(cropped)
|
249 |
+
|
250 |
+
print(f'\nPrediction saved to: {output_path}')
|
251 |
+
|
252 |
+
def predict_video_frames(self, input_dir, target_size=None, output_dir=None):
|
253 |
+
"""Processes video frames with specialized visualization."""
|
254 |
+
output_dir = output_dir or os.path.join(input_dir, 'predictions')
|
255 |
+
os.makedirs(output_dir, exist_ok=True)
|
256 |
+
|
257 |
+
dataset = StreamingDataset(
|
258 |
+
input_dir,
|
259 |
+
classes=self.classes,
|
260 |
+
augmentation=get_validation_augmentations(
|
261 |
+
target_size, target_size
|
262 |
+
) if target_size else None,
|
263 |
+
preprocessing=self.config.preprocessing
|
264 |
+
)
|
265 |
+
|
266 |
+
image = cv2.imread(dataset.image_paths[0])
|
267 |
+
height, width = image.shape[:2]
|
268 |
+
|
269 |
+
white = 255 * np.ones((height, width))
|
270 |
+
black = np.zeros_like(white)
|
271 |
+
red = np.dstack((white, black, black))
|
272 |
+
blue = np.dstack((black, black, white))
|
273 |
+
|
274 |
+
# Pre-compute rotated versions
|
275 |
+
rotated_red = np.rot90(red)
|
276 |
+
rotated_blue = np.rot90(blue)
|
277 |
+
|
278 |
+
total_frames = len(dataset)
|
279 |
+
start_time = time.time()
|
280 |
+
|
281 |
+
for idx in range(total_frames):
|
282 |
+
if (idx + 1) % 10 == 0 or idx == total_frames - 1:
|
283 |
+
elapsed = time.time() - start_time
|
284 |
+
print(f'\rProcessed {idx+1}/{total_frames} frames in {elapsed:.1f}s', end='')
|
285 |
+
|
286 |
+
frame, height, width = dataset[idx]
|
287 |
+
filename = dataset.filenames[idx]
|
288 |
+
|
289 |
+
x_tensor = torch.from_numpy(frame).to(self.device).unsqueeze(0)
|
290 |
+
with torch.no_grad():
|
291 |
+
prediction = self.model.predict(x_tensor)
|
292 |
+
|
293 |
+
if self.config.n_classes > 1:
|
294 |
+
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
|
295 |
+
masks = [prediction == i for i in range(1, self.config.n_classes)]
|
296 |
+
else:
|
297 |
+
prediction = prediction.squeeze().cpu().numpy().round()
|
298 |
+
masks = [prediction == 1]
|
299 |
+
|
300 |
+
if prediction.shape != (height, width):
|
301 |
+
prediction = cv2.resize(prediction, (width, height),
|
302 |
+
interpolation=cv2.INTER_NEAREST)
|
303 |
+
|
304 |
+
original = cv2.imread(os.path.join(input_dir, filename))
|
305 |
+
original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
|
306 |
+
|
307 |
+
try:
|
308 |
+
for i, mask in enumerate(masks):
|
309 |
+
color = red if i == 0 else blue
|
310 |
+
rotated_color = rotated_red if i == 0 else rotated_blue
|
311 |
+
try:
|
312 |
+
original[mask,:] = 0.45*original[mask,:] + 0.55*color[mask,:]
|
313 |
+
except:
|
314 |
+
original[mask,:] = 0.45*original[mask,:] + 0.55*rotated_color[mask,:]
|
315 |
+
except:
|
316 |
+
print(f"\nWarning: Error processing frame {filename}")
|
317 |
+
continue
|
318 |
+
|
319 |
+
output_path = os.path.join(output_dir, filename)
|
320 |
+
imageio.imwrite(output_path, original, quality=100)
|
321 |
+
|
322 |
+
print(f'\nProcessed frames saved to: {output_dir}')
|
323 |
+
return output_dir
|
324 |
+
|
325 |
+
def _apply_color_mapping(self, prediction):
|
326 |
+
"""Applies color mapping to prediction."""
|
327 |
+
height, width = prediction.shape
|
328 |
+
colored = np.zeros((height, width, 3), dtype='uint8')
|
329 |
+
|
330 |
+
for i, class_name in enumerate(self.classes):
|
331 |
+
if class_name.lower() == 'background':
|
332 |
+
continue
|
333 |
+
color = self.colors[i]
|
334 |
+
colored[prediction == i] = color
|
335 |
+
|
336 |
+
return colored
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/training-checkpoint.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import wandb
|
5 |
+
import datetime
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
from segmentation_models_pytorch.base.modules import Activation
|
11 |
+
|
12 |
+
from SemanticModel.data_loader import SegmentationDataset
|
13 |
+
from SemanticModel.metrics import compute_mean_iou
|
14 |
+
from SemanticModel.image_preprocessing import get_training_augmentations, get_validation_augmentations
|
15 |
+
from SemanticModel.utilities import list_images, validate_dimensions
|
16 |
+
|
17 |
+
class ModelTrainer:
|
18 |
+
def __init__(self, model_config, root_dir, epochs=40, train_size=1024,
|
19 |
+
val_size=None, workers=2, batch_size=2, learning_rate=1e-4,
|
20 |
+
step_count=2, decay_factor=0.8, wandb_config=None,
|
21 |
+
optimizer='rmsprop', target_class=None, resume_path=None):
|
22 |
+
|
23 |
+
self.config = model_config
|
24 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
25 |
+
self.root_dir = root_dir
|
26 |
+
self._initialize_training_params(epochs, train_size, val_size, workers,
|
27 |
+
batch_size, learning_rate, step_count,
|
28 |
+
decay_factor, optimizer, target_class)
|
29 |
+
self._setup_directories()
|
30 |
+
self._initialize_datasets()
|
31 |
+
self._setup_optimizer()
|
32 |
+
self._initialize_tracking()
|
33 |
+
|
34 |
+
if resume_path:
|
35 |
+
self._resume_training(resume_path)
|
36 |
+
|
37 |
+
def _initialize_training_params(self, epochs, train_size, val_size, workers,
|
38 |
+
batch_size, learning_rate, step_count,
|
39 |
+
decay_factor, optimizer, target_class):
|
40 |
+
self.epochs = epochs
|
41 |
+
self.train_size = train_size
|
42 |
+
self.val_size = val_size
|
43 |
+
self.workers = workers
|
44 |
+
self.batch_size = batch_size
|
45 |
+
self.learning_rate = learning_rate
|
46 |
+
self.step_schedule = self._calculate_step_schedule(epochs, step_count)
|
47 |
+
self.decay_factor = decay_factor
|
48 |
+
self.optimizer_type = optimizer
|
49 |
+
self.target_class = target_class
|
50 |
+
self.current_epoch = 1
|
51 |
+
self.best_iou = 0.0
|
52 |
+
self.best_epoch = 0
|
53 |
+
self.classes = ['background'] + self.config.classes if self.config.background_flag else self.config.classes
|
54 |
+
|
55 |
+
def _setup_directories(self):
|
56 |
+
"""Verifies and creates necessary directories."""
|
57 |
+
self.train_dir = os.path.join(self.root_dir, 'train')
|
58 |
+
self.val_dir = os.path.join(self.root_dir, 'val')
|
59 |
+
|
60 |
+
required_subdirs = ['Images', 'Masks']
|
61 |
+
for path in [self.train_dir] + ([self.val_dir] if os.path.exists(self.val_dir) else []):
|
62 |
+
for subdir in required_subdirs:
|
63 |
+
full_path = os.path.join(path, subdir)
|
64 |
+
if not os.path.exists(full_path):
|
65 |
+
raise FileNotFoundError(f"Missing directory: {full_path}")
|
66 |
+
|
67 |
+
def _initialize_datasets(self):
|
68 |
+
"""Sets up training and validation datasets."""
|
69 |
+
self.train_dataset = SegmentationDataset(
|
70 |
+
self.train_dir,
|
71 |
+
classes=self.classes,
|
72 |
+
augmentation=get_training_augmentations(self.train_size, self.train_size),
|
73 |
+
preprocessing=self.config.preprocessing
|
74 |
+
)
|
75 |
+
|
76 |
+
if os.path.exists(self.val_dir):
|
77 |
+
self.val_dataset = SegmentationDataset(
|
78 |
+
self.val_dir,
|
79 |
+
classes=self.classes,
|
80 |
+
augmentation=get_validation_augmentations(
|
81 |
+
self.val_size or self.train_size,
|
82 |
+
self.val_size or self.train_size,
|
83 |
+
fixed_size=False
|
84 |
+
),
|
85 |
+
preprocessing=self.config.preprocessing
|
86 |
+
)
|
87 |
+
self.val_loader = DataLoader(
|
88 |
+
self.val_dataset,
|
89 |
+
batch_size=1,
|
90 |
+
shuffle=False,
|
91 |
+
num_workers=self.workers
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
self.val_dataset = self.train_dataset
|
95 |
+
self.val_loader = DataLoader(
|
96 |
+
self.val_dataset,
|
97 |
+
batch_size=1,
|
98 |
+
shuffle=False,
|
99 |
+
num_workers=self.workers
|
100 |
+
)
|
101 |
+
|
102 |
+
self.train_loader = DataLoader(
|
103 |
+
self.train_dataset,
|
104 |
+
batch_size=self.batch_size,
|
105 |
+
shuffle=True,
|
106 |
+
num_workers=self.workers
|
107 |
+
)
|
108 |
+
|
109 |
+
def _setup_optimizer(self):
|
110 |
+
"""Configures model optimizer."""
|
111 |
+
optimizer_map = {
|
112 |
+
'adam': torch.optim.Adam,
|
113 |
+
'sgd': lambda params: torch.optim.SGD(params, momentum=0.9),
|
114 |
+
'rmsprop': torch.optim.RMSprop
|
115 |
+
}
|
116 |
+
optimizer_class = optimizer_map.get(self.optimizer_type.lower())
|
117 |
+
if not optimizer_class:
|
118 |
+
raise ValueError(f"Unsupported optimizer: {self.optimizer_type}")
|
119 |
+
|
120 |
+
self.optimizer = optimizer_class([{'params': self.config.model.parameters(),
|
121 |
+
'lr': self.learning_rate}])
|
122 |
+
|
123 |
+
def _initialize_tracking(self):
|
124 |
+
"""Sets up training progress tracking."""
|
125 |
+
timestamp = datetime.datetime.now().strftime("%m-%d-%Y_%H%M%S")
|
126 |
+
self.output_dir = os.path.join(
|
127 |
+
self.root_dir,
|
128 |
+
f'model_outputs-{self.config.architecture}[{self.config.encoder}]-{timestamp}'
|
129 |
+
)
|
130 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
131 |
+
|
132 |
+
self.writer = SummaryWriter(log_dir=self.output_dir)
|
133 |
+
self.metrics = {
|
134 |
+
'best_epoch': self.best_epoch,
|
135 |
+
'best_epoch_iou': self.best_iou,
|
136 |
+
'last_epoch': 0,
|
137 |
+
'last_epoch_iou': 0.0,
|
138 |
+
'last_epoch_lr': self.learning_rate,
|
139 |
+
'step_schedule': self.step_schedule,
|
140 |
+
'decay_factor': self.decay_factor,
|
141 |
+
'target_class': self.target_class or 'overall'
|
142 |
+
}
|
143 |
+
|
144 |
+
def _calculate_step_schedule(self, epochs, steps):
|
145 |
+
"""Calculates learning rate step schedule."""
|
146 |
+
return list(map(int, np.linspace(0, epochs, steps + 2)[1:-1]))
|
147 |
+
|
148 |
+
def train(self):
|
149 |
+
"""Executes training loop."""
|
150 |
+
model = self.config.model.to(self.device)
|
151 |
+
if torch.cuda.device_count() > 1:
|
152 |
+
model = torch.nn.DataParallel(model)
|
153 |
+
print(f'Using {torch.cuda.device_count()} GPUs')
|
154 |
+
|
155 |
+
self._save_config()
|
156 |
+
|
157 |
+
for epoch in range(self.current_epoch, self.epochs + 1):
|
158 |
+
print(f'\nEpoch {epoch}/{self.epochs}')
|
159 |
+
print(f'Learning rate: {self.optimizer.param_groups[0]["lr"]:.3e}')
|
160 |
+
|
161 |
+
train_loss = self._train_epoch(model)
|
162 |
+
val_loss, val_metrics = self._validate_epoch(model)
|
163 |
+
|
164 |
+
self._update_tracking(epoch, train_loss, val_loss, val_metrics)
|
165 |
+
self._adjust_learning_rate(epoch)
|
166 |
+
self._save_checkpoints(model, epoch, val_metrics)
|
167 |
+
|
168 |
+
print(f'\nTraining completed. Best {self.metrics["target_class"]} IoU: {self.best_iou:.3f}')
|
169 |
+
return model, self.metrics
|
170 |
+
|
171 |
+
def _train_epoch(self, model):
|
172 |
+
"""Executes single training epoch."""
|
173 |
+
model.train()
|
174 |
+
total_loss = 0
|
175 |
+
sample_count = 0
|
176 |
+
|
177 |
+
for batch in tqdm(self.train_loader, desc='Training'):
|
178 |
+
images, masks = [x.to(self.device) for x in batch]
|
179 |
+
self.optimizer.zero_grad()
|
180 |
+
|
181 |
+
outputs = model(images)
|
182 |
+
loss = self.config.loss(outputs, masks)
|
183 |
+
loss.backward()
|
184 |
+
self.optimizer.step()
|
185 |
+
|
186 |
+
total_loss += loss.item() * len(images)
|
187 |
+
sample_count += len(images)
|
188 |
+
|
189 |
+
return total_loss / sample_count
|
190 |
+
|
191 |
+
def _validate_epoch(self, model):
|
192 |
+
"""Executes validation pass."""
|
193 |
+
model.eval()
|
194 |
+
total_loss = 0
|
195 |
+
predictions = []
|
196 |
+
ground_truth = []
|
197 |
+
|
198 |
+
with torch.no_grad():
|
199 |
+
for batch in tqdm(self.val_loader, desc='Validation'):
|
200 |
+
images, masks = [x.to(self.device) for x in batch]
|
201 |
+
outputs = model(images)
|
202 |
+
loss = self.config.loss(outputs, masks)
|
203 |
+
|
204 |
+
total_loss += loss.item()
|
205 |
+
|
206 |
+
if self.config.n_classes > 1:
|
207 |
+
predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
|
208 |
+
ground_truth.extend([m.cpu().argmax(dim=0) for m in masks])
|
209 |
+
else:
|
210 |
+
predictions.extend([(torch.sigmoid(p) > 0.5).float().squeeze().cpu()
|
211 |
+
for p in outputs])
|
212 |
+
ground_truth.extend([m.cpu().squeeze() for m in masks])
|
213 |
+
|
214 |
+
metrics = compute_mean_iou(
|
215 |
+
predictions,
|
216 |
+
ground_truth,
|
217 |
+
num_classes=len(self.classes),
|
218 |
+
ignore_index=255
|
219 |
+
)
|
220 |
+
|
221 |
+
return total_loss / len(self.val_loader), metrics
|
222 |
+
|
223 |
+
def _update_tracking(self, epoch, train_loss, val_loss, val_metrics):
|
224 |
+
"""Updates training metrics and logging."""
|
225 |
+
mean_iou = val_metrics['mean_iou']
|
226 |
+
print(f"\nLosses - Train: {train_loss:.3f}, Val: {val_loss:.3f}")
|
227 |
+
print(f"Mean IoU: {mean_iou:.3f}")
|
228 |
+
|
229 |
+
self.writer.add_scalar('Loss/train', train_loss, epoch)
|
230 |
+
self.writer.add_scalar('Loss/val', val_loss, epoch)
|
231 |
+
self.writer.add_scalar('IoU/mean', mean_iou, epoch)
|
232 |
+
|
233 |
+
for idx, iou in enumerate(val_metrics['per_category_iou']):
|
234 |
+
print(f"{self.classes[idx]} IoU: {iou:.3f}")
|
235 |
+
self.writer.add_scalar(f'IoU/{self.classes[idx]}', iou, epoch)
|
236 |
+
|
237 |
+
def _adjust_learning_rate(self, epoch):
|
238 |
+
"""Adjusts learning rate according to schedule."""
|
239 |
+
if epoch in self.step_schedule:
|
240 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
241 |
+
new_lr = current_lr * self.decay_factor
|
242 |
+
for param_group in self.optimizer.param_groups:
|
243 |
+
param_group['lr'] = new_lr
|
244 |
+
print(f'\nDecreased learning rate: {current_lr:.3e} -> {new_lr:.3e}')
|
245 |
+
|
246 |
+
def _save_checkpoints(self, model, epoch, metrics):
|
247 |
+
"""Saves model checkpoints and metrics."""
|
248 |
+
epoch_iou = (metrics['mean_iou'] if self.target_class is None
|
249 |
+
else metrics['per_category_iou'][self.classes.index(self.target_class)])
|
250 |
+
|
251 |
+
self.metrics.update({
|
252 |
+
'last_epoch': epoch,
|
253 |
+
'last_epoch_iou': round(float(epoch_iou), 3),
|
254 |
+
'last_epoch_lr': self.optimizer.param_groups[0]['lr']
|
255 |
+
})
|
256 |
+
|
257 |
+
if epoch_iou > self.best_iou:
|
258 |
+
self.best_iou = epoch_iou
|
259 |
+
self.best_epoch = epoch
|
260 |
+
self.metrics.update({
|
261 |
+
'best_epoch': epoch,
|
262 |
+
'best_epoch_iou': round(float(epoch_iou), 3),
|
263 |
+
'overall_iou': round(float(metrics['mean_iou']), 3)
|
264 |
+
})
|
265 |
+
torch.save(model, os.path.join(self.output_dir, 'best_model.pth'))
|
266 |
+
print(f'New best model saved (IoU: {epoch_iou:.3f})')
|
267 |
+
|
268 |
+
torch.save(model, os.path.join(self.output_dir, 'last_model.pth'))
|
269 |
+
with open(os.path.join(self.output_dir, 'metrics.json'), 'w') as f:
|
270 |
+
json.dump(self.metrics, f, indent=4)
|
271 |
+
|
272 |
+
def _save_config(self):
|
273 |
+
"""Saves training configuration."""
|
274 |
+
config = {
|
275 |
+
**self.config.config_data,
|
276 |
+
'train_size': self.train_size,
|
277 |
+
'val_size': self.val_size,
|
278 |
+
'epochs': self.epochs,
|
279 |
+
'batch_size': self.batch_size,
|
280 |
+
'optimizer': self.optimizer_type,
|
281 |
+
'workers': self.workers,
|
282 |
+
'target_class': self.target_class or 'overall'
|
283 |
+
}
|
284 |
+
|
285 |
+
with open(os.path.join(self.output_dir, 'config.json'), 'w') as f:
|
286 |
+
json.dump(config, f, indent=4)
|
287 |
+
|
288 |
+
def _resume_training(self, resume_path):
|
289 |
+
"""Resumes training from checkpoint."""
|
290 |
+
if not os.path.exists(resume_path):
|
291 |
+
raise FileNotFoundError(f"Resume path not found: {resume_path}")
|
292 |
+
|
293 |
+
required_files = {
|
294 |
+
'model': 'last_model.pth',
|
295 |
+
'metrics': 'metrics.json',
|
296 |
+
'config': 'config.json'
|
297 |
+
}
|
298 |
+
|
299 |
+
paths = {k: os.path.join(resume_path, v) for k, v in required_files.items()}
|
300 |
+
if not all(os.path.exists(p) for p in paths.values()):
|
301 |
+
raise FileNotFoundError("Missing required checkpoint files")
|
302 |
+
|
303 |
+
with open(paths['config']) as f:
|
304 |
+
config = json.load(f)
|
305 |
+
with open(paths['metrics']) as f:
|
306 |
+
metrics = json.load(f)
|
307 |
+
|
308 |
+
self.current_epoch = metrics['last_epoch'] + 1
|
309 |
+
self.best_iou = metrics['best_epoch_iou']
|
310 |
+
self.best_epoch = metrics['best_epoch']
|
311 |
+
self.learning_rate = metrics['last_epoch_lr']
|
312 |
+
|
313 |
+
print(f'Resuming training from epoch {self.current_epoch}')
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/utilities-checkpoint.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import shutil
|
4 |
+
import imageio
|
5 |
+
import numpy as np
|
6 |
+
from glob import glob
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List, Tuple, Optional
|
9 |
+
|
10 |
+
def validate_dimensions(width: int, height: int, stride: int = 32) -> Tuple[int, int]:
|
11 |
+
if height % stride != 0 or width % stride != 0:
|
12 |
+
new_height = ((height // stride + 1) * stride
|
13 |
+
if height % stride != 0 else height)
|
14 |
+
new_width = ((width // stride + 1) * stride
|
15 |
+
if width % stride != 0 else width)
|
16 |
+
print(f'Adjusted dimensions to: {new_height}H x {new_width}W')
|
17 |
+
return width, height
|
18 |
+
|
19 |
+
def calc_image_size(image: np.ndarray, target_size: int) -> Tuple[int, int]:
|
20 |
+
height, width = image.shape[:2]
|
21 |
+
aspect_ratio = width / height
|
22 |
+
|
23 |
+
if aspect_ratio >= 1:
|
24 |
+
new_width = target_size
|
25 |
+
new_height = int(target_size / aspect_ratio)
|
26 |
+
else:
|
27 |
+
new_height = target_size
|
28 |
+
new_width = int(target_size * aspect_ratio)
|
29 |
+
|
30 |
+
return validate_dimensions(new_width, new_height)
|
31 |
+
|
32 |
+
def convert_coordinates(transform: np.ndarray, x: float, y: float) -> Tuple[float, float]:
|
33 |
+
transformed = transform @ np.array([x, y, 1])
|
34 |
+
return transformed[0], transformed[1]
|
35 |
+
|
36 |
+
def list_images(directory: str, mask_format: bool = False) -> List[str]:
|
37 |
+
extensions = ['*.png', '*.PNG'] if mask_format else [
|
38 |
+
'*.jpg', '*.jpeg', '*.png', '*.tif', '*.tiff',
|
39 |
+
'*.JPG', '*.JPEG', '*.PNG', '*.TIF', '*.TIFF'
|
40 |
+
]
|
41 |
+
|
42 |
+
image_paths = []
|
43 |
+
for ext in extensions:
|
44 |
+
image_paths.extend(glob(os.path.join(directory, ext)))
|
45 |
+
|
46 |
+
return sorted(list(set(image_paths)))
|
47 |
+
|
48 |
+
def prepare_dataset_split(root_dir: str, train_ratio: float = 0.7,
|
49 |
+
generate_empty_masks: bool = False) -> None:
|
50 |
+
image_dir = os.path.join(root_dir, 'Images')
|
51 |
+
mask_dir = os.path.join(root_dir, 'Masks')
|
52 |
+
|
53 |
+
if not all(os.path.exists(d) for d in [image_dir, mask_dir]):
|
54 |
+
raise Exception("Required 'Images' and 'Masks' directories not found")
|
55 |
+
|
56 |
+
image_paths = np.array(list_images(image_dir))
|
57 |
+
mask_paths = np.array(list_images(mask_dir, mask_format=True))
|
58 |
+
|
59 |
+
if generate_empty_masks:
|
60 |
+
temp_dir = os.path.join(mask_dir, 'temp')
|
61 |
+
create_empty_masks(image_dir, outdir=temp_dir)
|
62 |
+
|
63 |
+
for mask_path in list_images(temp_dir, mask_format=True):
|
64 |
+
target_path = os.path.join(mask_dir, os.path.basename(mask_path))
|
65 |
+
if not os.path.exists(target_path):
|
66 |
+
shutil.move(mask_path, target_path)
|
67 |
+
|
68 |
+
shutil.rmtree(temp_dir)
|
69 |
+
mask_paths = np.array(list_images(mask_dir, mask_format=True))
|
70 |
+
|
71 |
+
if len(image_paths) != len(mask_paths):
|
72 |
+
raise Exception(f"Unmatched images ({len(image_paths)}) and masks ({len(mask_paths)})")
|
73 |
+
|
74 |
+
train_ratio = float(train_ratio)
|
75 |
+
if not (0 < train_ratio <= 1):
|
76 |
+
raise ValueError(f"Invalid train ratio: {train_ratio}")
|
77 |
+
|
78 |
+
train_size = int(np.floor(train_ratio * len(image_paths)))
|
79 |
+
indices = np.random.permutation(len(image_paths))
|
80 |
+
|
81 |
+
splits = {
|
82 |
+
'train': {'indices': indices[:train_size]},
|
83 |
+
'val': {'indices': indices[train_size:]} if train_ratio < 1 else None
|
84 |
+
}
|
85 |
+
|
86 |
+
for split_name, split_data in splits.items():
|
87 |
+
if split_data is None:
|
88 |
+
continue
|
89 |
+
|
90 |
+
split_dir = os.path.join(root_dir, split_name)
|
91 |
+
for subdir in ['Images', 'Masks']:
|
92 |
+
subdir_path = os.path.join(split_dir, subdir)
|
93 |
+
os.makedirs(subdir_path, exist_ok=True)
|
94 |
+
|
95 |
+
sources = image_paths if subdir == 'Images' else mask_paths
|
96 |
+
for idx in split_data['indices']:
|
97 |
+
source = sources[idx]
|
98 |
+
destination = os.path.join(subdir_path, os.path.basename(source))
|
99 |
+
shutil.copyfile(source, destination)
|
100 |
+
|
101 |
+
print(f"Created {split_name} split with {len(split_data['indices'])} samples")
|
102 |
+
|
103 |
+
def create_empty_masks(image_dir: str, pixel_value: int = 0,
|
104 |
+
outdir: Optional[str] = None) -> str:
|
105 |
+
outdir = outdir or os.path.join(image_dir, 'Masks')
|
106 |
+
os.makedirs(outdir, exist_ok=True)
|
107 |
+
|
108 |
+
image_paths = list_images(image_dir)
|
109 |
+
print(f"Generating {len(image_paths)} empty masks...")
|
110 |
+
|
111 |
+
for image_path in image_paths:
|
112 |
+
image = imageio.imread(image_path)
|
113 |
+
mask = np.full((image.shape[0], image.shape[1]), pixel_value, dtype='uint8')
|
114 |
+
|
115 |
+
output_path = os.path.join(outdir,
|
116 |
+
f"{Path(image_path).stem}.png")
|
117 |
+
imageio.imwrite(output_path, mask)
|
118 |
+
|
119 |
+
return outdir
|
semantic-segmentation/SemanticModel/.ipynb_checkpoints/visualization-checkpoint.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def plot_predictions(model, images, masks, device, num_samples=4):
|
7 |
+
"""Visualize model predictions against ground truth."""
|
8 |
+
with torch.no_grad():
|
9 |
+
model.eval()
|
10 |
+
predictions = model.predict(images.to(device))
|
11 |
+
|
12 |
+
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
|
13 |
+
|
14 |
+
for idx in range(num_samples):
|
15 |
+
# Original image
|
16 |
+
img = images[idx].permute(1, 2, 0).cpu().numpy()
|
17 |
+
axes[idx, 0].imshow(img)
|
18 |
+
axes[idx, 0].set_title('Original Image')
|
19 |
+
|
20 |
+
# Ground truth
|
21 |
+
truth = masks[idx].argmax(dim=0).cpu().numpy()
|
22 |
+
axes[idx, 1].imshow(truth, cmap='tab20')
|
23 |
+
axes[idx, 1].set_title('Ground Truth')
|
24 |
+
|
25 |
+
# Prediction
|
26 |
+
pred = predictions[idx].argmax(dim=0).cpu().numpy()
|
27 |
+
axes[idx, 2].imshow(pred, cmap='tab20')
|
28 |
+
axes[idx, 2].set_title('Prediction')
|
29 |
+
|
30 |
+
for ax in axes[idx]:
|
31 |
+
ax.axis('off')
|
32 |
+
|
33 |
+
plt.tight_layout()
|
34 |
+
return fig
|
35 |
+
|
36 |
+
def create_overlay_mask(image, mask, alpha=0.5, color_map=None):
|
37 |
+
"""Create transparent overlay of segmentation mask on image."""
|
38 |
+
if color_map is None:
|
39 |
+
color_map = {
|
40 |
+
0: [0, 0, 0], # background
|
41 |
+
1: [255, 0, 0], # class 1 (red)
|
42 |
+
2: [0, 255, 0], # class 2 (green)
|
43 |
+
3: [0, 0, 255], # class 3 (blue)
|
44 |
+
}
|
45 |
+
|
46 |
+
overlay = image.copy()
|
47 |
+
mask_colored = np.zeros_like(image)
|
48 |
+
|
49 |
+
for label, color in color_map.items():
|
50 |
+
mask_colored[mask == label] = color
|
51 |
+
|
52 |
+
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
|
53 |
+
return overlay
|
54 |
+
|
55 |
+
def plot_training_history(history):
|
56 |
+
"""Plot training and validation metrics."""
|
57 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
58 |
+
|
59 |
+
# Loss plot
|
60 |
+
ax1.plot(history['train_loss'], label='Training Loss')
|
61 |
+
ax1.plot(history['val_loss'], label='Validation Loss')
|
62 |
+
ax1.set_xlabel('Epoch')
|
63 |
+
ax1.set_ylabel('Loss')
|
64 |
+
ax1.set_title('Training and Validation Loss')
|
65 |
+
ax1.legend()
|
66 |
+
|
67 |
+
# IoU plot
|
68 |
+
ax2.plot(history['mean_iou'], label='Mean IoU')
|
69 |
+
for class_name, ious in history['class_ious'].items():
|
70 |
+
ax2.plot(ious, label=f'{class_name} IoU')
|
71 |
+
ax2.set_xlabel('Epoch')
|
72 |
+
ax2.set_ylabel('IoU')
|
73 |
+
ax2.set_title('IoU Metrics')
|
74 |
+
ax2.legend()
|
75 |
+
|
76 |
+
plt.tight_layout()
|
77 |
+
return fig
|
78 |
+
|
79 |
+
def visualize_predictions_on_batch(model, batch_images, batch_size=8):
|
80 |
+
"""Create grid visualization for a batch of predictions."""
|
81 |
+
with torch.no_grad():
|
82 |
+
predictions = model.predict(batch_images)
|
83 |
+
|
84 |
+
fig = plt.figure(figsize=(15, 5))
|
85 |
+
for idx in range(min(batch_size, len(batch_images))):
|
86 |
+
plt.subplot(2, 4, idx + 1)
|
87 |
+
img = batch_images[idx].permute(1, 2, 0).cpu().numpy()
|
88 |
+
mask = predictions[idx].argmax(dim=0).cpu().numpy()
|
89 |
+
overlay = create_overlay_mask(img, mask)
|
90 |
+
plt.imshow(overlay)
|
91 |
+
plt.axis('off')
|
92 |
+
|
93 |
+
plt.tight_layout()
|
94 |
+
return fig
|
95 |
+
|
96 |
+
def save_visualization(fig, save_path):
|
97 |
+
"""Save visualization figure."""
|
98 |
+
fig.savefig(save_path, bbox_inches='tight', dpi=300)
|
99 |
+
plt.close(fig)
|
100 |
+
|
101 |
+
def generate_color_mapping(num_classes):
|
102 |
+
"""Generate distinct colors for segmentation classes."""
|
103 |
+
colors = [
|
104 |
+
[0, 0, 0], # Background (black)
|
105 |
+
[255, 0, 0], # Red
|
106 |
+
[0, 255, 0], # Green
|
107 |
+
[0, 0, 255], # Blue
|
108 |
+
[255, 255, 0], # Yellow
|
109 |
+
[255, 0, 255], # Magenta
|
110 |
+
[0, 255, 255], # Cyan
|
111 |
+
[128, 0, 0], # Dark Red
|
112 |
+
[0, 128, 0], # Dark Green
|
113 |
+
[0, 0, 128] # Dark Blue
|
114 |
+
]
|
115 |
+
return colors[:num_classes]
|
semantic-segmentation/SemanticModel/__init__.py
ADDED
File without changes
|
semantic-segmentation/SemanticModel/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (187 Bytes). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/custom_losses.cpython-38.pyc
ADDED
Binary file (3.45 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/data_loader.cpython-38.pyc
ADDED
Binary file (6.72 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/encoder_management.cpython-38.pyc
ADDED
Binary file (4.17 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/evaluation_utils.cpython-38.pyc
ADDED
Binary file (3.62 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/image_preprocessing.cpython-38.pyc
ADDED
Binary file (3.62 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/metrics.cpython-38.pyc
ADDED
Binary file (2.56 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/model_core.cpython-38.pyc
ADDED
Binary file (4.52 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/prediction.cpython-38.pyc
ADDED
Binary file (9.62 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/training.cpython-38.pyc
ADDED
Binary file (10.8 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/utilities.cpython-38.pyc
ADDED
Binary file (4.03 kB). View file
|
|
semantic-segmentation/SemanticModel/__pycache__/visualization.cpython-38.pyc
ADDED
Binary file (3.41 kB). View file
|
|
semantic-segmentation/SemanticModel/custom_losses.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from segmentation_models_pytorch.utils import base
|
4 |
+
from segmentation_models_pytorch.base.modules import Activation
|
5 |
+
|
6 |
+
class FocalLossFunction(base.Loss):
|
7 |
+
def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs):
|
8 |
+
super().__init__(**kwargs)
|
9 |
+
self.activation = Activation(activation)
|
10 |
+
self.alpha = alpha
|
11 |
+
self.gamma = gamma
|
12 |
+
self.reduction = reduction
|
13 |
+
|
14 |
+
def forward(self, inputs, targets):
|
15 |
+
if inputs.shape[1] == 1: # Binary case
|
16 |
+
inputs = torch.cat((inputs, 1 - inputs), dim=1)
|
17 |
+
targets = torch.cat((targets, 1 - targets), dim=1)
|
18 |
+
|
19 |
+
targets = torch.argmax(targets, dim=1)
|
20 |
+
cross_entropy = F.cross_entropy(inputs, targets, reduction='none')
|
21 |
+
probability = torch.exp(-cross_entropy)
|
22 |
+
alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where(
|
23 |
+
targets == 1, 1-self.alpha, self.alpha)
|
24 |
+
|
25 |
+
focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy
|
26 |
+
|
27 |
+
if self.reduction == 'mean':
|
28 |
+
return focal_weight.mean()
|
29 |
+
elif self.reduction == 'sum':
|
30 |
+
return focal_weight.sum()
|
31 |
+
return focal_weight
|
32 |
+
|
33 |
+
class TverskyLossFunction(base.Loss):
|
34 |
+
def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None,
|
35 |
+
reduction='mean', **kwargs):
|
36 |
+
super().__init__(**kwargs)
|
37 |
+
self.activation = Activation(activation)
|
38 |
+
self.alpha = alpha
|
39 |
+
self.beta = beta
|
40 |
+
self.ignore_channels = ignore_channels
|
41 |
+
self.reduction = reduction
|
42 |
+
|
43 |
+
def forward(self, inputs, targets):
|
44 |
+
if self.ignore_channels is not None:
|
45 |
+
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
|
46 |
+
mask[self.ignore_channels] = False
|
47 |
+
inputs = inputs[:, mask, ...]
|
48 |
+
|
49 |
+
num_classes = inputs.shape[1]
|
50 |
+
inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1
|
51 |
+
else F.softmax(inputs, dim=1))
|
52 |
+
|
53 |
+
if num_classes == 1:
|
54 |
+
inputs_softmax = inputs_softmax.squeeze(1)
|
55 |
+
targets = targets.squeeze(1)
|
56 |
+
|
57 |
+
tversky_loss = 0
|
58 |
+
for class_idx in range(num_classes):
|
59 |
+
if num_classes == 1:
|
60 |
+
flat_inputs = inputs_softmax.reshape(-1)
|
61 |
+
flat_targets = targets.reshape(-1)
|
62 |
+
else:
|
63 |
+
flat_inputs = inputs_softmax[:, class_idx].reshape(-1)
|
64 |
+
flat_targets = targets[:, class_idx].reshape(-1)
|
65 |
+
|
66 |
+
intersection = (flat_inputs * flat_targets).sum()
|
67 |
+
fps = ((1 - flat_targets) * flat_inputs).sum()
|
68 |
+
fns = (flat_targets * (1 - flat_inputs)).sum()
|
69 |
+
|
70 |
+
tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10
|
71 |
+
tversky_loss += 1 - intersection / tversky_index
|
72 |
+
|
73 |
+
if self.reduction == 'mean':
|
74 |
+
return tversky_loss / (1 if num_classes == 1 else num_classes)
|
75 |
+
elif self.reduction == 'sum':
|
76 |
+
return tversky_loss
|
77 |
+
return tversky_loss / inputs.shape[0]
|
78 |
+
|
79 |
+
class EnhancedCrossEntropy(base.Loss):
|
80 |
+
def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs):
|
81 |
+
super().__init__(**kwargs)
|
82 |
+
self.activation = Activation(activation)
|
83 |
+
self.ignore_channels = ignore_channels
|
84 |
+
self.reduction = reduction
|
85 |
+
|
86 |
+
def forward(self, inputs, targets):
|
87 |
+
inputs = self.activation(inputs)
|
88 |
+
|
89 |
+
if self.ignore_channels is not None:
|
90 |
+
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
|
91 |
+
mask[self.ignore_channels] = False
|
92 |
+
inputs = inputs[:, mask, ...]
|
93 |
+
|
94 |
+
if targets.dim() == 4: # Convert one-hot to class indices
|
95 |
+
targets = torch.argmax(targets, dim=1)
|
96 |
+
|
97 |
+
return F.cross_entropy(inputs, targets, reduction=self.reduction)
|
semantic-segmentation/SemanticModel/data_loader.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from torch.utils.data import Dataset as BaseDataset
|
5 |
+
|
6 |
+
class SegmentationDataset(BaseDataset):
|
7 |
+
"""Dataset class for semantic segmentation task."""
|
8 |
+
|
9 |
+
def __init__(self, data_dir, classes=['background', 'object'],
|
10 |
+
augmentation=None, preprocessing=None):
|
11 |
+
|
12 |
+
self.image_dir = os.path.join(data_dir, 'Images')
|
13 |
+
self.mask_dir = os.path.join(data_dir, 'Masks')
|
14 |
+
|
15 |
+
for dir_path in [self.image_dir, self.mask_dir]:
|
16 |
+
if not os.path.exists(dir_path):
|
17 |
+
raise FileNotFoundError(f"Directory not found: {dir_path}")
|
18 |
+
|
19 |
+
self.filenames = self._get_filenames()
|
20 |
+
self.image_paths = [os.path.join(self.image_dir, fname) for fname in self.filenames]
|
21 |
+
self.mask_paths = self._get_mask_paths()
|
22 |
+
|
23 |
+
self.target_classes = [cls for cls in classes if cls.lower() != 'background']
|
24 |
+
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
|
25 |
+
|
26 |
+
self.augmentation = augmentation
|
27 |
+
self.preprocessing = preprocessing
|
28 |
+
|
29 |
+
def __getitem__(self, index):
|
30 |
+
image = self._load_image(self.image_paths[index])
|
31 |
+
mask = self._load_mask(self.mask_paths[index])
|
32 |
+
|
33 |
+
if self.augmentation:
|
34 |
+
processed = self.augmentation(image=image, mask=mask)
|
35 |
+
image, mask = processed['image'], processed['mask']
|
36 |
+
|
37 |
+
if self.preprocessing:
|
38 |
+
processed = self.preprocessing(image=image, mask=mask)
|
39 |
+
image, mask = processed['image'], processed['mask']
|
40 |
+
|
41 |
+
return image, mask
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.filenames)
|
45 |
+
|
46 |
+
def _get_filenames(self):
|
47 |
+
"""Returns sorted list of filenames, excluding directories."""
|
48 |
+
files = sorted(os.listdir(self.image_dir))
|
49 |
+
return [f for f in files if not os.path.isdir(os.path.join(self.image_dir, f))]
|
50 |
+
|
51 |
+
def _get_mask_paths(self):
|
52 |
+
"""Generates corresponding mask paths for each image."""
|
53 |
+
mask_paths = []
|
54 |
+
for image_file in self.filenames:
|
55 |
+
name, _ = os.path.splitext(image_file)
|
56 |
+
mask_paths.append(os.path.join(self.mask_dir, f"{name}.png"))
|
57 |
+
return mask_paths
|
58 |
+
|
59 |
+
def _load_image(self, path):
|
60 |
+
"""Loads and converts image to RGB."""
|
61 |
+
image = cv2.imread(path)
|
62 |
+
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
63 |
+
|
64 |
+
def _load_mask(self, path):
|
65 |
+
"""Loads and processes segmentation mask."""
|
66 |
+
mask = cv2.imread(path, 0)
|
67 |
+
masks = [(mask == value) for value in self.class_values]
|
68 |
+
mask = np.stack(masks, axis=-1).astype('float')
|
69 |
+
return mask
|
70 |
+
|
71 |
+
class InferenceDataset(BaseDataset):
|
72 |
+
"""Dataset class for inference without ground truth masks."""
|
73 |
+
|
74 |
+
def __init__(self, data_dir, classes=['background', 'object'],
|
75 |
+
augmentation=None, preprocessing=None):
|
76 |
+
self.filenames = sorted([
|
77 |
+
f for f in os.listdir(data_dir)
|
78 |
+
if not os.path.isdir(os.path.join(data_dir, f))
|
79 |
+
])
|
80 |
+
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
|
81 |
+
|
82 |
+
self.target_classes = [cls for cls in classes if cls.lower() != 'background']
|
83 |
+
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
|
84 |
+
|
85 |
+
self.augmentation = augmentation
|
86 |
+
self.preprocessing = preprocessing
|
87 |
+
|
88 |
+
def __getitem__(self, index):
|
89 |
+
image = cv2.imread(self.image_paths[index])
|
90 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
91 |
+
original_height, original_width = image.shape[:2]
|
92 |
+
|
93 |
+
if self.augmentation:
|
94 |
+
image = self.augmentation(image=image)['image']
|
95 |
+
|
96 |
+
if self.preprocessing:
|
97 |
+
image = self.preprocessing(image=image)['image']
|
98 |
+
|
99 |
+
return image, original_height, original_width
|
100 |
+
|
101 |
+
def __len__(self):
|
102 |
+
return len(self.filenames)
|
103 |
+
|
104 |
+
class StreamingDataset(BaseDataset):
|
105 |
+
"""Dataset class optimized for video frame processing."""
|
106 |
+
|
107 |
+
def __init__(self, data_dir, classes=['background', 'object'],
|
108 |
+
augmentation=None, preprocessing=None):
|
109 |
+
self.filenames = self._get_frame_filenames(data_dir)
|
110 |
+
self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
|
111 |
+
|
112 |
+
self.target_classes = [cls for cls in classes if cls.lower() != 'background']
|
113 |
+
self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
|
114 |
+
|
115 |
+
self.augmentation = augmentation
|
116 |
+
self.preprocessing = preprocessing
|
117 |
+
|
118 |
+
def _get_frame_filenames(self, directory):
|
119 |
+
"""Returns sorted list of frame filenames."""
|
120 |
+
files = sorted(os.listdir(directory))
|
121 |
+
return [f for f in files if (('frame' in f or 'Image' in f) and
|
122 |
+
f.lower().endswith('jpg') and
|
123 |
+
not os.path.isdir(os.path.join(directory, f)))]
|
124 |
+
|
125 |
+
def __getitem__(self, index):
|
126 |
+
return InferenceDataset.__getitem__(self, index)
|
127 |
+
|
128 |
+
def __len__(self):
|
129 |
+
return len(self.filenames)
|
semantic-segmentation/SemanticModel/encoder_management.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ssl
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
import hashlib
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.hub import get_dir
|
8 |
+
from urllib.request import urlopen, Request
|
9 |
+
|
10 |
+
from segmentation_models_pytorch.encoders import (
|
11 |
+
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
|
12 |
+
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
|
13 |
+
efficient_net_encoders, mobilenet_encoders, xception_encoders,
|
14 |
+
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
|
15 |
+
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
|
16 |
+
timm_gernet_encoders
|
17 |
+
)
|
18 |
+
|
19 |
+
from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder
|
20 |
+
|
21 |
+
def initialize_encoders():
|
22 |
+
"""Initialize dictionary of available encoders."""
|
23 |
+
available_encoders = {}
|
24 |
+
encoder_modules = [
|
25 |
+
resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
|
26 |
+
densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
|
27 |
+
efficient_net_encoders, mobilenet_encoders, xception_encoders,
|
28 |
+
timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
|
29 |
+
timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
|
30 |
+
timm_gernet_encoders
|
31 |
+
]
|
32 |
+
|
33 |
+
for module in encoder_modules:
|
34 |
+
available_encoders.update(module)
|
35 |
+
|
36 |
+
try:
|
37 |
+
import segmentation_models_pytorch
|
38 |
+
from packaging import version
|
39 |
+
if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"):
|
40 |
+
from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders
|
41 |
+
from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders
|
42 |
+
available_encoders.update(mix_transformer_encoders)
|
43 |
+
available_encoders.update(mobileone_encoders)
|
44 |
+
except ImportError:
|
45 |
+
pass
|
46 |
+
|
47 |
+
return available_encoders
|
48 |
+
|
49 |
+
def download_weights(url, destination, hash_prefix=None, show_progress=True):
|
50 |
+
"""Downloads model weights with progress tracking and verification."""
|
51 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
52 |
+
|
53 |
+
req = Request(url, headers={"User-Agent": "torch.hub"})
|
54 |
+
response = urlopen(req)
|
55 |
+
content_length = response.headers.get("Content-Length")
|
56 |
+
file_size = int(content_length[0]) if content_length else None
|
57 |
+
|
58 |
+
destination = os.path.expanduser(destination)
|
59 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination))
|
60 |
+
|
61 |
+
try:
|
62 |
+
hasher = hashlib.sha256() if hash_prefix else None
|
63 |
+
|
64 |
+
with tqdm(total=file_size, disable=not show_progress,
|
65 |
+
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
66 |
+
while True:
|
67 |
+
buffer = response.read(8192)
|
68 |
+
if not buffer:
|
69 |
+
break
|
70 |
+
|
71 |
+
temp_file.write(buffer)
|
72 |
+
if hasher:
|
73 |
+
hasher.update(buffer)
|
74 |
+
pbar.update(len(buffer))
|
75 |
+
|
76 |
+
temp_file.close()
|
77 |
+
|
78 |
+
if hasher and hash_prefix:
|
79 |
+
digest = hasher.hexdigest()
|
80 |
+
if digest[:len(hash_prefix)] != hash_prefix:
|
81 |
+
raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")')
|
82 |
+
|
83 |
+
shutil.move(temp_file.name, destination)
|
84 |
+
|
85 |
+
finally:
|
86 |
+
temp_file.close()
|
87 |
+
if os.path.exists(temp_file.name):
|
88 |
+
os.remove(temp_file.name)
|
89 |
+
|
90 |
+
def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
|
91 |
+
"""Initializes and returns configured encoder."""
|
92 |
+
encoders = initialize_encoders()
|
93 |
+
|
94 |
+
if name.startswith("tu-"):
|
95 |
+
name = name[3:]
|
96 |
+
return TimmUniversalEncoder(
|
97 |
+
name=name,
|
98 |
+
in_channels=in_channels,
|
99 |
+
depth=depth,
|
100 |
+
output_stride=output_stride,
|
101 |
+
pretrained=weights is not None,
|
102 |
+
**kwargs
|
103 |
+
)
|
104 |
+
|
105 |
+
try:
|
106 |
+
encoder_config = encoders[name]
|
107 |
+
except KeyError:
|
108 |
+
raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}")
|
109 |
+
|
110 |
+
encoder_class = encoder_config["encoder"]
|
111 |
+
encoder_params = encoder_config["params"]
|
112 |
+
encoder_params.update(depth=depth)
|
113 |
+
|
114 |
+
if weights:
|
115 |
+
try:
|
116 |
+
weights_config = encoder_config["pretrained_settings"][weights]
|
117 |
+
except KeyError:
|
118 |
+
raise KeyError(
|
119 |
+
f"Invalid weights '{weights}' for encoder '{name}'. "
|
120 |
+
f"Available options: {list(encoder_config['pretrained_settings'].keys())}"
|
121 |
+
)
|
122 |
+
|
123 |
+
cache_dir = os.path.join(get_dir(), 'checkpoints')
|
124 |
+
os.makedirs(cache_dir, exist_ok=True)
|
125 |
+
|
126 |
+
weights_file = os.path.basename(weights_config["url"])
|
127 |
+
weights_path = os.path.join(cache_dir, weights_file)
|
128 |
+
|
129 |
+
if not os.path.exists(weights_path):
|
130 |
+
print(f'Downloading {weights_file}...')
|
131 |
+
download_weights(
|
132 |
+
weights_config["url"].replace("https", "http"),
|
133 |
+
weights_path
|
134 |
+
)
|
135 |
+
|
136 |
+
return encoder_class(**encoder_params)
|
semantic-segmentation/SemanticModel/evaluation_utils.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from segmentation_models_pytorch.base.modules import Activation
|
7 |
+
|
8 |
+
from SemanticModel.data_loader import SegmentationDataset
|
9 |
+
from SemanticModel.metrics import compute_mean_iou
|
10 |
+
from SemanticModel.image_preprocessing import get_validation_augmentations
|
11 |
+
|
12 |
+
def evaluate_model(model_config, data_path, image_size=None):
|
13 |
+
"""Evaluates model performance on a dataset."""
|
14 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
+
|
16 |
+
classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
|
17 |
+
|
18 |
+
data_path = os.path.realpath(data_path)
|
19 |
+
image_subdir = os.path.join(data_path, 'Images')
|
20 |
+
mask_subdir = os.path.join(data_path, 'Masks')
|
21 |
+
|
22 |
+
if not all(os.path.exists(d) for d in [image_subdir, mask_subdir]):
|
23 |
+
raise Exception("Missing required subdirectories: 'Images' and 'Masks'")
|
24 |
+
|
25 |
+
if not image_size:
|
26 |
+
sample_image = cv2.imread(os.path.join(image_subdir, os.listdir(image_subdir)[0]))
|
27 |
+
height, width = sample_image.shape[:2]
|
28 |
+
image_size = max(height, width)
|
29 |
+
|
30 |
+
evaluation_dataset = SegmentationDataset(
|
31 |
+
data_path,
|
32 |
+
classes=classes,
|
33 |
+
augmentation=get_validation_augmentations(
|
34 |
+
im_width=image_size,
|
35 |
+
im_height=image_size,
|
36 |
+
fixed_size=False
|
37 |
+
),
|
38 |
+
preprocessing=model_config.preprocessing
|
39 |
+
)
|
40 |
+
|
41 |
+
evaluation_loader = DataLoader(
|
42 |
+
evaluation_dataset,
|
43 |
+
batch_size=1,
|
44 |
+
shuffle=False,
|
45 |
+
num_workers=2
|
46 |
+
)
|
47 |
+
|
48 |
+
model = model_config.model.to(device)
|
49 |
+
model.eval()
|
50 |
+
|
51 |
+
requires_sigmoid = False
|
52 |
+
if model_config.n_classes == 1:
|
53 |
+
current_activation = _check_activation_function(model)
|
54 |
+
if current_activation != 'Sigmoid':
|
55 |
+
requires_sigmoid = True
|
56 |
+
|
57 |
+
predictions = []
|
58 |
+
ground_truth = []
|
59 |
+
|
60 |
+
print("Evaluating model performance...")
|
61 |
+
with torch.no_grad():
|
62 |
+
for images, masks in tqdm(evaluation_loader):
|
63 |
+
images = images.to(device)
|
64 |
+
masks = masks.to(device)
|
65 |
+
|
66 |
+
outputs = model.forward(images)
|
67 |
+
|
68 |
+
if model_config.n_classes > 1:
|
69 |
+
predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
|
70 |
+
ground_truth.extend([gt.cpu().argmax(dim=0) for gt in masks])
|
71 |
+
else:
|
72 |
+
if requires_sigmoid:
|
73 |
+
predictions.extend([
|
74 |
+
(torch.sigmoid(p) > 0.5).float().squeeze().cpu()
|
75 |
+
for p in outputs
|
76 |
+
])
|
77 |
+
else:
|
78 |
+
predictions.extend([
|
79 |
+
(p > 0.5).float().squeeze().cpu()
|
80 |
+
for p in outputs
|
81 |
+
])
|
82 |
+
ground_truth.extend([gt.cpu().squeeze() for gt in masks])
|
83 |
+
|
84 |
+
metrics = compute_mean_iou(
|
85 |
+
predictions,
|
86 |
+
ground_truth,
|
87 |
+
num_labels=len(classes),
|
88 |
+
ignore_index=255
|
89 |
+
)
|
90 |
+
|
91 |
+
print("\nEvaluation Results:")
|
92 |
+
print(f"Mean IoU: {metrics['mean_iou']:.3f}")
|
93 |
+
print("\nPer-class IoU:")
|
94 |
+
for idx, iou in enumerate(metrics['per_category_iou']):
|
95 |
+
print(f"{classes[idx]}: {iou:.3f}")
|
96 |
+
|
97 |
+
return metrics
|
98 |
+
|
99 |
+
def _check_activation_function(model):
|
100 |
+
"""Checks the activation function used in model's segmentation head."""
|
101 |
+
from segmentation_models_pytorch.base.modules import Activation
|
102 |
+
|
103 |
+
activation_functions = []
|
104 |
+
for _, module in model.segmentation_head.named_children():
|
105 |
+
if isinstance(module, Activation):
|
106 |
+
activation_functions.append(type(module.activation).__name__)
|
107 |
+
|
108 |
+
return activation_functions[-1] if activation_functions else None
|
semantic-segmentation/SemanticModel/image_preprocessing.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import albumentations as albu
|
4 |
+
from albumentations.augmentations.geometric.resize import LongestMaxSize
|
5 |
+
|
6 |
+
def round_pixel_dim(dimension: float) -> int:
|
7 |
+
"""Rounds pixel dimensions consistently."""
|
8 |
+
if abs(round(dimension) - dimension) == 0.5:
|
9 |
+
return int(2.0 * round(dimension / 2.0))
|
10 |
+
return int(round(dimension))
|
11 |
+
|
12 |
+
def resize_with_padding(image, target_size, stride=32, interpolation=cv2.INTER_LINEAR):
|
13 |
+
"""Resizes image maintaining aspect ratio and ensures dimensions are stride-compatible."""
|
14 |
+
height, width = image.shape[:2]
|
15 |
+
max_dimension = max(height, width)
|
16 |
+
|
17 |
+
if ((height % stride == 0) and (width % stride == 0) and
|
18 |
+
(max_dimension <= target_size)):
|
19 |
+
return image
|
20 |
+
|
21 |
+
scale = target_size / float(max(width, height))
|
22 |
+
new_dims = tuple(round_pixel_dim(dim * scale) for dim in (height, width))
|
23 |
+
new_height, new_width = new_dims
|
24 |
+
|
25 |
+
new_height = ((new_height // stride + 1) * stride
|
26 |
+
if new_height % stride != 0 else new_height)
|
27 |
+
new_width = ((new_width // stride + 1) * stride
|
28 |
+
if new_width % stride != 0 else new_width)
|
29 |
+
|
30 |
+
return cv2.resize(image, (new_width, new_height), interpolation=interpolation)
|
31 |
+
|
32 |
+
class PaddedResize(LongestMaxSize):
|
33 |
+
def apply(self, img: np.ndarray, target_size: int = 1024,
|
34 |
+
interpolation: int = cv2.INTER_LINEAR, **params) -> np.ndarray:
|
35 |
+
return resize_with_padding(img, target_size=target_size, interpolation=interpolation)
|
36 |
+
|
37 |
+
def get_training_augmentations(width=768, height=576):
|
38 |
+
"""Configures training-time augmentations."""
|
39 |
+
target_size = max([width, height])
|
40 |
+
transforms = [
|
41 |
+
albu.HorizontalFlip(p=0.5),
|
42 |
+
albu.ShiftScaleRotate(
|
43 |
+
scale_limit=0.5, rotate_limit=90, shift_limit=0.1, p=0.5, border_mode=0),
|
44 |
+
albu.PadIfNeeded(min_height=target_size, min_width=target_size, always_apply=True),
|
45 |
+
albu.RandomCrop(height=target_size, width=target_size, always_apply=True),
|
46 |
+
albu.GaussNoise(p=0.2),
|
47 |
+
albu.Perspective(p=0.2),
|
48 |
+
albu.OneOf([albu.CLAHE(p=1), albu.RandomGamma(p=1)], p=0.33),
|
49 |
+
albu.OneOf([
|
50 |
+
albu.Sharpen(p=1),
|
51 |
+
albu.Blur(blur_limit=3, p=1),
|
52 |
+
albu.MotionBlur(blur_limit=3, p=1)], p=0.33),
|
53 |
+
albu.OneOf([
|
54 |
+
albu.RandomBrightnessContrast(p=1),
|
55 |
+
albu.HueSaturationValue(p=1)], p=0.33),
|
56 |
+
]
|
57 |
+
return albu.Compose(transforms)
|
58 |
+
|
59 |
+
def get_validation_augmentations(width=1920, height=1440, fixed_size=True):
|
60 |
+
"""Configures validation/inference-time augmentations."""
|
61 |
+
if fixed_size:
|
62 |
+
transforms = [albu.Resize(height=height, width=width, always_apply=True)]
|
63 |
+
return albu.Compose(transforms)
|
64 |
+
|
65 |
+
target_size = max(width, height)
|
66 |
+
transforms = [PaddedResize(max_size=target_size, always_apply=True)]
|
67 |
+
return albu.Compose(transforms)
|
68 |
+
|
69 |
+
def convert_to_tensor(x, **kwargs):
|
70 |
+
"""Converts image array to PyTorch tensor format."""
|
71 |
+
if x.ndim == 2:
|
72 |
+
x = np.expand_dims(x, axis=-1)
|
73 |
+
return x.transpose(2, 0, 1).astype('float32')
|
74 |
+
|
75 |
+
def get_preprocessing_pipeline(preprocessing_fn):
|
76 |
+
"""Builds preprocessing pipeline including normalization and tensor conversion."""
|
77 |
+
transforms = [
|
78 |
+
albu.Lambda(image=preprocessing_fn),
|
79 |
+
albu.Lambda(image=convert_to_tensor, mask=convert_to_tensor),
|
80 |
+
]
|
81 |
+
return albu.Compose(transforms)
|
semantic-segmentation/SemanticModel/metrics.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def compute_intersection_union(prediction, ground_truth, num_classes, ignore_index: bool,
|
5 |
+
label_mapping: Optional[Dict[int, int]] = None,
|
6 |
+
reduce_labels: bool = False):
|
7 |
+
"""Computes intersection and union for IoU calculation."""
|
8 |
+
|
9 |
+
if label_mapping:
|
10 |
+
for old_id, new_id in label_mapping.items():
|
11 |
+
ground_truth[ground_truth == old_id] = new_id
|
12 |
+
|
13 |
+
prediction = np.array(prediction)
|
14 |
+
ground_truth = np.array(ground_truth)
|
15 |
+
|
16 |
+
if reduce_labels:
|
17 |
+
ground_truth[ground_truth == 0] = 255
|
18 |
+
ground_truth = ground_truth - 1
|
19 |
+
ground_truth[ground_truth == 254] = 255
|
20 |
+
|
21 |
+
valid_mask = np.not_equal(ground_truth, ignore_index)
|
22 |
+
prediction = prediction[valid_mask]
|
23 |
+
ground_truth = ground_truth[valid_mask]
|
24 |
+
|
25 |
+
intersection_mask = prediction == ground_truth
|
26 |
+
intersection = prediction[intersection_mask]
|
27 |
+
|
28 |
+
area_intersection = np.histogram(intersection, bins=num_classes,
|
29 |
+
range=(0, num_classes - 1))[0]
|
30 |
+
area_prediction = np.histogram(prediction, bins=num_classes,
|
31 |
+
range=(0, num_classes - 1))[0]
|
32 |
+
area_ground_truth = np.histogram(ground_truth, bins=num_classes,
|
33 |
+
range=(0, num_classes - 1))[0]
|
34 |
+
area_union = area_prediction + area_ground_truth - area_intersection
|
35 |
+
|
36 |
+
return area_intersection, area_union, area_prediction, area_ground_truth
|
37 |
+
|
38 |
+
def compute_total_intersection_union(predictions, ground_truths, num_classes, ignore_index: bool,
|
39 |
+
label_mapping: Optional[Dict[int, int]] = None,
|
40 |
+
reduce_labels: bool = False):
|
41 |
+
"""Computes total intersection and union across all samples."""
|
42 |
+
|
43 |
+
totals = {
|
44 |
+
'intersection': np.zeros((num_classes,), dtype=np.float64),
|
45 |
+
'union': np.zeros((num_classes,), dtype=np.float64),
|
46 |
+
'prediction': np.zeros((num_classes,), dtype=np.float64),
|
47 |
+
'ground_truth': np.zeros((num_classes,), dtype=np.float64)
|
48 |
+
}
|
49 |
+
|
50 |
+
for pred, gt in zip(predictions, ground_truths):
|
51 |
+
intersection, union, pred_area, gt_area = compute_intersection_union(
|
52 |
+
pred, gt, num_classes, ignore_index, label_mapping, reduce_labels
|
53 |
+
)
|
54 |
+
totals['intersection'] += intersection
|
55 |
+
totals['union'] += union
|
56 |
+
totals['prediction'] += pred_area
|
57 |
+
totals['ground_truth'] += gt_area
|
58 |
+
|
59 |
+
return tuple(totals.values())
|
60 |
+
|
61 |
+
def compute_mean_iou(predictions, ground_truths, num_classes, ignore_index: bool,
|
62 |
+
nan_to_num: Optional[int] = None,
|
63 |
+
label_mapping: Optional[Dict[int, int]] = None,
|
64 |
+
reduce_labels: bool = False):
|
65 |
+
"""Computes mean IoU and related metrics."""
|
66 |
+
|
67 |
+
intersection, union, prediction_area, ground_truth_area = compute_total_intersection_union(
|
68 |
+
predictions, ground_truths, num_classes, ignore_index, label_mapping, reduce_labels
|
69 |
+
)
|
70 |
+
|
71 |
+
metrics = {}
|
72 |
+
|
73 |
+
# Compute overall accuracy
|
74 |
+
total_accuracy = intersection.sum() / ground_truth_area.sum()
|
75 |
+
|
76 |
+
# Compute IoU per class
|
77 |
+
iou_per_class = intersection / union
|
78 |
+
accuracy_per_class = intersection / ground_truth_area
|
79 |
+
|
80 |
+
metrics.update({
|
81 |
+
"mean_iou": np.nanmean(iou_per_class),
|
82 |
+
"mean_accuracy": np.nanmean(accuracy_per_class),
|
83 |
+
"overall_accuracy": total_accuracy,
|
84 |
+
"per_category_iou": iou_per_class,
|
85 |
+
"per_category_accuracy": accuracy_per_class
|
86 |
+
})
|
87 |
+
|
88 |
+
if nan_to_num is not None:
|
89 |
+
metrics = {
|
90 |
+
metric: np.nan_to_num(value, nan=nan_to_num)
|
91 |
+
for metric, value in metrics.items()
|
92 |
+
}
|
93 |
+
|
94 |
+
return metrics
|
semantic-segmentation/SemanticModel/model_core.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import segmentation_models_pytorch as smp
|
4 |
+
from segmentation_models_pytorch import utils
|
5 |
+
|
6 |
+
from SemanticModel.encoder_management import initialize_encoder
|
7 |
+
from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy
|
8 |
+
from SemanticModel.image_preprocessing import get_preprocessing_pipeline
|
9 |
+
|
10 |
+
class SegmentationModel:
|
11 |
+
def __init__(self, classes=['background', 'foreground'], architecture='unet',
|
12 |
+
encoder='timm-regnety_120', weights='imagenet', loss=None):
|
13 |
+
self._initialize_classes(classes)
|
14 |
+
self.architecture = architecture
|
15 |
+
self.encoder = encoder
|
16 |
+
self.weights = weights
|
17 |
+
self._setup_loss_function(loss)
|
18 |
+
self._initialize_model()
|
19 |
+
|
20 |
+
def _initialize_classes(self, classes):
|
21 |
+
"""Sets up class configuration."""
|
22 |
+
if len(classes) <= 2:
|
23 |
+
self.classes = [c for c in classes if c.lower() != 'background']
|
24 |
+
self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background']
|
25 |
+
self.background_flag = 'background' in classes
|
26 |
+
else:
|
27 |
+
self.classes = classes
|
28 |
+
self.class_values = list(range(len(classes)))
|
29 |
+
self.background_flag = False
|
30 |
+
self.n_classes = len(self.classes)
|
31 |
+
|
32 |
+
def _setup_loss_function(self, loss):
|
33 |
+
"""Configures model's loss function."""
|
34 |
+
if not loss:
|
35 |
+
loss = 'bce_with_logits' if self.n_classes > 1 else 'dice'
|
36 |
+
|
37 |
+
if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']:
|
38 |
+
print(f'Invalid loss: {loss}, defaulting to dice')
|
39 |
+
loss = 'dice'
|
40 |
+
|
41 |
+
loss_configs = {
|
42 |
+
'bce_with_logits': {
|
43 |
+
'activation': None,
|
44 |
+
'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss()
|
45 |
+
},
|
46 |
+
'dice': {
|
47 |
+
'activation': 'softmax' if self.n_classes > 1 else 'sigmoid',
|
48 |
+
'loss': utils.losses.DiceLoss()
|
49 |
+
},
|
50 |
+
'focal': {
|
51 |
+
'activation': None,
|
52 |
+
'loss': FocalLossFunction()
|
53 |
+
},
|
54 |
+
'tversky': {
|
55 |
+
'activation': None,
|
56 |
+
'loss': TverskyLossFunction()
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
+
config = loss_configs[loss.lower()]
|
61 |
+
self.activation = config['activation']
|
62 |
+
self.loss = config['loss']
|
63 |
+
self.loss_name = loss
|
64 |
+
|
65 |
+
def _initialize_model(self):
|
66 |
+
"""Initializes the segmentation model architecture."""
|
67 |
+
if self.weights.endswith('pth'):
|
68 |
+
self._load_pretrained_model()
|
69 |
+
else:
|
70 |
+
self._create_new_model()
|
71 |
+
|
72 |
+
def _load_pretrained_model(self):
|
73 |
+
"""Loads model from pretrained weights."""
|
74 |
+
print('Loading pretrained model...')
|
75 |
+
self.model = torch.load(self.weights)
|
76 |
+
if isinstance(self.model, torch.nn.DataParallel):
|
77 |
+
self.model = self.model.module
|
78 |
+
|
79 |
+
try:
|
80 |
+
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
|
81 |
+
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
|
82 |
+
except:
|
83 |
+
print('Failed to configure preprocessing. Setting to None.')
|
84 |
+
self.preprocessing = None
|
85 |
+
|
86 |
+
def _create_new_model(self):
|
87 |
+
"""Creates new model with specified architecture."""
|
88 |
+
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
|
89 |
+
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
|
90 |
+
initialize_encoder(name=self.encoder, weights=self.weights)
|
91 |
+
|
92 |
+
architectures = {
|
93 |
+
'unet': smp.Unet,
|
94 |
+
'unet++': smp.UnetPlusPlus,
|
95 |
+
'deeplabv3': smp.DeepLabV3,
|
96 |
+
'deeplabv3+': smp.DeepLabV3Plus,
|
97 |
+
'fpn': smp.FPN,
|
98 |
+
'linknet': smp.Linknet,
|
99 |
+
'manet': smp.MAnet,
|
100 |
+
'pan': smp.PAN,
|
101 |
+
'pspnet': smp.PSPNet
|
102 |
+
}
|
103 |
+
|
104 |
+
if self.architecture not in architectures:
|
105 |
+
raise ValueError(f'Unsupported architecture: {self.architecture}')
|
106 |
+
|
107 |
+
self.model = architectures[self.architecture](
|
108 |
+
encoder_name=self.encoder,
|
109 |
+
encoder_weights=self.weights,
|
110 |
+
classes=self.n_classes,
|
111 |
+
activation=self.activation
|
112 |
+
)
|
113 |
+
|
114 |
+
@property
|
115 |
+
def config_data(self):
|
116 |
+
"""Returns model configuration data."""
|
117 |
+
return {
|
118 |
+
'architecture': self.architecture,
|
119 |
+
'encoder': self.encoder,
|
120 |
+
'weights': self.weights,
|
121 |
+
'activation': self.activation,
|
122 |
+
'loss': self.loss_name,
|
123 |
+
'classes': ['background'] + self.classes if self.background_flag else self.classes
|
124 |
+
}
|
125 |
+
|
126 |
+
def list_architectures():
|
127 |
+
"""Returns available architecture options."""
|
128 |
+
return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn',
|
129 |
+
'linknet', 'manet', 'pan', 'pspnet']
|
semantic-segmentation/SemanticModel/prediction.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import imageio
|
6 |
+
import tifffile
|
7 |
+
import numpy as np
|
8 |
+
import slidingwindow
|
9 |
+
import rasterio as rio
|
10 |
+
import geopandas as gpd
|
11 |
+
from shapely.geometry import Polygon
|
12 |
+
from rasterio import mask as riomask
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from SemanticModel.visualization import generate_color_mapping
|
15 |
+
from SemanticModel.image_preprocessing import get_validation_augmentations
|
16 |
+
from SemanticModel.data_loader import InferenceDataset, StreamingDataset
|
17 |
+
from SemanticModel.utilities import calc_image_size, convert_coordinates
|
18 |
+
|
19 |
+
class PredictionPipeline:
|
20 |
+
def __init__(self, model_config, device=None):
|
21 |
+
self.config = model_config
|
22 |
+
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
+
self.classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
|
24 |
+
self.colors = generate_color_mapping(len(self.classes))
|
25 |
+
self.model = model_config.model.to(self.device)
|
26 |
+
self.model.eval()
|
27 |
+
|
28 |
+
def _preprocess_image(self, image_path, target_size=None):
|
29 |
+
"""Preprocesses single image for prediction."""
|
30 |
+
image = cv2.imread(image_path)
|
31 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
32 |
+
height, width = image.shape[:2]
|
33 |
+
|
34 |
+
target_size = target_size or max(height, width)
|
35 |
+
test_height, test_width = calc_image_size(image, target_size)
|
36 |
+
|
37 |
+
augmentation = get_validation_augmentations(test_width, test_height)
|
38 |
+
image = augmentation(image=image)['image']
|
39 |
+
image = self.config.preprocessing(image=image)['image']
|
40 |
+
|
41 |
+
return image, (height, width)
|
42 |
+
|
43 |
+
def predict_single_image(self, image_path, target_size=None, output_dir=None,
|
44 |
+
format='integer', save_output=True):
|
45 |
+
"""Generates prediction for a single image."""
|
46 |
+
image, original_dims = self._preprocess_image(image_path, target_size)
|
47 |
+
x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)
|
48 |
+
|
49 |
+
with torch.no_grad():
|
50 |
+
prediction = self.model.predict(x_tensor)
|
51 |
+
|
52 |
+
if self.config.n_classes > 1:
|
53 |
+
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
|
54 |
+
else:
|
55 |
+
prediction = prediction.squeeze().cpu().numpy().round()
|
56 |
+
|
57 |
+
# Resize to original dimensions if needed
|
58 |
+
if prediction.shape[:2] != original_dims:
|
59 |
+
prediction = cv2.resize(prediction, original_dims[::-1],
|
60 |
+
interpolation=cv2.INTER_NEAREST)
|
61 |
+
|
62 |
+
prediction = self._format_prediction(prediction, format)
|
63 |
+
|
64 |
+
if save_output:
|
65 |
+
self._save_prediction(prediction, image_path, output_dir, format)
|
66 |
+
|
67 |
+
return prediction
|
68 |
+
|
69 |
+
def predict_directory(self, input_dir, target_size=None, output_dir=None,
|
70 |
+
fixed_size=True, format='integer'):
|
71 |
+
"""Generates predictions for all images in directory."""
|
72 |
+
output_dir = output_dir or os.path.join(input_dir, 'predictions')
|
73 |
+
os.makedirs(output_dir, exist_ok=True)
|
74 |
+
|
75 |
+
dataset = InferenceDataset(
|
76 |
+
input_dir,
|
77 |
+
classes=self.classes,
|
78 |
+
augmentation=get_validation_augmentations(
|
79 |
+
target_size, target_size, fixed_size=fixed_size
|
80 |
+
) if target_size else None,
|
81 |
+
preprocessing=self.config.preprocessing
|
82 |
+
)
|
83 |
+
|
84 |
+
total_images = len(dataset)
|
85 |
+
start_time = time.time()
|
86 |
+
|
87 |
+
for idx in range(total_images):
|
88 |
+
if (idx + 1) % 10 == 0 or idx == total_images - 1:
|
89 |
+
elapsed = time.time() - start_time
|
90 |
+
print(f'\rProcessed {idx+1}/{total_images} images in {elapsed:.1f}s',
|
91 |
+
end='')
|
92 |
+
|
93 |
+
image, height, width = dataset[idx]
|
94 |
+
filename = dataset.filenames[idx]
|
95 |
+
|
96 |
+
x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)
|
97 |
+
with torch.no_grad():
|
98 |
+
prediction = self.model.predict(x_tensor)
|
99 |
+
|
100 |
+
if self.config.n_classes > 1:
|
101 |
+
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
|
102 |
+
else:
|
103 |
+
prediction = prediction.squeeze().cpu().numpy().round()
|
104 |
+
|
105 |
+
if prediction.shape != (height, width):
|
106 |
+
prediction = cv2.resize(prediction, (width, height),
|
107 |
+
interpolation=cv2.INTER_NEAREST)
|
108 |
+
|
109 |
+
prediction = self._format_prediction(prediction, format)
|
110 |
+
self._save_prediction(prediction, filename, output_dir, format)
|
111 |
+
|
112 |
+
print(f'\nPredictions saved to: {output_dir}')
|
113 |
+
return output_dir
|
114 |
+
|
115 |
+
def predict_raster(self, raster_path, tile_size=1024, overlap=0.175,
|
116 |
+
boundary_path=None, output_path=None, format='integer'):
|
117 |
+
"""Processes large raster images using tiling approach."""
|
118 |
+
print('Loading raster...')
|
119 |
+
with rio.open(raster_path) as src:
|
120 |
+
raster = src.read()
|
121 |
+
raster = np.moveaxis(raster, 0, 2)[:,:,:3]
|
122 |
+
profile = src.profile
|
123 |
+
transform = src.transform
|
124 |
+
|
125 |
+
if boundary_path:
|
126 |
+
boundary = gpd.read_file(boundary_path)
|
127 |
+
boundary = boundary.to_crs(profile['crs'])
|
128 |
+
boundary_geom = boundary.iloc[0].geometry
|
129 |
+
|
130 |
+
tiles = slidingwindow.generate(
|
131 |
+
raster,
|
132 |
+
slidingwindow.DimOrder.HeightWidthChannel,
|
133 |
+
tile_size,
|
134 |
+
overlap
|
135 |
+
)
|
136 |
+
|
137 |
+
pred_raster = np.zeros_like(raster[:,:,0], dtype='uint8')
|
138 |
+
confidence = np.zeros_like(pred_raster, dtype=np.float32)
|
139 |
+
|
140 |
+
aug = get_validation_augmentations(tile_size, tile_size, fixed_size=False)
|
141 |
+
|
142 |
+
for idx, tile in enumerate(tiles):
|
143 |
+
if (idx + 1) % 10 == 0 or idx == len(tiles) - 1:
|
144 |
+
print(f'\rProcessed {idx+1}/{len(tiles)} tiles', end='')
|
145 |
+
|
146 |
+
bounds = tile.indices()
|
147 |
+
|
148 |
+
tile_image = raster[bounds[0], bounds[1]]
|
149 |
+
|
150 |
+
if boundary_path:
|
151 |
+
corners = [
|
152 |
+
convert_coordinates(transform, bounds[1].start, bounds[0].start),
|
153 |
+
convert_coordinates(transform, bounds[1].stop, bounds[0].start),
|
154 |
+
convert_coordinates(transform, bounds[1].stop, bounds[0].stop),
|
155 |
+
convert_coordinates(transform, bounds[1].start, bounds[0].stop)
|
156 |
+
]
|
157 |
+
if not Polygon(corners).intersects(boundary_geom):
|
158 |
+
continue
|
159 |
+
|
160 |
+
processed = aug(image=tile_image)['image']
|
161 |
+
processed = self.config.preprocessing(image=processed)['image']
|
162 |
+
|
163 |
+
x_tensor = torch.from_numpy(processed).to(self.device).unsqueeze(0)
|
164 |
+
with torch.no_grad():
|
165 |
+
prediction = self.model.predict(x_tensor)
|
166 |
+
prediction = prediction.squeeze().cpu().numpy()
|
167 |
+
|
168 |
+
if self.config.n_classes > 1:
|
169 |
+
tile_pred = np.argmax(prediction, axis=0)
|
170 |
+
tile_conf = np.max(prediction, axis=0)
|
171 |
+
else:
|
172 |
+
tile_conf = np.abs(prediction - 0.5)
|
173 |
+
tile_pred = prediction.round()
|
174 |
+
|
175 |
+
if tile_pred.shape != tile_image.shape[:2]:
|
176 |
+
tile_pred = cv2.resize(tile_pred, tile_image.shape[:2][::-1],
|
177 |
+
interpolation=cv2.INTER_NEAREST)
|
178 |
+
tile_conf = cv2.resize(tile_conf, tile_image.shape[:2][::-1],
|
179 |
+
interpolation=cv2.INTER_LINEAR)
|
180 |
+
|
181 |
+
# Update prediction and confidence maps
|
182 |
+
existing_conf = confidence[bounds[0], bounds[1]]
|
183 |
+
existing_pred = pred_raster[bounds[0], bounds[1]]
|
184 |
+
|
185 |
+
mask = existing_conf < tile_conf
|
186 |
+
existing_pred[mask] = tile_pred[mask]
|
187 |
+
existing_conf[mask] = tile_conf[mask]
|
188 |
+
|
189 |
+
pred_raster[bounds[0], bounds[1]] = existing_pred
|
190 |
+
confidence[bounds[0], bounds[1]] = existing_conf
|
191 |
+
|
192 |
+
pred_raster = self._format_prediction(pred_raster, format)
|
193 |
+
|
194 |
+
if output_path or boundary_path:
|
195 |
+
self._save_raster_prediction(
|
196 |
+
pred_raster, raster_path, output_path,
|
197 |
+
profile, boundary_geom if boundary_path else None
|
198 |
+
)
|
199 |
+
|
200 |
+
return pred_raster, profile
|
201 |
+
|
202 |
+
def _format_prediction(self, prediction, format):
|
203 |
+
"""Formats prediction according to specified output type."""
|
204 |
+
if format == 'integer':
|
205 |
+
return prediction.astype('uint8')
|
206 |
+
elif format == 'color':
|
207 |
+
return self._apply_color_mapping(prediction)
|
208 |
+
else:
|
209 |
+
raise ValueError(f"Unsupported format: {format}")
|
210 |
+
|
211 |
+
def _save_prediction(self, prediction, source_path, output_dir, format):
|
212 |
+
"""Saves prediction to disk."""
|
213 |
+
filename = os.path.splitext(os.path.basename(source_path))[0]
|
214 |
+
output_path = os.path.join(output_dir, f"{filename}_pred.png")
|
215 |
+
cv2.imwrite(output_path, prediction)
|
216 |
+
|
217 |
+
|
218 |
+
def _save_raster_prediction(self, prediction, source_path, output_path,
|
219 |
+
profile, boundary=None):
|
220 |
+
"""Saves raster prediction with geospatial information."""
|
221 |
+
output_path = output_path or source_path.replace(
|
222 |
+
os.path.splitext(source_path)[1], '_predicted.tif'
|
223 |
+
)
|
224 |
+
|
225 |
+
profile.update(
|
226 |
+
dtype='uint8',
|
227 |
+
count=3 if prediction.ndim == 3 else 1
|
228 |
+
)
|
229 |
+
|
230 |
+
with rio.open(output_path, 'w', **profile) as dst:
|
231 |
+
if prediction.ndim == 3:
|
232 |
+
for i in range(3):
|
233 |
+
dst.write(prediction[:,:,i], i+1)
|
234 |
+
else:
|
235 |
+
dst.write(prediction, 1)
|
236 |
+
|
237 |
+
if boundary:
|
238 |
+
with rio.open(output_path) as src:
|
239 |
+
cropped, transform = riomask.mask(src, [boundary], crop=True)
|
240 |
+
profile.update(
|
241 |
+
height=cropped.shape[1],
|
242 |
+
width=cropped.shape[2],
|
243 |
+
transform=transform
|
244 |
+
)
|
245 |
+
|
246 |
+
os.remove(output_path)
|
247 |
+
with rio.open(output_path, 'w', **profile) as dst:
|
248 |
+
dst.write(cropped)
|
249 |
+
|
250 |
+
print(f'\nPrediction saved to: {output_path}')
|
251 |
+
|
252 |
+
def predict_video_frames(self, input_dir, target_size=None, output_dir=None):
|
253 |
+
"""Processes video frames with specialized visualization."""
|
254 |
+
output_dir = output_dir or os.path.join(input_dir, 'predictions')
|
255 |
+
os.makedirs(output_dir, exist_ok=True)
|
256 |
+
|
257 |
+
dataset = StreamingDataset(
|
258 |
+
input_dir,
|
259 |
+
classes=self.classes,
|
260 |
+
augmentation=get_validation_augmentations(
|
261 |
+
target_size, target_size
|
262 |
+
) if target_size else None,
|
263 |
+
preprocessing=self.config.preprocessing
|
264 |
+
)
|
265 |
+
|
266 |
+
image = cv2.imread(dataset.image_paths[0])
|
267 |
+
height, width = image.shape[:2]
|
268 |
+
|
269 |
+
white = 255 * np.ones((height, width))
|
270 |
+
black = np.zeros_like(white)
|
271 |
+
red = np.dstack((white, black, black))
|
272 |
+
blue = np.dstack((black, black, white))
|
273 |
+
|
274 |
+
# Pre-compute rotated versions
|
275 |
+
rotated_red = np.rot90(red)
|
276 |
+
rotated_blue = np.rot90(blue)
|
277 |
+
|
278 |
+
total_frames = len(dataset)
|
279 |
+
start_time = time.time()
|
280 |
+
|
281 |
+
for idx in range(total_frames):
|
282 |
+
if (idx + 1) % 10 == 0 or idx == total_frames - 1:
|
283 |
+
elapsed = time.time() - start_time
|
284 |
+
print(f'\rProcessed {idx+1}/{total_frames} frames in {elapsed:.1f}s', end='')
|
285 |
+
|
286 |
+
frame, height, width = dataset[idx]
|
287 |
+
filename = dataset.filenames[idx]
|
288 |
+
|
289 |
+
x_tensor = torch.from_numpy(frame).to(self.device).unsqueeze(0)
|
290 |
+
with torch.no_grad():
|
291 |
+
prediction = self.model.predict(x_tensor)
|
292 |
+
|
293 |
+
if self.config.n_classes > 1:
|
294 |
+
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
|
295 |
+
masks = [prediction == i for i in range(1, self.config.n_classes)]
|
296 |
+
else:
|
297 |
+
prediction = prediction.squeeze().cpu().numpy().round()
|
298 |
+
masks = [prediction == 1]
|
299 |
+
|
300 |
+
if prediction.shape != (height, width):
|
301 |
+
prediction = cv2.resize(prediction, (width, height),
|
302 |
+
interpolation=cv2.INTER_NEAREST)
|
303 |
+
|
304 |
+
original = cv2.imread(os.path.join(input_dir, filename))
|
305 |
+
original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
|
306 |
+
|
307 |
+
try:
|
308 |
+
for i, mask in enumerate(masks):
|
309 |
+
color = red if i == 0 else blue
|
310 |
+
rotated_color = rotated_red if i == 0 else rotated_blue
|
311 |
+
try:
|
312 |
+
original[mask,:] = 0.45*original[mask,:] + 0.55*color[mask,:]
|
313 |
+
except:
|
314 |
+
original[mask,:] = 0.45*original[mask,:] + 0.55*rotated_color[mask,:]
|
315 |
+
except:
|
316 |
+
print(f"\nWarning: Error processing frame {filename}")
|
317 |
+
continue
|
318 |
+
|
319 |
+
output_path = os.path.join(output_dir, filename)
|
320 |
+
imageio.imwrite(output_path, original, quality=100)
|
321 |
+
|
322 |
+
print(f'\nProcessed frames saved to: {output_dir}')
|
323 |
+
return output_dir
|
324 |
+
|
325 |
+
def _apply_color_mapping(self, prediction):
|
326 |
+
"""Applies color mapping to prediction."""
|
327 |
+
height, width = prediction.shape
|
328 |
+
colored = np.zeros((height, width, 3), dtype='uint8')
|
329 |
+
|
330 |
+
for i, class_name in enumerate(self.classes):
|
331 |
+
if class_name.lower() == 'background':
|
332 |
+
continue
|
333 |
+
color = self.colors[i]
|
334 |
+
colored[prediction == i] = color
|
335 |
+
|
336 |
+
return colored
|
semantic-segmentation/SemanticModel/training.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import wandb
|
5 |
+
import datetime
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
from segmentation_models_pytorch.base.modules import Activation
|
11 |
+
|
12 |
+
from SemanticModel.data_loader import SegmentationDataset
|
13 |
+
from SemanticModel.metrics import compute_mean_iou
|
14 |
+
from SemanticModel.image_preprocessing import get_training_augmentations, get_validation_augmentations
|
15 |
+
from SemanticModel.utilities import list_images, validate_dimensions
|
16 |
+
|
17 |
+
class ModelTrainer:
|
18 |
+
def __init__(self, model_config, root_dir, epochs=40, train_size=1024,
|
19 |
+
val_size=None, workers=2, batch_size=2, learning_rate=1e-4,
|
20 |
+
step_count=2, decay_factor=0.8, wandb_config=None,
|
21 |
+
optimizer='rmsprop', target_class=None, resume_path=None):
|
22 |
+
|
23 |
+
self.config = model_config
|
24 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
25 |
+
self.root_dir = root_dir
|
26 |
+
self._initialize_training_params(epochs, train_size, val_size, workers,
|
27 |
+
batch_size, learning_rate, step_count,
|
28 |
+
decay_factor, optimizer, target_class)
|
29 |
+
self._setup_directories()
|
30 |
+
self._initialize_datasets()
|
31 |
+
self._setup_optimizer()
|
32 |
+
self._initialize_tracking()
|
33 |
+
|
34 |
+
if resume_path:
|
35 |
+
self._resume_training(resume_path)
|
36 |
+
|
37 |
+
def _initialize_training_params(self, epochs, train_size, val_size, workers,
|
38 |
+
batch_size, learning_rate, step_count,
|
39 |
+
decay_factor, optimizer, target_class):
|
40 |
+
self.epochs = epochs
|
41 |
+
self.train_size = train_size
|
42 |
+
self.val_size = val_size
|
43 |
+
self.workers = workers
|
44 |
+
self.batch_size = batch_size
|
45 |
+
self.learning_rate = learning_rate
|
46 |
+
self.step_schedule = self._calculate_step_schedule(epochs, step_count)
|
47 |
+
self.decay_factor = decay_factor
|
48 |
+
self.optimizer_type = optimizer
|
49 |
+
self.target_class = target_class
|
50 |
+
self.current_epoch = 1
|
51 |
+
self.best_iou = 0.0
|
52 |
+
self.best_epoch = 0
|
53 |
+
self.classes = ['background'] + self.config.classes if self.config.background_flag else self.config.classes
|
54 |
+
|
55 |
+
def _setup_directories(self):
|
56 |
+
"""Verifies and creates necessary directories."""
|
57 |
+
self.train_dir = os.path.join(self.root_dir, 'train')
|
58 |
+
self.val_dir = os.path.join(self.root_dir, 'val')
|
59 |
+
|
60 |
+
required_subdirs = ['Images', 'Masks']
|
61 |
+
for path in [self.train_dir] + ([self.val_dir] if os.path.exists(self.val_dir) else []):
|
62 |
+
for subdir in required_subdirs:
|
63 |
+
full_path = os.path.join(path, subdir)
|
64 |
+
if not os.path.exists(full_path):
|
65 |
+
raise FileNotFoundError(f"Missing directory: {full_path}")
|
66 |
+
|
67 |
+
def _initialize_datasets(self):
|
68 |
+
"""Sets up training and validation datasets."""
|
69 |
+
self.train_dataset = SegmentationDataset(
|
70 |
+
self.train_dir,
|
71 |
+
classes=self.classes,
|
72 |
+
augmentation=get_training_augmentations(self.train_size, self.train_size),
|
73 |
+
preprocessing=self.config.preprocessing
|
74 |
+
)
|
75 |
+
|
76 |
+
if os.path.exists(self.val_dir):
|
77 |
+
self.val_dataset = SegmentationDataset(
|
78 |
+
self.val_dir,
|
79 |
+
classes=self.classes,
|
80 |
+
augmentation=get_validation_augmentations(
|
81 |
+
self.val_size or self.train_size,
|
82 |
+
self.val_size or self.train_size,
|
83 |
+
fixed_size=False
|
84 |
+
),
|
85 |
+
preprocessing=self.config.preprocessing
|
86 |
+
)
|
87 |
+
self.val_loader = DataLoader(
|
88 |
+
self.val_dataset,
|
89 |
+
batch_size=1,
|
90 |
+
shuffle=False,
|
91 |
+
num_workers=self.workers
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
self.val_dataset = self.train_dataset
|
95 |
+
self.val_loader = DataLoader(
|
96 |
+
self.val_dataset,
|
97 |
+
batch_size=1,
|
98 |
+
shuffle=False,
|
99 |
+
num_workers=self.workers
|
100 |
+
)
|
101 |
+
|
102 |
+
self.train_loader = DataLoader(
|
103 |
+
self.train_dataset,
|
104 |
+
batch_size=self.batch_size,
|
105 |
+
shuffle=True,
|
106 |
+
num_workers=self.workers
|
107 |
+
)
|
108 |
+
|
109 |
+
def _setup_optimizer(self):
|
110 |
+
"""Configures model optimizer."""
|
111 |
+
optimizer_map = {
|
112 |
+
'adam': torch.optim.Adam,
|
113 |
+
'sgd': lambda params: torch.optim.SGD(params, momentum=0.9),
|
114 |
+
'rmsprop': torch.optim.RMSprop
|
115 |
+
}
|
116 |
+
optimizer_class = optimizer_map.get(self.optimizer_type.lower())
|
117 |
+
if not optimizer_class:
|
118 |
+
raise ValueError(f"Unsupported optimizer: {self.optimizer_type}")
|
119 |
+
|
120 |
+
self.optimizer = optimizer_class([{'params': self.config.model.parameters(),
|
121 |
+
'lr': self.learning_rate}])
|
122 |
+
|
123 |
+
def _initialize_tracking(self):
|
124 |
+
"""Sets up training progress tracking."""
|
125 |
+
timestamp = datetime.datetime.now().strftime("%m-%d-%Y_%H%M%S")
|
126 |
+
self.output_dir = os.path.join(
|
127 |
+
self.root_dir,
|
128 |
+
f'model_outputs-{self.config.architecture}[{self.config.encoder}]-{timestamp}'
|
129 |
+
)
|
130 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
131 |
+
|
132 |
+
self.writer = SummaryWriter(log_dir=self.output_dir)
|
133 |
+
self.metrics = {
|
134 |
+
'best_epoch': self.best_epoch,
|
135 |
+
'best_epoch_iou': self.best_iou,
|
136 |
+
'last_epoch': 0,
|
137 |
+
'last_epoch_iou': 0.0,
|
138 |
+
'last_epoch_lr': self.learning_rate,
|
139 |
+
'step_schedule': self.step_schedule,
|
140 |
+
'decay_factor': self.decay_factor,
|
141 |
+
'target_class': self.target_class or 'overall'
|
142 |
+
}
|
143 |
+
|
144 |
+
def _calculate_step_schedule(self, epochs, steps):
|
145 |
+
"""Calculates learning rate step schedule."""
|
146 |
+
return list(map(int, np.linspace(0, epochs, steps + 2)[1:-1]))
|
147 |
+
|
148 |
+
def train(self):
|
149 |
+
"""Executes training loop."""
|
150 |
+
model = self.config.model.to(self.device)
|
151 |
+
if torch.cuda.device_count() > 1:
|
152 |
+
model = torch.nn.DataParallel(model)
|
153 |
+
print(f'Using {torch.cuda.device_count()} GPUs')
|
154 |
+
|
155 |
+
self._save_config()
|
156 |
+
|
157 |
+
for epoch in range(self.current_epoch, self.epochs + 1):
|
158 |
+
print(f'\nEpoch {epoch}/{self.epochs}')
|
159 |
+
print(f'Learning rate: {self.optimizer.param_groups[0]["lr"]:.3e}')
|
160 |
+
|
161 |
+
train_loss = self._train_epoch(model)
|
162 |
+
val_loss, val_metrics = self._validate_epoch(model)
|
163 |
+
|
164 |
+
self._update_tracking(epoch, train_loss, val_loss, val_metrics)
|
165 |
+
self._adjust_learning_rate(epoch)
|
166 |
+
self._save_checkpoints(model, epoch, val_metrics)
|
167 |
+
|
168 |
+
print(f'\nTraining completed. Best {self.metrics["target_class"]} IoU: {self.best_iou:.3f}')
|
169 |
+
return model, self.metrics
|
170 |
+
|
171 |
+
def _train_epoch(self, model):
|
172 |
+
"""Executes single training epoch."""
|
173 |
+
model.train()
|
174 |
+
total_loss = 0
|
175 |
+
sample_count = 0
|
176 |
+
|
177 |
+
for batch in tqdm(self.train_loader, desc='Training'):
|
178 |
+
images, masks = [x.to(self.device) for x in batch]
|
179 |
+
self.optimizer.zero_grad()
|
180 |
+
|
181 |
+
outputs = model(images)
|
182 |
+
loss = self.config.loss(outputs, masks)
|
183 |
+
loss.backward()
|
184 |
+
self.optimizer.step()
|
185 |
+
|
186 |
+
total_loss += loss.item() * len(images)
|
187 |
+
sample_count += len(images)
|
188 |
+
|
189 |
+
return total_loss / sample_count
|
190 |
+
|
191 |
+
def _validate_epoch(self, model):
|
192 |
+
"""Executes validation pass."""
|
193 |
+
model.eval()
|
194 |
+
total_loss = 0
|
195 |
+
predictions = []
|
196 |
+
ground_truth = []
|
197 |
+
|
198 |
+
with torch.no_grad():
|
199 |
+
for batch in tqdm(self.val_loader, desc='Validation'):
|
200 |
+
images, masks = [x.to(self.device) for x in batch]
|
201 |
+
outputs = model(images)
|
202 |
+
loss = self.config.loss(outputs, masks)
|
203 |
+
|
204 |
+
total_loss += loss.item()
|
205 |
+
|
206 |
+
if self.config.n_classes > 1:
|
207 |
+
predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
|
208 |
+
ground_truth.extend([m.cpu().argmax(dim=0) for m in masks])
|
209 |
+
else:
|
210 |
+
predictions.extend([(torch.sigmoid(p) > 0.5).float().squeeze().cpu()
|
211 |
+
for p in outputs])
|
212 |
+
ground_truth.extend([m.cpu().squeeze() for m in masks])
|
213 |
+
|
214 |
+
metrics = compute_mean_iou(
|
215 |
+
predictions,
|
216 |
+
ground_truth,
|
217 |
+
num_classes=len(self.classes),
|
218 |
+
ignore_index=255
|
219 |
+
)
|
220 |
+
|
221 |
+
return total_loss / len(self.val_loader), metrics
|
222 |
+
|
223 |
+
def _update_tracking(self, epoch, train_loss, val_loss, val_metrics):
|
224 |
+
"""Updates training metrics and logging."""
|
225 |
+
mean_iou = val_metrics['mean_iou']
|
226 |
+
print(f"\nLosses - Train: {train_loss:.3f}, Val: {val_loss:.3f}")
|
227 |
+
print(f"Mean IoU: {mean_iou:.3f}")
|
228 |
+
|
229 |
+
self.writer.add_scalar('Loss/train', train_loss, epoch)
|
230 |
+
self.writer.add_scalar('Loss/val', val_loss, epoch)
|
231 |
+
self.writer.add_scalar('IoU/mean', mean_iou, epoch)
|
232 |
+
|
233 |
+
for idx, iou in enumerate(val_metrics['per_category_iou']):
|
234 |
+
print(f"{self.classes[idx]} IoU: {iou:.3f}")
|
235 |
+
self.writer.add_scalar(f'IoU/{self.classes[idx]}', iou, epoch)
|
236 |
+
|
237 |
+
def _adjust_learning_rate(self, epoch):
|
238 |
+
"""Adjusts learning rate according to schedule."""
|
239 |
+
if epoch in self.step_schedule:
|
240 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
241 |
+
new_lr = current_lr * self.decay_factor
|
242 |
+
for param_group in self.optimizer.param_groups:
|
243 |
+
param_group['lr'] = new_lr
|
244 |
+
print(f'\nDecreased learning rate: {current_lr:.3e} -> {new_lr:.3e}')
|
245 |
+
|
246 |
+
def _save_checkpoints(self, model, epoch, metrics):
|
247 |
+
"""Saves model checkpoints and metrics."""
|
248 |
+
epoch_iou = (metrics['mean_iou'] if self.target_class is None
|
249 |
+
else metrics['per_category_iou'][self.classes.index(self.target_class)])
|
250 |
+
|
251 |
+
self.metrics.update({
|
252 |
+
'last_epoch': epoch,
|
253 |
+
'last_epoch_iou': round(float(epoch_iou), 3),
|
254 |
+
'last_epoch_lr': self.optimizer.param_groups[0]['lr']
|
255 |
+
})
|
256 |
+
|
257 |
+
if epoch_iou > self.best_iou:
|
258 |
+
self.best_iou = epoch_iou
|
259 |
+
self.best_epoch = epoch
|
260 |
+
self.metrics.update({
|
261 |
+
'best_epoch': epoch,
|
262 |
+
'best_epoch_iou': round(float(epoch_iou), 3),
|
263 |
+
'overall_iou': round(float(metrics['mean_iou']), 3)
|
264 |
+
})
|
265 |
+
torch.save(model, os.path.join(self.output_dir, 'best_model.pth'))
|
266 |
+
print(f'New best model saved (IoU: {epoch_iou:.3f})')
|
267 |
+
|
268 |
+
torch.save(model, os.path.join(self.output_dir, 'last_model.pth'))
|
269 |
+
with open(os.path.join(self.output_dir, 'metrics.json'), 'w') as f:
|
270 |
+
json.dump(self.metrics, f, indent=4)
|
271 |
+
|
272 |
+
def _save_config(self):
|
273 |
+
"""Saves training configuration."""
|
274 |
+
config = {
|
275 |
+
**self.config.config_data,
|
276 |
+
'train_size': self.train_size,
|
277 |
+
'val_size': self.val_size,
|
278 |
+
'epochs': self.epochs,
|
279 |
+
'batch_size': self.batch_size,
|
280 |
+
'optimizer': self.optimizer_type,
|
281 |
+
'workers': self.workers,
|
282 |
+
'target_class': self.target_class or 'overall'
|
283 |
+
}
|
284 |
+
|
285 |
+
with open(os.path.join(self.output_dir, 'config.json'), 'w') as f:
|
286 |
+
json.dump(config, f, indent=4)
|
287 |
+
|
288 |
+
def _resume_training(self, resume_path):
|
289 |
+
"""Resumes training from checkpoint."""
|
290 |
+
if not os.path.exists(resume_path):
|
291 |
+
raise FileNotFoundError(f"Resume path not found: {resume_path}")
|
292 |
+
|
293 |
+
required_files = {
|
294 |
+
'model': 'last_model.pth',
|
295 |
+
'metrics': 'metrics.json',
|
296 |
+
'config': 'config.json'
|
297 |
+
}
|
298 |
+
|
299 |
+
paths = {k: os.path.join(resume_path, v) for k, v in required_files.items()}
|
300 |
+
if not all(os.path.exists(p) for p in paths.values()):
|
301 |
+
raise FileNotFoundError("Missing required checkpoint files")
|
302 |
+
|
303 |
+
with open(paths['config']) as f:
|
304 |
+
config = json.load(f)
|
305 |
+
with open(paths['metrics']) as f:
|
306 |
+
metrics = json.load(f)
|
307 |
+
|
308 |
+
self.current_epoch = metrics['last_epoch'] + 1
|
309 |
+
self.best_iou = metrics['best_epoch_iou']
|
310 |
+
self.best_epoch = metrics['best_epoch']
|
311 |
+
self.learning_rate = metrics['last_epoch_lr']
|
312 |
+
|
313 |
+
print(f'Resuming training from epoch {self.current_epoch}')
|
semantic-segmentation/SemanticModel/utilities.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import shutil
|
4 |
+
import imageio
|
5 |
+
import numpy as np
|
6 |
+
from glob import glob
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List, Tuple, Optional
|
9 |
+
|
10 |
+
def validate_dimensions(width: int, height: int, stride: int = 32) -> Tuple[int, int]:
|
11 |
+
if height % stride != 0 or width % stride != 0:
|
12 |
+
new_height = ((height // stride + 1) * stride
|
13 |
+
if height % stride != 0 else height)
|
14 |
+
new_width = ((width // stride + 1) * stride
|
15 |
+
if width % stride != 0 else width)
|
16 |
+
print(f'Adjusted dimensions to: {new_height}H x {new_width}W')
|
17 |
+
return width, height
|
18 |
+
|
19 |
+
def calc_image_size(image: np.ndarray, target_size: int) -> Tuple[int, int]:
|
20 |
+
height, width = image.shape[:2]
|
21 |
+
aspect_ratio = width / height
|
22 |
+
|
23 |
+
if aspect_ratio >= 1:
|
24 |
+
new_width = target_size
|
25 |
+
new_height = int(target_size / aspect_ratio)
|
26 |
+
else:
|
27 |
+
new_height = target_size
|
28 |
+
new_width = int(target_size * aspect_ratio)
|
29 |
+
|
30 |
+
return validate_dimensions(new_width, new_height)
|
31 |
+
|
32 |
+
def convert_coordinates(transform: np.ndarray, x: float, y: float) -> Tuple[float, float]:
|
33 |
+
transformed = transform @ np.array([x, y, 1])
|
34 |
+
return transformed[0], transformed[1]
|
35 |
+
|
36 |
+
def list_images(directory: str, mask_format: bool = False) -> List[str]:
|
37 |
+
extensions = ['*.png', '*.PNG'] if mask_format else [
|
38 |
+
'*.jpg', '*.jpeg', '*.png', '*.tif', '*.tiff',
|
39 |
+
'*.JPG', '*.JPEG', '*.PNG', '*.TIF', '*.TIFF'
|
40 |
+
]
|
41 |
+
|
42 |
+
image_paths = []
|
43 |
+
for ext in extensions:
|
44 |
+
image_paths.extend(glob(os.path.join(directory, ext)))
|
45 |
+
|
46 |
+
return sorted(list(set(image_paths)))
|
47 |
+
|
48 |
+
def prepare_dataset_split(root_dir: str, train_ratio: float = 0.7,
|
49 |
+
generate_empty_masks: bool = False) -> None:
|
50 |
+
image_dir = os.path.join(root_dir, 'Images')
|
51 |
+
mask_dir = os.path.join(root_dir, 'Masks')
|
52 |
+
|
53 |
+
if not all(os.path.exists(d) for d in [image_dir, mask_dir]):
|
54 |
+
raise Exception("Required 'Images' and 'Masks' directories not found")
|
55 |
+
|
56 |
+
image_paths = np.array(list_images(image_dir))
|
57 |
+
mask_paths = np.array(list_images(mask_dir, mask_format=True))
|
58 |
+
|
59 |
+
if generate_empty_masks:
|
60 |
+
temp_dir = os.path.join(mask_dir, 'temp')
|
61 |
+
create_empty_masks(image_dir, outdir=temp_dir)
|
62 |
+
|
63 |
+
for mask_path in list_images(temp_dir, mask_format=True):
|
64 |
+
target_path = os.path.join(mask_dir, os.path.basename(mask_path))
|
65 |
+
if not os.path.exists(target_path):
|
66 |
+
shutil.move(mask_path, target_path)
|
67 |
+
|
68 |
+
shutil.rmtree(temp_dir)
|
69 |
+
mask_paths = np.array(list_images(mask_dir, mask_format=True))
|
70 |
+
|
71 |
+
if len(image_paths) != len(mask_paths):
|
72 |
+
raise Exception(f"Unmatched images ({len(image_paths)}) and masks ({len(mask_paths)})")
|
73 |
+
|
74 |
+
train_ratio = float(train_ratio)
|
75 |
+
if not (0 < train_ratio <= 1):
|
76 |
+
raise ValueError(f"Invalid train ratio: {train_ratio}")
|
77 |
+
|
78 |
+
train_size = int(np.floor(train_ratio * len(image_paths)))
|
79 |
+
indices = np.random.permutation(len(image_paths))
|
80 |
+
|
81 |
+
splits = {
|
82 |
+
'train': {'indices': indices[:train_size]},
|
83 |
+
'val': {'indices': indices[train_size:]} if train_ratio < 1 else None
|
84 |
+
}
|
85 |
+
|
86 |
+
for split_name, split_data in splits.items():
|
87 |
+
if split_data is None:
|
88 |
+
continue
|
89 |
+
|
90 |
+
split_dir = os.path.join(root_dir, split_name)
|
91 |
+
for subdir in ['Images', 'Masks']:
|
92 |
+
subdir_path = os.path.join(split_dir, subdir)
|
93 |
+
os.makedirs(subdir_path, exist_ok=True)
|
94 |
+
|
95 |
+
sources = image_paths if subdir == 'Images' else mask_paths
|
96 |
+
for idx in split_data['indices']:
|
97 |
+
source = sources[idx]
|
98 |
+
destination = os.path.join(subdir_path, os.path.basename(source))
|
99 |
+
shutil.copyfile(source, destination)
|
100 |
+
|
101 |
+
print(f"Created {split_name} split with {len(split_data['indices'])} samples")
|
102 |
+
|
103 |
+
def create_empty_masks(image_dir: str, pixel_value: int = 0,
|
104 |
+
outdir: Optional[str] = None) -> str:
|
105 |
+
outdir = outdir or os.path.join(image_dir, 'Masks')
|
106 |
+
os.makedirs(outdir, exist_ok=True)
|
107 |
+
|
108 |
+
image_paths = list_images(image_dir)
|
109 |
+
print(f"Generating {len(image_paths)} empty masks...")
|
110 |
+
|
111 |
+
for image_path in image_paths:
|
112 |
+
image = imageio.imread(image_path)
|
113 |
+
mask = np.full((image.shape[0], image.shape[1]), pixel_value, dtype='uint8')
|
114 |
+
|
115 |
+
output_path = os.path.join(outdir,
|
116 |
+
f"{Path(image_path).stem}.png")
|
117 |
+
imageio.imwrite(output_path, mask)
|
118 |
+
|
119 |
+
return outdir
|
semantic-segmentation/SemanticModel/visualization.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def plot_predictions(model, images, masks, device, num_samples=4):
|
7 |
+
"""Visualize model predictions against ground truth."""
|
8 |
+
with torch.no_grad():
|
9 |
+
model.eval()
|
10 |
+
predictions = model.predict(images.to(device))
|
11 |
+
|
12 |
+
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
|
13 |
+
|
14 |
+
for idx in range(num_samples):
|
15 |
+
# Original image
|
16 |
+
img = images[idx].permute(1, 2, 0).cpu().numpy()
|
17 |
+
axes[idx, 0].imshow(img)
|
18 |
+
axes[idx, 0].set_title('Original Image')
|
19 |
+
|
20 |
+
# Ground truth
|
21 |
+
truth = masks[idx].argmax(dim=0).cpu().numpy()
|
22 |
+
axes[idx, 1].imshow(truth, cmap='tab20')
|
23 |
+
axes[idx, 1].set_title('Ground Truth')
|
24 |
+
|
25 |
+
# Prediction
|
26 |
+
pred = predictions[idx].argmax(dim=0).cpu().numpy()
|
27 |
+
axes[idx, 2].imshow(pred, cmap='tab20')
|
28 |
+
axes[idx, 2].set_title('Prediction')
|
29 |
+
|
30 |
+
for ax in axes[idx]:
|
31 |
+
ax.axis('off')
|
32 |
+
|
33 |
+
plt.tight_layout()
|
34 |
+
return fig
|
35 |
+
|
36 |
+
def create_overlay_mask(image, mask, alpha=0.5, color_map=None):
|
37 |
+
"""Create transparent overlay of segmentation mask on image."""
|
38 |
+
if color_map is None:
|
39 |
+
color_map = {
|
40 |
+
0: [0, 0, 0], # background
|
41 |
+
1: [255, 0, 0], # class 1 (red)
|
42 |
+
2: [0, 255, 0], # class 2 (green)
|
43 |
+
3: [0, 0, 255], # class 3 (blue)
|
44 |
+
}
|
45 |
+
|
46 |
+
overlay = image.copy()
|
47 |
+
mask_colored = np.zeros_like(image)
|
48 |
+
|
49 |
+
for label, color in color_map.items():
|
50 |
+
mask_colored[mask == label] = color
|
51 |
+
|
52 |
+
cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
|
53 |
+
return overlay
|
54 |
+
|
55 |
+
def plot_training_history(history):
|
56 |
+
"""Plot training and validation metrics."""
|
57 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
58 |
+
|
59 |
+
# Loss plot
|
60 |
+
ax1.plot(history['train_loss'], label='Training Loss')
|
61 |
+
ax1.plot(history['val_loss'], label='Validation Loss')
|
62 |
+
ax1.set_xlabel('Epoch')
|
63 |
+
ax1.set_ylabel('Loss')
|
64 |
+
ax1.set_title('Training and Validation Loss')
|
65 |
+
ax1.legend()
|
66 |
+
|
67 |
+
# IoU plot
|
68 |
+
ax2.plot(history['mean_iou'], label='Mean IoU')
|
69 |
+
for class_name, ious in history['class_ious'].items():
|
70 |
+
ax2.plot(ious, label=f'{class_name} IoU')
|
71 |
+
ax2.set_xlabel('Epoch')
|
72 |
+
ax2.set_ylabel('IoU')
|
73 |
+
ax2.set_title('IoU Metrics')
|
74 |
+
ax2.legend()
|
75 |
+
|
76 |
+
plt.tight_layout()
|
77 |
+
return fig
|
78 |
+
|
79 |
+
def visualize_predictions_on_batch(model, batch_images, batch_size=8):
|
80 |
+
"""Create grid visualization for a batch of predictions."""
|
81 |
+
with torch.no_grad():
|
82 |
+
predictions = model.predict(batch_images)
|
83 |
+
|
84 |
+
fig = plt.figure(figsize=(15, 5))
|
85 |
+
for idx in range(min(batch_size, len(batch_images))):
|
86 |
+
plt.subplot(2, 4, idx + 1)
|
87 |
+
img = batch_images[idx].permute(1, 2, 0).cpu().numpy()
|
88 |
+
mask = predictions[idx].argmax(dim=0).cpu().numpy()
|
89 |
+
overlay = create_overlay_mask(img, mask)
|
90 |
+
plt.imshow(overlay)
|
91 |
+
plt.axis('off')
|
92 |
+
|
93 |
+
plt.tight_layout()
|
94 |
+
return fig
|
95 |
+
|
96 |
+
def save_visualization(fig, save_path):
|
97 |
+
"""Save visualization figure."""
|
98 |
+
fig.savefig(save_path, bbox_inches='tight', dpi=300)
|
99 |
+
plt.close(fig)
|
100 |
+
|
101 |
+
def generate_color_mapping(num_classes):
|
102 |
+
"""Generate distinct colors for segmentation classes."""
|
103 |
+
colors = [
|
104 |
+
[0, 0, 0], # Background (black)
|
105 |
+
[255, 0, 0], # Red
|
106 |
+
[0, 255, 0], # Green
|
107 |
+
[0, 0, 255], # Blue
|
108 |
+
[255, 255, 0], # Yellow
|
109 |
+
[255, 0, 255], # Magenta
|
110 |
+
[0, 255, 255], # Cyan
|
111 |
+
[128, 0, 0], # Dark Red
|
112 |
+
[0, 128, 0], # Dark Green
|
113 |
+
[0, 0, 128] # Dark Blue
|
114 |
+
]
|
115 |
+
return colors[:num_classes]
|
setup.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# setup.py
|
2 |
+
from setuptools import setup, find_packages
|
3 |
+
|
4 |
+
setup(
|
5 |
+
name="SemanticModel",
|
6 |
+
version="0.1.0",
|
7 |
+
description="Deep learning framework for semantic segmentation",
|
8 |
+
author="Your Name",
|
9 |
+
packages=find_packages(),
|
10 |
+
python_requires=">=3.8",
|
11 |
+
install_requires=[
|
12 |
+
'torch',
|
13 |
+
'torchvision',
|
14 |
+
'tensorboard',
|
15 |
+
'pyproj',
|
16 |
+
'fiona==1.8.20',
|
17 |
+
'rtree',
|
18 |
+
'geopandas',
|
19 |
+
'rasterio',
|
20 |
+
'slidingwindow',
|
21 |
+
'opencv-python',
|
22 |
+
'wandb',
|
23 |
+
'tifffile',
|
24 |
+
'imagecodecs',
|
25 |
+
'albumentations',
|
26 |
+
'segmentation-models-pytorch>=0.3.3'
|
27 |
+
],
|
28 |
+
classifiers=[
|
29 |
+
"Development Status :: 3 - Alpha",
|
30 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
31 |
+
"License :: OSI Approved :: MIT License",
|
32 |
+
"Programming Language :: Python :: 3.8",
|
33 |
+
],
|
34 |
+
)
|