Spaces:
Runtime error
Runtime error
| # This file is copied from https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/data.py | |
| # MIT License | |
| # Copyright (c) 2022 Lorenzo Breschi | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| from typing import Callable, Dict | |
| import torch | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms.functional as F | |
| from torchvision import transforms | |
| import pytorch_lightning as pl | |
| from collections.abc import Iterable | |
| # image reader writer | |
| from pathlib import Path | |
| from PIL import Image | |
| from typing import Tuple | |
| def read_image(filepath: Path, mode: str = None) -> Image: | |
| with open(filepath, 'rb') as file: | |
| image = Image.open(file) | |
| return image.convert(mode) | |
| image2tensor = transforms.ToTensor() | |
| tensor2image = transforms.ToPILImage() | |
| def write_image(image: Image, filepath: Path): | |
| filepath.parent.mkdir(parents=True, exist_ok=True) | |
| image.save(str(filepath)) | |
| def read_image_tensor(filepath: Path, mode: str = 'RGB') -> torch.Tensor: | |
| return image2tensor(read_image(filepath, mode)) | |
| def write_image_tensor(input: torch.Tensor, filepath: Path): | |
| write_image(tensor2image(input), filepath) | |
| def get_valid_indices(H: int, W: int, patch_size: int, random_overlap: int = 0): | |
| vih = torch.arange(random_overlap, H-patch_size - | |
| random_overlap+1, patch_size) | |
| viw = torch.arange(random_overlap, W-patch_size - | |
| random_overlap+1, patch_size) | |
| if random_overlap > 0: | |
| rih = torch.randint_like(vih, -random_overlap, random_overlap) | |
| riw = torch.randint_like(viw, -random_overlap, random_overlap) | |
| vih += rih | |
| viw += riw | |
| vi = torch.stack(torch.meshgrid(vih, viw)).view(2, -1).t() | |
| return vi | |
| def cut_patches(input: torch.Tensor, indices: Tuple[Tuple[int, int]], patch_size: int, padding: int = 0): | |
| # TODO use slices to get all patches at the same time ? | |
| patches_l = [] | |
| for n in range(len(indices)): | |
| patch = F.crop(input, *(indices[n]-padding), | |
| *(patch_size+padding*2,)*2) | |
| patches_l.append(patch) | |
| patches = torch.cat(patches_l, dim=0) | |
| return patches | |
| def prepare_data(data_path: Path, read_func: Callable = read_image_tensor) -> Dict: | |
| """ | |
| Takes a data_path of a folder which contains subfolders with input, target, etc. | |
| lablelled by the same names. | |
| :param data_path: Path of the folder containing data | |
| :param read_func: function that reads data and returns a tensor | |
| """ | |
| data_dict = {} | |
| subdir_names = ["target", "input", "mask"] # ,"helper" | |
| # checks only files for which there is an target | |
| # TODO check for images | |
| name_ls = [file.name for file in ( | |
| data_path / "target").iterdir() if file.is_file()] | |
| subdirs = [data_path / sdn for sdn in subdir_names] | |
| for sd in subdirs: | |
| if sd.is_dir(): | |
| data_ls = [] | |
| files = [sd / name for name in name_ls] | |
| for file in files: | |
| tensor = read_func(file) | |
| H, W = tensor.shape[-2:] | |
| data_ls.append(tensor) | |
| # TODO check that all sizes match | |
| data_dict[sd.name] = torch.stack(data_ls, dim=0) | |
| data_dict['name'] = name_ls | |
| data_dict['len'] = len(data_dict['name']) | |
| data_dict['H'] = H | |
| data_dict['W'] = W | |
| return data_dict | |
| # TODO an image is loaded whenever a patch is needed, this may be a bottleneck | |
| class DataDictLoader(): | |
| def __init__(self, data_dict: Dict, | |
| batch_size: int = 16, | |
| max_length: int = 128, | |
| shuffle: bool = False): | |
| """ | |
| """ | |
| self.batch_size = batch_size | |
| self.shuffle = shuffle | |
| self.batch_size = batch_size | |
| self.data_dict = data_dict | |
| self.dataset_len = data_dict['len'] | |
| self.len = self.dataset_len if max_length is None else min( | |
| self.dataset_len, max_length) | |
| # Calculate # batches | |
| num_batches, remainder = divmod(self.len, self.batch_size) | |
| if remainder > 0: | |
| num_batches += 1 | |
| self.num_batches = num_batches | |
| def __iter__(self): | |
| if self.shuffle: | |
| r = torch.randperm(self.dataset_len) | |
| self.data_dict = {k: v[r] if isinstance( | |
| v, Iterable) else v for k, v in self.data_dict.items()} | |
| self.i = 0 | |
| return self | |
| def __next__(self): | |
| if self.i >= self.len: | |
| raise StopIteration | |
| batch = {k: v[self.i:self.i+self.batch_size] | |
| if isinstance(v, Iterable) else v for k, v in self.data_dict.items()} | |
| self.i += self.batch_size | |
| return batch | |
| def __len__(self): | |
| return self.num_batches | |
| class PatchDataModule(pl.LightningDataModule): | |
| def __init__(self, data_dict, | |
| patch_size: int = 2**5, | |
| batch_size: int = 2**4, | |
| patch_num: int = 2**6): | |
| super().__init__() | |
| self.data_dict = data_dict | |
| self.H, self.W = data_dict['H'], data_dict['W'] | |
| self.len = data_dict['len'] | |
| self.batch_size = batch_size | |
| self.patch_size = patch_size | |
| self.patch_num = patch_num | |
| def dataloader(self, data_dict, **kwargs): | |
| return DataDictLoader(data_dict, **kwargs) | |
| def train_dataloader(self): | |
| patches = self.cut_patches() | |
| return self.dataloader(patches, batch_size=self.batch_size, shuffle=True, | |
| max_length=self.patch_num) | |
| def val_dataloader(self): | |
| return self.dataloader(self.data_dict, batch_size=1) | |
| def test_dataloader(self): | |
| return self.dataloader(self.data_dict) # TODO batch size | |
| def cut_patches(self): | |
| # TODO cycle once | |
| patch_indices = get_valid_indices( | |
| self.H, self.W, self.patch_size, self.patch_size//4) | |
| dd = {k: cut_patches( | |
| v, patch_indices, self.patch_size) for k, v in self.data_dict.items() | |
| if isinstance(v, torch.Tensor) | |
| } | |
| threshold = 0.1 | |
| mask_p = torch.mean( | |
| dd.get('mask', torch.ones_like(dd['input'])), dim=(-1, -2, -3)) | |
| masked_idx = (mask_p > threshold).nonzero(as_tuple=True)[0] | |
| dd = {k: v[masked_idx] for k, v in dd.items()} | |
| dd['len'] = len(masked_idx) | |
| dd['H'], dd['W'] = (self.patch_size,)*2 | |
| return dd | |
| class ImageDataset(Dataset): | |
| def __init__(self, file_paths: Iterable, read_func: Callable = read_image_tensor): | |
| self.file_paths = file_paths | |
| def __getitem__(self, idx: int) -> dict: | |
| file = self.file_paths[idx] | |
| return read_image_tensor(file), file.name | |
| def __len__(self) -> int: | |
| return len(self.file_paths) |