import json import os from pathlib import Path import numpy as np import torch import torch.multiprocessing import torch.nn.functional as F from PIL import Image from torch.utils.data import Dataset from ..preprocess import normalize_bands torch.multiprocessing.set_sharing_strategy("file_system") def split_and_filter_tensors(image_tensor, label_tensor): """ Split image and label tensors into 9 tiles and filter based on label content. Args: image_tensor (torch.Tensor): Input tensor of shape (13, 240, 240) label_tensor (torch.Tensor): Label tensor of shape (240, 240) Returns: list of tuples: Each tuple contains (image_tile, label_tile) """ assert image_tensor.shape == (13, 240, 240), "Image tensor must be of shape (13, 240, 240)" assert label_tensor.shape == (240, 240), "Label tensor must be of shape (240, 240)" tile_size = 80 tiles = [] labels = [] for i in range(3): for j in range(3): # Extract image tile image_tile = image_tensor[ :, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size ] # Extract corresponding label tile label_tile = label_tensor[ i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size ] # Check if label tile has any non-zero values if torch.any(label_tile > 0): tiles.append(image_tile) labels.append(label_tile) return tiles, labels class PrepMADOSDataset(Dataset): def __init__(self, root_dir, split_file): self.root_dir = root_dir with open(os.path.join(root_dir, "splits", split_file), "r") as f: self.scene_list = [line.strip() for line in f] def __len__(self): return len(self.scene_list) def __getitem__(self, idx): scene_name = self.scene_list[idx] scene_num_1 = scene_name.split("_")[1] scene_num_2 = scene_name.split("_")[2] # Load all bands B1 = self._load_band(scene_num_1, scene_num_2, [442, 443], 60) B2 = self._load_band(scene_num_1, scene_num_2, [492], 10) B3 = self._load_band(scene_num_1, scene_num_2, [559, 560], 10) B4 = self._load_band(scene_num_1, scene_num_2, [665], 10) B5 = self._load_band(scene_num_1, scene_num_2, [704], 20) B7 = self._load_band(scene_num_1, scene_num_2, [780, 783], 20) B8 = self._load_band(scene_num_1, scene_num_2, [833], 10) B8A = self._load_band(scene_num_1, scene_num_2, [864, 865], 20) B11 = self._load_band(scene_num_1, scene_num_2, [1610, 1614], 20) B12 = self._load_band(scene_num_1, scene_num_2, [2186, 2202], 20) B1 = self._resize(B1) B5 = self._resize(B5) B7 = self._resize(B7) B8A = self._resize(B8A) B11 = self._resize(B11) B12 = self._resize(B12) # Interpolate missing bands B6 = (B5 + B7) / 2 B9 = B8A B10 = (B8A + B11) / 2 image = torch.cat( [B1, B2, B3, B4, B5, B6, B7, B8, B8A, B9, B10, B11, B12], axis=1 ).squeeze(0) # (13, 240, 240) mask = self._load_mask(scene_num_1, scene_num_2).squeeze(0).squeeze(0) # (240, 240) images, masks = split_and_filter_tensors(image, mask) return images, masks def _load_band(self, scene_num_1, scene_num_2, bands, resolution): for band in bands: band_path = f"{self.root_dir}/Scene_{scene_num_1}/{resolution}/Scene_{scene_num_1}_L2R_rhorc_{band}_{scene_num_2}.tif" if os.path.exists(band_path): return ( torch.from_numpy(np.array(Image.open(band_path))) .float() .unsqueeze(0) .unsqueeze(0) ) print(f"COULDNT FIND {scene_num_1, scene_num_2, bands, resolution}") def _resize(self, image): return F.interpolate(image, size=240, mode="bilinear", align_corners=False) def _load_mask(self, scene_num_1, scene_num_2): mask_path = ( f"{self.root_dir}/Scene_{scene_num_1}/10/Scene_{scene_num_1}_L2R_cl_{scene_num_2}.tif" ) return torch.from_numpy(np.array(Image.open(mask_path))).long().unsqueeze(0).unsqueeze(0) def get_mados(save_path, root_dir="MADOS", split_file="test_X.txt"): dataset = PrepMADOSDataset(root_dir=root_dir, split_file=split_file) all_images = [] all_masks = [] for i in dataset: all_images += i[0] all_masks += i[1] split_images = torch.stack(all_images) # shape (N, 13, 80, 80) split_masks = torch.stack(all_masks) # shape (N, 80, 80) torch.save(obj={"images": split_images, "labels": split_masks}, f=save_path) class MADOSDataset(Dataset): def __init__(self, path_to_splits: Path, split: str, norm_operation, augmentation, partition): with (Path(__file__).parents[0] / Path("configs") / Path("mados.json")).open("r") as f: config = json.load(f) # NOTE: I imputed bands for this dataset before saving the tensors, so no imputation is necessary assert split in ["train", "val", "valid", "test"] if split == "valid": split = "val" self.band_info = config["band_info"] self.split = split self.augmentation = augmentation self.norm_operation = norm_operation torch_obj = torch.load(path_to_splits / f"MADOS_{split}.pt") self.images = torch_obj["images"] self.labels = torch_obj["labels"] if (partition != "default") and (split == "train"): with open(path_to_splits / f"{partition}_partition.json", "r") as json_file: subset_indices = json.load(json_file) self.images = self.images[subset_indices] self.labels = self.labels[subset_indices] def __len__(self): return self.images.shape[0] def __getitem__(self, idx): image = self.images[idx] # (80, 80, 13) label = self.labels[idx] # (80, 80) image = torch.tensor(normalize_bands(image.numpy(), self.norm_operation, self.band_info)) image, label = self.augmentation.apply(image, label, "seg") return {"s2": image, "target": label}