|
import os |
|
import numpy as np |
|
import albumentations |
|
from torch.utils.data import Dataset |
|
|
|
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex |
|
|
|
|
|
class FacesBase(Dataset): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
self.data = None |
|
self.keys = None |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, i): |
|
example = self.data[i] |
|
ex = {} |
|
if self.keys is not None: |
|
for k in self.keys: |
|
ex[k] = example[k] |
|
else: |
|
ex = example |
|
return ex |
|
|
|
|
|
class CelebAHQTrain(FacesBase): |
|
def __init__(self, size, keys=None): |
|
super().__init__() |
|
root = "data/celebahq" |
|
with open("data/celebahqtrain.txt", "r") as f: |
|
relpaths = f.read().splitlines() |
|
paths = [os.path.join(root, relpath) for relpath in relpaths] |
|
self.data = NumpyPaths(paths=paths, size=size, random_crop=False) |
|
self.keys = keys |
|
|
|
|
|
class CelebAHQValidation(FacesBase): |
|
def __init__(self, size, keys=None): |
|
super().__init__() |
|
root = "data/celebahq" |
|
with open("data/celebahqvalidation.txt", "r") as f: |
|
relpaths = f.read().splitlines() |
|
paths = [os.path.join(root, relpath) for relpath in relpaths] |
|
self.data = NumpyPaths(paths=paths, size=size, random_crop=False) |
|
self.keys = keys |
|
|
|
|
|
class FFHQTrain(FacesBase): |
|
def __init__(self, size, keys=None): |
|
super().__init__() |
|
root = "data/ffhq" |
|
with open("data/ffhqtrain.txt", "r") as f: |
|
relpaths = f.read().splitlines() |
|
paths = [os.path.join(root, relpath) for relpath in relpaths] |
|
self.data = ImagePaths(paths=paths, size=size, random_crop=False) |
|
self.keys = keys |
|
|
|
|
|
class FFHQValidation(FacesBase): |
|
def __init__(self, size, keys=None): |
|
super().__init__() |
|
root = "data/ffhq" |
|
with open("data/ffhqvalidation.txt", "r") as f: |
|
relpaths = f.read().splitlines() |
|
paths = [os.path.join(root, relpath) for relpath in relpaths] |
|
self.data = ImagePaths(paths=paths, size=size, random_crop=False) |
|
self.keys = keys |
|
|
|
|
|
class FacesHQTrain(Dataset): |
|
|
|
def __init__(self, size, keys=None, crop_size=None, coord=False): |
|
d1 = CelebAHQTrain(size=size, keys=keys) |
|
d2 = FFHQTrain(size=size, keys=keys) |
|
self.data = ConcatDatasetWithIndex([d1, d2]) |
|
self.coord = coord |
|
if crop_size is not None: |
|
self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) |
|
if self.coord: |
|
self.cropper = albumentations.Compose([self.cropper], |
|
additional_targets={"coord": "image"}) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, i): |
|
ex, y = self.data[i] |
|
if hasattr(self, "cropper"): |
|
if not self.coord: |
|
out = self.cropper(image=ex["image"]) |
|
ex["image"] = out["image"] |
|
else: |
|
h,w,_ = ex["image"].shape |
|
coord = np.arange(h*w).reshape(h,w,1)/(h*w) |
|
out = self.cropper(image=ex["image"], coord=coord) |
|
ex["image"] = out["image"] |
|
ex["coord"] = out["coord"] |
|
ex["class"] = y |
|
return ex |
|
|
|
|
|
class FacesHQValidation(Dataset): |
|
|
|
def __init__(self, size, keys=None, crop_size=None, coord=False): |
|
d1 = CelebAHQValidation(size=size, keys=keys) |
|
d2 = FFHQValidation(size=size, keys=keys) |
|
self.data = ConcatDatasetWithIndex([d1, d2]) |
|
self.coord = coord |
|
if crop_size is not None: |
|
self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) |
|
if self.coord: |
|
self.cropper = albumentations.Compose([self.cropper], |
|
additional_targets={"coord": "image"}) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, i): |
|
ex, y = self.data[i] |
|
if hasattr(self, "cropper"): |
|
if not self.coord: |
|
out = self.cropper(image=ex["image"]) |
|
ex["image"] = out["image"] |
|
else: |
|
h,w,_ = ex["image"].shape |
|
coord = np.arange(h*w).reshape(h,w,1)/(h*w) |
|
out = self.cropper(image=ex["image"], coord=coord) |
|
ex["image"] = out["image"] |
|
ex["coord"] = out["coord"] |
|
ex["class"] = y |
|
return ex |
|
|