NASA-Galileo / src /eval /datasets /sen1floods11.py
openfree's picture
Deploy from GitHub repository
b20c769 verified
import json
from pathlib import Path
import pandas as pd
import rioxarray
import torch
from einops import rearrange
from torch.utils.data import Dataset
from tqdm import tqdm
from src.utils import data_dir
from ..preprocess import normalize_bands
flood_folder = data_dir / "sen1floods"
class Sen1Floods11Processor:
input_hw = 512
output_tile_size = 64
s1_bands = ("VV", "VH")
s2_bands = ("B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B10", "B11", "B12")
def __init__(self, folder: Path, split_path: Path):
split_labelnames = pd.read_csv(split_path, header=None)[1].tolist()
all_labels = list(folder.glob("LabelHand/*.tif"))
split_labels = []
for label in all_labels:
if label.name in split_labelnames:
split_labels.append(label)
self.all_labels = split_labels
def __len__(self):
return len(self.all_labels)
@classmethod
def split_and_filter_tensors(cls, s1, s2, labels):
"""
Split image and label tensors into 9 tiles and filter based on label content.
Args:
image_tensor (torch.Tensor): Input tensor of shape (13, 240, 240)
label_tensor (torch.Tensor): Label tensor of shape (240, 240)
Returns:
list of tuples: Each tuple contains (image_tile, label_tile)
"""
assert s1.shape == (
len(cls.s1_bands),
cls.input_hw,
cls.input_hw,
), (
f"s1 tensor must be of shape ({len(cls.s1_bands)}, {cls.input_hw}, {cls.input_hw}), "
f"got {s1.shape}"
)
assert s2.shape == (
len(cls.s2_bands),
cls.input_hw,
cls.input_hw,
), f"s2 tensor must be of shape ({len(cls.s2_bands)}, {cls.input_hw}, {cls.input_hw})"
assert labels.shape == (
1,
cls.input_hw,
cls.input_hw,
), f"labels tensor must be of shape (1, {cls.input_hw}, {cls.input_hw})"
tile_size = cls.output_tile_size
s1_list, s2_list, labels_list = [], [], []
num_tiles_per_dim = cls.input_hw // cls.output_tile_size
for i in range(num_tiles_per_dim):
for j in range(num_tiles_per_dim):
# Extract image tile
s1_tile = s1[
:, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size
]
s2_tile = s2[
:, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size
]
# Extract corresponding label tile
label_tile = labels[
:, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size
]
# Check if label tile has any non-zero values
if torch.any(label_tile > 0):
s1_list.append(s1_tile)
s2_list.append(s2_tile)
labels_list.append(label_tile)
return s1_list, s2_list, labels_list
@staticmethod
def label_to(label: Path, to: str = "s1"):
sen_root = label.parents[1]
location, tile_id, _ = label.stem.split("_")
if to == "s1":
return sen_root / f"s1/{location}_{tile_id}_S1Hand.tif"
elif to == "s2":
return sen_root / f"s2/{location}_{tile_id}_S2Hand.tif"
else:
raise ValueError(f"Expected `to` to be s1 or s2, got {to}")
def __getitem__(self, idx: int):
labels_path = self.all_labels[idx]
with rioxarray.open_rasterio(labels_path) as ds: # type: ignore
labels = torch.from_numpy(ds.values) # type: ignore
with rioxarray.open_rasterio(self.label_to(labels_path, "s1")) as ds: # type: ignore
s1 = torch.from_numpy(ds.values) # type: ignore
with rioxarray.open_rasterio(self.label_to(labels_path, "s2")) as ds: # type: ignore
s2 = torch.from_numpy(ds.values) # type: ignore
return self.split_and_filter_tensors(s1, s2, labels)
def get_sen1floods11(split_name: str = "flood_bolivia_data.csv"):
split_path = flood_folder / split_name
dataset = Sen1Floods11Processor(folder=flood_folder, split_path=split_path)
all_s1, all_s2, all_labels = [], [], []
for i in tqdm(range(len(dataset))):
b = dataset[i]
all_s1 += b[0]
all_s2 += b[1]
all_labels += b[2]
save_path = flood_folder / f"{split_path.stem}.pt"
torch.save(
obj={
"s1": torch.stack(all_s1),
"labels": torch.stack(all_labels),
"s2": torch.stack(all_s2),
},
f=save_path,
)
def remove_nan(s1, target):
# s1 is shape (N, H, W, C)
# target is shape (N, H, W)
new_s1, new_target = [], []
for i in range(s1.shape[0]):
if torch.any(torch.isnan(s1[i])) or torch.any(torch.isinf(s1[i])):
continue
new_s1.append(s1[i])
new_target.append(target[i])
return torch.stack(new_s1), torch.stack(new_target)
class Sen1Floods11Dataset(Dataset):
def __init__(
self,
path_to_splits: Path,
split: str,
norm_operation,
augmentation,
partition,
mode: str = "s1", # not sure if we would ever want s2?
):
with (Path(__file__).parents[0] / Path("configs") / Path("sen1floods11.json")).open(
"r"
) as f:
config = json.load(f)
assert split in ["train", "val", "valid", "test", "bolivia"]
if split == "val":
split = "valid"
self.band_info = config["band_info"]["s1"]
self.split = split
self.augmentation = augmentation
self.norm_operation = norm_operation
torch_obj = torch.load(path_to_splits / f"flood_{split}_data.pt")
self.s1 = torch_obj["s1"] # (N, 2, 64, 64)
self.s1 = rearrange(self.s1, "n c h w -> n h w c")
# print(f"Before removing nans, we have {self.s1.shape[0]} tiles")
self.labels = torch_obj["labels"]
self.s1, self.labels = remove_nan(
self.s1, self.labels
) # should we remove the tile or impute the pixel?
# print(f"After removing nans, we have {self.s1.shape[0]} tiles")
if (partition != "default") and (split == "train"):
with open(path_to_splits / f"{partition}_partition.json", "r") as json_file:
subset_indices = json.load(json_file)
self.s1 = self.s1[subset_indices]
self.labels = self.labels[subset_indices]
if mode != "s1":
raise ValueError(f"Modes other than s1 not yet supported, got {mode}")
def __len__(self):
return self.s1.shape[0]
def __getitem__(self, idx):
image = self.s1[idx]
label = self.labels[idx][0]
image = torch.tensor(normalize_bands(image.numpy(), self.norm_operation, self.band_info))
image, label = self.augmentation.apply(image, label, "seg")
return {"s1": image, "target": label.long()}