improved_sen1flood11_lr

#4
Files changed (2) hide show
  1. README.md +1 -1
  2. sen1floods11_Prithvi_100M.py +5 -3
README.md CHANGED
@@ -28,7 +28,7 @@ We extract the following bands for flood mapping:
28
  5. SWIR 1
29
  6. SWIR 2
30
 
31
- Labels represent no water (class 0), water/flood (class 1), and no data/clouds (class -1).
32
 
33
  The Prithvi-100m model was initially pretrained using a sequence length of 3 timesteps. Based on the characteristics of this benchmark dataset, we focus on single-timestamp segmentation. This demonstrates that our model can be utilized with an arbitrary number of timestamps during finetuning.
34
 
 
28
  5. SWIR 1
29
  6. SWIR 2
30
 
31
+ Labels represent no water (class 0), water/flood (class 1), and no data/clouds (class 2).
32
 
33
  The Prithvi-100m model was initially pretrained using a sequence length of 3 timesteps. Based on the characteristics of this benchmark dataset, we focus on single-timestamp segmentation. This demonstrates that our model can be utilized with an arbitrary number of timestamps during finetuning.
34
 
sen1floods11_Prithvi_100M.py CHANGED
@@ -74,7 +74,8 @@ train_pipeline = [
74
  type="LoadGeospatialImageFromFile",
75
  to_float32=False,
76
  nodata=image_nodata,
77
- nodata_replace=image_nodata_replace
 
78
  ),
79
  dict(
80
  type="LoadGeospatialAnnotations",
@@ -106,7 +107,8 @@ test_pipeline = [
106
  type="LoadGeospatialImageFromFile",
107
  to_float32=False,
108
  nodata=image_nodata,
109
- nodata_replace=image_nodata_replace
 
110
  ),
111
  dict(type="BandsExtract", bands=bands),
112
  dict(type="ConstantMultiply", constant=constant),
@@ -224,9 +226,9 @@ ce_weights = [0.3, 0.7]
224
  model = dict(
225
  type="TemporalEncoderDecoder",
226
  frozen_backbone=False,
227
- pretrained=pretrained_weights_path,
228
  backbone=dict(
229
  type="TemporalViTEncoder",
 
230
  img_size=img_size,
231
  patch_size=patch_size,
232
  num_frames=num_frames,
 
74
  type="LoadGeospatialImageFromFile",
75
  to_float32=False,
76
  nodata=image_nodata,
77
+ nodata_replace=image_nodata_replace,
78
+ channels_last=False
79
  ),
80
  dict(
81
  type="LoadGeospatialAnnotations",
 
107
  type="LoadGeospatialImageFromFile",
108
  to_float32=False,
109
  nodata=image_nodata,
110
+ nodata_replace=image_nodata_replace,
111
+ channels_last=False
112
  ),
113
  dict(type="BandsExtract", bands=bands),
114
  dict(type="ConstantMultiply", constant=constant),
 
226
  model = dict(
227
  type="TemporalEncoderDecoder",
228
  frozen_backbone=False,
 
229
  backbone=dict(
230
  type="TemporalViTEncoder",
231
+ pretrained=pretrained_weights_path,
232
  img_size=img_size,
233
  patch_size=patch_size,
234
  num_frames=num_frames,