Spaces:
Sleeping
Sleeping
import json | |
import warnings | |
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
import torch | |
from breizhcrops import BreizhCrops | |
from breizhcrops.datasets.breizhcrops import SELECTED_BANDS | |
from einops import repeat | |
from torch.utils.data import ConcatDataset, Dataset | |
from src.data.config import DATA_FOLDER | |
from ..preprocess import normalize_bands | |
LEVEL = "L1C" | |
DATAPATH = DATA_FOLDER / "breizhcrops" | |
OUTPUT_BAND_ORDER = [ | |
"B1", | |
"B2", | |
"B3", | |
"B4", | |
"B5", | |
"B6", | |
"B7", | |
"B8", | |
"B8A", | |
"B9", | |
"B10", | |
"B11", | |
"B12", | |
] | |
INPUT_TO_OUTPUT_BAND_MAPPING = [SELECTED_BANDS[LEVEL].index(b) for b in OUTPUT_BAND_ORDER] | |
class BreizhCropsDataset(Dataset): | |
def __init__( | |
self, | |
path_to_splits: Path, | |
split: str, | |
norm_operation, | |
augmentation, | |
partition, | |
monthly_average: bool = True, | |
): | |
""" | |
https://isprs-archives.copernicus.org/articles/XLIII-B2-2020/1545/2020/ | |
isprs-archives-XLIII-B2-2020-1545-2020.pdf | |
We partitioned all acquired field parcels | |
according to the NUTS-3 regions and suggest to subdivide the | |
dataset into training (FRH01, FRH02), validation (FRH03), and | |
evaluation (FRH04) subsets based on these spatially distinct | |
regions. | |
""" | |
kwargs = { | |
"root": path_to_splits, | |
"preload_ram": False, | |
"level": LEVEL, | |
"transform": raw_transform, | |
} | |
# belle-ille is small, so its useful for testing | |
assert split in ["train", "valid", "test", "belle-ile"] | |
if split == "train": | |
self.ds: Dataset = ConcatDataset( | |
[BreizhCrops(region=r, **kwargs) for r in ["frh01", "frh02"]] | |
) | |
elif split == "valid": | |
self.ds = BreizhCrops(region="frh03", **kwargs) | |
elif split == "test": | |
self.ds = BreizhCrops(region="frh04", **kwargs) | |
else: | |
self.ds = BreizhCrops(region="belle-ile", **kwargs) | |
self.monthly_average = monthly_average | |
with (Path(__file__).parents[0] / Path("configs") / Path("breizhcrops.json")).open( | |
"r" | |
) as f: | |
config = json.load(f) | |
self.band_info = config["band_info"] | |
self.norm_operation = norm_operation | |
self.augmentation = augmentation | |
warnings.warn("Augmentations ignored for time series") | |
if partition != "default": | |
raise NotImplementedError(f"partition {partition} not implemented yet") | |
def __len__(self): | |
return len(self.ds) | |
def __getitem__(self, idx): | |
x, y_true, _ = self.ds[idx] | |
if self.monthly_average: | |
x = self.average_over_month(x) | |
eo = normalize_bands( | |
x[:, INPUT_TO_OUTPUT_BAND_MAPPING], self.norm_operation, self.band_info | |
) | |
eo = repeat(eo, "t d -> h w t d", h=1, w=1) | |
months = x[:, SELECTED_BANDS[LEVEL].index("doa")] | |
return {"s2": torch.tensor(eo), "months": torch.tensor(months), "target": y_true} | |
def average_over_month(x: np.ndarray): | |
x[:, SELECTED_BANDS[LEVEL].index("doa")] = np.array( | |
[t.month - 1 for t in pd.to_datetime(x[:, SELECTED_BANDS[LEVEL].index("doa")])] | |
) | |
per_month = np.split( | |
x, np.unique(x[:, SELECTED_BANDS[LEVEL].index("doa")], return_index=True)[1] | |
)[1:] | |
return np.array([per_month[idx].mean(axis=0) for idx in range(len(per_month))]) | |
def raw_transform(input_timeseries): | |
return input_timeseries | |