import os |
import numpy as np |
import albumentations |
from torch.utils.data import Dataset |
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex |
class FacesBase(Dataset): |
def __init__(self, *args, **kwargs): |
super().__init__() |
self.data = None |
self.keys = None |
def __len__(self): |
return len(self.data) |
def __getitem__(self, i): |
example = self.data[i] |
ex = {} |
if self.keys is not None: |
for k in self.keys: |
ex[k] = example[k] |
else: |
ex = example |
return ex |
class CelebAHQTrain(FacesBase): |
def __init__(self, size, keys=None): |
super().__init__() |
root = "data/celebahq" |
with open("data/celebahqtrain.txt", "r") as f: |
relpaths = f.read().splitlines() |
paths = [os.path.join(root, relpath) for relpath in relpaths] |
self.data = NumpyPaths(paths=paths, size=size, random_crop=False) |
self.keys = keys |
class CelebAHQValidation(FacesBase): |
def __init__(self, size, keys=None): |
super().__init__() |
root = "data/celebahq" |
with open("data/celebahqvalidation.txt", "r") as f: |
relpaths = f.read().splitlines() |
paths = [os.path.join(root, relpath) for relpath in relpaths] |
self.data = NumpyPaths(paths=paths, size=size, random_crop=False) |
self.keys = keys |
class FFHQTrain(FacesBase): |
def __init__(self, size, keys=None): |
super().__init__() |
root = "data/ffhq" |
with open("data/ffhqtrain.txt", "r") as f: |
relpaths = f.read().splitlines() |
paths = [os.path.join(root, relpath) for relpath in relpaths] |
self.data = ImagePaths(paths=paths, size=size, random_crop=False) |
self.keys = keys |
class FFHQValidation(FacesBase): |
def __init__(self, size, keys=None): |
super().__init__() |
root = "data/ffhq" |
with open("data/ffhqvalidation.txt", "r") as f: |
relpaths = f.read().splitlines() |
paths = [os.path.join(root, relpath) for relpath in relpaths] |
self.data = ImagePaths(paths=paths, size=size, random_crop=False) |
self.keys = keys |
class FacesHQTrain(Dataset): |
def __init__(self, size, keys=None, crop_size=None, coord=False): |
d1 = CelebAHQTrain(size=size, keys=keys) |
d2 = FFHQTrain(size=size, keys=keys) |
self.data = ConcatDatasetWithIndex([d1, d2]) |
self.coord = coord |
if crop_size is not None: |
self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) |
if self.coord: |
self.cropper = albumentations.Compose([self.cropper], |
additional_targets={"coord": "image"}) |
def __len__(self): |
return len(self.data) |
def __getitem__(self, i): |
ex, y = self.data[i] |
if hasattr(self, "cropper"): |
if not self.coord: |
out = self.cropper(image=ex["image"]) |
ex["image"] = out["image"] |
else: |
h,w,_ = ex["image"].shape |
coord = np.arange(h*w).reshape(h,w,1)/(h*w) |
out = self.cropper(image=ex["image"], coord=coord) |
ex["image"] = out["image"] |
ex["coord"] = out["coord"] |
ex["class"] = y |
return ex |
class FacesHQValidation(Dataset): |
def __init__(self, size, keys=None, crop_size=None, coord=False): |
d1 = CelebAHQValidation(size=size, keys=keys) |
d2 = FFHQValidation(size=size, keys=keys) |
self.data = ConcatDatasetWithIndex([d1, d2]) |
self.coord = coord |
if crop_size is not None: |
self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) |
if self.coord: |
self.cropper = albumentations.Compose([self.cropper], |
additional_targets={"coord": "image"}) |
def __len__(self): |
return len(self.data) |
def __getitem__(self, i): |
ex, y = self.data[i] |
if hasattr(self, "cropper"): |
if not self.coord: |
out = self.cropper(image=ex["image"]) |
ex["image"] = out["image"] |
else: |
h,w,_ = ex["image"].shape |
coord = np.arange(h*w).reshape(h,w,1)/(h*w) |
out = self.cropper(image=ex["image"], coord=coord) |
ex["image"] = out["image"] |
ex["coord"] = out["coord"] |
ex["class"] = y |
return ex |