Upload 3 files
Browse files- README.md +61 -0
- burn_scars_config.yaml +104 -0
- inference.py +335 -0
README.md
CHANGED
@@ -1,3 +1,64 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
library_name: terratorch
|
4 |
---
|
5 |
+
|
6 |
+
### Model and Inputs
|
7 |
+
The pretrained [Prithvi-EO-2.0-300M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M) model is fine-tuned to segment the extent of burned areas on HLS images from the [HLS Burn Scars dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars).
|
8 |
+
|
9 |
+
The dataset consists of ~800 labeled 512x512 chips from the continental US.
|
10 |
+
|
11 |
+
We use the following six bands for the predictions: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
|
12 |
+
|
13 |
+
Labels represent no burned areas (class 0), burned areas (class 1), and no data/clouds (class -1).
|
14 |
+
|
15 |
+
The Prithvi-EO-2.0-300M model was initially pretrained using a sequence length of 4 timestamps. Based on the characteristics of this benchmark dataset, we focus on single-timestamp segmentation.
|
16 |
+
This demonstrates that our model can be utilized with an arbitrary number of timestamps during fine-tuning.
|
17 |
+
|
18 |
+
### Fine-tuning
|
19 |
+
|
20 |
+
The model was fine-tuned using [TerraTorch](https://github.com/IBM/terratorch):
|
21 |
+
|
22 |
+
```shell
|
23 |
+
terratorch fit -c burn_scars_config.yaml
|
24 |
+
```
|
25 |
+
|
26 |
+
The configuration used for fine-tuning is available in [burn_scars_config.yaml](burn_scars_config.yaml).
|
27 |
+
|
28 |
+
We created new non-overlapping splits for train, validation and test which you find in [splits](splits).
|
29 |
+
The same splits where used in the evaluation in the Prithvi-EO-2.0 paper.
|
30 |
+
Compared to the Prithvi-EO-2.0 paper, we used a UNetDecoder instead of a UperNetDecoder for this model.
|
31 |
+
We repeated the run five times and selected the model with the lowest validation loss over all runs and epochs.
|
32 |
+
Finally, we evaluated the selected model on the test split with the following results:
|
33 |
+
|
34 |
+
| Model | Decoder | test IoU Burned | test mIoU | val IoU Bruned | val mIoU |
|
35 |
+
|:--------------------|:------------|:---------------:|:---------:|:--------------:|:--------:|
|
36 |
+
| Prithvi EO 2.0 300M | UNetDecoder | 87.52 | 93.00 | 84.28 | 90.95 |
|
37 |
+
|
38 |
+
|
39 |
+
### Inference and demo
|
40 |
+
|
41 |
+
A **demo** running this model is available **[here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-BurnScars-demo)**.
|
42 |
+
|
43 |
+
This repo includes an inference script that allows running the model for inference on HLS images.
|
44 |
+
|
45 |
+
```shell
|
46 |
+
python inference.py --data_file examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif
|
47 |
+
```
|
48 |
+
|
49 |
+
### Feedback
|
50 |
+
|
51 |
+
Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by submitting issues on GitHub or start a discussion on HuggingFace.
|
52 |
+
|
53 |
+
### Citation
|
54 |
+
|
55 |
+
If this model helped your research, please cite [Prithvi-EO-2.0](https://arxiv.org/abs/2412.02732) in your publications.
|
56 |
+
|
57 |
+
```
|
58 |
+
@article{Prithvi-EO-V2-preprint,
|
59 |
+
author = {Szwarcman, Daniela and Roy, Sujit and Fraccaro, Paolo and Gíslason, Þorsteinn Elí and Blumenstiel, Benedikt and Ghosal, Rinki and de Oliveira, Pedro Henrique and de Sousa Almeida, João Lucas and Sedona, Rocco and Kang, Yanghui and Chakraborty, Srija and Wang, Sizhe and Kumar, Ankur and Truong, Myscon and Godwin, Denys and Lee, Hyunho and Hsu, Chia-Yu and Akbari Asanjan, Ata and Mujeci, Besart and Keenan, Trevor and Arévolo, Paulo and Li, Wenwen and Alemohammad, Hamed and Olofsson, Pontus and Hain, Christopher and Kennedy, Robert and Zadrozny, Bianca and Cavallaro, Gabriele and Watson, Campbell and Maskey, Manil and Ramachandran, Rahul and Bernabe Moreno, Juan},
|
60 |
+
title = {{Prithvi-EO-2.0: A Versatile Multi-Temporal Foundation Model for Earth Observation Applications}},
|
61 |
+
journal = {arXiv preprint arXiv:2412.02732},
|
62 |
+
year = {2024}
|
63 |
+
}
|
64 |
+
```
|
burn_scars_config.yaml
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# lightning.pytorch==2.4.0
|
2 |
+
seed_everything: 2
|
3 |
+
trainer:
|
4 |
+
logger: true
|
5 |
+
max_epochs: 100
|
6 |
+
log_every_n_steps: 1
|
7 |
+
callbacks:
|
8 |
+
- class_path: EarlyStopping
|
9 |
+
init_args:
|
10 |
+
monitor: val/loss
|
11 |
+
patience: 15
|
12 |
+
- class_path: LearningRateMonitor
|
13 |
+
init_args:
|
14 |
+
logging_interval: epoch
|
15 |
+
enable_progress_bar: false
|
16 |
+
precision: bf16-mixed
|
17 |
+
|
18 |
+
model:
|
19 |
+
class_path: terratorch.tasks.SemanticSegmentationTask
|
20 |
+
init_args:
|
21 |
+
model_factory: EncoderDecoderFactory
|
22 |
+
model_args:
|
23 |
+
backbone: prithvi_eo_v2_300
|
24 |
+
backbone_pretrained: true
|
25 |
+
backbone_bands: ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"]
|
26 |
+
necks:
|
27 |
+
- name: SelectIndices
|
28 |
+
indices: [5, 11, 17, 23]
|
29 |
+
- name: ReshapeTokensToImage
|
30 |
+
- name: LearnedInterpolateToPyramidal
|
31 |
+
decoder: UNetDecoder
|
32 |
+
decoder_channels: [512, 256, 128, 64]
|
33 |
+
num_classes: 2
|
34 |
+
loss: ce
|
35 |
+
ignore_index: -1
|
36 |
+
freeze_backbone: false
|
37 |
+
plot_on_val: false
|
38 |
+
class_names: [Not burned, Burn scar]
|
39 |
+
|
40 |
+
optimizer:
|
41 |
+
class_path: torch.optim.AdamW
|
42 |
+
init_args:
|
43 |
+
lr: 1.e-4
|
44 |
+
lr_scheduler:
|
45 |
+
class_path: ReduceLROnPlateau
|
46 |
+
init_args:
|
47 |
+
monitor: val/loss
|
48 |
+
factor: 0.5
|
49 |
+
patience: 4
|
50 |
+
|
51 |
+
data:
|
52 |
+
class_path: GenericNonGeoSegmentationDataModule
|
53 |
+
init_args:
|
54 |
+
batch_size: 8
|
55 |
+
num_workers: 8
|
56 |
+
dataset_bands: # Dataset bands
|
57 |
+
- BLUE
|
58 |
+
- GREEN
|
59 |
+
- RED
|
60 |
+
- NIR_NARROW
|
61 |
+
- SWIR_1
|
62 |
+
- SWIR_2
|
63 |
+
output_bands: # Model input bands
|
64 |
+
- BLUE
|
65 |
+
- GREEN
|
66 |
+
- RED
|
67 |
+
- NIR_NARROW
|
68 |
+
- SWIR_1
|
69 |
+
- SWIR_2
|
70 |
+
rgb_indices:
|
71 |
+
- 2
|
72 |
+
- 1
|
73 |
+
- 0
|
74 |
+
train_data_root: hls_burn_scars/data
|
75 |
+
val_data_root: hls_burn_scars/data
|
76 |
+
test_data_root: hls_burn_scars/data
|
77 |
+
train_split: hls_burn_scars/splits/train.txt
|
78 |
+
val_split: hls_burn_scars/splits/val.txt
|
79 |
+
test_split: hls_burn_scars/splits/test.txt
|
80 |
+
img_grep: "*_merged.tif"
|
81 |
+
label_grep: "*.mask.tif"
|
82 |
+
means:
|
83 |
+
- 0.033349706741586264
|
84 |
+
- 0.05701185520536176
|
85 |
+
- 0.05889748132001316
|
86 |
+
- 0.2323245113436119
|
87 |
+
- 0.1972854853760658
|
88 |
+
- 0.11944914225186566
|
89 |
+
stds:
|
90 |
+
- 0.02269135568823774
|
91 |
+
- 0.026807560223070237
|
92 |
+
- 0.04004109844362779
|
93 |
+
- 0.07791732423672691
|
94 |
+
- 0.08708738838140137
|
95 |
+
- 0.07241979477437814
|
96 |
+
num_classes: 2
|
97 |
+
train_transform:
|
98 |
+
- class_path: albumentations.D4
|
99 |
+
- class_path: ToTensorV2
|
100 |
+
test_transform:
|
101 |
+
- class_path: ToTensorV2
|
102 |
+
|
103 |
+
no_data_replace: 0
|
104 |
+
no_label_replace: -1
|
inference.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
from typing import List, Union
|
5 |
+
import re
|
6 |
+
import datetime
|
7 |
+
import numpy as np
|
8 |
+
import rasterio
|
9 |
+
import torch
|
10 |
+
import yaml
|
11 |
+
from einops import rearrange
|
12 |
+
from terratorch.cli_tools import LightningInferenceModel
|
13 |
+
|
14 |
+
NO_DATA = -9999
|
15 |
+
NO_DATA_FLOAT = 0.0001
|
16 |
+
OFFSET = 0
|
17 |
+
PERCENTILE = 99
|
18 |
+
|
19 |
+
|
20 |
+
def process_channel_group(orig_img, channels):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
24 |
+
channels: list of indices representing RGB channels.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
torch.Tensor with shape (num_channels, height, width) for original image
|
28 |
+
"""
|
29 |
+
|
30 |
+
orig_img = orig_img[channels, ...]
|
31 |
+
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
32 |
+
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
33 |
+
|
34 |
+
|
35 |
+
# Rescale (enhancing contrast)
|
36 |
+
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
37 |
+
min_value = OFFSET
|
38 |
+
|
39 |
+
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
|
40 |
+
|
41 |
+
# No data as zeros
|
42 |
+
orig_img[~valid_mask] = 0
|
43 |
+
|
44 |
+
return orig_img
|
45 |
+
|
46 |
+
|
47 |
+
def read_geotiff(file_path: str):
|
48 |
+
"""Read all bands from *file_path* and return image + meta info.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
file_path: path to image file.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
np.ndarray with shape (bands, height, width)
|
55 |
+
meta info dict
|
56 |
+
"""
|
57 |
+
|
58 |
+
with rasterio.open(file_path) as src:
|
59 |
+
img = src.read()
|
60 |
+
meta = src.meta
|
61 |
+
try:
|
62 |
+
coords = src.lnglat()
|
63 |
+
except:
|
64 |
+
# Cannot read coords
|
65 |
+
coords = None
|
66 |
+
|
67 |
+
return img, meta, coords
|
68 |
+
|
69 |
+
|
70 |
+
def save_geotiff(image, output_path: str, meta: dict):
|
71 |
+
"""Save multi-band image in Geotiff file.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
image: np.ndarray with shape (bands, height, width)
|
75 |
+
output_path: path where to save the image
|
76 |
+
meta: dict with meta info.
|
77 |
+
"""
|
78 |
+
|
79 |
+
with rasterio.open(output_path, "w", **meta) as dest:
|
80 |
+
for i in range(image.shape[0]):
|
81 |
+
dest.write(image[i, :, :], i + 1)
|
82 |
+
|
83 |
+
return
|
84 |
+
|
85 |
+
|
86 |
+
def _convert_np_uint8(float_image: torch.Tensor):
|
87 |
+
image = float_image.numpy() * 255.0
|
88 |
+
image = image.astype(dtype=np.uint8)
|
89 |
+
|
90 |
+
return image
|
91 |
+
|
92 |
+
|
93 |
+
def load_example(
|
94 |
+
file_paths: List[str],
|
95 |
+
mean: List[float] = None,
|
96 |
+
std: List[float] = None,
|
97 |
+
indices: Union[list[int], None] = None,
|
98 |
+
):
|
99 |
+
"""Build an input example by loading images in *file_paths*.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
file_paths: list of file paths .
|
103 |
+
mean: list containing mean values for each band in the images in *file_paths*.
|
104 |
+
std: list containing std values for each band in the images in *file_paths*.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
np.array containing created example
|
108 |
+
list of meta info for each image in *file_paths*
|
109 |
+
"""
|
110 |
+
|
111 |
+
imgs = []
|
112 |
+
metas = []
|
113 |
+
temporal_coords = []
|
114 |
+
location_coords = []
|
115 |
+
|
116 |
+
for file in file_paths:
|
117 |
+
img, meta, coords = read_geotiff(file)
|
118 |
+
|
119 |
+
# Rescaling (don't normalize on nodata)
|
120 |
+
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
121 |
+
if indices is not None:
|
122 |
+
img = img[..., indices]
|
123 |
+
if mean is not None and std is not None:
|
124 |
+
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
125 |
+
|
126 |
+
imgs.append(img)
|
127 |
+
metas.append(meta)
|
128 |
+
if coords is not None:
|
129 |
+
location_coords.append(coords)
|
130 |
+
|
131 |
+
try:
|
132 |
+
match = re.search(r'(\d{7,8}T\d{6})', file)
|
133 |
+
if match:
|
134 |
+
year = int(match.group(1)[:4])
|
135 |
+
julian_day = match.group(1).split('T')[0][4:]
|
136 |
+
if len(julian_day) == 3:
|
137 |
+
julian_day = int(julian_day)
|
138 |
+
else:
|
139 |
+
julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday
|
140 |
+
temporal_coords.append([year, julian_day])
|
141 |
+
except Exception as e:
|
142 |
+
print(f'Could not extract timestamp for {file} ({e})')
|
143 |
+
|
144 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
145 |
+
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
146 |
+
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
147 |
+
|
148 |
+
return imgs, temporal_coords, location_coords, metas
|
149 |
+
|
150 |
+
|
151 |
+
def run_model(input_data, model, datamodule, img_size):
|
152 |
+
# Reflect pad if not divisible by img_size
|
153 |
+
original_h, original_w = input_data.shape[-2:]
|
154 |
+
pad_h = (img_size - (original_h % img_size)) % img_size
|
155 |
+
pad_w = (img_size - (original_w % img_size)) % img_size
|
156 |
+
input_data = np.pad(
|
157 |
+
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
158 |
+
)
|
159 |
+
|
160 |
+
# Build sliding window
|
161 |
+
|
162 |
+
batch_size = 1
|
163 |
+
batch = torch.tensor(input_data, device="cpu")
|
164 |
+
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
165 |
+
h1, w1 = windows.shape[3:5]
|
166 |
+
windows = rearrange(
|
167 |
+
windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
|
168 |
+
)
|
169 |
+
|
170 |
+
# Split into batches if number of windows > batch_size
|
171 |
+
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
172 |
+
windows = torch.tensor_split(windows, num_batches, dim=0)
|
173 |
+
|
174 |
+
# Run model
|
175 |
+
pred_imgs = []
|
176 |
+
for x in windows:
|
177 |
+
# Apply standardization
|
178 |
+
x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0))
|
179 |
+
x['image'] = x['image'].unsqueeze(0)
|
180 |
+
x = datamodule.aug(x)['image']
|
181 |
+
|
182 |
+
with torch.no_grad():
|
183 |
+
x = x.to(model.device)
|
184 |
+
pred = model(x)
|
185 |
+
pred = pred.output.detach().cpu()
|
186 |
+
|
187 |
+
y_hat = pred.argmax(dim=1)
|
188 |
+
|
189 |
+
y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), size=img_size, mode="nearest")
|
190 |
+
|
191 |
+
pred_imgs.append(y_hat)
|
192 |
+
|
193 |
+
pred_imgs = torch.concat(pred_imgs, dim=0)
|
194 |
+
|
195 |
+
# Build images from patches
|
196 |
+
pred_imgs = rearrange(
|
197 |
+
pred_imgs,
|
198 |
+
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
|
199 |
+
h=img_size,
|
200 |
+
w=img_size,
|
201 |
+
b=1,
|
202 |
+
c=1,
|
203 |
+
h1=h1,
|
204 |
+
w1=w1,
|
205 |
+
)
|
206 |
+
|
207 |
+
# Cut padded area back to original size
|
208 |
+
pred_imgs = pred_imgs[..., :original_h, :original_w]
|
209 |
+
|
210 |
+
# Squeeze (batch size 1)
|
211 |
+
pred_imgs = pred_imgs[0]
|
212 |
+
|
213 |
+
return pred_imgs
|
214 |
+
|
215 |
+
|
216 |
+
def main(
|
217 |
+
data_file: str,
|
218 |
+
config: str,
|
219 |
+
checkpoint: str,
|
220 |
+
output_dir: str,
|
221 |
+
rgb_outputs: bool,
|
222 |
+
input_indices: list[int] = None,
|
223 |
+
):
|
224 |
+
os.makedirs(output_dir, exist_ok=True)
|
225 |
+
|
226 |
+
with open(config, "r") as f:
|
227 |
+
config_dict = yaml.safe_load(f)
|
228 |
+
|
229 |
+
# Load model ---------------------------------------------------------------------------------
|
230 |
+
|
231 |
+
lightning_model = LightningInferenceModel.from_config(config, checkpoint)
|
232 |
+
img_size = 512 # Size of BurnScars
|
233 |
+
|
234 |
+
# Loading data ---------------------------------------------------------------------------------
|
235 |
+
|
236 |
+
input_data, temporal_coords, location_coords, meta_data = load_example(
|
237 |
+
file_paths=[data_file], indices=input_indices,
|
238 |
+
)
|
239 |
+
|
240 |
+
meta_data = meta_data[0] # only one image
|
241 |
+
|
242 |
+
if input_data.mean() > 1:
|
243 |
+
input_data = input_data / 10000 # Convert to range 0-1
|
244 |
+
|
245 |
+
# Running model --------------------------------------------------------------------------------
|
246 |
+
|
247 |
+
lightning_model.model.eval()
|
248 |
+
|
249 |
+
channels = config_dict['data']['init_args']['rgb_indices']
|
250 |
+
|
251 |
+
pred = run_model(input_data, lightning_model.model, lightning_model.datamodule, img_size)
|
252 |
+
|
253 |
+
# Save pred
|
254 |
+
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
255 |
+
pred_file = os.path.join(output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
256 |
+
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
|
257 |
+
|
258 |
+
# Save image + pred
|
259 |
+
meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
260 |
+
|
261 |
+
if input_data.mean() < 1:
|
262 |
+
input_data = input_data * 10000 # Scale to 0-10000
|
263 |
+
|
264 |
+
rgb_orig = process_channel_group(
|
265 |
+
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
|
266 |
+
channels=channels,
|
267 |
+
)
|
268 |
+
|
269 |
+
pred[pred == 0.] = np.nan
|
270 |
+
img_pred = rgb_orig * 0.7 + pred * 0.3
|
271 |
+
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
|
272 |
+
|
273 |
+
img_pred_file = os.path.join(output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
274 |
+
save_geotiff(
|
275 |
+
image=_convert_np_uint8(img_pred),
|
276 |
+
output_path=img_pred_file,
|
277 |
+
meta=meta_data,
|
278 |
+
)
|
279 |
+
|
280 |
+
# Save image rgb
|
281 |
+
if rgb_outputs:
|
282 |
+
rgb_file = os.path.join(output_dir, f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
283 |
+
save_geotiff(
|
284 |
+
image=_convert_np_uint8(rgb_orig),
|
285 |
+
output_path=rgb_file,
|
286 |
+
meta=meta_data,
|
287 |
+
)
|
288 |
+
|
289 |
+
print("Done!")
|
290 |
+
|
291 |
+
|
292 |
+
if __name__ == "__main__":
|
293 |
+
parser = argparse.ArgumentParser("run inference", add_help=False)
|
294 |
+
|
295 |
+
parser.add_argument(
|
296 |
+
"--data_file",
|
297 |
+
type=str,
|
298 |
+
default="examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif",
|
299 |
+
help="Path to the file.",
|
300 |
+
)
|
301 |
+
parser.add_argument(
|
302 |
+
"--config",
|
303 |
+
"-c",
|
304 |
+
type=str,
|
305 |
+
default="burn_scars_config.yaml",
|
306 |
+
help="Path to yaml file containing model parameters.",
|
307 |
+
)
|
308 |
+
parser.add_argument(
|
309 |
+
"--checkpoint",
|
310 |
+
type=str,
|
311 |
+
default="Prithvi_EO_V2_300M_BurnScars.pt",
|
312 |
+
help="Path to a checkpoint file to load from.",
|
313 |
+
)
|
314 |
+
parser.add_argument(
|
315 |
+
"--output_dir",
|
316 |
+
type=str,
|
317 |
+
default="output",
|
318 |
+
help="Path to the directory where to save outputs.",
|
319 |
+
)
|
320 |
+
parser.add_argument(
|
321 |
+
"--input_indices",
|
322 |
+
default=[0,1,2,3,4,5],
|
323 |
+
type=int,
|
324 |
+
nargs="+",
|
325 |
+
help="0-based indices of the six Prithvi channels to be selected from the input. By default selects [0,1,2,3,4,5] for filtered HLS data.",
|
326 |
+
)
|
327 |
+
parser.add_argument(
|
328 |
+
"--rgb_outputs",
|
329 |
+
action="store_true",
|
330 |
+
help="If present, output files will only contain RGB channels. "
|
331 |
+
"Otherwise, all bands will be saved.",
|
332 |
+
)
|
333 |
+
args = parser.parse_args()
|
334 |
+
|
335 |
+
main(**vars(args))
|