|
import bisect |
|
import numpy as np |
|
import albumentations |
|
from PIL import Image |
|
from torch.utils.data import Dataset, ConcatDataset |
|
|
|
|
|
class ConcatDatasetWithIndex(ConcatDataset): |
|
"""Modified from original pytorch code to return dataset idx""" |
|
def __getitem__(self, idx): |
|
if idx < 0: |
|
if -idx > len(self): |
|
raise ValueError("absolute value of index should not exceed dataset length") |
|
idx = len(self) + idx |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
return self.datasets[dataset_idx][sample_idx], dataset_idx |
|
|
|
|
|
class ImagePaths(Dataset): |
|
def __init__(self, paths, size=None, random_crop=False, labels=None): |
|
self.size = size |
|
self.random_crop = random_crop |
|
|
|
self.labels = dict() if labels is None else labels |
|
self.labels["file_path_"] = paths |
|
self._length = len(paths) |
|
|
|
if self.size is not None and self.size > 0: |
|
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) |
|
if not self.random_crop: |
|
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) |
|
else: |
|
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) |
|
self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) |
|
else: |
|
self.preprocessor = lambda **kwargs: kwargs |
|
|
|
def __len__(self): |
|
return self._length |
|
|
|
def preprocess_image(self, image_path): |
|
image = Image.open(image_path) |
|
if not image.mode == "RGB": |
|
image = image.convert("RGB") |
|
image = np.array(image).astype(np.uint8) |
|
image = self.preprocessor(image=image)["image"] |
|
image = (image/127.5 - 1.0).astype(np.float32) |
|
return image |
|
|
|
def __getitem__(self, i): |
|
example = dict() |
|
example["image"] = self.preprocess_image(self.labels["file_path_"][i]) |
|
for k in self.labels: |
|
example[k] = self.labels[k][i] |
|
return example |
|
|
|
|
|
class NumpyPaths(ImagePaths): |
|
def preprocess_image(self, image_path): |
|
image = np.load(image_path).squeeze(0) |
|
image = np.transpose(image, (1,2,0)) |
|
image = Image.fromarray(image, mode="RGB") |
|
image = np.array(image).astype(np.uint8) |
|
image = self.preprocessor(image=image)["image"] |
|
image = (image/127.5 - 1.0).astype(np.float32) |
|
return image |
|
|