TerraTorch
File size: 2,504 Bytes
6a5609b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# lightning.pytorch==2.4.0
seed_everything: 2
trainer:
  logger: true
  max_epochs: 100
  log_every_n_steps: 1
  callbacks:
    - class_path: EarlyStopping
      init_args:
        monitor: val/loss
        patience: 15
    - class_path: LearningRateMonitor
      init_args:
        logging_interval: epoch
  enable_progress_bar: false
  precision: bf16-mixed

model:
  class_path: terratorch.tasks.SemanticSegmentationTask
  init_args:
    model_factory: EncoderDecoderFactory
    model_args:
      backbone: prithvi_eo_v2_300
      backbone_pretrained: true
      backbone_bands: ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"]
      necks:
        - name: SelectIndices
          indices: [5, 11, 17, 23]
        - name: ReshapeTokensToImage
        - name: LearnedInterpolateToPyramidal
      decoder: UNetDecoder
      decoder_channels: [512, 256, 128, 64]
      num_classes: 2
    loss: ce
    ignore_index: -1
    freeze_backbone: false
    plot_on_val: false
    class_names: [Not burned, Burn scar]

optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 1.e-4
lr_scheduler:
  class_path: ReduceLROnPlateau
  init_args:
    monitor: val/loss
    factor: 0.5
    patience: 4

data:
  class_path: GenericNonGeoSegmentationDataModule
  init_args:
    batch_size: 8
    num_workers: 8
    dataset_bands:  # Dataset bands
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
    output_bands: # Model input bands
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
    rgb_indices:
      - 2
      - 1
      - 0
    train_data_root: hls_burn_scars/data
    val_data_root: hls_burn_scars/data
    test_data_root: hls_burn_scars/data
    train_split: hls_burn_scars/splits/train.txt
    val_split: hls_burn_scars/splits/val.txt
    test_split: hls_burn_scars/splits/test.txt
    img_grep: "*_merged.tif"
    label_grep: "*.mask.tif"
    means:
      -  0.033349706741586264
      -  0.05701185520536176
      -  0.05889748132001316
      -  0.2323245113436119
      -  0.1972854853760658
      -  0.11944914225186566
    stds:
      -  0.02269135568823774
      -  0.026807560223070237
      -  0.04004109844362779
      -  0.07791732423672691
      -  0.08708738838140137
      -  0.07241979477437814
    num_classes: 2
    train_transform:
      - class_path: albumentations.D4
      - class_path: ToTensorV2
    test_transform:
      - class_path: ToTensorV2

    no_data_replace: 0
    no_label_replace: -1