vrk05 commited on
Commit
62e8869
·
1 Parent(s): 493f0d3
Files changed (1) hide show
  1. sen1floods11_config.py +287 -0
sen1floods11_config.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ # base options
5
+ dist_params = dict(backend='nccl')
6
+ log_level = 'INFO'
7
+ load_from = None
8
+ resume_from = None
9
+ cudnn_benchmark = True
10
+
11
+ custom_imports = dict(imports=["geospatial_fm"])
12
+
13
+ data_root = "/home"
14
+
15
+ dataset_type = "GeospatialDataset"
16
+ num_classes = 2
17
+ num_frames = 1
18
+ img_size = 224
19
+ num_workers = 2
20
+ samples_per_gpu = 4
21
+ CLASSES = (0, 1)
22
+
23
+ img_norm_cfg = dict(means=[0.14245495, 0.13921481, 0.12434631, 0.31420089, 0.20743526, 0.12046503],
24
+ stds=[0.04036231, 0.04186983, 0.05267646, 0.0822221, 0.06834774, 0.05294205])
25
+
26
+
27
+ bands = [1, 2, 3, 8, 11, 12]
28
+ tile_size = img_size
29
+ orig_nsize = 512
30
+ crop_size = (tile_size, tile_size)
31
+
32
+ img_dir = data_root + "/files/S1/"
33
+ ann_dir = data_root + "/files/Labels/"
34
+ img_suffix = f"_S1Hand.tif"
35
+ seg_map_suffix = f"_LabelHand.tif"
36
+
37
+ splits = {
38
+ "train": "/home/flood_train_data.csv",
39
+ "val": "/home/flood_val_data.csv",
40
+ "test": "/home/flood_test_data.csv",
41
+ }
42
+ splits = {k: os.path.abspath(v) for (k, v) in splits.items()}
43
+
44
+
45
+ ignore_index = 2
46
+ label_nodata = -1
47
+ image_nodata = -9999
48
+ image_nodata_replace = 0
49
+ constant = 0.0001
50
+
51
+ # Model
52
+ # TO BE DEFINED BY USER: path to pretrained backbone weights
53
+ pretrained_weights_path = "/home/Prithvi_100M.pt"
54
+ num_layers = 12
55
+ patch_size = 16
56
+ embed_dim = 768
57
+ num_heads = 12
58
+ tubelet_size = 1
59
+
60
+
61
+ epochs = 30
62
+ eval_epoch_interval = 5
63
+ experiment = "/home/output"
64
+ save_path = experiment
65
+ train_pipeline = [
66
+ dict(
67
+ type="LoadGeospatialImageFromFile",
68
+ to_float32=False,
69
+ nodata=image_nodata,
70
+ nodata_replace=image_nodata_replace
71
+ ),
72
+ dict(
73
+ type="LoadGeospatialAnnotations",
74
+ reduce_zero_label=False,
75
+ nodata=label_nodata,
76
+ nodata_replace=ignore_index,
77
+ ),
78
+ dict(type="BandsExtract", bands=bands),
79
+ dict(type="ConstantMultiply", constant=constant),
80
+ dict(type="RandomFlip", prob=0.5),
81
+ dict(type="ToTensor", keys=["img", "gt_semantic_seg"]),
82
+ # to channels first
83
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
84
+ dict(type="TorchNormalize", **img_norm_cfg),
85
+ dict(type="TorchRandomCrop", crop_size=crop_size),
86
+ dict(
87
+ type="Reshape",
88
+ keys=["img"],
89
+ new_shape=(len(bands), num_frames, tile_size, tile_size),
90
+ ),
91
+ dict(type="Reshape", keys=["gt_semantic_seg"],
92
+ new_shape=(1, tile_size, tile_size)),
93
+ dict(type="CastTensor", keys=[
94
+ "gt_semantic_seg"], new_type="torch.LongTensor"),
95
+ dict(type="Collect", keys=["img", "gt_semantic_seg"]),
96
+ ]
97
+
98
+ test_pipeline = [
99
+ dict(
100
+ type="LoadGeospatialImageFromFile",
101
+ to_float32=False,
102
+ nodata=image_nodata,
103
+ nodata_replace=image_nodata_replace
104
+ ),
105
+ dict(type="BandsExtract", bands=bands),
106
+ dict(type="ConstantMultiply", constant=constant),
107
+ dict(type="ToTensor", keys=["img"]),
108
+ # to channels first
109
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
110
+ dict(type="TorchNormalize", **img_norm_cfg),
111
+ dict(
112
+ type="Reshape",
113
+ keys=["img"],
114
+ new_shape=(len(bands), num_frames, -1, -1),
115
+ look_up={'2': 1, '3': 2}
116
+ ),
117
+ dict(type="CastTensor", keys=["img"], new_type="torch.FloatTensor"),
118
+ dict(
119
+ type="CollectTestList",
120
+ keys=["img"],
121
+ meta_keys=[
122
+ "img_info",
123
+ "seg_fields",
124
+ "img_prefix",
125
+ "seg_prefix",
126
+ "filename",
127
+ "ori_filename",
128
+ "img",
129
+ "img_shape",
130
+ "ori_shape",
131
+ "pad_shape",
132
+ "scale_factor",
133
+ "img_norm_cfg",
134
+ ],
135
+ ),
136
+ ]
137
+
138
+
139
+ data = dict(
140
+ samples_per_gpu=samples_per_gpu,
141
+ workers_per_gpu=num_workers,
142
+ train=dict(
143
+ type=dataset_type,
144
+ CLASSES=CLASSES,
145
+ data_root=data_root,
146
+ img_dir=img_dir,
147
+ ann_dir=ann_dir,
148
+ img_suffix=img_suffix,
149
+ seg_map_suffix=seg_map_suffix,
150
+ pipeline=train_pipeline,
151
+ ignore_index=ignore_index,
152
+ split=splits["train"],
153
+ ),
154
+ val=dict(
155
+ type=dataset_type,
156
+ CLASSES=CLASSES,
157
+ data_root=data_root,
158
+ img_dir=img_dir,
159
+ ann_dir=ann_dir,
160
+ img_suffix=img_suffix,
161
+ seg_map_suffix=seg_map_suffix,
162
+ pipeline=test_pipeline,
163
+ ignore_index=ignore_index,
164
+ split=splits["val"],
165
+ gt_seg_map_loader_cfg=dict(
166
+ nodata=label_nodata, nodata_replace=ignore_index)
167
+ ),
168
+ test=dict(
169
+ type=dataset_type,
170
+ CLASSES=CLASSES,
171
+ data_root=data_root,
172
+ img_dir=img_dir,
173
+ ann_dir=ann_dir,
174
+ img_suffix=img_suffix,
175
+ seg_map_suffix=seg_map_suffix,
176
+ pipeline=test_pipeline,
177
+ ignore_index=ignore_index,
178
+ split=splits["test"],
179
+ gt_seg_map_loader_cfg=dict(
180
+ nodata=label_nodata, nodata_replace=ignore_index),
181
+ ),
182
+ )
183
+
184
+ optimizer = dict(type="SGD", lr=6e-5, weight_decay=0.05)
185
+ optimizer_config = dict(grad_clip=None)
186
+ lr_config = dict(
187
+ policy="poly",
188
+ warmup="linear",
189
+ warmup_iters=1500,
190
+ warmup_ratio=1e-6,
191
+ power=1.0,
192
+ min_lr=0.0,
193
+ by_epoch=False,
194
+ )
195
+
196
+ log_config = dict(
197
+ interval=10,
198
+ hooks=[
199
+ dict(type='TextLoggerHook', by_epoch=True),
200
+ dict(type='TensorboardLoggerHook', by_epoch=True),
201
+ ])
202
+
203
+ checkpoint_config = dict(
204
+ by_epoch=True, interval=10, out_dir=save_path
205
+ )
206
+
207
+ evaluation = dict(
208
+ interval=eval_epoch_interval, metric="mIoU", pre_eval=True, save_best="mIoU", by_epoch=True
209
+ )
210
+
211
+
212
+ runner = dict(type="EpochBasedRunner", max_epochs=epochs)
213
+
214
+ workflow = [("train", 1), ("val", 1)]
215
+
216
+ norm_cfg = dict(type="BN", requires_grad=True)
217
+
218
+ ce_weights = [0.3, 0.7]
219
+
220
+
221
+ model = dict(
222
+ type="TemporalEncoderDecoder",
223
+ frozen_backbone=False,
224
+ backbone=dict(
225
+ type="TemporalViTEncoder",
226
+ pretrained=pretrained_weights_path,
227
+ img_size=img_size,
228
+ patch_size=patch_size,
229
+ num_frames=num_frames,
230
+ tubelet_size=1,
231
+ in_chans=len(bands),
232
+ embed_dim=embed_dim,
233
+ depth=num_layers,
234
+ num_heads=num_heads,
235
+ mlp_ratio=4.0,
236
+ norm_pix_loss=False,
237
+ ),
238
+ neck=dict(
239
+ type="ConvTransformerTokensToEmbeddingNeck",
240
+ embed_dim=num_frames*embed_dim,
241
+ output_embed_dim=embed_dim,
242
+ drop_cls_token=True,
243
+ Hp=img_size // patch_size,
244
+ Wp=img_size // patch_size,
245
+ ),
246
+ decode_head=dict(
247
+ num_classes=num_classes,
248
+ in_channels=embed_dim,
249
+ type="FCNHead",
250
+ in_index=-1,
251
+ ignore_index=ignore_index,
252
+ channels=256,
253
+ num_convs=1,
254
+ concat_input=False,
255
+ dropout_ratio=0.1,
256
+ norm_cfg=norm_cfg,
257
+ align_corners=False,
258
+ loss_decode=dict(
259
+ type="CrossEntropyLoss",
260
+ use_sigmoid=False,
261
+ loss_weight=1,
262
+ class_weight=ce_weights,
263
+ ),
264
+ ),
265
+ auxiliary_head=dict(
266
+ num_classes=num_classes,
267
+ in_channels=embed_dim,
268
+ ignore_index=ignore_index,
269
+ type="FCNHead",
270
+ in_index=-1,
271
+ channels=256,
272
+ num_convs=2,
273
+ concat_input=False,
274
+ dropout_ratio=0.1,
275
+ norm_cfg=norm_cfg,
276
+ align_corners=False,
277
+ loss_decode=dict(
278
+ type="CrossEntropyLoss",
279
+ use_sigmoid=False,
280
+ loss_weight=1,
281
+ class_weight=ce_weights,
282
+ ),
283
+ ),
284
+ train_cfg=dict(),
285
+ test_cfg=dict(mode="slide", stride=(int(tile_size/2),
286
+ int(tile_size/2)), crop_size=(tile_size, tile_size)),
287
+ )