Spaces:
Sleeping
Sleeping
File size: 7,103 Bytes
b20c769 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import json
from pathlib import Path
import pandas as pd
import rioxarray
import torch
from einops import rearrange
from torch.utils.data import Dataset
from tqdm import tqdm
from src.utils import data_dir
from ..preprocess import normalize_bands
flood_folder = data_dir / "sen1floods"
class Sen1Floods11Processor:
input_hw = 512
output_tile_size = 64
s1_bands = ("VV", "VH")
s2_bands = ("B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B10", "B11", "B12")
def __init__(self, folder: Path, split_path: Path):
split_labelnames = pd.read_csv(split_path, header=None)[1].tolist()
all_labels = list(folder.glob("LabelHand/*.tif"))
split_labels = []
for label in all_labels:
if label.name in split_labelnames:
split_labels.append(label)
self.all_labels = split_labels
def __len__(self):
return len(self.all_labels)
@classmethod
def split_and_filter_tensors(cls, s1, s2, labels):
"""
Split image and label tensors into 9 tiles and filter based on label content.
Args:
image_tensor (torch.Tensor): Input tensor of shape (13, 240, 240)
label_tensor (torch.Tensor): Label tensor of shape (240, 240)
Returns:
list of tuples: Each tuple contains (image_tile, label_tile)
"""
assert s1.shape == (
len(cls.s1_bands),
cls.input_hw,
cls.input_hw,
), (
f"s1 tensor must be of shape ({len(cls.s1_bands)}, {cls.input_hw}, {cls.input_hw}), "
f"got {s1.shape}"
)
assert s2.shape == (
len(cls.s2_bands),
cls.input_hw,
cls.input_hw,
), f"s2 tensor must be of shape ({len(cls.s2_bands)}, {cls.input_hw}, {cls.input_hw})"
assert labels.shape == (
1,
cls.input_hw,
cls.input_hw,
), f"labels tensor must be of shape (1, {cls.input_hw}, {cls.input_hw})"
tile_size = cls.output_tile_size
s1_list, s2_list, labels_list = [], [], []
num_tiles_per_dim = cls.input_hw // cls.output_tile_size
for i in range(num_tiles_per_dim):
for j in range(num_tiles_per_dim):
# Extract image tile
s1_tile = s1[
:, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size
]
s2_tile = s2[
:, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size
]
# Extract corresponding label tile
label_tile = labels[
:, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size
]
# Check if label tile has any non-zero values
if torch.any(label_tile > 0):
s1_list.append(s1_tile)
s2_list.append(s2_tile)
labels_list.append(label_tile)
return s1_list, s2_list, labels_list
@staticmethod
def label_to(label: Path, to: str = "s1"):
sen_root = label.parents[1]
location, tile_id, _ = label.stem.split("_")
if to == "s1":
return sen_root / f"s1/{location}_{tile_id}_S1Hand.tif"
elif to == "s2":
return sen_root / f"s2/{location}_{tile_id}_S2Hand.tif"
else:
raise ValueError(f"Expected `to` to be s1 or s2, got {to}")
def __getitem__(self, idx: int):
labels_path = self.all_labels[idx]
with rioxarray.open_rasterio(labels_path) as ds: # type: ignore
labels = torch.from_numpy(ds.values) # type: ignore
with rioxarray.open_rasterio(self.label_to(labels_path, "s1")) as ds: # type: ignore
s1 = torch.from_numpy(ds.values) # type: ignore
with rioxarray.open_rasterio(self.label_to(labels_path, "s2")) as ds: # type: ignore
s2 = torch.from_numpy(ds.values) # type: ignore
return self.split_and_filter_tensors(s1, s2, labels)
def get_sen1floods11(split_name: str = "flood_bolivia_data.csv"):
split_path = flood_folder / split_name
dataset = Sen1Floods11Processor(folder=flood_folder, split_path=split_path)
all_s1, all_s2, all_labels = [], [], []
for i in tqdm(range(len(dataset))):
b = dataset[i]
all_s1 += b[0]
all_s2 += b[1]
all_labels += b[2]
save_path = flood_folder / f"{split_path.stem}.pt"
torch.save(
obj={
"s1": torch.stack(all_s1),
"labels": torch.stack(all_labels),
"s2": torch.stack(all_s2),
},
f=save_path,
)
def remove_nan(s1, target):
# s1 is shape (N, H, W, C)
# target is shape (N, H, W)
new_s1, new_target = [], []
for i in range(s1.shape[0]):
if torch.any(torch.isnan(s1[i])) or torch.any(torch.isinf(s1[i])):
continue
new_s1.append(s1[i])
new_target.append(target[i])
return torch.stack(new_s1), torch.stack(new_target)
class Sen1Floods11Dataset(Dataset):
def __init__(
self,
path_to_splits: Path,
split: str,
norm_operation,
augmentation,
partition,
mode: str = "s1", # not sure if we would ever want s2?
):
with (Path(__file__).parents[0] / Path("configs") / Path("sen1floods11.json")).open(
"r"
) as f:
config = json.load(f)
assert split in ["train", "val", "valid", "test", "bolivia"]
if split == "val":
split = "valid"
self.band_info = config["band_info"]["s1"]
self.split = split
self.augmentation = augmentation
self.norm_operation = norm_operation
torch_obj = torch.load(path_to_splits / f"flood_{split}_data.pt")
self.s1 = torch_obj["s1"] # (N, 2, 64, 64)
self.s1 = rearrange(self.s1, "n c h w -> n h w c")
# print(f"Before removing nans, we have {self.s1.shape[0]} tiles")
self.labels = torch_obj["labels"]
self.s1, self.labels = remove_nan(
self.s1, self.labels
) # should we remove the tile or impute the pixel?
# print(f"After removing nans, we have {self.s1.shape[0]} tiles")
if (partition != "default") and (split == "train"):
with open(path_to_splits / f"{partition}_partition.json", "r") as json_file:
subset_indices = json.load(json_file)
self.s1 = self.s1[subset_indices]
self.labels = self.labels[subset_indices]
if mode != "s1":
raise ValueError(f"Modes other than s1 not yet supported, got {mode}")
def __len__(self):
return self.s1.shape[0]
def __getitem__(self, idx):
image = self.s1[idx]
label = self.labels[idx][0]
image = torch.tensor(normalize_bands(image.numpy(), self.norm_operation, self.band_info))
image, label = self.augmentation.apply(image, label, "seg")
return {"s1": image, "target": label.long()}
|