|
import os |
|
import numpy as np |
|
import albumentations |
|
from torch.utils.data import Dataset |
|
|
|
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex |
|
|
|
|
|
class CustomBase(Dataset): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
self.data = None |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, i): |
|
example = self.data[i] |
|
return example |
|
|
|
|
|
|
|
class CustomTrain(CustomBase): |
|
def __init__(self, size, training_images_list_file): |
|
super().__init__() |
|
with open(training_images_list_file, "r") as f: |
|
paths = f.read().splitlines() |
|
self.data = ImagePaths(paths=paths, size=size, random_crop=False) |
|
|
|
|
|
class CustomTest(CustomBase): |
|
def __init__(self, size, test_images_list_file): |
|
super().__init__() |
|
with open(test_images_list_file, "r") as f: |
|
paths = f.read().splitlines() |
|
self.data = ImagePaths(paths=paths, size=size, random_crop=False) |
|
|
|
|
|
|