obichimav commited on
Commit
8e5d8c7
·
verified ·
1 Parent(s): 062ecd1

Upload 42 files

Browse files

Pytorch Segmentation pipeline. use to train/predict binary/multi/raster images

Files changed (42) hide show
  1. README.md +52 -0
  2. examples/.ipynb_checkpoints/predict-checkpoint.ipynb +255 -0
  3. examples/.ipynb_checkpoints/train-checkpoint.ipynb +113 -0
  4. examples/predict.ipynb +255 -0
  5. examples/train.ipynb +113 -0
  6. requirements.txt +15 -0
  7. semantic-segmentation/SemanticModel/.ipynb_checkpoints/custom_losses-checkpoint.py +97 -0
  8. semantic-segmentation/SemanticModel/.ipynb_checkpoints/data_loader-checkpoint.py +129 -0
  9. semantic-segmentation/SemanticModel/.ipynb_checkpoints/encoder_management-checkpoint.py +136 -0
  10. semantic-segmentation/SemanticModel/.ipynb_checkpoints/evaluation_utils-checkpoint.py +108 -0
  11. semantic-segmentation/SemanticModel/.ipynb_checkpoints/image_preprocessing-checkpoint.py +81 -0
  12. semantic-segmentation/SemanticModel/.ipynb_checkpoints/metrics-checkpoint.py +94 -0
  13. semantic-segmentation/SemanticModel/.ipynb_checkpoints/model_core-checkpoint.py +129 -0
  14. semantic-segmentation/SemanticModel/.ipynb_checkpoints/prediction-checkpoint.py +336 -0
  15. semantic-segmentation/SemanticModel/.ipynb_checkpoints/training-checkpoint.py +313 -0
  16. semantic-segmentation/SemanticModel/.ipynb_checkpoints/utilities-checkpoint.py +119 -0
  17. semantic-segmentation/SemanticModel/.ipynb_checkpoints/visualization-checkpoint.py +115 -0
  18. semantic-segmentation/SemanticModel/__init__.py +0 -0
  19. semantic-segmentation/SemanticModel/__pycache__/__init__.cpython-38.pyc +0 -0
  20. semantic-segmentation/SemanticModel/__pycache__/custom_losses.cpython-38.pyc +0 -0
  21. semantic-segmentation/SemanticModel/__pycache__/data_loader.cpython-38.pyc +0 -0
  22. semantic-segmentation/SemanticModel/__pycache__/encoder_management.cpython-38.pyc +0 -0
  23. semantic-segmentation/SemanticModel/__pycache__/evaluation_utils.cpython-38.pyc +0 -0
  24. semantic-segmentation/SemanticModel/__pycache__/image_preprocessing.cpython-38.pyc +0 -0
  25. semantic-segmentation/SemanticModel/__pycache__/metrics.cpython-38.pyc +0 -0
  26. semantic-segmentation/SemanticModel/__pycache__/model_core.cpython-38.pyc +0 -0
  27. semantic-segmentation/SemanticModel/__pycache__/prediction.cpython-38.pyc +0 -0
  28. semantic-segmentation/SemanticModel/__pycache__/training.cpython-38.pyc +0 -0
  29. semantic-segmentation/SemanticModel/__pycache__/utilities.cpython-38.pyc +0 -0
  30. semantic-segmentation/SemanticModel/__pycache__/visualization.cpython-38.pyc +0 -0
  31. semantic-segmentation/SemanticModel/custom_losses.py +97 -0
  32. semantic-segmentation/SemanticModel/data_loader.py +129 -0
  33. semantic-segmentation/SemanticModel/encoder_management.py +136 -0
  34. semantic-segmentation/SemanticModel/evaluation_utils.py +108 -0
  35. semantic-segmentation/SemanticModel/image_preprocessing.py +81 -0
  36. semantic-segmentation/SemanticModel/metrics.py +94 -0
  37. semantic-segmentation/SemanticModel/model_core.py +129 -0
  38. semantic-segmentation/SemanticModel/prediction.py +336 -0
  39. semantic-segmentation/SemanticModel/training.py +313 -0
  40. semantic-segmentation/SemanticModel/utilities.py +119 -0
  41. semantic-segmentation/SemanticModel/visualization.py +115 -0
  42. setup.py +34 -0
README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SemanticModel
2
+
3
+ Deep learning framework for semantic segmentation using PyTorch.
4
+
5
+ ## Install
6
+ ```bash
7
+ pip install -r requirements.txt
8
+ python setup.py install
9
+ ```
10
+
11
+ ## Usage
12
+ ```python
13
+ from SemanticModel.model_core import SegmentationModel
14
+ from SemanticModel.prediction import PredictionPipeline
15
+
16
+ # Train
17
+ model = SegmentationModel(
18
+ classes=['background', 'object'],
19
+ architecture='unet',
20
+ encoder='timm-regnety_120'
21
+ )
22
+
23
+ trainer = ModelTrainer(
24
+ model_config=model,
25
+ root_dir='path/to/dataset',
26
+ epochs=40
27
+ )
28
+ model, metrics = trainer.train()
29
+
30
+ # Predict
31
+ predictor = PredictionPipeline(model)
32
+ predictor.predict_single_image('image.jpg')
33
+ predictor.predict_directory('image_dir/')
34
+ predictor.predict_raster('raster.tif')
35
+
36
+ # Load pretrained
37
+ model = SegmentationModel(
38
+ classes=['background', 'object'],
39
+ weights='path/to/best_model.pth'
40
+ )
41
+ ```
42
+
43
+ ## Data Structure
44
+ ```
45
+ dataset/
46
+ ├── train/
47
+ │ ├── Images/
48
+ │ └── Masks/
49
+ └── val/
50
+ ├── Images/
51
+ └── Masks/
52
+ ```
examples/.ipynb_checkpoints/predict-checkpoint.ipynb ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "42e4027f",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/home/jovyan/shared/Chima/ml_project/repos/MYSMP\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "%cd \"../../MYSMP\""
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "ef6ea33e",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# %pip install -r requirements.txt"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 3,
34
+ "id": "30812aed",
35
+ "metadata": {},
36
+ "outputs": [
37
+ {
38
+ "name": "stdout",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "/home/jovyan/shared/Chima/ml_project/repos/MYSMP/semantic-segmentation\n"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "%cd \"../MYSMP/semantic-segmentation\""
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 5,
52
+ "id": "aaa4a036",
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "name": "stdout",
57
+ "output_type": "stream",
58
+ "text": [
59
+ "/home/jovyan/shared/Chima/ml_project/repos/MYSMP/semantic-segmentation\n"
60
+ ]
61
+ }
62
+ ],
63
+ "source": [
64
+ "!pwd"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 6,
70
+ "id": "099239c4",
71
+ "metadata": {},
72
+ "outputs": [
73
+ {
74
+ "name": "stdout",
75
+ "output_type": "stream",
76
+ "text": [
77
+ "Loading pretrained model...\n"
78
+ ]
79
+ },
80
+ {
81
+ "name": "stderr",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "/srv/conda/envs/notebook/lib/python3.8/site-packages/segmentation_models_pytorch/base/modules.py:116: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
85
+ " return self.activation(x)\n"
86
+ ]
87
+ }
88
+ ],
89
+ "source": [
90
+ "# Initialize the model\n",
91
+ "from SemanticModel.model_core import SegmentationModel\n",
92
+ "\n",
93
+ "model = SegmentationModel(\n",
94
+ " classes=['bg', 'cacao', 'matarraton', 'abarco'],\n",
95
+ " architecture='unet',\n",
96
+ " encoder='timm-regnety_120',\n",
97
+ " weights='../data/model_outputs-unet[timm-regnety_120]-01-23-2025_075803/best_model.pth'\n",
98
+ ")\n",
99
+ "\n",
100
+ "# Initialize prediction pipeline\n",
101
+ "from SemanticModel.prediction import PredictionPipeline\n",
102
+ "\n",
103
+ "predictor = PredictionPipeline(model)\n",
104
+ "output_dir='../predictions'\n",
105
+ "image_path= '../data/Images/2019-Mission2-odm_1_42.jpg'\n",
106
+ "\n",
107
+ "# Make prediction\n",
108
+ "prediction = predictor.predict_single_image(image_path=image_path,output_dir=output_dir)"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 7,
114
+ "id": "e0ae4c21",
115
+ "metadata": {},
116
+ "outputs": [
117
+ {
118
+ "name": "stdout",
119
+ "output_type": "stream",
120
+ "text": [
121
+ "\n",
122
+ "Predictions saved to: path/to/folderofImages/predictions\n"
123
+ ]
124
+ },
125
+ {
126
+ "data": {
127
+ "text/plain": [
128
+ "'path/to/folderofImages/predictions'"
129
+ ]
130
+ },
131
+ "execution_count": 7,
132
+ "metadata": {},
133
+ "output_type": "execute_result"
134
+ }
135
+ ],
136
+ "source": [
137
+ "# Directory of images\n",
138
+ "image_path= 'path/to/folderofImages'\n",
139
+ "predictor.predict_directory(image_path)"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 8,
145
+ "id": "fc4f2a48",
146
+ "metadata": {},
147
+ "outputs": [
148
+ {
149
+ "name": "stdout",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "Loading raster...\n",
153
+ "Processed 6/6 tiles\n",
154
+ "Prediction saved to: ../predictions/prediction.tif\n"
155
+ ]
156
+ },
157
+ {
158
+ "data": {
159
+ "text/plain": [
160
+ "(array([[[0, 0, 0],\n",
161
+ " [0, 0, 0],\n",
162
+ " [0, 0, 0],\n",
163
+ " ...,\n",
164
+ " [0, 0, 0],\n",
165
+ " [0, 0, 0],\n",
166
+ " [0, 0, 0]],\n",
167
+ " \n",
168
+ " [[0, 0, 0],\n",
169
+ " [0, 0, 0],\n",
170
+ " [0, 0, 0],\n",
171
+ " ...,\n",
172
+ " [0, 0, 0],\n",
173
+ " [0, 0, 0],\n",
174
+ " [0, 0, 0]],\n",
175
+ " \n",
176
+ " [[0, 0, 0],\n",
177
+ " [0, 0, 0],\n",
178
+ " [0, 0, 0],\n",
179
+ " ...,\n",
180
+ " [0, 0, 0],\n",
181
+ " [0, 0, 0],\n",
182
+ " [0, 0, 0]],\n",
183
+ " \n",
184
+ " ...,\n",
185
+ " \n",
186
+ " [[0, 0, 0],\n",
187
+ " [0, 0, 0],\n",
188
+ " [0, 0, 0],\n",
189
+ " ...,\n",
190
+ " [0, 0, 0],\n",
191
+ " [0, 0, 0],\n",
192
+ " [0, 0, 0]],\n",
193
+ " \n",
194
+ " [[0, 0, 0],\n",
195
+ " [0, 0, 0],\n",
196
+ " [0, 0, 0],\n",
197
+ " ...,\n",
198
+ " [0, 0, 0],\n",
199
+ " [0, 0, 0],\n",
200
+ " [0, 0, 0]],\n",
201
+ " \n",
202
+ " [[0, 0, 0],\n",
203
+ " [0, 0, 0],\n",
204
+ " [0, 0, 0],\n",
205
+ " ...,\n",
206
+ " [0, 0, 0],\n",
207
+ " [0, 0, 0],\n",
208
+ " [0, 0, 0]]], dtype=uint8),\n",
209
+ " {'driver': 'GTiff', 'dtype': 'uint8', 'nodata': None, 'width': 2365, 'height': 1797, 'count': 3, 'crs': CRS.from_epsg(32618), 'transform': Affine(0.03564594364277113, 0.0, 740295.5186183113,\n",
210
+ " 0.0, -0.03564594364276106, 485117.0212715292), 'tiled': False, 'interleave': 'pixel'})"
211
+ ]
212
+ },
213
+ "execution_count": 8,
214
+ "metadata": {},
215
+ "output_type": "execute_result"
216
+ }
217
+ ],
218
+ "source": [
219
+ "# Large raster\n",
220
+ "output_path='../predictions/prediction.tif'\n",
221
+ "raster_path = '../data/2021-Mission7_clipped_2.tif'\n",
222
+ "predictor.predict_raster(raster_path, tile_size=1024,output_path=output_path,format='color')"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "id": "8e4ab5a5",
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": []
232
+ }
233
+ ],
234
+ "metadata": {
235
+ "kernelspec": {
236
+ "display_name": "AgLab - Python 3",
237
+ "language": "python",
238
+ "name": "python3"
239
+ },
240
+ "language_info": {
241
+ "codemirror_mode": {
242
+ "name": "ipython",
243
+ "version": 3
244
+ },
245
+ "file_extension": ".py",
246
+ "mimetype": "text/x-python",
247
+ "name": "python",
248
+ "nbconvert_exporter": "python",
249
+ "pygments_lexer": "ipython3",
250
+ "version": "3.8.13"
251
+ }
252
+ },
253
+ "nbformat": 4,
254
+ "nbformat_minor": 5
255
+ }
examples/.ipynb_checkpoints/train-checkpoint.ipynb ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "333ede5f",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%cd \"../../MYSMP\""
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "84d4c945",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "# %pip install -r requirements.txt"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "7b088ef9",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "%cd \"../MYSMP/semantic-segmentation\""
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "99937292",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "from SemanticModel.model_core import SegmentationModel\n",
41
+ "from SemanticModel.training import ModelTrainer"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "ab69a291",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "# initialization loss function\n",
52
+ "model = SegmentationModel(\n",
53
+ " classes=['bg', 'cacao', 'matarraton', 'abarco'],\n",
54
+ " architecture='unet',\n",
55
+ " encoder='timm-regnety_120',\n",
56
+ " weights='imagenet',\n",
57
+ " loss='dice' # Try 'dice' or 'tversky' instead of default\n",
58
+ ")\n",
59
+ "\n",
60
+ "# training parameters\n",
61
+ "trainer = ModelTrainer(\n",
62
+ " model_config=model,\n",
63
+ " root_dir='../data',\n",
64
+ " epochs=100,\n",
65
+ " train_size=1024,\n",
66
+ " batch_size=4,\n",
67
+ " learning_rate=1e-3, # Increased learning rate\n",
68
+ " step_count=3, # More learning rate adjustments\n",
69
+ " decay_factor=0.5 # Stronger decay\n",
70
+ ")"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "38fc7c6f",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "trained_model, metrics = trainer.train()"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "id": "a053c2ae",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": []
90
+ }
91
+ ],
92
+ "metadata": {
93
+ "kernelspec": {
94
+ "display_name": "AgLab - Python 3",
95
+ "language": "python",
96
+ "name": "python3"
97
+ },
98
+ "language_info": {
99
+ "codemirror_mode": {
100
+ "name": "ipython",
101
+ "version": 3
102
+ },
103
+ "file_extension": ".py",
104
+ "mimetype": "text/x-python",
105
+ "name": "python",
106
+ "nbconvert_exporter": "python",
107
+ "pygments_lexer": "ipython3",
108
+ "version": "3.8.13"
109
+ }
110
+ },
111
+ "nbformat": 4,
112
+ "nbformat_minor": 5
113
+ }
examples/predict.ipynb ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "42e4027f",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/home/jovyan/shared/Chima/ml_project/repos/MYSMP\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "%cd \"../../MYSMP\""
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "ef6ea33e",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# %pip install -r requirements.txt"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 3,
34
+ "id": "30812aed",
35
+ "metadata": {},
36
+ "outputs": [
37
+ {
38
+ "name": "stdout",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "/home/jovyan/shared/Chima/ml_project/repos/MYSMP/semantic-segmentation\n"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "%cd \"../MYSMP/semantic-segmentation\""
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 5,
52
+ "id": "aaa4a036",
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "name": "stdout",
57
+ "output_type": "stream",
58
+ "text": [
59
+ "/home/jovyan/shared/Chima/ml_project/repos/MYSMP/semantic-segmentation\n"
60
+ ]
61
+ }
62
+ ],
63
+ "source": [
64
+ "!pwd"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 6,
70
+ "id": "099239c4",
71
+ "metadata": {},
72
+ "outputs": [
73
+ {
74
+ "name": "stdout",
75
+ "output_type": "stream",
76
+ "text": [
77
+ "Loading pretrained model...\n"
78
+ ]
79
+ },
80
+ {
81
+ "name": "stderr",
82
+ "output_type": "stream",
83
+ "text": [
84
+ "/srv/conda/envs/notebook/lib/python3.8/site-packages/segmentation_models_pytorch/base/modules.py:116: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
85
+ " return self.activation(x)\n"
86
+ ]
87
+ }
88
+ ],
89
+ "source": [
90
+ "# Initialize the model\n",
91
+ "from SemanticModel.model_core import SegmentationModel\n",
92
+ "\n",
93
+ "model = SegmentationModel(\n",
94
+ " classes=['bg', 'cacao', 'matarraton', 'abarco'],\n",
95
+ " architecture='unet',\n",
96
+ " encoder='timm-regnety_120',\n",
97
+ " weights='../data/model_outputs-unet[timm-regnety_120]-01-23-2025_075803/best_model.pth'\n",
98
+ ")\n",
99
+ "\n",
100
+ "# Initialize prediction pipeline\n",
101
+ "from SemanticModel.prediction import PredictionPipeline\n",
102
+ "\n",
103
+ "predictor = PredictionPipeline(model)\n",
104
+ "output_dir='../predictions'\n",
105
+ "image_path= '../data/Images/2019-Mission2-odm_1_42.jpg'\n",
106
+ "\n",
107
+ "# Make prediction\n",
108
+ "prediction = predictor.predict_single_image(image_path=image_path,output_dir=output_dir)"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 7,
114
+ "id": "e0ae4c21",
115
+ "metadata": {},
116
+ "outputs": [
117
+ {
118
+ "name": "stdout",
119
+ "output_type": "stream",
120
+ "text": [
121
+ "\n",
122
+ "Predictions saved to: path/to/folderofImages/predictions\n"
123
+ ]
124
+ },
125
+ {
126
+ "data": {
127
+ "text/plain": [
128
+ "'path/to/folderofImages/predictions'"
129
+ ]
130
+ },
131
+ "execution_count": 7,
132
+ "metadata": {},
133
+ "output_type": "execute_result"
134
+ }
135
+ ],
136
+ "source": [
137
+ "# Directory of images\n",
138
+ "image_path= 'path/to/folderofImages'\n",
139
+ "predictor.predict_directory(image_path)"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 8,
145
+ "id": "fc4f2a48",
146
+ "metadata": {},
147
+ "outputs": [
148
+ {
149
+ "name": "stdout",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "Loading raster...\n",
153
+ "Processed 6/6 tiles\n",
154
+ "Prediction saved to: ../predictions/prediction.tif\n"
155
+ ]
156
+ },
157
+ {
158
+ "data": {
159
+ "text/plain": [
160
+ "(array([[[0, 0, 0],\n",
161
+ " [0, 0, 0],\n",
162
+ " [0, 0, 0],\n",
163
+ " ...,\n",
164
+ " [0, 0, 0],\n",
165
+ " [0, 0, 0],\n",
166
+ " [0, 0, 0]],\n",
167
+ " \n",
168
+ " [[0, 0, 0],\n",
169
+ " [0, 0, 0],\n",
170
+ " [0, 0, 0],\n",
171
+ " ...,\n",
172
+ " [0, 0, 0],\n",
173
+ " [0, 0, 0],\n",
174
+ " [0, 0, 0]],\n",
175
+ " \n",
176
+ " [[0, 0, 0],\n",
177
+ " [0, 0, 0],\n",
178
+ " [0, 0, 0],\n",
179
+ " ...,\n",
180
+ " [0, 0, 0],\n",
181
+ " [0, 0, 0],\n",
182
+ " [0, 0, 0]],\n",
183
+ " \n",
184
+ " ...,\n",
185
+ " \n",
186
+ " [[0, 0, 0],\n",
187
+ " [0, 0, 0],\n",
188
+ " [0, 0, 0],\n",
189
+ " ...,\n",
190
+ " [0, 0, 0],\n",
191
+ " [0, 0, 0],\n",
192
+ " [0, 0, 0]],\n",
193
+ " \n",
194
+ " [[0, 0, 0],\n",
195
+ " [0, 0, 0],\n",
196
+ " [0, 0, 0],\n",
197
+ " ...,\n",
198
+ " [0, 0, 0],\n",
199
+ " [0, 0, 0],\n",
200
+ " [0, 0, 0]],\n",
201
+ " \n",
202
+ " [[0, 0, 0],\n",
203
+ " [0, 0, 0],\n",
204
+ " [0, 0, 0],\n",
205
+ " ...,\n",
206
+ " [0, 0, 0],\n",
207
+ " [0, 0, 0],\n",
208
+ " [0, 0, 0]]], dtype=uint8),\n",
209
+ " {'driver': 'GTiff', 'dtype': 'uint8', 'nodata': None, 'width': 2365, 'height': 1797, 'count': 3, 'crs': CRS.from_epsg(32618), 'transform': Affine(0.03564594364277113, 0.0, 740295.5186183113,\n",
210
+ " 0.0, -0.03564594364276106, 485117.0212715292), 'tiled': False, 'interleave': 'pixel'})"
211
+ ]
212
+ },
213
+ "execution_count": 8,
214
+ "metadata": {},
215
+ "output_type": "execute_result"
216
+ }
217
+ ],
218
+ "source": [
219
+ "# Large raster\n",
220
+ "output_path='../predictions/prediction.tif'\n",
221
+ "raster_path = '../data/2021-Mission7_clipped_2.tif'\n",
222
+ "predictor.predict_raster(raster_path, tile_size=1024,output_path=output_path,format='color')"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "id": "8e4ab5a5",
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": []
232
+ }
233
+ ],
234
+ "metadata": {
235
+ "kernelspec": {
236
+ "display_name": "AgLab - Python 3",
237
+ "language": "python",
238
+ "name": "python3"
239
+ },
240
+ "language_info": {
241
+ "codemirror_mode": {
242
+ "name": "ipython",
243
+ "version": 3
244
+ },
245
+ "file_extension": ".py",
246
+ "mimetype": "text/x-python",
247
+ "name": "python",
248
+ "nbconvert_exporter": "python",
249
+ "pygments_lexer": "ipython3",
250
+ "version": "3.8.13"
251
+ }
252
+ },
253
+ "nbformat": 4,
254
+ "nbformat_minor": 5
255
+ }
examples/train.ipynb ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "333ede5f",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%cd \"../../MYSMP\""
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "84d4c945",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "# %pip install -r requirements.txt"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "7b088ef9",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "%cd \"../MYSMP/semantic-segmentation\""
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "99937292",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "from SemanticModel.model_core import SegmentationModel\n",
41
+ "from SemanticModel.training import ModelTrainer"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "ab69a291",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "# initialization loss function\n",
52
+ "model = SegmentationModel(\n",
53
+ " classes=['bg', 'cacao', 'matarraton', 'abarco'],\n",
54
+ " architecture='unet',\n",
55
+ " encoder='timm-regnety_120',\n",
56
+ " weights='imagenet',\n",
57
+ " loss='dice' # Try 'dice' or 'tversky' instead of default\n",
58
+ ")\n",
59
+ "\n",
60
+ "# training parameters\n",
61
+ "trainer = ModelTrainer(\n",
62
+ " model_config=model,\n",
63
+ " root_dir='../data',\n",
64
+ " epochs=100,\n",
65
+ " train_size=1024,\n",
66
+ " batch_size=4,\n",
67
+ " learning_rate=1e-3, # Increased learning rate\n",
68
+ " step_count=3, # More learning rate adjustments\n",
69
+ " decay_factor=0.5 # Stronger decay\n",
70
+ ")"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "38fc7c6f",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "trained_model, metrics = trainer.train()"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "id": "a053c2ae",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": []
90
+ }
91
+ ],
92
+ "metadata": {
93
+ "kernelspec": {
94
+ "display_name": "AgLab - Python 3",
95
+ "language": "python",
96
+ "name": "python3"
97
+ },
98
+ "language_info": {
99
+ "codemirror_mode": {
100
+ "name": "ipython",
101
+ "version": 3
102
+ },
103
+ "file_extension": ".py",
104
+ "mimetype": "text/x-python",
105
+ "name": "python",
106
+ "nbconvert_exporter": "python",
107
+ "pygments_lexer": "ipython3",
108
+ "version": "3.8.13"
109
+ }
110
+ },
111
+ "nbformat": 4,
112
+ "nbformat_minor": 5
113
+ }
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ tensorboard
4
+ pyproj
5
+ fiona==1.8.20
6
+ rtree
7
+ geopandas
8
+ rasterio
9
+ slidingwindow
10
+ opencv-python
11
+ wandb
12
+ tifffile
13
+ imagecodecs
14
+ albumentations
15
+ segmentation-models-pytorch>=0.3.3
semantic-segmentation/SemanticModel/.ipynb_checkpoints/custom_losses-checkpoint.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from segmentation_models_pytorch.utils import base
4
+ from segmentation_models_pytorch.base.modules import Activation
5
+
6
+ class FocalLossFunction(base.Loss):
7
+ def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.activation = Activation(activation)
10
+ self.alpha = alpha
11
+ self.gamma = gamma
12
+ self.reduction = reduction
13
+
14
+ def forward(self, inputs, targets):
15
+ if inputs.shape[1] == 1: # Binary case
16
+ inputs = torch.cat((inputs, 1 - inputs), dim=1)
17
+ targets = torch.cat((targets, 1 - targets), dim=1)
18
+
19
+ targets = torch.argmax(targets, dim=1)
20
+ cross_entropy = F.cross_entropy(inputs, targets, reduction='none')
21
+ probability = torch.exp(-cross_entropy)
22
+ alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where(
23
+ targets == 1, 1-self.alpha, self.alpha)
24
+
25
+ focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy
26
+
27
+ if self.reduction == 'mean':
28
+ return focal_weight.mean()
29
+ elif self.reduction == 'sum':
30
+ return focal_weight.sum()
31
+ return focal_weight
32
+
33
+ class TverskyLossFunction(base.Loss):
34
+ def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None,
35
+ reduction='mean', **kwargs):
36
+ super().__init__(**kwargs)
37
+ self.activation = Activation(activation)
38
+ self.alpha = alpha
39
+ self.beta = beta
40
+ self.ignore_channels = ignore_channels
41
+ self.reduction = reduction
42
+
43
+ def forward(self, inputs, targets):
44
+ if self.ignore_channels is not None:
45
+ mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
46
+ mask[self.ignore_channels] = False
47
+ inputs = inputs[:, mask, ...]
48
+
49
+ num_classes = inputs.shape[1]
50
+ inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1
51
+ else F.softmax(inputs, dim=1))
52
+
53
+ if num_classes == 1:
54
+ inputs_softmax = inputs_softmax.squeeze(1)
55
+ targets = targets.squeeze(1)
56
+
57
+ tversky_loss = 0
58
+ for class_idx in range(num_classes):
59
+ if num_classes == 1:
60
+ flat_inputs = inputs_softmax.reshape(-1)
61
+ flat_targets = targets.reshape(-1)
62
+ else:
63
+ flat_inputs = inputs_softmax[:, class_idx].reshape(-1)
64
+ flat_targets = targets[:, class_idx].reshape(-1)
65
+
66
+ intersection = (flat_inputs * flat_targets).sum()
67
+ fps = ((1 - flat_targets) * flat_inputs).sum()
68
+ fns = (flat_targets * (1 - flat_inputs)).sum()
69
+
70
+ tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10
71
+ tversky_loss += 1 - intersection / tversky_index
72
+
73
+ if self.reduction == 'mean':
74
+ return tversky_loss / (1 if num_classes == 1 else num_classes)
75
+ elif self.reduction == 'sum':
76
+ return tversky_loss
77
+ return tversky_loss / inputs.shape[0]
78
+
79
+ class EnhancedCrossEntropy(base.Loss):
80
+ def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs):
81
+ super().__init__(**kwargs)
82
+ self.activation = Activation(activation)
83
+ self.ignore_channels = ignore_channels
84
+ self.reduction = reduction
85
+
86
+ def forward(self, inputs, targets):
87
+ inputs = self.activation(inputs)
88
+
89
+ if self.ignore_channels is not None:
90
+ mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
91
+ mask[self.ignore_channels] = False
92
+ inputs = inputs[:, mask, ...]
93
+
94
+ if targets.dim() == 4: # Convert one-hot to class indices
95
+ targets = torch.argmax(targets, dim=1)
96
+
97
+ return F.cross_entropy(inputs, targets, reduction=self.reduction)
semantic-segmentation/SemanticModel/.ipynb_checkpoints/data_loader-checkpoint.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from torch.utils.data import Dataset as BaseDataset
5
+
6
+ class SegmentationDataset(BaseDataset):
7
+ """Dataset class for semantic segmentation task."""
8
+
9
+ def __init__(self, data_dir, classes=['background', 'object'],
10
+ augmentation=None, preprocessing=None):
11
+
12
+ self.image_dir = os.path.join(data_dir, 'Images')
13
+ self.mask_dir = os.path.join(data_dir, 'Masks')
14
+
15
+ for dir_path in [self.image_dir, self.mask_dir]:
16
+ if not os.path.exists(dir_path):
17
+ raise FileNotFoundError(f"Directory not found: {dir_path}")
18
+
19
+ self.filenames = self._get_filenames()
20
+ self.image_paths = [os.path.join(self.image_dir, fname) for fname in self.filenames]
21
+ self.mask_paths = self._get_mask_paths()
22
+
23
+ self.target_classes = [cls for cls in classes if cls.lower() != 'background']
24
+ self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
25
+
26
+ self.augmentation = augmentation
27
+ self.preprocessing = preprocessing
28
+
29
+ def __getitem__(self, index):
30
+ image = self._load_image(self.image_paths[index])
31
+ mask = self._load_mask(self.mask_paths[index])
32
+
33
+ if self.augmentation:
34
+ processed = self.augmentation(image=image, mask=mask)
35
+ image, mask = processed['image'], processed['mask']
36
+
37
+ if self.preprocessing:
38
+ processed = self.preprocessing(image=image, mask=mask)
39
+ image, mask = processed['image'], processed['mask']
40
+
41
+ return image, mask
42
+
43
+ def __len__(self):
44
+ return len(self.filenames)
45
+
46
+ def _get_filenames(self):
47
+ """Returns sorted list of filenames, excluding directories."""
48
+ files = sorted(os.listdir(self.image_dir))
49
+ return [f for f in files if not os.path.isdir(os.path.join(self.image_dir, f))]
50
+
51
+ def _get_mask_paths(self):
52
+ """Generates corresponding mask paths for each image."""
53
+ mask_paths = []
54
+ for image_file in self.filenames:
55
+ name, _ = os.path.splitext(image_file)
56
+ mask_paths.append(os.path.join(self.mask_dir, f"{name}.png"))
57
+ return mask_paths
58
+
59
+ def _load_image(self, path):
60
+ """Loads and converts image to RGB."""
61
+ image = cv2.imread(path)
62
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
63
+
64
+ def _load_mask(self, path):
65
+ """Loads and processes segmentation mask."""
66
+ mask = cv2.imread(path, 0)
67
+ masks = [(mask == value) for value in self.class_values]
68
+ mask = np.stack(masks, axis=-1).astype('float')
69
+ return mask
70
+
71
+ class InferenceDataset(BaseDataset):
72
+ """Dataset class for inference without ground truth masks."""
73
+
74
+ def __init__(self, data_dir, classes=['background', 'object'],
75
+ augmentation=None, preprocessing=None):
76
+ self.filenames = sorted([
77
+ f for f in os.listdir(data_dir)
78
+ if not os.path.isdir(os.path.join(data_dir, f))
79
+ ])
80
+ self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
81
+
82
+ self.target_classes = [cls for cls in classes if cls.lower() != 'background']
83
+ self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
84
+
85
+ self.augmentation = augmentation
86
+ self.preprocessing = preprocessing
87
+
88
+ def __getitem__(self, index):
89
+ image = cv2.imread(self.image_paths[index])
90
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
91
+ original_height, original_width = image.shape[:2]
92
+
93
+ if self.augmentation:
94
+ image = self.augmentation(image=image)['image']
95
+
96
+ if self.preprocessing:
97
+ image = self.preprocessing(image=image)['image']
98
+
99
+ return image, original_height, original_width
100
+
101
+ def __len__(self):
102
+ return len(self.filenames)
103
+
104
+ class StreamingDataset(BaseDataset):
105
+ """Dataset class optimized for video frame processing."""
106
+
107
+ def __init__(self, data_dir, classes=['background', 'object'],
108
+ augmentation=None, preprocessing=None):
109
+ self.filenames = self._get_frame_filenames(data_dir)
110
+ self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
111
+
112
+ self.target_classes = [cls for cls in classes if cls.lower() != 'background']
113
+ self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
114
+
115
+ self.augmentation = augmentation
116
+ self.preprocessing = preprocessing
117
+
118
+ def _get_frame_filenames(self, directory):
119
+ """Returns sorted list of frame filenames."""
120
+ files = sorted(os.listdir(directory))
121
+ return [f for f in files if (('frame' in f or 'Image' in f) and
122
+ f.lower().endswith('jpg') and
123
+ not os.path.isdir(os.path.join(directory, f)))]
124
+
125
+ def __getitem__(self, index):
126
+ return InferenceDataset.__getitem__(self, index)
127
+
128
+ def __len__(self):
129
+ return len(self.filenames)
semantic-segmentation/SemanticModel/.ipynb_checkpoints/encoder_management-checkpoint.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ssl
3
+ import shutil
4
+ import tempfile
5
+ import hashlib
6
+ from tqdm import tqdm
7
+ from torch.hub import get_dir
8
+ from urllib.request import urlopen, Request
9
+
10
+ from segmentation_models_pytorch.encoders import (
11
+ resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
12
+ densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
13
+ efficient_net_encoders, mobilenet_encoders, xception_encoders,
14
+ timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
15
+ timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
16
+ timm_gernet_encoders
17
+ )
18
+
19
+ from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder
20
+
21
+ def initialize_encoders():
22
+ """Initialize dictionary of available encoders."""
23
+ available_encoders = {}
24
+ encoder_modules = [
25
+ resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
26
+ densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
27
+ efficient_net_encoders, mobilenet_encoders, xception_encoders,
28
+ timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
29
+ timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
30
+ timm_gernet_encoders
31
+ ]
32
+
33
+ for module in encoder_modules:
34
+ available_encoders.update(module)
35
+
36
+ try:
37
+ import segmentation_models_pytorch
38
+ from packaging import version
39
+ if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"):
40
+ from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders
41
+ from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders
42
+ available_encoders.update(mix_transformer_encoders)
43
+ available_encoders.update(mobileone_encoders)
44
+ except ImportError:
45
+ pass
46
+
47
+ return available_encoders
48
+
49
+ def download_weights(url, destination, hash_prefix=None, show_progress=True):
50
+ """Downloads model weights with progress tracking and verification."""
51
+ ssl._create_default_https_context = ssl._create_unverified_context
52
+
53
+ req = Request(url, headers={"User-Agent": "torch.hub"})
54
+ response = urlopen(req)
55
+ content_length = response.headers.get("Content-Length")
56
+ file_size = int(content_length[0]) if content_length else None
57
+
58
+ destination = os.path.expanduser(destination)
59
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination))
60
+
61
+ try:
62
+ hasher = hashlib.sha256() if hash_prefix else None
63
+
64
+ with tqdm(total=file_size, disable=not show_progress,
65
+ unit='B', unit_scale=True, unit_divisor=1024) as pbar:
66
+ while True:
67
+ buffer = response.read(8192)
68
+ if not buffer:
69
+ break
70
+
71
+ temp_file.write(buffer)
72
+ if hasher:
73
+ hasher.update(buffer)
74
+ pbar.update(len(buffer))
75
+
76
+ temp_file.close()
77
+
78
+ if hasher and hash_prefix:
79
+ digest = hasher.hexdigest()
80
+ if digest[:len(hash_prefix)] != hash_prefix:
81
+ raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")')
82
+
83
+ shutil.move(temp_file.name, destination)
84
+
85
+ finally:
86
+ temp_file.close()
87
+ if os.path.exists(temp_file.name):
88
+ os.remove(temp_file.name)
89
+
90
+ def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
91
+ """Initializes and returns configured encoder."""
92
+ encoders = initialize_encoders()
93
+
94
+ if name.startswith("tu-"):
95
+ name = name[3:]
96
+ return TimmUniversalEncoder(
97
+ name=name,
98
+ in_channels=in_channels,
99
+ depth=depth,
100
+ output_stride=output_stride,
101
+ pretrained=weights is not None,
102
+ **kwargs
103
+ )
104
+
105
+ try:
106
+ encoder_config = encoders[name]
107
+ except KeyError:
108
+ raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}")
109
+
110
+ encoder_class = encoder_config["encoder"]
111
+ encoder_params = encoder_config["params"]
112
+ encoder_params.update(depth=depth)
113
+
114
+ if weights:
115
+ try:
116
+ weights_config = encoder_config["pretrained_settings"][weights]
117
+ except KeyError:
118
+ raise KeyError(
119
+ f"Invalid weights '{weights}' for encoder '{name}'. "
120
+ f"Available options: {list(encoder_config['pretrained_settings'].keys())}"
121
+ )
122
+
123
+ cache_dir = os.path.join(get_dir(), 'checkpoints')
124
+ os.makedirs(cache_dir, exist_ok=True)
125
+
126
+ weights_file = os.path.basename(weights_config["url"])
127
+ weights_path = os.path.join(cache_dir, weights_file)
128
+
129
+ if not os.path.exists(weights_path):
130
+ print(f'Downloading {weights_file}...')
131
+ download_weights(
132
+ weights_config["url"].replace("https", "http"),
133
+ weights_path
134
+ )
135
+
136
+ return encoder_class(**encoder_params)
semantic-segmentation/SemanticModel/.ipynb_checkpoints/evaluation_utils-checkpoint.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ from tqdm import tqdm
5
+ from torch.utils.data import DataLoader
6
+ from segmentation_models_pytorch.base.modules import Activation
7
+
8
+ from SemanticModel.data_loader import SegmentationDataset
9
+ from SemanticModel.metrics import compute_mean_iou
10
+ from SemanticModel.image_preprocessing import get_validation_augmentations
11
+
12
+ def evaluate_model(model_config, data_path, image_size=None):
13
+ """Evaluates model performance on a dataset."""
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
17
+
18
+ data_path = os.path.realpath(data_path)
19
+ image_subdir = os.path.join(data_path, 'Images')
20
+ mask_subdir = os.path.join(data_path, 'Masks')
21
+
22
+ if not all(os.path.exists(d) for d in [image_subdir, mask_subdir]):
23
+ raise Exception("Missing required subdirectories: 'Images' and 'Masks'")
24
+
25
+ if not image_size:
26
+ sample_image = cv2.imread(os.path.join(image_subdir, os.listdir(image_subdir)[0]))
27
+ height, width = sample_image.shape[:2]
28
+ image_size = max(height, width)
29
+
30
+ evaluation_dataset = SegmentationDataset(
31
+ data_path,
32
+ classes=classes,
33
+ augmentation=get_validation_augmentations(
34
+ im_width=image_size,
35
+ im_height=image_size,
36
+ fixed_size=False
37
+ ),
38
+ preprocessing=model_config.preprocessing
39
+ )
40
+
41
+ evaluation_loader = DataLoader(
42
+ evaluation_dataset,
43
+ batch_size=1,
44
+ shuffle=False,
45
+ num_workers=2
46
+ )
47
+
48
+ model = model_config.model.to(device)
49
+ model.eval()
50
+
51
+ requires_sigmoid = False
52
+ if model_config.n_classes == 1:
53
+ current_activation = _check_activation_function(model)
54
+ if current_activation != 'Sigmoid':
55
+ requires_sigmoid = True
56
+
57
+ predictions = []
58
+ ground_truth = []
59
+
60
+ print("Evaluating model performance...")
61
+ with torch.no_grad():
62
+ for images, masks in tqdm(evaluation_loader):
63
+ images = images.to(device)
64
+ masks = masks.to(device)
65
+
66
+ outputs = model.forward(images)
67
+
68
+ if model_config.n_classes > 1:
69
+ predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
70
+ ground_truth.extend([gt.cpu().argmax(dim=0) for gt in masks])
71
+ else:
72
+ if requires_sigmoid:
73
+ predictions.extend([
74
+ (torch.sigmoid(p) > 0.5).float().squeeze().cpu()
75
+ for p in outputs
76
+ ])
77
+ else:
78
+ predictions.extend([
79
+ (p > 0.5).float().squeeze().cpu()
80
+ for p in outputs
81
+ ])
82
+ ground_truth.extend([gt.cpu().squeeze() for gt in masks])
83
+
84
+ metrics = compute_mean_iou(
85
+ predictions,
86
+ ground_truth,
87
+ num_labels=len(classes),
88
+ ignore_index=255
89
+ )
90
+
91
+ print("\nEvaluation Results:")
92
+ print(f"Mean IoU: {metrics['mean_iou']:.3f}")
93
+ print("\nPer-class IoU:")
94
+ for idx, iou in enumerate(metrics['per_category_iou']):
95
+ print(f"{classes[idx]}: {iou:.3f}")
96
+
97
+ return metrics
98
+
99
+ def _check_activation_function(model):
100
+ """Checks the activation function used in model's segmentation head."""
101
+ from segmentation_models_pytorch.base.modules import Activation
102
+
103
+ activation_functions = []
104
+ for _, module in model.segmentation_head.named_children():
105
+ if isinstance(module, Activation):
106
+ activation_functions.append(type(module.activation).__name__)
107
+
108
+ return activation_functions[-1] if activation_functions else None
semantic-segmentation/SemanticModel/.ipynb_checkpoints/image_preprocessing-checkpoint.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import albumentations as albu
4
+ from albumentations.augmentations.geometric.resize import LongestMaxSize
5
+
6
+ def round_pixel_dim(dimension: float) -> int:
7
+ """Rounds pixel dimensions consistently."""
8
+ if abs(round(dimension) - dimension) == 0.5:
9
+ return int(2.0 * round(dimension / 2.0))
10
+ return int(round(dimension))
11
+
12
+ def resize_with_padding(image, target_size, stride=32, interpolation=cv2.INTER_LINEAR):
13
+ """Resizes image maintaining aspect ratio and ensures dimensions are stride-compatible."""
14
+ height, width = image.shape[:2]
15
+ max_dimension = max(height, width)
16
+
17
+ if ((height % stride == 0) and (width % stride == 0) and
18
+ (max_dimension <= target_size)):
19
+ return image
20
+
21
+ scale = target_size / float(max(width, height))
22
+ new_dims = tuple(round_pixel_dim(dim * scale) for dim in (height, width))
23
+ new_height, new_width = new_dims
24
+
25
+ new_height = ((new_height // stride + 1) * stride
26
+ if new_height % stride != 0 else new_height)
27
+ new_width = ((new_width // stride + 1) * stride
28
+ if new_width % stride != 0 else new_width)
29
+
30
+ return cv2.resize(image, (new_width, new_height), interpolation=interpolation)
31
+
32
+ class PaddedResize(LongestMaxSize):
33
+ def apply(self, img: np.ndarray, target_size: int = 1024,
34
+ interpolation: int = cv2.INTER_LINEAR, **params) -> np.ndarray:
35
+ return resize_with_padding(img, target_size=target_size, interpolation=interpolation)
36
+
37
+ def get_training_augmentations(width=768, height=576):
38
+ """Configures training-time augmentations."""
39
+ target_size = max([width, height])
40
+ transforms = [
41
+ albu.HorizontalFlip(p=0.5),
42
+ albu.ShiftScaleRotate(
43
+ scale_limit=0.5, rotate_limit=90, shift_limit=0.1, p=0.5, border_mode=0),
44
+ albu.PadIfNeeded(min_height=target_size, min_width=target_size, always_apply=True),
45
+ albu.RandomCrop(height=target_size, width=target_size, always_apply=True),
46
+ albu.GaussNoise(p=0.2),
47
+ albu.Perspective(p=0.2),
48
+ albu.OneOf([albu.CLAHE(p=1), albu.RandomGamma(p=1)], p=0.33),
49
+ albu.OneOf([
50
+ albu.Sharpen(p=1),
51
+ albu.Blur(blur_limit=3, p=1),
52
+ albu.MotionBlur(blur_limit=3, p=1)], p=0.33),
53
+ albu.OneOf([
54
+ albu.RandomBrightnessContrast(p=1),
55
+ albu.HueSaturationValue(p=1)], p=0.33),
56
+ ]
57
+ return albu.Compose(transforms)
58
+
59
+ def get_validation_augmentations(width=1920, height=1440, fixed_size=True):
60
+ """Configures validation/inference-time augmentations."""
61
+ if fixed_size:
62
+ transforms = [albu.Resize(height=height, width=width, always_apply=True)]
63
+ return albu.Compose(transforms)
64
+
65
+ target_size = max(width, height)
66
+ transforms = [PaddedResize(max_size=target_size, always_apply=True)]
67
+ return albu.Compose(transforms)
68
+
69
+ def convert_to_tensor(x, **kwargs):
70
+ """Converts image array to PyTorch tensor format."""
71
+ if x.ndim == 2:
72
+ x = np.expand_dims(x, axis=-1)
73
+ return x.transpose(2, 0, 1).astype('float32')
74
+
75
+ def get_preprocessing_pipeline(preprocessing_fn):
76
+ """Builds preprocessing pipeline including normalization and tensor conversion."""
77
+ transforms = [
78
+ albu.Lambda(image=preprocessing_fn),
79
+ albu.Lambda(image=convert_to_tensor, mask=convert_to_tensor),
80
+ ]
81
+ return albu.Compose(transforms)
semantic-segmentation/SemanticModel/.ipynb_checkpoints/metrics-checkpoint.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import numpy as np
3
+
4
+ def compute_intersection_union(prediction, ground_truth, num_classes, ignore_index: bool,
5
+ label_mapping: Optional[Dict[int, int]] = None,
6
+ reduce_labels: bool = False):
7
+ """Computes intersection and union for IoU calculation."""
8
+
9
+ if label_mapping:
10
+ for old_id, new_id in label_mapping.items():
11
+ ground_truth[ground_truth == old_id] = new_id
12
+
13
+ prediction = np.array(prediction)
14
+ ground_truth = np.array(ground_truth)
15
+
16
+ if reduce_labels:
17
+ ground_truth[ground_truth == 0] = 255
18
+ ground_truth = ground_truth - 1
19
+ ground_truth[ground_truth == 254] = 255
20
+
21
+ valid_mask = np.not_equal(ground_truth, ignore_index)
22
+ prediction = prediction[valid_mask]
23
+ ground_truth = ground_truth[valid_mask]
24
+
25
+ intersection_mask = prediction == ground_truth
26
+ intersection = prediction[intersection_mask]
27
+
28
+ area_intersection = np.histogram(intersection, bins=num_classes,
29
+ range=(0, num_classes - 1))[0]
30
+ area_prediction = np.histogram(prediction, bins=num_classes,
31
+ range=(0, num_classes - 1))[0]
32
+ area_ground_truth = np.histogram(ground_truth, bins=num_classes,
33
+ range=(0, num_classes - 1))[0]
34
+ area_union = area_prediction + area_ground_truth - area_intersection
35
+
36
+ return area_intersection, area_union, area_prediction, area_ground_truth
37
+
38
+ def compute_total_intersection_union(predictions, ground_truths, num_classes, ignore_index: bool,
39
+ label_mapping: Optional[Dict[int, int]] = None,
40
+ reduce_labels: bool = False):
41
+ """Computes total intersection and union across all samples."""
42
+
43
+ totals = {
44
+ 'intersection': np.zeros((num_classes,), dtype=np.float64),
45
+ 'union': np.zeros((num_classes,), dtype=np.float64),
46
+ 'prediction': np.zeros((num_classes,), dtype=np.float64),
47
+ 'ground_truth': np.zeros((num_classes,), dtype=np.float64)
48
+ }
49
+
50
+ for pred, gt in zip(predictions, ground_truths):
51
+ intersection, union, pred_area, gt_area = compute_intersection_union(
52
+ pred, gt, num_classes, ignore_index, label_mapping, reduce_labels
53
+ )
54
+ totals['intersection'] += intersection
55
+ totals['union'] += union
56
+ totals['prediction'] += pred_area
57
+ totals['ground_truth'] += gt_area
58
+
59
+ return tuple(totals.values())
60
+
61
+ def compute_mean_iou(predictions, ground_truths, num_classes, ignore_index: bool,
62
+ nan_to_num: Optional[int] = None,
63
+ label_mapping: Optional[Dict[int, int]] = None,
64
+ reduce_labels: bool = False):
65
+ """Computes mean IoU and related metrics."""
66
+
67
+ intersection, union, prediction_area, ground_truth_area = compute_total_intersection_union(
68
+ predictions, ground_truths, num_classes, ignore_index, label_mapping, reduce_labels
69
+ )
70
+
71
+ metrics = {}
72
+
73
+ # Compute overall accuracy
74
+ total_accuracy = intersection.sum() / ground_truth_area.sum()
75
+
76
+ # Compute IoU per class
77
+ iou_per_class = intersection / union
78
+ accuracy_per_class = intersection / ground_truth_area
79
+
80
+ metrics.update({
81
+ "mean_iou": np.nanmean(iou_per_class),
82
+ "mean_accuracy": np.nanmean(accuracy_per_class),
83
+ "overall_accuracy": total_accuracy,
84
+ "per_category_iou": iou_per_class,
85
+ "per_category_accuracy": accuracy_per_class
86
+ })
87
+
88
+ if nan_to_num is not None:
89
+ metrics = {
90
+ metric: np.nan_to_num(value, nan=nan_to_num)
91
+ for metric, value in metrics.items()
92
+ }
93
+
94
+ return metrics
semantic-segmentation/SemanticModel/.ipynb_checkpoints/model_core-checkpoint.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from segmentation_models_pytorch import utils
5
+
6
+ from SemanticModel.encoder_management import initialize_encoder
7
+ from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy
8
+ from SemanticModel.image_preprocessing import get_preprocessing_pipeline
9
+
10
+ class SegmentationModel:
11
+ def __init__(self, classes=['background', 'foreground'], architecture='unet',
12
+ encoder='timm-regnety_120', weights='imagenet', loss=None):
13
+ self._initialize_classes(classes)
14
+ self.architecture = architecture
15
+ self.encoder = encoder
16
+ self.weights = weights
17
+ self._setup_loss_function(loss)
18
+ self._initialize_model()
19
+
20
+ def _initialize_classes(self, classes):
21
+ """Sets up class configuration."""
22
+ if len(classes) <= 2:
23
+ self.classes = [c for c in classes if c.lower() != 'background']
24
+ self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background']
25
+ self.background_flag = 'background' in classes
26
+ else:
27
+ self.classes = classes
28
+ self.class_values = list(range(len(classes)))
29
+ self.background_flag = False
30
+ self.n_classes = len(self.classes)
31
+
32
+ def _setup_loss_function(self, loss):
33
+ """Configures model's loss function."""
34
+ if not loss:
35
+ loss = 'bce_with_logits' if self.n_classes > 1 else 'dice'
36
+
37
+ if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']:
38
+ print(f'Invalid loss: {loss}, defaulting to dice')
39
+ loss = 'dice'
40
+
41
+ loss_configs = {
42
+ 'bce_with_logits': {
43
+ 'activation': None,
44
+ 'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss()
45
+ },
46
+ 'dice': {
47
+ 'activation': 'softmax' if self.n_classes > 1 else 'sigmoid',
48
+ 'loss': utils.losses.DiceLoss()
49
+ },
50
+ 'focal': {
51
+ 'activation': None,
52
+ 'loss': FocalLossFunction()
53
+ },
54
+ 'tversky': {
55
+ 'activation': None,
56
+ 'loss': TverskyLossFunction()
57
+ }
58
+ }
59
+
60
+ config = loss_configs[loss.lower()]
61
+ self.activation = config['activation']
62
+ self.loss = config['loss']
63
+ self.loss_name = loss
64
+
65
+ def _initialize_model(self):
66
+ """Initializes the segmentation model architecture."""
67
+ if self.weights.endswith('pth'):
68
+ self._load_pretrained_model()
69
+ else:
70
+ self._create_new_model()
71
+
72
+ def _load_pretrained_model(self):
73
+ """Loads model from pretrained weights."""
74
+ print('Loading pretrained model...')
75
+ self.model = torch.load(self.weights)
76
+ if isinstance(self.model, torch.nn.DataParallel):
77
+ self.model = self.model.module
78
+
79
+ try:
80
+ preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
81
+ self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
82
+ except:
83
+ print('Failed to configure preprocessing. Setting to None.')
84
+ self.preprocessing = None
85
+
86
+ def _create_new_model(self):
87
+ """Creates new model with specified architecture."""
88
+ preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
89
+ self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
90
+ initialize_encoder(name=self.encoder, weights=self.weights)
91
+
92
+ architectures = {
93
+ 'unet': smp.Unet,
94
+ 'unet++': smp.UnetPlusPlus,
95
+ 'deeplabv3': smp.DeepLabV3,
96
+ 'deeplabv3+': smp.DeepLabV3Plus,
97
+ 'fpn': smp.FPN,
98
+ 'linknet': smp.Linknet,
99
+ 'manet': smp.MAnet,
100
+ 'pan': smp.PAN,
101
+ 'pspnet': smp.PSPNet
102
+ }
103
+
104
+ if self.architecture not in architectures:
105
+ raise ValueError(f'Unsupported architecture: {self.architecture}')
106
+
107
+ self.model = architectures[self.architecture](
108
+ encoder_name=self.encoder,
109
+ encoder_weights=self.weights,
110
+ classes=self.n_classes,
111
+ activation=self.activation
112
+ )
113
+
114
+ @property
115
+ def config_data(self):
116
+ """Returns model configuration data."""
117
+ return {
118
+ 'architecture': self.architecture,
119
+ 'encoder': self.encoder,
120
+ 'weights': self.weights,
121
+ 'activation': self.activation,
122
+ 'loss': self.loss_name,
123
+ 'classes': ['background'] + self.classes if self.background_flag else self.classes
124
+ }
125
+
126
+ def list_architectures():
127
+ """Returns available architecture options."""
128
+ return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn',
129
+ 'linknet', 'manet', 'pan', 'pspnet']
semantic-segmentation/SemanticModel/.ipynb_checkpoints/prediction-checkpoint.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import torch
5
+ import imageio
6
+ import tifffile
7
+ import numpy as np
8
+ import slidingwindow
9
+ import rasterio as rio
10
+ import geopandas as gpd
11
+ from shapely.geometry import Polygon
12
+ from rasterio import mask as riomask
13
+ from torch.utils.data import DataLoader
14
+ from SemanticModel.visualization import generate_color_mapping
15
+ from SemanticModel.image_preprocessing import get_validation_augmentations
16
+ from SemanticModel.data_loader import InferenceDataset, StreamingDataset
17
+ from SemanticModel.utilities import calc_image_size, convert_coordinates
18
+
19
+ class PredictionPipeline:
20
+ def __init__(self, model_config, device=None):
21
+ self.config = model_config
22
+ self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ self.classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
24
+ self.colors = generate_color_mapping(len(self.classes))
25
+ self.model = model_config.model.to(self.device)
26
+ self.model.eval()
27
+
28
+ def _preprocess_image(self, image_path, target_size=None):
29
+ """Preprocesses single image for prediction."""
30
+ image = cv2.imread(image_path)
31
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
32
+ height, width = image.shape[:2]
33
+
34
+ target_size = target_size or max(height, width)
35
+ test_height, test_width = calc_image_size(image, target_size)
36
+
37
+ augmentation = get_validation_augmentations(test_width, test_height)
38
+ image = augmentation(image=image)['image']
39
+ image = self.config.preprocessing(image=image)['image']
40
+
41
+ return image, (height, width)
42
+
43
+ def predict_single_image(self, image_path, target_size=None, output_dir=None,
44
+ format='integer', save_output=True):
45
+ """Generates prediction for a single image."""
46
+ image, original_dims = self._preprocess_image(image_path, target_size)
47
+ x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)
48
+
49
+ with torch.no_grad():
50
+ prediction = self.model.predict(x_tensor)
51
+
52
+ if self.config.n_classes > 1:
53
+ prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
54
+ else:
55
+ prediction = prediction.squeeze().cpu().numpy().round()
56
+
57
+ # Resize to original dimensions if needed
58
+ if prediction.shape[:2] != original_dims:
59
+ prediction = cv2.resize(prediction, original_dims[::-1],
60
+ interpolation=cv2.INTER_NEAREST)
61
+
62
+ prediction = self._format_prediction(prediction, format)
63
+
64
+ if save_output:
65
+ self._save_prediction(prediction, image_path, output_dir, format)
66
+
67
+ return prediction
68
+
69
+ def predict_directory(self, input_dir, target_size=None, output_dir=None,
70
+ fixed_size=True, format='integer'):
71
+ """Generates predictions for all images in directory."""
72
+ output_dir = output_dir or os.path.join(input_dir, 'predictions')
73
+ os.makedirs(output_dir, exist_ok=True)
74
+
75
+ dataset = InferenceDataset(
76
+ input_dir,
77
+ classes=self.classes,
78
+ augmentation=get_validation_augmentations(
79
+ target_size, target_size, fixed_size=fixed_size
80
+ ) if target_size else None,
81
+ preprocessing=self.config.preprocessing
82
+ )
83
+
84
+ total_images = len(dataset)
85
+ start_time = time.time()
86
+
87
+ for idx in range(total_images):
88
+ if (idx + 1) % 10 == 0 or idx == total_images - 1:
89
+ elapsed = time.time() - start_time
90
+ print(f'\rProcessed {idx+1}/{total_images} images in {elapsed:.1f}s',
91
+ end='')
92
+
93
+ image, height, width = dataset[idx]
94
+ filename = dataset.filenames[idx]
95
+
96
+ x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)
97
+ with torch.no_grad():
98
+ prediction = self.model.predict(x_tensor)
99
+
100
+ if self.config.n_classes > 1:
101
+ prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
102
+ else:
103
+ prediction = prediction.squeeze().cpu().numpy().round()
104
+
105
+ if prediction.shape != (height, width):
106
+ prediction = cv2.resize(prediction, (width, height),
107
+ interpolation=cv2.INTER_NEAREST)
108
+
109
+ prediction = self._format_prediction(prediction, format)
110
+ self._save_prediction(prediction, filename, output_dir, format)
111
+
112
+ print(f'\nPredictions saved to: {output_dir}')
113
+ return output_dir
114
+
115
+ def predict_raster(self, raster_path, tile_size=1024, overlap=0.175,
116
+ boundary_path=None, output_path=None, format='integer'):
117
+ """Processes large raster images using tiling approach."""
118
+ print('Loading raster...')
119
+ with rio.open(raster_path) as src:
120
+ raster = src.read()
121
+ raster = np.moveaxis(raster, 0, 2)[:,:,:3]
122
+ profile = src.profile
123
+ transform = src.transform
124
+
125
+ if boundary_path:
126
+ boundary = gpd.read_file(boundary_path)
127
+ boundary = boundary.to_crs(profile['crs'])
128
+ boundary_geom = boundary.iloc[0].geometry
129
+
130
+ tiles = slidingwindow.generate(
131
+ raster,
132
+ slidingwindow.DimOrder.HeightWidthChannel,
133
+ tile_size,
134
+ overlap
135
+ )
136
+
137
+ pred_raster = np.zeros_like(raster[:,:,0], dtype='uint8')
138
+ confidence = np.zeros_like(pred_raster, dtype=np.float32)
139
+
140
+ aug = get_validation_augmentations(tile_size, tile_size, fixed_size=False)
141
+
142
+ for idx, tile in enumerate(tiles):
143
+ if (idx + 1) % 10 == 0 or idx == len(tiles) - 1:
144
+ print(f'\rProcessed {idx+1}/{len(tiles)} tiles', end='')
145
+
146
+ bounds = tile.indices()
147
+
148
+ tile_image = raster[bounds[0], bounds[1]]
149
+
150
+ if boundary_path:
151
+ corners = [
152
+ convert_coordinates(transform, bounds[1].start, bounds[0].start),
153
+ convert_coordinates(transform, bounds[1].stop, bounds[0].start),
154
+ convert_coordinates(transform, bounds[1].stop, bounds[0].stop),
155
+ convert_coordinates(transform, bounds[1].start, bounds[0].stop)
156
+ ]
157
+ if not Polygon(corners).intersects(boundary_geom):
158
+ continue
159
+
160
+ processed = aug(image=tile_image)['image']
161
+ processed = self.config.preprocessing(image=processed)['image']
162
+
163
+ x_tensor = torch.from_numpy(processed).to(self.device).unsqueeze(0)
164
+ with torch.no_grad():
165
+ prediction = self.model.predict(x_tensor)
166
+ prediction = prediction.squeeze().cpu().numpy()
167
+
168
+ if self.config.n_classes > 1:
169
+ tile_pred = np.argmax(prediction, axis=0)
170
+ tile_conf = np.max(prediction, axis=0)
171
+ else:
172
+ tile_conf = np.abs(prediction - 0.5)
173
+ tile_pred = prediction.round()
174
+
175
+ if tile_pred.shape != tile_image.shape[:2]:
176
+ tile_pred = cv2.resize(tile_pred, tile_image.shape[:2][::-1],
177
+ interpolation=cv2.INTER_NEAREST)
178
+ tile_conf = cv2.resize(tile_conf, tile_image.shape[:2][::-1],
179
+ interpolation=cv2.INTER_LINEAR)
180
+
181
+ # Update prediction and confidence maps
182
+ existing_conf = confidence[bounds[0], bounds[1]]
183
+ existing_pred = pred_raster[bounds[0], bounds[1]]
184
+
185
+ mask = existing_conf < tile_conf
186
+ existing_pred[mask] = tile_pred[mask]
187
+ existing_conf[mask] = tile_conf[mask]
188
+
189
+ pred_raster[bounds[0], bounds[1]] = existing_pred
190
+ confidence[bounds[0], bounds[1]] = existing_conf
191
+
192
+ pred_raster = self._format_prediction(pred_raster, format)
193
+
194
+ if output_path or boundary_path:
195
+ self._save_raster_prediction(
196
+ pred_raster, raster_path, output_path,
197
+ profile, boundary_geom if boundary_path else None
198
+ )
199
+
200
+ return pred_raster, profile
201
+
202
+ def _format_prediction(self, prediction, format):
203
+ """Formats prediction according to specified output type."""
204
+ if format == 'integer':
205
+ return prediction.astype('uint8')
206
+ elif format == 'color':
207
+ return self._apply_color_mapping(prediction)
208
+ else:
209
+ raise ValueError(f"Unsupported format: {format}")
210
+
211
+ def _save_prediction(self, prediction, source_path, output_dir, format):
212
+ """Saves prediction to disk."""
213
+ filename = os.path.splitext(os.path.basename(source_path))[0]
214
+ output_path = os.path.join(output_dir, f"{filename}_pred.png")
215
+ cv2.imwrite(output_path, prediction)
216
+
217
+
218
+ def _save_raster_prediction(self, prediction, source_path, output_path,
219
+ profile, boundary=None):
220
+ """Saves raster prediction with geospatial information."""
221
+ output_path = output_path or source_path.replace(
222
+ os.path.splitext(source_path)[1], '_predicted.tif'
223
+ )
224
+
225
+ profile.update(
226
+ dtype='uint8',
227
+ count=3 if prediction.ndim == 3 else 1
228
+ )
229
+
230
+ with rio.open(output_path, 'w', **profile) as dst:
231
+ if prediction.ndim == 3:
232
+ for i in range(3):
233
+ dst.write(prediction[:,:,i], i+1)
234
+ else:
235
+ dst.write(prediction, 1)
236
+
237
+ if boundary:
238
+ with rio.open(output_path) as src:
239
+ cropped, transform = riomask.mask(src, [boundary], crop=True)
240
+ profile.update(
241
+ height=cropped.shape[1],
242
+ width=cropped.shape[2],
243
+ transform=transform
244
+ )
245
+
246
+ os.remove(output_path)
247
+ with rio.open(output_path, 'w', **profile) as dst:
248
+ dst.write(cropped)
249
+
250
+ print(f'\nPrediction saved to: {output_path}')
251
+
252
+ def predict_video_frames(self, input_dir, target_size=None, output_dir=None):
253
+ """Processes video frames with specialized visualization."""
254
+ output_dir = output_dir or os.path.join(input_dir, 'predictions')
255
+ os.makedirs(output_dir, exist_ok=True)
256
+
257
+ dataset = StreamingDataset(
258
+ input_dir,
259
+ classes=self.classes,
260
+ augmentation=get_validation_augmentations(
261
+ target_size, target_size
262
+ ) if target_size else None,
263
+ preprocessing=self.config.preprocessing
264
+ )
265
+
266
+ image = cv2.imread(dataset.image_paths[0])
267
+ height, width = image.shape[:2]
268
+
269
+ white = 255 * np.ones((height, width))
270
+ black = np.zeros_like(white)
271
+ red = np.dstack((white, black, black))
272
+ blue = np.dstack((black, black, white))
273
+
274
+ # Pre-compute rotated versions
275
+ rotated_red = np.rot90(red)
276
+ rotated_blue = np.rot90(blue)
277
+
278
+ total_frames = len(dataset)
279
+ start_time = time.time()
280
+
281
+ for idx in range(total_frames):
282
+ if (idx + 1) % 10 == 0 or idx == total_frames - 1:
283
+ elapsed = time.time() - start_time
284
+ print(f'\rProcessed {idx+1}/{total_frames} frames in {elapsed:.1f}s', end='')
285
+
286
+ frame, height, width = dataset[idx]
287
+ filename = dataset.filenames[idx]
288
+
289
+ x_tensor = torch.from_numpy(frame).to(self.device).unsqueeze(0)
290
+ with torch.no_grad():
291
+ prediction = self.model.predict(x_tensor)
292
+
293
+ if self.config.n_classes > 1:
294
+ prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
295
+ masks = [prediction == i for i in range(1, self.config.n_classes)]
296
+ else:
297
+ prediction = prediction.squeeze().cpu().numpy().round()
298
+ masks = [prediction == 1]
299
+
300
+ if prediction.shape != (height, width):
301
+ prediction = cv2.resize(prediction, (width, height),
302
+ interpolation=cv2.INTER_NEAREST)
303
+
304
+ original = cv2.imread(os.path.join(input_dir, filename))
305
+ original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
306
+
307
+ try:
308
+ for i, mask in enumerate(masks):
309
+ color = red if i == 0 else blue
310
+ rotated_color = rotated_red if i == 0 else rotated_blue
311
+ try:
312
+ original[mask,:] = 0.45*original[mask,:] + 0.55*color[mask,:]
313
+ except:
314
+ original[mask,:] = 0.45*original[mask,:] + 0.55*rotated_color[mask,:]
315
+ except:
316
+ print(f"\nWarning: Error processing frame {filename}")
317
+ continue
318
+
319
+ output_path = os.path.join(output_dir, filename)
320
+ imageio.imwrite(output_path, original, quality=100)
321
+
322
+ print(f'\nProcessed frames saved to: {output_dir}')
323
+ return output_dir
324
+
325
+ def _apply_color_mapping(self, prediction):
326
+ """Applies color mapping to prediction."""
327
+ height, width = prediction.shape
328
+ colored = np.zeros((height, width, 3), dtype='uint8')
329
+
330
+ for i, class_name in enumerate(self.classes):
331
+ if class_name.lower() == 'background':
332
+ continue
333
+ color = self.colors[i]
334
+ colored[prediction == i] = color
335
+
336
+ return colored
semantic-segmentation/SemanticModel/.ipynb_checkpoints/training-checkpoint.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import wandb
5
+ import datetime
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from segmentation_models_pytorch.base.modules import Activation
11
+
12
+ from SemanticModel.data_loader import SegmentationDataset
13
+ from SemanticModel.metrics import compute_mean_iou
14
+ from SemanticModel.image_preprocessing import get_training_augmentations, get_validation_augmentations
15
+ from SemanticModel.utilities import list_images, validate_dimensions
16
+
17
+ class ModelTrainer:
18
+ def __init__(self, model_config, root_dir, epochs=40, train_size=1024,
19
+ val_size=None, workers=2, batch_size=2, learning_rate=1e-4,
20
+ step_count=2, decay_factor=0.8, wandb_config=None,
21
+ optimizer='rmsprop', target_class=None, resume_path=None):
22
+
23
+ self.config = model_config
24
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ self.root_dir = root_dir
26
+ self._initialize_training_params(epochs, train_size, val_size, workers,
27
+ batch_size, learning_rate, step_count,
28
+ decay_factor, optimizer, target_class)
29
+ self._setup_directories()
30
+ self._initialize_datasets()
31
+ self._setup_optimizer()
32
+ self._initialize_tracking()
33
+
34
+ if resume_path:
35
+ self._resume_training(resume_path)
36
+
37
+ def _initialize_training_params(self, epochs, train_size, val_size, workers,
38
+ batch_size, learning_rate, step_count,
39
+ decay_factor, optimizer, target_class):
40
+ self.epochs = epochs
41
+ self.train_size = train_size
42
+ self.val_size = val_size
43
+ self.workers = workers
44
+ self.batch_size = batch_size
45
+ self.learning_rate = learning_rate
46
+ self.step_schedule = self._calculate_step_schedule(epochs, step_count)
47
+ self.decay_factor = decay_factor
48
+ self.optimizer_type = optimizer
49
+ self.target_class = target_class
50
+ self.current_epoch = 1
51
+ self.best_iou = 0.0
52
+ self.best_epoch = 0
53
+ self.classes = ['background'] + self.config.classes if self.config.background_flag else self.config.classes
54
+
55
+ def _setup_directories(self):
56
+ """Verifies and creates necessary directories."""
57
+ self.train_dir = os.path.join(self.root_dir, 'train')
58
+ self.val_dir = os.path.join(self.root_dir, 'val')
59
+
60
+ required_subdirs = ['Images', 'Masks']
61
+ for path in [self.train_dir] + ([self.val_dir] if os.path.exists(self.val_dir) else []):
62
+ for subdir in required_subdirs:
63
+ full_path = os.path.join(path, subdir)
64
+ if not os.path.exists(full_path):
65
+ raise FileNotFoundError(f"Missing directory: {full_path}")
66
+
67
+ def _initialize_datasets(self):
68
+ """Sets up training and validation datasets."""
69
+ self.train_dataset = SegmentationDataset(
70
+ self.train_dir,
71
+ classes=self.classes,
72
+ augmentation=get_training_augmentations(self.train_size, self.train_size),
73
+ preprocessing=self.config.preprocessing
74
+ )
75
+
76
+ if os.path.exists(self.val_dir):
77
+ self.val_dataset = SegmentationDataset(
78
+ self.val_dir,
79
+ classes=self.classes,
80
+ augmentation=get_validation_augmentations(
81
+ self.val_size or self.train_size,
82
+ self.val_size or self.train_size,
83
+ fixed_size=False
84
+ ),
85
+ preprocessing=self.config.preprocessing
86
+ )
87
+ self.val_loader = DataLoader(
88
+ self.val_dataset,
89
+ batch_size=1,
90
+ shuffle=False,
91
+ num_workers=self.workers
92
+ )
93
+ else:
94
+ self.val_dataset = self.train_dataset
95
+ self.val_loader = DataLoader(
96
+ self.val_dataset,
97
+ batch_size=1,
98
+ shuffle=False,
99
+ num_workers=self.workers
100
+ )
101
+
102
+ self.train_loader = DataLoader(
103
+ self.train_dataset,
104
+ batch_size=self.batch_size,
105
+ shuffle=True,
106
+ num_workers=self.workers
107
+ )
108
+
109
+ def _setup_optimizer(self):
110
+ """Configures model optimizer."""
111
+ optimizer_map = {
112
+ 'adam': torch.optim.Adam,
113
+ 'sgd': lambda params: torch.optim.SGD(params, momentum=0.9),
114
+ 'rmsprop': torch.optim.RMSprop
115
+ }
116
+ optimizer_class = optimizer_map.get(self.optimizer_type.lower())
117
+ if not optimizer_class:
118
+ raise ValueError(f"Unsupported optimizer: {self.optimizer_type}")
119
+
120
+ self.optimizer = optimizer_class([{'params': self.config.model.parameters(),
121
+ 'lr': self.learning_rate}])
122
+
123
+ def _initialize_tracking(self):
124
+ """Sets up training progress tracking."""
125
+ timestamp = datetime.datetime.now().strftime("%m-%d-%Y_%H%M%S")
126
+ self.output_dir = os.path.join(
127
+ self.root_dir,
128
+ f'model_outputs-{self.config.architecture}[{self.config.encoder}]-{timestamp}'
129
+ )
130
+ os.makedirs(self.output_dir, exist_ok=True)
131
+
132
+ self.writer = SummaryWriter(log_dir=self.output_dir)
133
+ self.metrics = {
134
+ 'best_epoch': self.best_epoch,
135
+ 'best_epoch_iou': self.best_iou,
136
+ 'last_epoch': 0,
137
+ 'last_epoch_iou': 0.0,
138
+ 'last_epoch_lr': self.learning_rate,
139
+ 'step_schedule': self.step_schedule,
140
+ 'decay_factor': self.decay_factor,
141
+ 'target_class': self.target_class or 'overall'
142
+ }
143
+
144
+ def _calculate_step_schedule(self, epochs, steps):
145
+ """Calculates learning rate step schedule."""
146
+ return list(map(int, np.linspace(0, epochs, steps + 2)[1:-1]))
147
+
148
+ def train(self):
149
+ """Executes training loop."""
150
+ model = self.config.model.to(self.device)
151
+ if torch.cuda.device_count() > 1:
152
+ model = torch.nn.DataParallel(model)
153
+ print(f'Using {torch.cuda.device_count()} GPUs')
154
+
155
+ self._save_config()
156
+
157
+ for epoch in range(self.current_epoch, self.epochs + 1):
158
+ print(f'\nEpoch {epoch}/{self.epochs}')
159
+ print(f'Learning rate: {self.optimizer.param_groups[0]["lr"]:.3e}')
160
+
161
+ train_loss = self._train_epoch(model)
162
+ val_loss, val_metrics = self._validate_epoch(model)
163
+
164
+ self._update_tracking(epoch, train_loss, val_loss, val_metrics)
165
+ self._adjust_learning_rate(epoch)
166
+ self._save_checkpoints(model, epoch, val_metrics)
167
+
168
+ print(f'\nTraining completed. Best {self.metrics["target_class"]} IoU: {self.best_iou:.3f}')
169
+ return model, self.metrics
170
+
171
+ def _train_epoch(self, model):
172
+ """Executes single training epoch."""
173
+ model.train()
174
+ total_loss = 0
175
+ sample_count = 0
176
+
177
+ for batch in tqdm(self.train_loader, desc='Training'):
178
+ images, masks = [x.to(self.device) for x in batch]
179
+ self.optimizer.zero_grad()
180
+
181
+ outputs = model(images)
182
+ loss = self.config.loss(outputs, masks)
183
+ loss.backward()
184
+ self.optimizer.step()
185
+
186
+ total_loss += loss.item() * len(images)
187
+ sample_count += len(images)
188
+
189
+ return total_loss / sample_count
190
+
191
+ def _validate_epoch(self, model):
192
+ """Executes validation pass."""
193
+ model.eval()
194
+ total_loss = 0
195
+ predictions = []
196
+ ground_truth = []
197
+
198
+ with torch.no_grad():
199
+ for batch in tqdm(self.val_loader, desc='Validation'):
200
+ images, masks = [x.to(self.device) for x in batch]
201
+ outputs = model(images)
202
+ loss = self.config.loss(outputs, masks)
203
+
204
+ total_loss += loss.item()
205
+
206
+ if self.config.n_classes > 1:
207
+ predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
208
+ ground_truth.extend([m.cpu().argmax(dim=0) for m in masks])
209
+ else:
210
+ predictions.extend([(torch.sigmoid(p) > 0.5).float().squeeze().cpu()
211
+ for p in outputs])
212
+ ground_truth.extend([m.cpu().squeeze() for m in masks])
213
+
214
+ metrics = compute_mean_iou(
215
+ predictions,
216
+ ground_truth,
217
+ num_classes=len(self.classes),
218
+ ignore_index=255
219
+ )
220
+
221
+ return total_loss / len(self.val_loader), metrics
222
+
223
+ def _update_tracking(self, epoch, train_loss, val_loss, val_metrics):
224
+ """Updates training metrics and logging."""
225
+ mean_iou = val_metrics['mean_iou']
226
+ print(f"\nLosses - Train: {train_loss:.3f}, Val: {val_loss:.3f}")
227
+ print(f"Mean IoU: {mean_iou:.3f}")
228
+
229
+ self.writer.add_scalar('Loss/train', train_loss, epoch)
230
+ self.writer.add_scalar('Loss/val', val_loss, epoch)
231
+ self.writer.add_scalar('IoU/mean', mean_iou, epoch)
232
+
233
+ for idx, iou in enumerate(val_metrics['per_category_iou']):
234
+ print(f"{self.classes[idx]} IoU: {iou:.3f}")
235
+ self.writer.add_scalar(f'IoU/{self.classes[idx]}', iou, epoch)
236
+
237
+ def _adjust_learning_rate(self, epoch):
238
+ """Adjusts learning rate according to schedule."""
239
+ if epoch in self.step_schedule:
240
+ current_lr = self.optimizer.param_groups[0]['lr']
241
+ new_lr = current_lr * self.decay_factor
242
+ for param_group in self.optimizer.param_groups:
243
+ param_group['lr'] = new_lr
244
+ print(f'\nDecreased learning rate: {current_lr:.3e} -> {new_lr:.3e}')
245
+
246
+ def _save_checkpoints(self, model, epoch, metrics):
247
+ """Saves model checkpoints and metrics."""
248
+ epoch_iou = (metrics['mean_iou'] if self.target_class is None
249
+ else metrics['per_category_iou'][self.classes.index(self.target_class)])
250
+
251
+ self.metrics.update({
252
+ 'last_epoch': epoch,
253
+ 'last_epoch_iou': round(float(epoch_iou), 3),
254
+ 'last_epoch_lr': self.optimizer.param_groups[0]['lr']
255
+ })
256
+
257
+ if epoch_iou > self.best_iou:
258
+ self.best_iou = epoch_iou
259
+ self.best_epoch = epoch
260
+ self.metrics.update({
261
+ 'best_epoch': epoch,
262
+ 'best_epoch_iou': round(float(epoch_iou), 3),
263
+ 'overall_iou': round(float(metrics['mean_iou']), 3)
264
+ })
265
+ torch.save(model, os.path.join(self.output_dir, 'best_model.pth'))
266
+ print(f'New best model saved (IoU: {epoch_iou:.3f})')
267
+
268
+ torch.save(model, os.path.join(self.output_dir, 'last_model.pth'))
269
+ with open(os.path.join(self.output_dir, 'metrics.json'), 'w') as f:
270
+ json.dump(self.metrics, f, indent=4)
271
+
272
+ def _save_config(self):
273
+ """Saves training configuration."""
274
+ config = {
275
+ **self.config.config_data,
276
+ 'train_size': self.train_size,
277
+ 'val_size': self.val_size,
278
+ 'epochs': self.epochs,
279
+ 'batch_size': self.batch_size,
280
+ 'optimizer': self.optimizer_type,
281
+ 'workers': self.workers,
282
+ 'target_class': self.target_class or 'overall'
283
+ }
284
+
285
+ with open(os.path.join(self.output_dir, 'config.json'), 'w') as f:
286
+ json.dump(config, f, indent=4)
287
+
288
+ def _resume_training(self, resume_path):
289
+ """Resumes training from checkpoint."""
290
+ if not os.path.exists(resume_path):
291
+ raise FileNotFoundError(f"Resume path not found: {resume_path}")
292
+
293
+ required_files = {
294
+ 'model': 'last_model.pth',
295
+ 'metrics': 'metrics.json',
296
+ 'config': 'config.json'
297
+ }
298
+
299
+ paths = {k: os.path.join(resume_path, v) for k, v in required_files.items()}
300
+ if not all(os.path.exists(p) for p in paths.values()):
301
+ raise FileNotFoundError("Missing required checkpoint files")
302
+
303
+ with open(paths['config']) as f:
304
+ config = json.load(f)
305
+ with open(paths['metrics']) as f:
306
+ metrics = json.load(f)
307
+
308
+ self.current_epoch = metrics['last_epoch'] + 1
309
+ self.best_iou = metrics['best_epoch_iou']
310
+ self.best_epoch = metrics['best_epoch']
311
+ self.learning_rate = metrics['last_epoch_lr']
312
+
313
+ print(f'Resuming training from epoch {self.current_epoch}')
semantic-segmentation/SemanticModel/.ipynb_checkpoints/utilities-checkpoint.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import shutil
4
+ import imageio
5
+ import numpy as np
6
+ from glob import glob
7
+ from pathlib import Path
8
+ from typing import List, Tuple, Optional
9
+
10
+ def validate_dimensions(width: int, height: int, stride: int = 32) -> Tuple[int, int]:
11
+ if height % stride != 0 or width % stride != 0:
12
+ new_height = ((height // stride + 1) * stride
13
+ if height % stride != 0 else height)
14
+ new_width = ((width // stride + 1) * stride
15
+ if width % stride != 0 else width)
16
+ print(f'Adjusted dimensions to: {new_height}H x {new_width}W')
17
+ return width, height
18
+
19
+ def calc_image_size(image: np.ndarray, target_size: int) -> Tuple[int, int]:
20
+ height, width = image.shape[:2]
21
+ aspect_ratio = width / height
22
+
23
+ if aspect_ratio >= 1:
24
+ new_width = target_size
25
+ new_height = int(target_size / aspect_ratio)
26
+ else:
27
+ new_height = target_size
28
+ new_width = int(target_size * aspect_ratio)
29
+
30
+ return validate_dimensions(new_width, new_height)
31
+
32
+ def convert_coordinates(transform: np.ndarray, x: float, y: float) -> Tuple[float, float]:
33
+ transformed = transform @ np.array([x, y, 1])
34
+ return transformed[0], transformed[1]
35
+
36
+ def list_images(directory: str, mask_format: bool = False) -> List[str]:
37
+ extensions = ['*.png', '*.PNG'] if mask_format else [
38
+ '*.jpg', '*.jpeg', '*.png', '*.tif', '*.tiff',
39
+ '*.JPG', '*.JPEG', '*.PNG', '*.TIF', '*.TIFF'
40
+ ]
41
+
42
+ image_paths = []
43
+ for ext in extensions:
44
+ image_paths.extend(glob(os.path.join(directory, ext)))
45
+
46
+ return sorted(list(set(image_paths)))
47
+
48
+ def prepare_dataset_split(root_dir: str, train_ratio: float = 0.7,
49
+ generate_empty_masks: bool = False) -> None:
50
+ image_dir = os.path.join(root_dir, 'Images')
51
+ mask_dir = os.path.join(root_dir, 'Masks')
52
+
53
+ if not all(os.path.exists(d) for d in [image_dir, mask_dir]):
54
+ raise Exception("Required 'Images' and 'Masks' directories not found")
55
+
56
+ image_paths = np.array(list_images(image_dir))
57
+ mask_paths = np.array(list_images(mask_dir, mask_format=True))
58
+
59
+ if generate_empty_masks:
60
+ temp_dir = os.path.join(mask_dir, 'temp')
61
+ create_empty_masks(image_dir, outdir=temp_dir)
62
+
63
+ for mask_path in list_images(temp_dir, mask_format=True):
64
+ target_path = os.path.join(mask_dir, os.path.basename(mask_path))
65
+ if not os.path.exists(target_path):
66
+ shutil.move(mask_path, target_path)
67
+
68
+ shutil.rmtree(temp_dir)
69
+ mask_paths = np.array(list_images(mask_dir, mask_format=True))
70
+
71
+ if len(image_paths) != len(mask_paths):
72
+ raise Exception(f"Unmatched images ({len(image_paths)}) and masks ({len(mask_paths)})")
73
+
74
+ train_ratio = float(train_ratio)
75
+ if not (0 < train_ratio <= 1):
76
+ raise ValueError(f"Invalid train ratio: {train_ratio}")
77
+
78
+ train_size = int(np.floor(train_ratio * len(image_paths)))
79
+ indices = np.random.permutation(len(image_paths))
80
+
81
+ splits = {
82
+ 'train': {'indices': indices[:train_size]},
83
+ 'val': {'indices': indices[train_size:]} if train_ratio < 1 else None
84
+ }
85
+
86
+ for split_name, split_data in splits.items():
87
+ if split_data is None:
88
+ continue
89
+
90
+ split_dir = os.path.join(root_dir, split_name)
91
+ for subdir in ['Images', 'Masks']:
92
+ subdir_path = os.path.join(split_dir, subdir)
93
+ os.makedirs(subdir_path, exist_ok=True)
94
+
95
+ sources = image_paths if subdir == 'Images' else mask_paths
96
+ for idx in split_data['indices']:
97
+ source = sources[idx]
98
+ destination = os.path.join(subdir_path, os.path.basename(source))
99
+ shutil.copyfile(source, destination)
100
+
101
+ print(f"Created {split_name} split with {len(split_data['indices'])} samples")
102
+
103
+ def create_empty_masks(image_dir: str, pixel_value: int = 0,
104
+ outdir: Optional[str] = None) -> str:
105
+ outdir = outdir or os.path.join(image_dir, 'Masks')
106
+ os.makedirs(outdir, exist_ok=True)
107
+
108
+ image_paths = list_images(image_dir)
109
+ print(f"Generating {len(image_paths)} empty masks...")
110
+
111
+ for image_path in image_paths:
112
+ image = imageio.imread(image_path)
113
+ mask = np.full((image.shape[0], image.shape[1]), pixel_value, dtype='uint8')
114
+
115
+ output_path = os.path.join(outdir,
116
+ f"{Path(image_path).stem}.png")
117
+ imageio.imwrite(output_path, mask)
118
+
119
+ return outdir
semantic-segmentation/SemanticModel/.ipynb_checkpoints/visualization-checkpoint.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+
6
+ def plot_predictions(model, images, masks, device, num_samples=4):
7
+ """Visualize model predictions against ground truth."""
8
+ with torch.no_grad():
9
+ model.eval()
10
+ predictions = model.predict(images.to(device))
11
+
12
+ fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
13
+
14
+ for idx in range(num_samples):
15
+ # Original image
16
+ img = images[idx].permute(1, 2, 0).cpu().numpy()
17
+ axes[idx, 0].imshow(img)
18
+ axes[idx, 0].set_title('Original Image')
19
+
20
+ # Ground truth
21
+ truth = masks[idx].argmax(dim=0).cpu().numpy()
22
+ axes[idx, 1].imshow(truth, cmap='tab20')
23
+ axes[idx, 1].set_title('Ground Truth')
24
+
25
+ # Prediction
26
+ pred = predictions[idx].argmax(dim=0).cpu().numpy()
27
+ axes[idx, 2].imshow(pred, cmap='tab20')
28
+ axes[idx, 2].set_title('Prediction')
29
+
30
+ for ax in axes[idx]:
31
+ ax.axis('off')
32
+
33
+ plt.tight_layout()
34
+ return fig
35
+
36
+ def create_overlay_mask(image, mask, alpha=0.5, color_map=None):
37
+ """Create transparent overlay of segmentation mask on image."""
38
+ if color_map is None:
39
+ color_map = {
40
+ 0: [0, 0, 0], # background
41
+ 1: [255, 0, 0], # class 1 (red)
42
+ 2: [0, 255, 0], # class 2 (green)
43
+ 3: [0, 0, 255], # class 3 (blue)
44
+ }
45
+
46
+ overlay = image.copy()
47
+ mask_colored = np.zeros_like(image)
48
+
49
+ for label, color in color_map.items():
50
+ mask_colored[mask == label] = color
51
+
52
+ cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
53
+ return overlay
54
+
55
+ def plot_training_history(history):
56
+ """Plot training and validation metrics."""
57
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
58
+
59
+ # Loss plot
60
+ ax1.plot(history['train_loss'], label='Training Loss')
61
+ ax1.plot(history['val_loss'], label='Validation Loss')
62
+ ax1.set_xlabel('Epoch')
63
+ ax1.set_ylabel('Loss')
64
+ ax1.set_title('Training and Validation Loss')
65
+ ax1.legend()
66
+
67
+ # IoU plot
68
+ ax2.plot(history['mean_iou'], label='Mean IoU')
69
+ for class_name, ious in history['class_ious'].items():
70
+ ax2.plot(ious, label=f'{class_name} IoU')
71
+ ax2.set_xlabel('Epoch')
72
+ ax2.set_ylabel('IoU')
73
+ ax2.set_title('IoU Metrics')
74
+ ax2.legend()
75
+
76
+ plt.tight_layout()
77
+ return fig
78
+
79
+ def visualize_predictions_on_batch(model, batch_images, batch_size=8):
80
+ """Create grid visualization for a batch of predictions."""
81
+ with torch.no_grad():
82
+ predictions = model.predict(batch_images)
83
+
84
+ fig = plt.figure(figsize=(15, 5))
85
+ for idx in range(min(batch_size, len(batch_images))):
86
+ plt.subplot(2, 4, idx + 1)
87
+ img = batch_images[idx].permute(1, 2, 0).cpu().numpy()
88
+ mask = predictions[idx].argmax(dim=0).cpu().numpy()
89
+ overlay = create_overlay_mask(img, mask)
90
+ plt.imshow(overlay)
91
+ plt.axis('off')
92
+
93
+ plt.tight_layout()
94
+ return fig
95
+
96
+ def save_visualization(fig, save_path):
97
+ """Save visualization figure."""
98
+ fig.savefig(save_path, bbox_inches='tight', dpi=300)
99
+ plt.close(fig)
100
+
101
+ def generate_color_mapping(num_classes):
102
+ """Generate distinct colors for segmentation classes."""
103
+ colors = [
104
+ [0, 0, 0], # Background (black)
105
+ [255, 0, 0], # Red
106
+ [0, 255, 0], # Green
107
+ [0, 0, 255], # Blue
108
+ [255, 255, 0], # Yellow
109
+ [255, 0, 255], # Magenta
110
+ [0, 255, 255], # Cyan
111
+ [128, 0, 0], # Dark Red
112
+ [0, 128, 0], # Dark Green
113
+ [0, 0, 128] # Dark Blue
114
+ ]
115
+ return colors[:num_classes]
semantic-segmentation/SemanticModel/__init__.py ADDED
File without changes
semantic-segmentation/SemanticModel/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (187 Bytes). View file
 
semantic-segmentation/SemanticModel/__pycache__/custom_losses.cpython-38.pyc ADDED
Binary file (3.45 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/data_loader.cpython-38.pyc ADDED
Binary file (6.72 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/encoder_management.cpython-38.pyc ADDED
Binary file (4.17 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/evaluation_utils.cpython-38.pyc ADDED
Binary file (3.62 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/image_preprocessing.cpython-38.pyc ADDED
Binary file (3.62 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/metrics.cpython-38.pyc ADDED
Binary file (2.56 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/model_core.cpython-38.pyc ADDED
Binary file (4.52 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/prediction.cpython-38.pyc ADDED
Binary file (9.62 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/training.cpython-38.pyc ADDED
Binary file (10.8 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/utilities.cpython-38.pyc ADDED
Binary file (4.03 kB). View file
 
semantic-segmentation/SemanticModel/__pycache__/visualization.cpython-38.pyc ADDED
Binary file (3.41 kB). View file
 
semantic-segmentation/SemanticModel/custom_losses.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from segmentation_models_pytorch.utils import base
4
+ from segmentation_models_pytorch.base.modules import Activation
5
+
6
+ class FocalLossFunction(base.Loss):
7
+ def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs):
8
+ super().__init__(**kwargs)
9
+ self.activation = Activation(activation)
10
+ self.alpha = alpha
11
+ self.gamma = gamma
12
+ self.reduction = reduction
13
+
14
+ def forward(self, inputs, targets):
15
+ if inputs.shape[1] == 1: # Binary case
16
+ inputs = torch.cat((inputs, 1 - inputs), dim=1)
17
+ targets = torch.cat((targets, 1 - targets), dim=1)
18
+
19
+ targets = torch.argmax(targets, dim=1)
20
+ cross_entropy = F.cross_entropy(inputs, targets, reduction='none')
21
+ probability = torch.exp(-cross_entropy)
22
+ alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where(
23
+ targets == 1, 1-self.alpha, self.alpha)
24
+
25
+ focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy
26
+
27
+ if self.reduction == 'mean':
28
+ return focal_weight.mean()
29
+ elif self.reduction == 'sum':
30
+ return focal_weight.sum()
31
+ return focal_weight
32
+
33
+ class TverskyLossFunction(base.Loss):
34
+ def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None,
35
+ reduction='mean', **kwargs):
36
+ super().__init__(**kwargs)
37
+ self.activation = Activation(activation)
38
+ self.alpha = alpha
39
+ self.beta = beta
40
+ self.ignore_channels = ignore_channels
41
+ self.reduction = reduction
42
+
43
+ def forward(self, inputs, targets):
44
+ if self.ignore_channels is not None:
45
+ mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
46
+ mask[self.ignore_channels] = False
47
+ inputs = inputs[:, mask, ...]
48
+
49
+ num_classes = inputs.shape[1]
50
+ inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1
51
+ else F.softmax(inputs, dim=1))
52
+
53
+ if num_classes == 1:
54
+ inputs_softmax = inputs_softmax.squeeze(1)
55
+ targets = targets.squeeze(1)
56
+
57
+ tversky_loss = 0
58
+ for class_idx in range(num_classes):
59
+ if num_classes == 1:
60
+ flat_inputs = inputs_softmax.reshape(-1)
61
+ flat_targets = targets.reshape(-1)
62
+ else:
63
+ flat_inputs = inputs_softmax[:, class_idx].reshape(-1)
64
+ flat_targets = targets[:, class_idx].reshape(-1)
65
+
66
+ intersection = (flat_inputs * flat_targets).sum()
67
+ fps = ((1 - flat_targets) * flat_inputs).sum()
68
+ fns = (flat_targets * (1 - flat_inputs)).sum()
69
+
70
+ tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10
71
+ tversky_loss += 1 - intersection / tversky_index
72
+
73
+ if self.reduction == 'mean':
74
+ return tversky_loss / (1 if num_classes == 1 else num_classes)
75
+ elif self.reduction == 'sum':
76
+ return tversky_loss
77
+ return tversky_loss / inputs.shape[0]
78
+
79
+ class EnhancedCrossEntropy(base.Loss):
80
+ def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs):
81
+ super().__init__(**kwargs)
82
+ self.activation = Activation(activation)
83
+ self.ignore_channels = ignore_channels
84
+ self.reduction = reduction
85
+
86
+ def forward(self, inputs, targets):
87
+ inputs = self.activation(inputs)
88
+
89
+ if self.ignore_channels is not None:
90
+ mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
91
+ mask[self.ignore_channels] = False
92
+ inputs = inputs[:, mask, ...]
93
+
94
+ if targets.dim() == 4: # Convert one-hot to class indices
95
+ targets = torch.argmax(targets, dim=1)
96
+
97
+ return F.cross_entropy(inputs, targets, reduction=self.reduction)
semantic-segmentation/SemanticModel/data_loader.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from torch.utils.data import Dataset as BaseDataset
5
+
6
+ class SegmentationDataset(BaseDataset):
7
+ """Dataset class for semantic segmentation task."""
8
+
9
+ def __init__(self, data_dir, classes=['background', 'object'],
10
+ augmentation=None, preprocessing=None):
11
+
12
+ self.image_dir = os.path.join(data_dir, 'Images')
13
+ self.mask_dir = os.path.join(data_dir, 'Masks')
14
+
15
+ for dir_path in [self.image_dir, self.mask_dir]:
16
+ if not os.path.exists(dir_path):
17
+ raise FileNotFoundError(f"Directory not found: {dir_path}")
18
+
19
+ self.filenames = self._get_filenames()
20
+ self.image_paths = [os.path.join(self.image_dir, fname) for fname in self.filenames]
21
+ self.mask_paths = self._get_mask_paths()
22
+
23
+ self.target_classes = [cls for cls in classes if cls.lower() != 'background']
24
+ self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
25
+
26
+ self.augmentation = augmentation
27
+ self.preprocessing = preprocessing
28
+
29
+ def __getitem__(self, index):
30
+ image = self._load_image(self.image_paths[index])
31
+ mask = self._load_mask(self.mask_paths[index])
32
+
33
+ if self.augmentation:
34
+ processed = self.augmentation(image=image, mask=mask)
35
+ image, mask = processed['image'], processed['mask']
36
+
37
+ if self.preprocessing:
38
+ processed = self.preprocessing(image=image, mask=mask)
39
+ image, mask = processed['image'], processed['mask']
40
+
41
+ return image, mask
42
+
43
+ def __len__(self):
44
+ return len(self.filenames)
45
+
46
+ def _get_filenames(self):
47
+ """Returns sorted list of filenames, excluding directories."""
48
+ files = sorted(os.listdir(self.image_dir))
49
+ return [f for f in files if not os.path.isdir(os.path.join(self.image_dir, f))]
50
+
51
+ def _get_mask_paths(self):
52
+ """Generates corresponding mask paths for each image."""
53
+ mask_paths = []
54
+ for image_file in self.filenames:
55
+ name, _ = os.path.splitext(image_file)
56
+ mask_paths.append(os.path.join(self.mask_dir, f"{name}.png"))
57
+ return mask_paths
58
+
59
+ def _load_image(self, path):
60
+ """Loads and converts image to RGB."""
61
+ image = cv2.imread(path)
62
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
63
+
64
+ def _load_mask(self, path):
65
+ """Loads and processes segmentation mask."""
66
+ mask = cv2.imread(path, 0)
67
+ masks = [(mask == value) for value in self.class_values]
68
+ mask = np.stack(masks, axis=-1).astype('float')
69
+ return mask
70
+
71
+ class InferenceDataset(BaseDataset):
72
+ """Dataset class for inference without ground truth masks."""
73
+
74
+ def __init__(self, data_dir, classes=['background', 'object'],
75
+ augmentation=None, preprocessing=None):
76
+ self.filenames = sorted([
77
+ f for f in os.listdir(data_dir)
78
+ if not os.path.isdir(os.path.join(data_dir, f))
79
+ ])
80
+ self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
81
+
82
+ self.target_classes = [cls for cls in classes if cls.lower() != 'background']
83
+ self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
84
+
85
+ self.augmentation = augmentation
86
+ self.preprocessing = preprocessing
87
+
88
+ def __getitem__(self, index):
89
+ image = cv2.imread(self.image_paths[index])
90
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
91
+ original_height, original_width = image.shape[:2]
92
+
93
+ if self.augmentation:
94
+ image = self.augmentation(image=image)['image']
95
+
96
+ if self.preprocessing:
97
+ image = self.preprocessing(image=image)['image']
98
+
99
+ return image, original_height, original_width
100
+
101
+ def __len__(self):
102
+ return len(self.filenames)
103
+
104
+ class StreamingDataset(BaseDataset):
105
+ """Dataset class optimized for video frame processing."""
106
+
107
+ def __init__(self, data_dir, classes=['background', 'object'],
108
+ augmentation=None, preprocessing=None):
109
+ self.filenames = self._get_frame_filenames(data_dir)
110
+ self.image_paths = [os.path.join(data_dir, fname) for fname in self.filenames]
111
+
112
+ self.target_classes = [cls for cls in classes if cls.lower() != 'background']
113
+ self.class_values = [i for i, cls in enumerate(classes) if cls.lower() != 'background']
114
+
115
+ self.augmentation = augmentation
116
+ self.preprocessing = preprocessing
117
+
118
+ def _get_frame_filenames(self, directory):
119
+ """Returns sorted list of frame filenames."""
120
+ files = sorted(os.listdir(directory))
121
+ return [f for f in files if (('frame' in f or 'Image' in f) and
122
+ f.lower().endswith('jpg') and
123
+ not os.path.isdir(os.path.join(directory, f)))]
124
+
125
+ def __getitem__(self, index):
126
+ return InferenceDataset.__getitem__(self, index)
127
+
128
+ def __len__(self):
129
+ return len(self.filenames)
semantic-segmentation/SemanticModel/encoder_management.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ssl
3
+ import shutil
4
+ import tempfile
5
+ import hashlib
6
+ from tqdm import tqdm
7
+ from torch.hub import get_dir
8
+ from urllib.request import urlopen, Request
9
+
10
+ from segmentation_models_pytorch.encoders import (
11
+ resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
12
+ densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
13
+ efficient_net_encoders, mobilenet_encoders, xception_encoders,
14
+ timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
15
+ timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
16
+ timm_gernet_encoders
17
+ )
18
+
19
+ from segmentation_models_pytorch.encoders.timm_universal import TimmUniversalEncoder
20
+
21
+ def initialize_encoders():
22
+ """Initialize dictionary of available encoders."""
23
+ available_encoders = {}
24
+ encoder_modules = [
25
+ resnet_encoders, dpn_encoders, vgg_encoders, senet_encoders,
26
+ densenet_encoders, inceptionresnetv2_encoders, inceptionv4_encoders,
27
+ efficient_net_encoders, mobilenet_encoders, xception_encoders,
28
+ timm_efficientnet_encoders, timm_resnest_encoders, timm_res2net_encoders,
29
+ timm_regnet_encoders, timm_sknet_encoders, timm_mobilenetv3_encoders,
30
+ timm_gernet_encoders
31
+ ]
32
+
33
+ for module in encoder_modules:
34
+ available_encoders.update(module)
35
+
36
+ try:
37
+ import segmentation_models_pytorch
38
+ from packaging import version
39
+ if version.parse(segmentation_models_pytorch.__version__) >= version.parse("0.3.3"):
40
+ from segmentation_models_pytorch.encoders.mix_transformer import mix_transformer_encoders
41
+ from segmentation_models_pytorch.encoders.mobileone import mobileone_encoders
42
+ available_encoders.update(mix_transformer_encoders)
43
+ available_encoders.update(mobileone_encoders)
44
+ except ImportError:
45
+ pass
46
+
47
+ return available_encoders
48
+
49
+ def download_weights(url, destination, hash_prefix=None, show_progress=True):
50
+ """Downloads model weights with progress tracking and verification."""
51
+ ssl._create_default_https_context = ssl._create_unverified_context
52
+
53
+ req = Request(url, headers={"User-Agent": "torch.hub"})
54
+ response = urlopen(req)
55
+ content_length = response.headers.get("Content-Length")
56
+ file_size = int(content_length[0]) if content_length else None
57
+
58
+ destination = os.path.expanduser(destination)
59
+ temp_file = tempfile.NamedTemporaryFile(delete=False, dir=os.path.dirname(destination))
60
+
61
+ try:
62
+ hasher = hashlib.sha256() if hash_prefix else None
63
+
64
+ with tqdm(total=file_size, disable=not show_progress,
65
+ unit='B', unit_scale=True, unit_divisor=1024) as pbar:
66
+ while True:
67
+ buffer = response.read(8192)
68
+ if not buffer:
69
+ break
70
+
71
+ temp_file.write(buffer)
72
+ if hasher:
73
+ hasher.update(buffer)
74
+ pbar.update(len(buffer))
75
+
76
+ temp_file.close()
77
+
78
+ if hasher and hash_prefix:
79
+ digest = hasher.hexdigest()
80
+ if digest[:len(hash_prefix)] != hash_prefix:
81
+ raise RuntimeError(f'Invalid hash value (expected "{hash_prefix}", got "{digest}")')
82
+
83
+ shutil.move(temp_file.name, destination)
84
+
85
+ finally:
86
+ temp_file.close()
87
+ if os.path.exists(temp_file.name):
88
+ os.remove(temp_file.name)
89
+
90
+ def initialize_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
91
+ """Initializes and returns configured encoder."""
92
+ encoders = initialize_encoders()
93
+
94
+ if name.startswith("tu-"):
95
+ name = name[3:]
96
+ return TimmUniversalEncoder(
97
+ name=name,
98
+ in_channels=in_channels,
99
+ depth=depth,
100
+ output_stride=output_stride,
101
+ pretrained=weights is not None,
102
+ **kwargs
103
+ )
104
+
105
+ try:
106
+ encoder_config = encoders[name]
107
+ except KeyError:
108
+ raise KeyError(f"Invalid encoder name '{name}'. Available encoders: {list(encoders.keys())}")
109
+
110
+ encoder_class = encoder_config["encoder"]
111
+ encoder_params = encoder_config["params"]
112
+ encoder_params.update(depth=depth)
113
+
114
+ if weights:
115
+ try:
116
+ weights_config = encoder_config["pretrained_settings"][weights]
117
+ except KeyError:
118
+ raise KeyError(
119
+ f"Invalid weights '{weights}' for encoder '{name}'. "
120
+ f"Available options: {list(encoder_config['pretrained_settings'].keys())}"
121
+ )
122
+
123
+ cache_dir = os.path.join(get_dir(), 'checkpoints')
124
+ os.makedirs(cache_dir, exist_ok=True)
125
+
126
+ weights_file = os.path.basename(weights_config["url"])
127
+ weights_path = os.path.join(cache_dir, weights_file)
128
+
129
+ if not os.path.exists(weights_path):
130
+ print(f'Downloading {weights_file}...')
131
+ download_weights(
132
+ weights_config["url"].replace("https", "http"),
133
+ weights_path
134
+ )
135
+
136
+ return encoder_class(**encoder_params)
semantic-segmentation/SemanticModel/evaluation_utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ from tqdm import tqdm
5
+ from torch.utils.data import DataLoader
6
+ from segmentation_models_pytorch.base.modules import Activation
7
+
8
+ from SemanticModel.data_loader import SegmentationDataset
9
+ from SemanticModel.metrics import compute_mean_iou
10
+ from SemanticModel.image_preprocessing import get_validation_augmentations
11
+
12
+ def evaluate_model(model_config, data_path, image_size=None):
13
+ """Evaluates model performance on a dataset."""
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
17
+
18
+ data_path = os.path.realpath(data_path)
19
+ image_subdir = os.path.join(data_path, 'Images')
20
+ mask_subdir = os.path.join(data_path, 'Masks')
21
+
22
+ if not all(os.path.exists(d) for d in [image_subdir, mask_subdir]):
23
+ raise Exception("Missing required subdirectories: 'Images' and 'Masks'")
24
+
25
+ if not image_size:
26
+ sample_image = cv2.imread(os.path.join(image_subdir, os.listdir(image_subdir)[0]))
27
+ height, width = sample_image.shape[:2]
28
+ image_size = max(height, width)
29
+
30
+ evaluation_dataset = SegmentationDataset(
31
+ data_path,
32
+ classes=classes,
33
+ augmentation=get_validation_augmentations(
34
+ im_width=image_size,
35
+ im_height=image_size,
36
+ fixed_size=False
37
+ ),
38
+ preprocessing=model_config.preprocessing
39
+ )
40
+
41
+ evaluation_loader = DataLoader(
42
+ evaluation_dataset,
43
+ batch_size=1,
44
+ shuffle=False,
45
+ num_workers=2
46
+ )
47
+
48
+ model = model_config.model.to(device)
49
+ model.eval()
50
+
51
+ requires_sigmoid = False
52
+ if model_config.n_classes == 1:
53
+ current_activation = _check_activation_function(model)
54
+ if current_activation != 'Sigmoid':
55
+ requires_sigmoid = True
56
+
57
+ predictions = []
58
+ ground_truth = []
59
+
60
+ print("Evaluating model performance...")
61
+ with torch.no_grad():
62
+ for images, masks in tqdm(evaluation_loader):
63
+ images = images.to(device)
64
+ masks = masks.to(device)
65
+
66
+ outputs = model.forward(images)
67
+
68
+ if model_config.n_classes > 1:
69
+ predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
70
+ ground_truth.extend([gt.cpu().argmax(dim=0) for gt in masks])
71
+ else:
72
+ if requires_sigmoid:
73
+ predictions.extend([
74
+ (torch.sigmoid(p) > 0.5).float().squeeze().cpu()
75
+ for p in outputs
76
+ ])
77
+ else:
78
+ predictions.extend([
79
+ (p > 0.5).float().squeeze().cpu()
80
+ for p in outputs
81
+ ])
82
+ ground_truth.extend([gt.cpu().squeeze() for gt in masks])
83
+
84
+ metrics = compute_mean_iou(
85
+ predictions,
86
+ ground_truth,
87
+ num_labels=len(classes),
88
+ ignore_index=255
89
+ )
90
+
91
+ print("\nEvaluation Results:")
92
+ print(f"Mean IoU: {metrics['mean_iou']:.3f}")
93
+ print("\nPer-class IoU:")
94
+ for idx, iou in enumerate(metrics['per_category_iou']):
95
+ print(f"{classes[idx]}: {iou:.3f}")
96
+
97
+ return metrics
98
+
99
+ def _check_activation_function(model):
100
+ """Checks the activation function used in model's segmentation head."""
101
+ from segmentation_models_pytorch.base.modules import Activation
102
+
103
+ activation_functions = []
104
+ for _, module in model.segmentation_head.named_children():
105
+ if isinstance(module, Activation):
106
+ activation_functions.append(type(module.activation).__name__)
107
+
108
+ return activation_functions[-1] if activation_functions else None
semantic-segmentation/SemanticModel/image_preprocessing.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import albumentations as albu
4
+ from albumentations.augmentations.geometric.resize import LongestMaxSize
5
+
6
+ def round_pixel_dim(dimension: float) -> int:
7
+ """Rounds pixel dimensions consistently."""
8
+ if abs(round(dimension) - dimension) == 0.5:
9
+ return int(2.0 * round(dimension / 2.0))
10
+ return int(round(dimension))
11
+
12
+ def resize_with_padding(image, target_size, stride=32, interpolation=cv2.INTER_LINEAR):
13
+ """Resizes image maintaining aspect ratio and ensures dimensions are stride-compatible."""
14
+ height, width = image.shape[:2]
15
+ max_dimension = max(height, width)
16
+
17
+ if ((height % stride == 0) and (width % stride == 0) and
18
+ (max_dimension <= target_size)):
19
+ return image
20
+
21
+ scale = target_size / float(max(width, height))
22
+ new_dims = tuple(round_pixel_dim(dim * scale) for dim in (height, width))
23
+ new_height, new_width = new_dims
24
+
25
+ new_height = ((new_height // stride + 1) * stride
26
+ if new_height % stride != 0 else new_height)
27
+ new_width = ((new_width // stride + 1) * stride
28
+ if new_width % stride != 0 else new_width)
29
+
30
+ return cv2.resize(image, (new_width, new_height), interpolation=interpolation)
31
+
32
+ class PaddedResize(LongestMaxSize):
33
+ def apply(self, img: np.ndarray, target_size: int = 1024,
34
+ interpolation: int = cv2.INTER_LINEAR, **params) -> np.ndarray:
35
+ return resize_with_padding(img, target_size=target_size, interpolation=interpolation)
36
+
37
+ def get_training_augmentations(width=768, height=576):
38
+ """Configures training-time augmentations."""
39
+ target_size = max([width, height])
40
+ transforms = [
41
+ albu.HorizontalFlip(p=0.5),
42
+ albu.ShiftScaleRotate(
43
+ scale_limit=0.5, rotate_limit=90, shift_limit=0.1, p=0.5, border_mode=0),
44
+ albu.PadIfNeeded(min_height=target_size, min_width=target_size, always_apply=True),
45
+ albu.RandomCrop(height=target_size, width=target_size, always_apply=True),
46
+ albu.GaussNoise(p=0.2),
47
+ albu.Perspective(p=0.2),
48
+ albu.OneOf([albu.CLAHE(p=1), albu.RandomGamma(p=1)], p=0.33),
49
+ albu.OneOf([
50
+ albu.Sharpen(p=1),
51
+ albu.Blur(blur_limit=3, p=1),
52
+ albu.MotionBlur(blur_limit=3, p=1)], p=0.33),
53
+ albu.OneOf([
54
+ albu.RandomBrightnessContrast(p=1),
55
+ albu.HueSaturationValue(p=1)], p=0.33),
56
+ ]
57
+ return albu.Compose(transforms)
58
+
59
+ def get_validation_augmentations(width=1920, height=1440, fixed_size=True):
60
+ """Configures validation/inference-time augmentations."""
61
+ if fixed_size:
62
+ transforms = [albu.Resize(height=height, width=width, always_apply=True)]
63
+ return albu.Compose(transforms)
64
+
65
+ target_size = max(width, height)
66
+ transforms = [PaddedResize(max_size=target_size, always_apply=True)]
67
+ return albu.Compose(transforms)
68
+
69
+ def convert_to_tensor(x, **kwargs):
70
+ """Converts image array to PyTorch tensor format."""
71
+ if x.ndim == 2:
72
+ x = np.expand_dims(x, axis=-1)
73
+ return x.transpose(2, 0, 1).astype('float32')
74
+
75
+ def get_preprocessing_pipeline(preprocessing_fn):
76
+ """Builds preprocessing pipeline including normalization and tensor conversion."""
77
+ transforms = [
78
+ albu.Lambda(image=preprocessing_fn),
79
+ albu.Lambda(image=convert_to_tensor, mask=convert_to_tensor),
80
+ ]
81
+ return albu.Compose(transforms)
semantic-segmentation/SemanticModel/metrics.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import numpy as np
3
+
4
+ def compute_intersection_union(prediction, ground_truth, num_classes, ignore_index: bool,
5
+ label_mapping: Optional[Dict[int, int]] = None,
6
+ reduce_labels: bool = False):
7
+ """Computes intersection and union for IoU calculation."""
8
+
9
+ if label_mapping:
10
+ for old_id, new_id in label_mapping.items():
11
+ ground_truth[ground_truth == old_id] = new_id
12
+
13
+ prediction = np.array(prediction)
14
+ ground_truth = np.array(ground_truth)
15
+
16
+ if reduce_labels:
17
+ ground_truth[ground_truth == 0] = 255
18
+ ground_truth = ground_truth - 1
19
+ ground_truth[ground_truth == 254] = 255
20
+
21
+ valid_mask = np.not_equal(ground_truth, ignore_index)
22
+ prediction = prediction[valid_mask]
23
+ ground_truth = ground_truth[valid_mask]
24
+
25
+ intersection_mask = prediction == ground_truth
26
+ intersection = prediction[intersection_mask]
27
+
28
+ area_intersection = np.histogram(intersection, bins=num_classes,
29
+ range=(0, num_classes - 1))[0]
30
+ area_prediction = np.histogram(prediction, bins=num_classes,
31
+ range=(0, num_classes - 1))[0]
32
+ area_ground_truth = np.histogram(ground_truth, bins=num_classes,
33
+ range=(0, num_classes - 1))[0]
34
+ area_union = area_prediction + area_ground_truth - area_intersection
35
+
36
+ return area_intersection, area_union, area_prediction, area_ground_truth
37
+
38
+ def compute_total_intersection_union(predictions, ground_truths, num_classes, ignore_index: bool,
39
+ label_mapping: Optional[Dict[int, int]] = None,
40
+ reduce_labels: bool = False):
41
+ """Computes total intersection and union across all samples."""
42
+
43
+ totals = {
44
+ 'intersection': np.zeros((num_classes,), dtype=np.float64),
45
+ 'union': np.zeros((num_classes,), dtype=np.float64),
46
+ 'prediction': np.zeros((num_classes,), dtype=np.float64),
47
+ 'ground_truth': np.zeros((num_classes,), dtype=np.float64)
48
+ }
49
+
50
+ for pred, gt in zip(predictions, ground_truths):
51
+ intersection, union, pred_area, gt_area = compute_intersection_union(
52
+ pred, gt, num_classes, ignore_index, label_mapping, reduce_labels
53
+ )
54
+ totals['intersection'] += intersection
55
+ totals['union'] += union
56
+ totals['prediction'] += pred_area
57
+ totals['ground_truth'] += gt_area
58
+
59
+ return tuple(totals.values())
60
+
61
+ def compute_mean_iou(predictions, ground_truths, num_classes, ignore_index: bool,
62
+ nan_to_num: Optional[int] = None,
63
+ label_mapping: Optional[Dict[int, int]] = None,
64
+ reduce_labels: bool = False):
65
+ """Computes mean IoU and related metrics."""
66
+
67
+ intersection, union, prediction_area, ground_truth_area = compute_total_intersection_union(
68
+ predictions, ground_truths, num_classes, ignore_index, label_mapping, reduce_labels
69
+ )
70
+
71
+ metrics = {}
72
+
73
+ # Compute overall accuracy
74
+ total_accuracy = intersection.sum() / ground_truth_area.sum()
75
+
76
+ # Compute IoU per class
77
+ iou_per_class = intersection / union
78
+ accuracy_per_class = intersection / ground_truth_area
79
+
80
+ metrics.update({
81
+ "mean_iou": np.nanmean(iou_per_class),
82
+ "mean_accuracy": np.nanmean(accuracy_per_class),
83
+ "overall_accuracy": total_accuracy,
84
+ "per_category_iou": iou_per_class,
85
+ "per_category_accuracy": accuracy_per_class
86
+ })
87
+
88
+ if nan_to_num is not None:
89
+ metrics = {
90
+ metric: np.nan_to_num(value, nan=nan_to_num)
91
+ for metric, value in metrics.items()
92
+ }
93
+
94
+ return metrics
semantic-segmentation/SemanticModel/model_core.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import segmentation_models_pytorch as smp
4
+ from segmentation_models_pytorch import utils
5
+
6
+ from SemanticModel.encoder_management import initialize_encoder
7
+ from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy
8
+ from SemanticModel.image_preprocessing import get_preprocessing_pipeline
9
+
10
+ class SegmentationModel:
11
+ def __init__(self, classes=['background', 'foreground'], architecture='unet',
12
+ encoder='timm-regnety_120', weights='imagenet', loss=None):
13
+ self._initialize_classes(classes)
14
+ self.architecture = architecture
15
+ self.encoder = encoder
16
+ self.weights = weights
17
+ self._setup_loss_function(loss)
18
+ self._initialize_model()
19
+
20
+ def _initialize_classes(self, classes):
21
+ """Sets up class configuration."""
22
+ if len(classes) <= 2:
23
+ self.classes = [c for c in classes if c.lower() != 'background']
24
+ self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background']
25
+ self.background_flag = 'background' in classes
26
+ else:
27
+ self.classes = classes
28
+ self.class_values = list(range(len(classes)))
29
+ self.background_flag = False
30
+ self.n_classes = len(self.classes)
31
+
32
+ def _setup_loss_function(self, loss):
33
+ """Configures model's loss function."""
34
+ if not loss:
35
+ loss = 'bce_with_logits' if self.n_classes > 1 else 'dice'
36
+
37
+ if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']:
38
+ print(f'Invalid loss: {loss}, defaulting to dice')
39
+ loss = 'dice'
40
+
41
+ loss_configs = {
42
+ 'bce_with_logits': {
43
+ 'activation': None,
44
+ 'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss()
45
+ },
46
+ 'dice': {
47
+ 'activation': 'softmax' if self.n_classes > 1 else 'sigmoid',
48
+ 'loss': utils.losses.DiceLoss()
49
+ },
50
+ 'focal': {
51
+ 'activation': None,
52
+ 'loss': FocalLossFunction()
53
+ },
54
+ 'tversky': {
55
+ 'activation': None,
56
+ 'loss': TverskyLossFunction()
57
+ }
58
+ }
59
+
60
+ config = loss_configs[loss.lower()]
61
+ self.activation = config['activation']
62
+ self.loss = config['loss']
63
+ self.loss_name = loss
64
+
65
+ def _initialize_model(self):
66
+ """Initializes the segmentation model architecture."""
67
+ if self.weights.endswith('pth'):
68
+ self._load_pretrained_model()
69
+ else:
70
+ self._create_new_model()
71
+
72
+ def _load_pretrained_model(self):
73
+ """Loads model from pretrained weights."""
74
+ print('Loading pretrained model...')
75
+ self.model = torch.load(self.weights)
76
+ if isinstance(self.model, torch.nn.DataParallel):
77
+ self.model = self.model.module
78
+
79
+ try:
80
+ preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
81
+ self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
82
+ except:
83
+ print('Failed to configure preprocessing. Setting to None.')
84
+ self.preprocessing = None
85
+
86
+ def _create_new_model(self):
87
+ """Creates new model with specified architecture."""
88
+ preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet')
89
+ self.preprocessing = get_preprocessing_pipeline(preprocessing_fn)
90
+ initialize_encoder(name=self.encoder, weights=self.weights)
91
+
92
+ architectures = {
93
+ 'unet': smp.Unet,
94
+ 'unet++': smp.UnetPlusPlus,
95
+ 'deeplabv3': smp.DeepLabV3,
96
+ 'deeplabv3+': smp.DeepLabV3Plus,
97
+ 'fpn': smp.FPN,
98
+ 'linknet': smp.Linknet,
99
+ 'manet': smp.MAnet,
100
+ 'pan': smp.PAN,
101
+ 'pspnet': smp.PSPNet
102
+ }
103
+
104
+ if self.architecture not in architectures:
105
+ raise ValueError(f'Unsupported architecture: {self.architecture}')
106
+
107
+ self.model = architectures[self.architecture](
108
+ encoder_name=self.encoder,
109
+ encoder_weights=self.weights,
110
+ classes=self.n_classes,
111
+ activation=self.activation
112
+ )
113
+
114
+ @property
115
+ def config_data(self):
116
+ """Returns model configuration data."""
117
+ return {
118
+ 'architecture': self.architecture,
119
+ 'encoder': self.encoder,
120
+ 'weights': self.weights,
121
+ 'activation': self.activation,
122
+ 'loss': self.loss_name,
123
+ 'classes': ['background'] + self.classes if self.background_flag else self.classes
124
+ }
125
+
126
+ def list_architectures():
127
+ """Returns available architecture options."""
128
+ return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn',
129
+ 'linknet', 'manet', 'pan', 'pspnet']
semantic-segmentation/SemanticModel/prediction.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import torch
5
+ import imageio
6
+ import tifffile
7
+ import numpy as np
8
+ import slidingwindow
9
+ import rasterio as rio
10
+ import geopandas as gpd
11
+ from shapely.geometry import Polygon
12
+ from rasterio import mask as riomask
13
+ from torch.utils.data import DataLoader
14
+ from SemanticModel.visualization import generate_color_mapping
15
+ from SemanticModel.image_preprocessing import get_validation_augmentations
16
+ from SemanticModel.data_loader import InferenceDataset, StreamingDataset
17
+ from SemanticModel.utilities import calc_image_size, convert_coordinates
18
+
19
+ class PredictionPipeline:
20
+ def __init__(self, model_config, device=None):
21
+ self.config = model_config
22
+ self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ self.classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes
24
+ self.colors = generate_color_mapping(len(self.classes))
25
+ self.model = model_config.model.to(self.device)
26
+ self.model.eval()
27
+
28
+ def _preprocess_image(self, image_path, target_size=None):
29
+ """Preprocesses single image for prediction."""
30
+ image = cv2.imread(image_path)
31
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
32
+ height, width = image.shape[:2]
33
+
34
+ target_size = target_size or max(height, width)
35
+ test_height, test_width = calc_image_size(image, target_size)
36
+
37
+ augmentation = get_validation_augmentations(test_width, test_height)
38
+ image = augmentation(image=image)['image']
39
+ image = self.config.preprocessing(image=image)['image']
40
+
41
+ return image, (height, width)
42
+
43
+ def predict_single_image(self, image_path, target_size=None, output_dir=None,
44
+ format='integer', save_output=True):
45
+ """Generates prediction for a single image."""
46
+ image, original_dims = self._preprocess_image(image_path, target_size)
47
+ x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)
48
+
49
+ with torch.no_grad():
50
+ prediction = self.model.predict(x_tensor)
51
+
52
+ if self.config.n_classes > 1:
53
+ prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
54
+ else:
55
+ prediction = prediction.squeeze().cpu().numpy().round()
56
+
57
+ # Resize to original dimensions if needed
58
+ if prediction.shape[:2] != original_dims:
59
+ prediction = cv2.resize(prediction, original_dims[::-1],
60
+ interpolation=cv2.INTER_NEAREST)
61
+
62
+ prediction = self._format_prediction(prediction, format)
63
+
64
+ if save_output:
65
+ self._save_prediction(prediction, image_path, output_dir, format)
66
+
67
+ return prediction
68
+
69
+ def predict_directory(self, input_dir, target_size=None, output_dir=None,
70
+ fixed_size=True, format='integer'):
71
+ """Generates predictions for all images in directory."""
72
+ output_dir = output_dir or os.path.join(input_dir, 'predictions')
73
+ os.makedirs(output_dir, exist_ok=True)
74
+
75
+ dataset = InferenceDataset(
76
+ input_dir,
77
+ classes=self.classes,
78
+ augmentation=get_validation_augmentations(
79
+ target_size, target_size, fixed_size=fixed_size
80
+ ) if target_size else None,
81
+ preprocessing=self.config.preprocessing
82
+ )
83
+
84
+ total_images = len(dataset)
85
+ start_time = time.time()
86
+
87
+ for idx in range(total_images):
88
+ if (idx + 1) % 10 == 0 or idx == total_images - 1:
89
+ elapsed = time.time() - start_time
90
+ print(f'\rProcessed {idx+1}/{total_images} images in {elapsed:.1f}s',
91
+ end='')
92
+
93
+ image, height, width = dataset[idx]
94
+ filename = dataset.filenames[idx]
95
+
96
+ x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0)
97
+ with torch.no_grad():
98
+ prediction = self.model.predict(x_tensor)
99
+
100
+ if self.config.n_classes > 1:
101
+ prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
102
+ else:
103
+ prediction = prediction.squeeze().cpu().numpy().round()
104
+
105
+ if prediction.shape != (height, width):
106
+ prediction = cv2.resize(prediction, (width, height),
107
+ interpolation=cv2.INTER_NEAREST)
108
+
109
+ prediction = self._format_prediction(prediction, format)
110
+ self._save_prediction(prediction, filename, output_dir, format)
111
+
112
+ print(f'\nPredictions saved to: {output_dir}')
113
+ return output_dir
114
+
115
+ def predict_raster(self, raster_path, tile_size=1024, overlap=0.175,
116
+ boundary_path=None, output_path=None, format='integer'):
117
+ """Processes large raster images using tiling approach."""
118
+ print('Loading raster...')
119
+ with rio.open(raster_path) as src:
120
+ raster = src.read()
121
+ raster = np.moveaxis(raster, 0, 2)[:,:,:3]
122
+ profile = src.profile
123
+ transform = src.transform
124
+
125
+ if boundary_path:
126
+ boundary = gpd.read_file(boundary_path)
127
+ boundary = boundary.to_crs(profile['crs'])
128
+ boundary_geom = boundary.iloc[0].geometry
129
+
130
+ tiles = slidingwindow.generate(
131
+ raster,
132
+ slidingwindow.DimOrder.HeightWidthChannel,
133
+ tile_size,
134
+ overlap
135
+ )
136
+
137
+ pred_raster = np.zeros_like(raster[:,:,0], dtype='uint8')
138
+ confidence = np.zeros_like(pred_raster, dtype=np.float32)
139
+
140
+ aug = get_validation_augmentations(tile_size, tile_size, fixed_size=False)
141
+
142
+ for idx, tile in enumerate(tiles):
143
+ if (idx + 1) % 10 == 0 or idx == len(tiles) - 1:
144
+ print(f'\rProcessed {idx+1}/{len(tiles)} tiles', end='')
145
+
146
+ bounds = tile.indices()
147
+
148
+ tile_image = raster[bounds[0], bounds[1]]
149
+
150
+ if boundary_path:
151
+ corners = [
152
+ convert_coordinates(transform, bounds[1].start, bounds[0].start),
153
+ convert_coordinates(transform, bounds[1].stop, bounds[0].start),
154
+ convert_coordinates(transform, bounds[1].stop, bounds[0].stop),
155
+ convert_coordinates(transform, bounds[1].start, bounds[0].stop)
156
+ ]
157
+ if not Polygon(corners).intersects(boundary_geom):
158
+ continue
159
+
160
+ processed = aug(image=tile_image)['image']
161
+ processed = self.config.preprocessing(image=processed)['image']
162
+
163
+ x_tensor = torch.from_numpy(processed).to(self.device).unsqueeze(0)
164
+ with torch.no_grad():
165
+ prediction = self.model.predict(x_tensor)
166
+ prediction = prediction.squeeze().cpu().numpy()
167
+
168
+ if self.config.n_classes > 1:
169
+ tile_pred = np.argmax(prediction, axis=0)
170
+ tile_conf = np.max(prediction, axis=0)
171
+ else:
172
+ tile_conf = np.abs(prediction - 0.5)
173
+ tile_pred = prediction.round()
174
+
175
+ if tile_pred.shape != tile_image.shape[:2]:
176
+ tile_pred = cv2.resize(tile_pred, tile_image.shape[:2][::-1],
177
+ interpolation=cv2.INTER_NEAREST)
178
+ tile_conf = cv2.resize(tile_conf, tile_image.shape[:2][::-1],
179
+ interpolation=cv2.INTER_LINEAR)
180
+
181
+ # Update prediction and confidence maps
182
+ existing_conf = confidence[bounds[0], bounds[1]]
183
+ existing_pred = pred_raster[bounds[0], bounds[1]]
184
+
185
+ mask = existing_conf < tile_conf
186
+ existing_pred[mask] = tile_pred[mask]
187
+ existing_conf[mask] = tile_conf[mask]
188
+
189
+ pred_raster[bounds[0], bounds[1]] = existing_pred
190
+ confidence[bounds[0], bounds[1]] = existing_conf
191
+
192
+ pred_raster = self._format_prediction(pred_raster, format)
193
+
194
+ if output_path or boundary_path:
195
+ self._save_raster_prediction(
196
+ pred_raster, raster_path, output_path,
197
+ profile, boundary_geom if boundary_path else None
198
+ )
199
+
200
+ return pred_raster, profile
201
+
202
+ def _format_prediction(self, prediction, format):
203
+ """Formats prediction according to specified output type."""
204
+ if format == 'integer':
205
+ return prediction.astype('uint8')
206
+ elif format == 'color':
207
+ return self._apply_color_mapping(prediction)
208
+ else:
209
+ raise ValueError(f"Unsupported format: {format}")
210
+
211
+ def _save_prediction(self, prediction, source_path, output_dir, format):
212
+ """Saves prediction to disk."""
213
+ filename = os.path.splitext(os.path.basename(source_path))[0]
214
+ output_path = os.path.join(output_dir, f"{filename}_pred.png")
215
+ cv2.imwrite(output_path, prediction)
216
+
217
+
218
+ def _save_raster_prediction(self, prediction, source_path, output_path,
219
+ profile, boundary=None):
220
+ """Saves raster prediction with geospatial information."""
221
+ output_path = output_path or source_path.replace(
222
+ os.path.splitext(source_path)[1], '_predicted.tif'
223
+ )
224
+
225
+ profile.update(
226
+ dtype='uint8',
227
+ count=3 if prediction.ndim == 3 else 1
228
+ )
229
+
230
+ with rio.open(output_path, 'w', **profile) as dst:
231
+ if prediction.ndim == 3:
232
+ for i in range(3):
233
+ dst.write(prediction[:,:,i], i+1)
234
+ else:
235
+ dst.write(prediction, 1)
236
+
237
+ if boundary:
238
+ with rio.open(output_path) as src:
239
+ cropped, transform = riomask.mask(src, [boundary], crop=True)
240
+ profile.update(
241
+ height=cropped.shape[1],
242
+ width=cropped.shape[2],
243
+ transform=transform
244
+ )
245
+
246
+ os.remove(output_path)
247
+ with rio.open(output_path, 'w', **profile) as dst:
248
+ dst.write(cropped)
249
+
250
+ print(f'\nPrediction saved to: {output_path}')
251
+
252
+ def predict_video_frames(self, input_dir, target_size=None, output_dir=None):
253
+ """Processes video frames with specialized visualization."""
254
+ output_dir = output_dir or os.path.join(input_dir, 'predictions')
255
+ os.makedirs(output_dir, exist_ok=True)
256
+
257
+ dataset = StreamingDataset(
258
+ input_dir,
259
+ classes=self.classes,
260
+ augmentation=get_validation_augmentations(
261
+ target_size, target_size
262
+ ) if target_size else None,
263
+ preprocessing=self.config.preprocessing
264
+ )
265
+
266
+ image = cv2.imread(dataset.image_paths[0])
267
+ height, width = image.shape[:2]
268
+
269
+ white = 255 * np.ones((height, width))
270
+ black = np.zeros_like(white)
271
+ red = np.dstack((white, black, black))
272
+ blue = np.dstack((black, black, white))
273
+
274
+ # Pre-compute rotated versions
275
+ rotated_red = np.rot90(red)
276
+ rotated_blue = np.rot90(blue)
277
+
278
+ total_frames = len(dataset)
279
+ start_time = time.time()
280
+
281
+ for idx in range(total_frames):
282
+ if (idx + 1) % 10 == 0 or idx == total_frames - 1:
283
+ elapsed = time.time() - start_time
284
+ print(f'\rProcessed {idx+1}/{total_frames} frames in {elapsed:.1f}s', end='')
285
+
286
+ frame, height, width = dataset[idx]
287
+ filename = dataset.filenames[idx]
288
+
289
+ x_tensor = torch.from_numpy(frame).to(self.device).unsqueeze(0)
290
+ with torch.no_grad():
291
+ prediction = self.model.predict(x_tensor)
292
+
293
+ if self.config.n_classes > 1:
294
+ prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0)
295
+ masks = [prediction == i for i in range(1, self.config.n_classes)]
296
+ else:
297
+ prediction = prediction.squeeze().cpu().numpy().round()
298
+ masks = [prediction == 1]
299
+
300
+ if prediction.shape != (height, width):
301
+ prediction = cv2.resize(prediction, (width, height),
302
+ interpolation=cv2.INTER_NEAREST)
303
+
304
+ original = cv2.imread(os.path.join(input_dir, filename))
305
+ original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
306
+
307
+ try:
308
+ for i, mask in enumerate(masks):
309
+ color = red if i == 0 else blue
310
+ rotated_color = rotated_red if i == 0 else rotated_blue
311
+ try:
312
+ original[mask,:] = 0.45*original[mask,:] + 0.55*color[mask,:]
313
+ except:
314
+ original[mask,:] = 0.45*original[mask,:] + 0.55*rotated_color[mask,:]
315
+ except:
316
+ print(f"\nWarning: Error processing frame {filename}")
317
+ continue
318
+
319
+ output_path = os.path.join(output_dir, filename)
320
+ imageio.imwrite(output_path, original, quality=100)
321
+
322
+ print(f'\nProcessed frames saved to: {output_dir}')
323
+ return output_dir
324
+
325
+ def _apply_color_mapping(self, prediction):
326
+ """Applies color mapping to prediction."""
327
+ height, width = prediction.shape
328
+ colored = np.zeros((height, width, 3), dtype='uint8')
329
+
330
+ for i, class_name in enumerate(self.classes):
331
+ if class_name.lower() == 'background':
332
+ continue
333
+ color = self.colors[i]
334
+ colored[prediction == i] = color
335
+
336
+ return colored
semantic-segmentation/SemanticModel/training.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import wandb
5
+ import datetime
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from segmentation_models_pytorch.base.modules import Activation
11
+
12
+ from SemanticModel.data_loader import SegmentationDataset
13
+ from SemanticModel.metrics import compute_mean_iou
14
+ from SemanticModel.image_preprocessing import get_training_augmentations, get_validation_augmentations
15
+ from SemanticModel.utilities import list_images, validate_dimensions
16
+
17
+ class ModelTrainer:
18
+ def __init__(self, model_config, root_dir, epochs=40, train_size=1024,
19
+ val_size=None, workers=2, batch_size=2, learning_rate=1e-4,
20
+ step_count=2, decay_factor=0.8, wandb_config=None,
21
+ optimizer='rmsprop', target_class=None, resume_path=None):
22
+
23
+ self.config = model_config
24
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ self.root_dir = root_dir
26
+ self._initialize_training_params(epochs, train_size, val_size, workers,
27
+ batch_size, learning_rate, step_count,
28
+ decay_factor, optimizer, target_class)
29
+ self._setup_directories()
30
+ self._initialize_datasets()
31
+ self._setup_optimizer()
32
+ self._initialize_tracking()
33
+
34
+ if resume_path:
35
+ self._resume_training(resume_path)
36
+
37
+ def _initialize_training_params(self, epochs, train_size, val_size, workers,
38
+ batch_size, learning_rate, step_count,
39
+ decay_factor, optimizer, target_class):
40
+ self.epochs = epochs
41
+ self.train_size = train_size
42
+ self.val_size = val_size
43
+ self.workers = workers
44
+ self.batch_size = batch_size
45
+ self.learning_rate = learning_rate
46
+ self.step_schedule = self._calculate_step_schedule(epochs, step_count)
47
+ self.decay_factor = decay_factor
48
+ self.optimizer_type = optimizer
49
+ self.target_class = target_class
50
+ self.current_epoch = 1
51
+ self.best_iou = 0.0
52
+ self.best_epoch = 0
53
+ self.classes = ['background'] + self.config.classes if self.config.background_flag else self.config.classes
54
+
55
+ def _setup_directories(self):
56
+ """Verifies and creates necessary directories."""
57
+ self.train_dir = os.path.join(self.root_dir, 'train')
58
+ self.val_dir = os.path.join(self.root_dir, 'val')
59
+
60
+ required_subdirs = ['Images', 'Masks']
61
+ for path in [self.train_dir] + ([self.val_dir] if os.path.exists(self.val_dir) else []):
62
+ for subdir in required_subdirs:
63
+ full_path = os.path.join(path, subdir)
64
+ if not os.path.exists(full_path):
65
+ raise FileNotFoundError(f"Missing directory: {full_path}")
66
+
67
+ def _initialize_datasets(self):
68
+ """Sets up training and validation datasets."""
69
+ self.train_dataset = SegmentationDataset(
70
+ self.train_dir,
71
+ classes=self.classes,
72
+ augmentation=get_training_augmentations(self.train_size, self.train_size),
73
+ preprocessing=self.config.preprocessing
74
+ )
75
+
76
+ if os.path.exists(self.val_dir):
77
+ self.val_dataset = SegmentationDataset(
78
+ self.val_dir,
79
+ classes=self.classes,
80
+ augmentation=get_validation_augmentations(
81
+ self.val_size or self.train_size,
82
+ self.val_size or self.train_size,
83
+ fixed_size=False
84
+ ),
85
+ preprocessing=self.config.preprocessing
86
+ )
87
+ self.val_loader = DataLoader(
88
+ self.val_dataset,
89
+ batch_size=1,
90
+ shuffle=False,
91
+ num_workers=self.workers
92
+ )
93
+ else:
94
+ self.val_dataset = self.train_dataset
95
+ self.val_loader = DataLoader(
96
+ self.val_dataset,
97
+ batch_size=1,
98
+ shuffle=False,
99
+ num_workers=self.workers
100
+ )
101
+
102
+ self.train_loader = DataLoader(
103
+ self.train_dataset,
104
+ batch_size=self.batch_size,
105
+ shuffle=True,
106
+ num_workers=self.workers
107
+ )
108
+
109
+ def _setup_optimizer(self):
110
+ """Configures model optimizer."""
111
+ optimizer_map = {
112
+ 'adam': torch.optim.Adam,
113
+ 'sgd': lambda params: torch.optim.SGD(params, momentum=0.9),
114
+ 'rmsprop': torch.optim.RMSprop
115
+ }
116
+ optimizer_class = optimizer_map.get(self.optimizer_type.lower())
117
+ if not optimizer_class:
118
+ raise ValueError(f"Unsupported optimizer: {self.optimizer_type}")
119
+
120
+ self.optimizer = optimizer_class([{'params': self.config.model.parameters(),
121
+ 'lr': self.learning_rate}])
122
+
123
+ def _initialize_tracking(self):
124
+ """Sets up training progress tracking."""
125
+ timestamp = datetime.datetime.now().strftime("%m-%d-%Y_%H%M%S")
126
+ self.output_dir = os.path.join(
127
+ self.root_dir,
128
+ f'model_outputs-{self.config.architecture}[{self.config.encoder}]-{timestamp}'
129
+ )
130
+ os.makedirs(self.output_dir, exist_ok=True)
131
+
132
+ self.writer = SummaryWriter(log_dir=self.output_dir)
133
+ self.metrics = {
134
+ 'best_epoch': self.best_epoch,
135
+ 'best_epoch_iou': self.best_iou,
136
+ 'last_epoch': 0,
137
+ 'last_epoch_iou': 0.0,
138
+ 'last_epoch_lr': self.learning_rate,
139
+ 'step_schedule': self.step_schedule,
140
+ 'decay_factor': self.decay_factor,
141
+ 'target_class': self.target_class or 'overall'
142
+ }
143
+
144
+ def _calculate_step_schedule(self, epochs, steps):
145
+ """Calculates learning rate step schedule."""
146
+ return list(map(int, np.linspace(0, epochs, steps + 2)[1:-1]))
147
+
148
+ def train(self):
149
+ """Executes training loop."""
150
+ model = self.config.model.to(self.device)
151
+ if torch.cuda.device_count() > 1:
152
+ model = torch.nn.DataParallel(model)
153
+ print(f'Using {torch.cuda.device_count()} GPUs')
154
+
155
+ self._save_config()
156
+
157
+ for epoch in range(self.current_epoch, self.epochs + 1):
158
+ print(f'\nEpoch {epoch}/{self.epochs}')
159
+ print(f'Learning rate: {self.optimizer.param_groups[0]["lr"]:.3e}')
160
+
161
+ train_loss = self._train_epoch(model)
162
+ val_loss, val_metrics = self._validate_epoch(model)
163
+
164
+ self._update_tracking(epoch, train_loss, val_loss, val_metrics)
165
+ self._adjust_learning_rate(epoch)
166
+ self._save_checkpoints(model, epoch, val_metrics)
167
+
168
+ print(f'\nTraining completed. Best {self.metrics["target_class"]} IoU: {self.best_iou:.3f}')
169
+ return model, self.metrics
170
+
171
+ def _train_epoch(self, model):
172
+ """Executes single training epoch."""
173
+ model.train()
174
+ total_loss = 0
175
+ sample_count = 0
176
+
177
+ for batch in tqdm(self.train_loader, desc='Training'):
178
+ images, masks = [x.to(self.device) for x in batch]
179
+ self.optimizer.zero_grad()
180
+
181
+ outputs = model(images)
182
+ loss = self.config.loss(outputs, masks)
183
+ loss.backward()
184
+ self.optimizer.step()
185
+
186
+ total_loss += loss.item() * len(images)
187
+ sample_count += len(images)
188
+
189
+ return total_loss / sample_count
190
+
191
+ def _validate_epoch(self, model):
192
+ """Executes validation pass."""
193
+ model.eval()
194
+ total_loss = 0
195
+ predictions = []
196
+ ground_truth = []
197
+
198
+ with torch.no_grad():
199
+ for batch in tqdm(self.val_loader, desc='Validation'):
200
+ images, masks = [x.to(self.device) for x in batch]
201
+ outputs = model(images)
202
+ loss = self.config.loss(outputs, masks)
203
+
204
+ total_loss += loss.item()
205
+
206
+ if self.config.n_classes > 1:
207
+ predictions.extend([p.cpu().argmax(dim=0) for p in outputs])
208
+ ground_truth.extend([m.cpu().argmax(dim=0) for m in masks])
209
+ else:
210
+ predictions.extend([(torch.sigmoid(p) > 0.5).float().squeeze().cpu()
211
+ for p in outputs])
212
+ ground_truth.extend([m.cpu().squeeze() for m in masks])
213
+
214
+ metrics = compute_mean_iou(
215
+ predictions,
216
+ ground_truth,
217
+ num_classes=len(self.classes),
218
+ ignore_index=255
219
+ )
220
+
221
+ return total_loss / len(self.val_loader), metrics
222
+
223
+ def _update_tracking(self, epoch, train_loss, val_loss, val_metrics):
224
+ """Updates training metrics and logging."""
225
+ mean_iou = val_metrics['mean_iou']
226
+ print(f"\nLosses - Train: {train_loss:.3f}, Val: {val_loss:.3f}")
227
+ print(f"Mean IoU: {mean_iou:.3f}")
228
+
229
+ self.writer.add_scalar('Loss/train', train_loss, epoch)
230
+ self.writer.add_scalar('Loss/val', val_loss, epoch)
231
+ self.writer.add_scalar('IoU/mean', mean_iou, epoch)
232
+
233
+ for idx, iou in enumerate(val_metrics['per_category_iou']):
234
+ print(f"{self.classes[idx]} IoU: {iou:.3f}")
235
+ self.writer.add_scalar(f'IoU/{self.classes[idx]}', iou, epoch)
236
+
237
+ def _adjust_learning_rate(self, epoch):
238
+ """Adjusts learning rate according to schedule."""
239
+ if epoch in self.step_schedule:
240
+ current_lr = self.optimizer.param_groups[0]['lr']
241
+ new_lr = current_lr * self.decay_factor
242
+ for param_group in self.optimizer.param_groups:
243
+ param_group['lr'] = new_lr
244
+ print(f'\nDecreased learning rate: {current_lr:.3e} -> {new_lr:.3e}')
245
+
246
+ def _save_checkpoints(self, model, epoch, metrics):
247
+ """Saves model checkpoints and metrics."""
248
+ epoch_iou = (metrics['mean_iou'] if self.target_class is None
249
+ else metrics['per_category_iou'][self.classes.index(self.target_class)])
250
+
251
+ self.metrics.update({
252
+ 'last_epoch': epoch,
253
+ 'last_epoch_iou': round(float(epoch_iou), 3),
254
+ 'last_epoch_lr': self.optimizer.param_groups[0]['lr']
255
+ })
256
+
257
+ if epoch_iou > self.best_iou:
258
+ self.best_iou = epoch_iou
259
+ self.best_epoch = epoch
260
+ self.metrics.update({
261
+ 'best_epoch': epoch,
262
+ 'best_epoch_iou': round(float(epoch_iou), 3),
263
+ 'overall_iou': round(float(metrics['mean_iou']), 3)
264
+ })
265
+ torch.save(model, os.path.join(self.output_dir, 'best_model.pth'))
266
+ print(f'New best model saved (IoU: {epoch_iou:.3f})')
267
+
268
+ torch.save(model, os.path.join(self.output_dir, 'last_model.pth'))
269
+ with open(os.path.join(self.output_dir, 'metrics.json'), 'w') as f:
270
+ json.dump(self.metrics, f, indent=4)
271
+
272
+ def _save_config(self):
273
+ """Saves training configuration."""
274
+ config = {
275
+ **self.config.config_data,
276
+ 'train_size': self.train_size,
277
+ 'val_size': self.val_size,
278
+ 'epochs': self.epochs,
279
+ 'batch_size': self.batch_size,
280
+ 'optimizer': self.optimizer_type,
281
+ 'workers': self.workers,
282
+ 'target_class': self.target_class or 'overall'
283
+ }
284
+
285
+ with open(os.path.join(self.output_dir, 'config.json'), 'w') as f:
286
+ json.dump(config, f, indent=4)
287
+
288
+ def _resume_training(self, resume_path):
289
+ """Resumes training from checkpoint."""
290
+ if not os.path.exists(resume_path):
291
+ raise FileNotFoundError(f"Resume path not found: {resume_path}")
292
+
293
+ required_files = {
294
+ 'model': 'last_model.pth',
295
+ 'metrics': 'metrics.json',
296
+ 'config': 'config.json'
297
+ }
298
+
299
+ paths = {k: os.path.join(resume_path, v) for k, v in required_files.items()}
300
+ if not all(os.path.exists(p) for p in paths.values()):
301
+ raise FileNotFoundError("Missing required checkpoint files")
302
+
303
+ with open(paths['config']) as f:
304
+ config = json.load(f)
305
+ with open(paths['metrics']) as f:
306
+ metrics = json.load(f)
307
+
308
+ self.current_epoch = metrics['last_epoch'] + 1
309
+ self.best_iou = metrics['best_epoch_iou']
310
+ self.best_epoch = metrics['best_epoch']
311
+ self.learning_rate = metrics['last_epoch_lr']
312
+
313
+ print(f'Resuming training from epoch {self.current_epoch}')
semantic-segmentation/SemanticModel/utilities.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import shutil
4
+ import imageio
5
+ import numpy as np
6
+ from glob import glob
7
+ from pathlib import Path
8
+ from typing import List, Tuple, Optional
9
+
10
+ def validate_dimensions(width: int, height: int, stride: int = 32) -> Tuple[int, int]:
11
+ if height % stride != 0 or width % stride != 0:
12
+ new_height = ((height // stride + 1) * stride
13
+ if height % stride != 0 else height)
14
+ new_width = ((width // stride + 1) * stride
15
+ if width % stride != 0 else width)
16
+ print(f'Adjusted dimensions to: {new_height}H x {new_width}W')
17
+ return width, height
18
+
19
+ def calc_image_size(image: np.ndarray, target_size: int) -> Tuple[int, int]:
20
+ height, width = image.shape[:2]
21
+ aspect_ratio = width / height
22
+
23
+ if aspect_ratio >= 1:
24
+ new_width = target_size
25
+ new_height = int(target_size / aspect_ratio)
26
+ else:
27
+ new_height = target_size
28
+ new_width = int(target_size * aspect_ratio)
29
+
30
+ return validate_dimensions(new_width, new_height)
31
+
32
+ def convert_coordinates(transform: np.ndarray, x: float, y: float) -> Tuple[float, float]:
33
+ transformed = transform @ np.array([x, y, 1])
34
+ return transformed[0], transformed[1]
35
+
36
+ def list_images(directory: str, mask_format: bool = False) -> List[str]:
37
+ extensions = ['*.png', '*.PNG'] if mask_format else [
38
+ '*.jpg', '*.jpeg', '*.png', '*.tif', '*.tiff',
39
+ '*.JPG', '*.JPEG', '*.PNG', '*.TIF', '*.TIFF'
40
+ ]
41
+
42
+ image_paths = []
43
+ for ext in extensions:
44
+ image_paths.extend(glob(os.path.join(directory, ext)))
45
+
46
+ return sorted(list(set(image_paths)))
47
+
48
+ def prepare_dataset_split(root_dir: str, train_ratio: float = 0.7,
49
+ generate_empty_masks: bool = False) -> None:
50
+ image_dir = os.path.join(root_dir, 'Images')
51
+ mask_dir = os.path.join(root_dir, 'Masks')
52
+
53
+ if not all(os.path.exists(d) for d in [image_dir, mask_dir]):
54
+ raise Exception("Required 'Images' and 'Masks' directories not found")
55
+
56
+ image_paths = np.array(list_images(image_dir))
57
+ mask_paths = np.array(list_images(mask_dir, mask_format=True))
58
+
59
+ if generate_empty_masks:
60
+ temp_dir = os.path.join(mask_dir, 'temp')
61
+ create_empty_masks(image_dir, outdir=temp_dir)
62
+
63
+ for mask_path in list_images(temp_dir, mask_format=True):
64
+ target_path = os.path.join(mask_dir, os.path.basename(mask_path))
65
+ if not os.path.exists(target_path):
66
+ shutil.move(mask_path, target_path)
67
+
68
+ shutil.rmtree(temp_dir)
69
+ mask_paths = np.array(list_images(mask_dir, mask_format=True))
70
+
71
+ if len(image_paths) != len(mask_paths):
72
+ raise Exception(f"Unmatched images ({len(image_paths)}) and masks ({len(mask_paths)})")
73
+
74
+ train_ratio = float(train_ratio)
75
+ if not (0 < train_ratio <= 1):
76
+ raise ValueError(f"Invalid train ratio: {train_ratio}")
77
+
78
+ train_size = int(np.floor(train_ratio * len(image_paths)))
79
+ indices = np.random.permutation(len(image_paths))
80
+
81
+ splits = {
82
+ 'train': {'indices': indices[:train_size]},
83
+ 'val': {'indices': indices[train_size:]} if train_ratio < 1 else None
84
+ }
85
+
86
+ for split_name, split_data in splits.items():
87
+ if split_data is None:
88
+ continue
89
+
90
+ split_dir = os.path.join(root_dir, split_name)
91
+ for subdir in ['Images', 'Masks']:
92
+ subdir_path = os.path.join(split_dir, subdir)
93
+ os.makedirs(subdir_path, exist_ok=True)
94
+
95
+ sources = image_paths if subdir == 'Images' else mask_paths
96
+ for idx in split_data['indices']:
97
+ source = sources[idx]
98
+ destination = os.path.join(subdir_path, os.path.basename(source))
99
+ shutil.copyfile(source, destination)
100
+
101
+ print(f"Created {split_name} split with {len(split_data['indices'])} samples")
102
+
103
+ def create_empty_masks(image_dir: str, pixel_value: int = 0,
104
+ outdir: Optional[str] = None) -> str:
105
+ outdir = outdir or os.path.join(image_dir, 'Masks')
106
+ os.makedirs(outdir, exist_ok=True)
107
+
108
+ image_paths = list_images(image_dir)
109
+ print(f"Generating {len(image_paths)} empty masks...")
110
+
111
+ for image_path in image_paths:
112
+ image = imageio.imread(image_path)
113
+ mask = np.full((image.shape[0], image.shape[1]), pixel_value, dtype='uint8')
114
+
115
+ output_path = os.path.join(outdir,
116
+ f"{Path(image_path).stem}.png")
117
+ imageio.imwrite(output_path, mask)
118
+
119
+ return outdir
semantic-segmentation/SemanticModel/visualization.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+
6
+ def plot_predictions(model, images, masks, device, num_samples=4):
7
+ """Visualize model predictions against ground truth."""
8
+ with torch.no_grad():
9
+ model.eval()
10
+ predictions = model.predict(images.to(device))
11
+
12
+ fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
13
+
14
+ for idx in range(num_samples):
15
+ # Original image
16
+ img = images[idx].permute(1, 2, 0).cpu().numpy()
17
+ axes[idx, 0].imshow(img)
18
+ axes[idx, 0].set_title('Original Image')
19
+
20
+ # Ground truth
21
+ truth = masks[idx].argmax(dim=0).cpu().numpy()
22
+ axes[idx, 1].imshow(truth, cmap='tab20')
23
+ axes[idx, 1].set_title('Ground Truth')
24
+
25
+ # Prediction
26
+ pred = predictions[idx].argmax(dim=0).cpu().numpy()
27
+ axes[idx, 2].imshow(pred, cmap='tab20')
28
+ axes[idx, 2].set_title('Prediction')
29
+
30
+ for ax in axes[idx]:
31
+ ax.axis('off')
32
+
33
+ plt.tight_layout()
34
+ return fig
35
+
36
+ def create_overlay_mask(image, mask, alpha=0.5, color_map=None):
37
+ """Create transparent overlay of segmentation mask on image."""
38
+ if color_map is None:
39
+ color_map = {
40
+ 0: [0, 0, 0], # background
41
+ 1: [255, 0, 0], # class 1 (red)
42
+ 2: [0, 255, 0], # class 2 (green)
43
+ 3: [0, 0, 255], # class 3 (blue)
44
+ }
45
+
46
+ overlay = image.copy()
47
+ mask_colored = np.zeros_like(image)
48
+
49
+ for label, color in color_map.items():
50
+ mask_colored[mask == label] = color
51
+
52
+ cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay)
53
+ return overlay
54
+
55
+ def plot_training_history(history):
56
+ """Plot training and validation metrics."""
57
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
58
+
59
+ # Loss plot
60
+ ax1.plot(history['train_loss'], label='Training Loss')
61
+ ax1.plot(history['val_loss'], label='Validation Loss')
62
+ ax1.set_xlabel('Epoch')
63
+ ax1.set_ylabel('Loss')
64
+ ax1.set_title('Training and Validation Loss')
65
+ ax1.legend()
66
+
67
+ # IoU plot
68
+ ax2.plot(history['mean_iou'], label='Mean IoU')
69
+ for class_name, ious in history['class_ious'].items():
70
+ ax2.plot(ious, label=f'{class_name} IoU')
71
+ ax2.set_xlabel('Epoch')
72
+ ax2.set_ylabel('IoU')
73
+ ax2.set_title('IoU Metrics')
74
+ ax2.legend()
75
+
76
+ plt.tight_layout()
77
+ return fig
78
+
79
+ def visualize_predictions_on_batch(model, batch_images, batch_size=8):
80
+ """Create grid visualization for a batch of predictions."""
81
+ with torch.no_grad():
82
+ predictions = model.predict(batch_images)
83
+
84
+ fig = plt.figure(figsize=(15, 5))
85
+ for idx in range(min(batch_size, len(batch_images))):
86
+ plt.subplot(2, 4, idx + 1)
87
+ img = batch_images[idx].permute(1, 2, 0).cpu().numpy()
88
+ mask = predictions[idx].argmax(dim=0).cpu().numpy()
89
+ overlay = create_overlay_mask(img, mask)
90
+ plt.imshow(overlay)
91
+ plt.axis('off')
92
+
93
+ plt.tight_layout()
94
+ return fig
95
+
96
+ def save_visualization(fig, save_path):
97
+ """Save visualization figure."""
98
+ fig.savefig(save_path, bbox_inches='tight', dpi=300)
99
+ plt.close(fig)
100
+
101
+ def generate_color_mapping(num_classes):
102
+ """Generate distinct colors for segmentation classes."""
103
+ colors = [
104
+ [0, 0, 0], # Background (black)
105
+ [255, 0, 0], # Red
106
+ [0, 255, 0], # Green
107
+ [0, 0, 255], # Blue
108
+ [255, 255, 0], # Yellow
109
+ [255, 0, 255], # Magenta
110
+ [0, 255, 255], # Cyan
111
+ [128, 0, 0], # Dark Red
112
+ [0, 128, 0], # Dark Green
113
+ [0, 0, 128] # Dark Blue
114
+ ]
115
+ return colors[:num_classes]
setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # setup.py
2
+ from setuptools import setup, find_packages
3
+
4
+ setup(
5
+ name="SemanticModel",
6
+ version="0.1.0",
7
+ description="Deep learning framework for semantic segmentation",
8
+ author="Your Name",
9
+ packages=find_packages(),
10
+ python_requires=">=3.8",
11
+ install_requires=[
12
+ 'torch',
13
+ 'torchvision',
14
+ 'tensorboard',
15
+ 'pyproj',
16
+ 'fiona==1.8.20',
17
+ 'rtree',
18
+ 'geopandas',
19
+ 'rasterio',
20
+ 'slidingwindow',
21
+ 'opencv-python',
22
+ 'wandb',
23
+ 'tifffile',
24
+ 'imagecodecs',
25
+ 'albumentations',
26
+ 'segmentation-models-pytorch>=0.3.3'
27
+ ],
28
+ classifiers=[
29
+ "Development Status :: 3 - Alpha",
30
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
31
+ "License :: OSI Approved :: MIT License",
32
+ "Programming Language :: Python :: 3.8",
33
+ ],
34
+ )