danurahul's picture
Add application file
d4bc11a
import os
import numpy as np
import albumentations
from torch.utils.data import Dataset
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
class FacesBase(Dataset):
def __init__(self, *args, **kwargs):
super().__init__()
self.data = None
self.keys = None
def __len__(self):
return len(self.data)
def __getitem__(self, i):
example = self.data[i]
ex = {}
if self.keys is not None:
for k in self.keys:
ex[k] = example[k]
else:
ex = example
return ex
class CelebAHQTrain(FacesBase):
def __init__(self, size, keys=None):
super().__init__()
root = "data/celebahq"
with open("data/celebahqtrain.txt", "r") as f:
relpaths = f.read().splitlines()
paths = [os.path.join(root, relpath) for relpath in relpaths]
self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
self.keys = keys
class CelebAHQValidation(FacesBase):
def __init__(self, size, keys=None):
super().__init__()
root = "data/celebahq"
with open("data/celebahqvalidation.txt", "r") as f:
relpaths = f.read().splitlines()
paths = [os.path.join(root, relpath) for relpath in relpaths]
self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
self.keys = keys
class FFHQTrain(FacesBase):
def __init__(self, size, keys=None):
super().__init__()
root = "data/ffhq"
with open("data/ffhqtrain.txt", "r") as f:
relpaths = f.read().splitlines()
paths = [os.path.join(root, relpath) for relpath in relpaths]
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
self.keys = keys
class FFHQValidation(FacesBase):
def __init__(self, size, keys=None):
super().__init__()
root = "data/ffhq"
with open("data/ffhqvalidation.txt", "r") as f:
relpaths = f.read().splitlines()
paths = [os.path.join(root, relpath) for relpath in relpaths]
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
self.keys = keys
class FacesHQTrain(Dataset):
# CelebAHQ [0] + FFHQ [1]
def __init__(self, size, keys=None, crop_size=None, coord=False):
d1 = CelebAHQTrain(size=size, keys=keys)
d2 = FFHQTrain(size=size, keys=keys)
self.data = ConcatDatasetWithIndex([d1, d2])
self.coord = coord
if crop_size is not None:
self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
if self.coord:
self.cropper = albumentations.Compose([self.cropper],
additional_targets={"coord": "image"})
def __len__(self):
return len(self.data)
def __getitem__(self, i):
ex, y = self.data[i]
if hasattr(self, "cropper"):
if not self.coord:
out = self.cropper(image=ex["image"])
ex["image"] = out["image"]
else:
h,w,_ = ex["image"].shape
coord = np.arange(h*w).reshape(h,w,1)/(h*w)
out = self.cropper(image=ex["image"], coord=coord)
ex["image"] = out["image"]
ex["coord"] = out["coord"]
ex["class"] = y
return ex
class FacesHQValidation(Dataset):
# CelebAHQ [0] + FFHQ [1]
def __init__(self, size, keys=None, crop_size=None, coord=False):
d1 = CelebAHQValidation(size=size, keys=keys)
d2 = FFHQValidation(size=size, keys=keys)
self.data = ConcatDatasetWithIndex([d1, d2])
self.coord = coord
if crop_size is not None:
self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
if self.coord:
self.cropper = albumentations.Compose([self.cropper],
additional_targets={"coord": "image"})
def __len__(self):
return len(self.data)
def __getitem__(self, i):
ex, y = self.data[i]
if hasattr(self, "cropper"):
if not self.coord:
out = self.cropper(image=ex["image"])
ex["image"] = out["image"]
else:
h,w,_ = ex["image"].shape
coord = np.arange(h*w).reshape(h,w,1)/(h*w)
out = self.cropper(image=ex["image"], coord=coord)
ex["image"] = out["image"]
ex["coord"] = out["coord"]
ex["class"] = y
return ex