NASA-Galileo / src /eval /datasets /breizhcrops.py
openfree's picture
Deploy from GitHub repository
b20c769 verified
import json
import warnings
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from breizhcrops import BreizhCrops
from breizhcrops.datasets.breizhcrops import SELECTED_BANDS
from einops import repeat
from torch.utils.data import ConcatDataset, Dataset
from src.data.config import DATA_FOLDER
from ..preprocess import normalize_bands
LEVEL = "L1C"
DATAPATH = DATA_FOLDER / "breizhcrops"
OUTPUT_BAND_ORDER = [
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B8A",
"B9",
"B10",
"B11",
"B12",
]
INPUT_TO_OUTPUT_BAND_MAPPING = [SELECTED_BANDS[LEVEL].index(b) for b in OUTPUT_BAND_ORDER]
class BreizhCropsDataset(Dataset):
def __init__(
self,
path_to_splits: Path,
split: str,
norm_operation,
augmentation,
partition,
monthly_average: bool = True,
):
"""
https://isprs-archives.copernicus.org/articles/XLIII-B2-2020/1545/2020/
isprs-archives-XLIII-B2-2020-1545-2020.pdf
We partitioned all acquired field parcels
according to the NUTS-3 regions and suggest to subdivide the
dataset into training (FRH01, FRH02), validation (FRH03), and
evaluation (FRH04) subsets based on these spatially distinct
regions.
"""
kwargs = {
"root": path_to_splits,
"preload_ram": False,
"level": LEVEL,
"transform": raw_transform,
}
# belle-ille is small, so its useful for testing
assert split in ["train", "valid", "test", "belle-ile"]
if split == "train":
self.ds: Dataset = ConcatDataset(
[BreizhCrops(region=r, **kwargs) for r in ["frh01", "frh02"]]
)
elif split == "valid":
self.ds = BreizhCrops(region="frh03", **kwargs)
elif split == "test":
self.ds = BreizhCrops(region="frh04", **kwargs)
else:
self.ds = BreizhCrops(region="belle-ile", **kwargs)
self.monthly_average = monthly_average
with (Path(__file__).parents[0] / Path("configs") / Path("breizhcrops.json")).open(
"r"
) as f:
config = json.load(f)
self.band_info = config["band_info"]
self.norm_operation = norm_operation
self.augmentation = augmentation
warnings.warn("Augmentations ignored for time series")
if partition != "default":
raise NotImplementedError(f"partition {partition} not implemented yet")
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
x, y_true, _ = self.ds[idx]
if self.monthly_average:
x = self.average_over_month(x)
eo = normalize_bands(
x[:, INPUT_TO_OUTPUT_BAND_MAPPING], self.norm_operation, self.band_info
)
eo = repeat(eo, "t d -> h w t d", h=1, w=1)
months = x[:, SELECTED_BANDS[LEVEL].index("doa")]
return {"s2": torch.tensor(eo), "months": torch.tensor(months), "target": y_true}
@staticmethod
def average_over_month(x: np.ndarray):
x[:, SELECTED_BANDS[LEVEL].index("doa")] = np.array(
[t.month - 1 for t in pd.to_datetime(x[:, SELECTED_BANDS[LEVEL].index("doa")])]
)
per_month = np.split(
x, np.unique(x[:, SELECTED_BANDS[LEVEL].index("doa")], return_index=True)[1]
)[1:]
return np.array([per_month[idx].mean(axis=0) for idx in range(len(per_month))])
def raw_transform(input_timeseries):
return input_timeseries