|
import os, tarfile, glob, shutil |
|
import yaml |
|
import numpy as np |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import albumentations |
|
from omegaconf import OmegaConf |
|
from torch.utils.data import Dataset |
|
|
|
from taming.data.base import ImagePaths |
|
from taming.util import download, retrieve |
|
import taming.data.utils as bdu |
|
|
|
|
|
def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): |
|
synsets = [] |
|
with open(path_to_yaml) as f: |
|
di2s = yaml.load(f) |
|
for idx in indices: |
|
synsets.append(str(di2s[idx])) |
|
print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets))) |
|
return synsets |
|
|
|
|
|
def str_to_indices(string): |
|
"""Expects a string in the format '32-123, 256, 280-321'""" |
|
assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string) |
|
subs = string.split(",") |
|
indices = [] |
|
for sub in subs: |
|
subsubs = sub.split("-") |
|
assert len(subsubs) > 0 |
|
if len(subsubs) == 1: |
|
indices.append(int(subsubs[0])) |
|
else: |
|
rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))] |
|
indices.extend(rang) |
|
return sorted(indices) |
|
|
|
|
|
class ImageNetBase(Dataset): |
|
def __init__(self, config=None): |
|
self.config = config or OmegaConf.create() |
|
if not type(self.config)==dict: |
|
self.config = OmegaConf.to_container(self.config) |
|
self._prepare() |
|
self._prepare_synset_to_human() |
|
self._prepare_idx_to_synset() |
|
self._load() |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, i): |
|
return self.data[i] |
|
|
|
def _prepare(self): |
|
raise NotImplementedError() |
|
|
|
def _filter_relpaths(self, relpaths): |
|
ignore = set([ |
|
"n06596364_9591.JPEG", |
|
]) |
|
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] |
|
if "sub_indices" in self.config: |
|
indices = str_to_indices(self.config["sub_indices"]) |
|
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) |
|
files = [] |
|
for rpath in relpaths: |
|
syn = rpath.split("/")[0] |
|
if syn in synsets: |
|
files.append(rpath) |
|
return files |
|
else: |
|
return relpaths |
|
|
|
def _prepare_synset_to_human(self): |
|
SIZE = 2655750 |
|
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" |
|
self.human_dict = os.path.join(self.root, "synset_human.txt") |
|
if (not os.path.exists(self.human_dict) or |
|
not os.path.getsize(self.human_dict)==SIZE): |
|
download(URL, self.human_dict) |
|
|
|
def _prepare_idx_to_synset(self): |
|
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" |
|
self.idx2syn = os.path.join(self.root, "index_synset.yaml") |
|
if (not os.path.exists(self.idx2syn)): |
|
download(URL, self.idx2syn) |
|
|
|
def _load(self): |
|
with open(self.txt_filelist, "r") as f: |
|
self.relpaths = f.read().splitlines() |
|
l1 = len(self.relpaths) |
|
self.relpaths = self._filter_relpaths(self.relpaths) |
|
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) |
|
|
|
self.synsets = [p.split("/")[0] for p in self.relpaths] |
|
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] |
|
|
|
unique_synsets = np.unique(self.synsets) |
|
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) |
|
self.class_labels = [class_dict[s] for s in self.synsets] |
|
|
|
with open(self.human_dict, "r") as f: |
|
human_dict = f.read().splitlines() |
|
human_dict = dict(line.split(maxsplit=1) for line in human_dict) |
|
|
|
self.human_labels = [human_dict[s] for s in self.synsets] |
|
|
|
labels = { |
|
"relpath": np.array(self.relpaths), |
|
"synsets": np.array(self.synsets), |
|
"class_label": np.array(self.class_labels), |
|
"human_label": np.array(self.human_labels), |
|
} |
|
self.data = ImagePaths(self.abspaths, |
|
labels=labels, |
|
size=retrieve(self.config, "size", default=0), |
|
random_crop=self.random_crop) |
|
|
|
|
|
class ImageNetTrain(ImageNetBase): |
|
NAME = "ILSVRC2012_train" |
|
URL = "http://www.image-net.org/challenges/LSVRC/2012/" |
|
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" |
|
FILES = [ |
|
"ILSVRC2012_img_train.tar", |
|
] |
|
SIZES = [ |
|
147897477120, |
|
] |
|
|
|
def _prepare(self): |
|
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", |
|
default=True) |
|
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) |
|
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) |
|
self.datadir = os.path.join(self.root, "data") |
|
self.txt_filelist = os.path.join(self.root, "filelist.txt") |
|
self.expected_length = 1281167 |
|
if not bdu.is_prepared(self.root): |
|
|
|
print("Preparing dataset {} in {}".format(self.NAME, self.root)) |
|
|
|
datadir = self.datadir |
|
if not os.path.exists(datadir): |
|
path = os.path.join(self.root, self.FILES[0]) |
|
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: |
|
import academictorrents as at |
|
atpath = at.get(self.AT_HASH, datastore=self.root) |
|
assert atpath == path |
|
|
|
print("Extracting {} to {}".format(path, datadir)) |
|
os.makedirs(datadir, exist_ok=True) |
|
with tarfile.open(path, "r:") as tar: |
|
tar.extractall(path=datadir) |
|
|
|
print("Extracting sub-tars.") |
|
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) |
|
for subpath in tqdm(subpaths): |
|
subdir = subpath[:-len(".tar")] |
|
os.makedirs(subdir, exist_ok=True) |
|
with tarfile.open(subpath, "r:") as tar: |
|
tar.extractall(path=subdir) |
|
|
|
|
|
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) |
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist] |
|
filelist = sorted(filelist) |
|
filelist = "\n".join(filelist)+"\n" |
|
with open(self.txt_filelist, "w") as f: |
|
f.write(filelist) |
|
|
|
bdu.mark_prepared(self.root) |
|
|
|
|
|
class ImageNetValidation(ImageNetBase): |
|
NAME = "ILSVRC2012_validation" |
|
URL = "http://www.image-net.org/challenges/LSVRC/2012/" |
|
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" |
|
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" |
|
FILES = [ |
|
"ILSVRC2012_img_val.tar", |
|
"validation_synset.txt", |
|
] |
|
SIZES = [ |
|
6744924160, |
|
1950000, |
|
] |
|
|
|
def _prepare(self): |
|
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", |
|
default=False) |
|
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) |
|
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) |
|
self.datadir = os.path.join(self.root, "data") |
|
self.txt_filelist = os.path.join(self.root, "filelist.txt") |
|
self.expected_length = 50000 |
|
if not bdu.is_prepared(self.root): |
|
|
|
print("Preparing dataset {} in {}".format(self.NAME, self.root)) |
|
|
|
datadir = self.datadir |
|
if not os.path.exists(datadir): |
|
path = os.path.join(self.root, self.FILES[0]) |
|
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: |
|
import academictorrents as at |
|
atpath = at.get(self.AT_HASH, datastore=self.root) |
|
assert atpath == path |
|
|
|
print("Extracting {} to {}".format(path, datadir)) |
|
os.makedirs(datadir, exist_ok=True) |
|
with tarfile.open(path, "r:") as tar: |
|
tar.extractall(path=datadir) |
|
|
|
vspath = os.path.join(self.root, self.FILES[1]) |
|
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: |
|
download(self.VS_URL, vspath) |
|
|
|
with open(vspath, "r") as f: |
|
synset_dict = f.read().splitlines() |
|
synset_dict = dict(line.split() for line in synset_dict) |
|
|
|
print("Reorganizing into synset folders") |
|
synsets = np.unique(list(synset_dict.values())) |
|
for s in synsets: |
|
os.makedirs(os.path.join(datadir, s), exist_ok=True) |
|
for k, v in synset_dict.items(): |
|
src = os.path.join(datadir, k) |
|
dst = os.path.join(datadir, v) |
|
shutil.move(src, dst) |
|
|
|
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) |
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist] |
|
filelist = sorted(filelist) |
|
filelist = "\n".join(filelist)+"\n" |
|
with open(self.txt_filelist, "w") as f: |
|
f.write(filelist) |
|
|
|
bdu.mark_prepared(self.root) |
|
|
|
|
|
def get_preprocessor(size=None, random_crop=False, additional_targets=None, |
|
crop_size=None): |
|
if size is not None and size > 0: |
|
transforms = list() |
|
rescaler = albumentations.SmallestMaxSize(max_size = size) |
|
transforms.append(rescaler) |
|
if not random_crop: |
|
cropper = albumentations.CenterCrop(height=size,width=size) |
|
transforms.append(cropper) |
|
else: |
|
cropper = albumentations.RandomCrop(height=size,width=size) |
|
transforms.append(cropper) |
|
flipper = albumentations.HorizontalFlip() |
|
transforms.append(flipper) |
|
preprocessor = albumentations.Compose(transforms, |
|
additional_targets=additional_targets) |
|
elif crop_size is not None and crop_size > 0: |
|
if not random_crop: |
|
cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) |
|
else: |
|
cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) |
|
transforms = [cropper] |
|
preprocessor = albumentations.Compose(transforms, |
|
additional_targets=additional_targets) |
|
else: |
|
preprocessor = lambda **kwargs: kwargs |
|
return preprocessor |
|
|
|
|
|
def rgba_to_depth(x): |
|
assert x.dtype == np.uint8 |
|
assert len(x.shape) == 3 and x.shape[2] == 4 |
|
y = x.copy() |
|
y.dtype = np.float32 |
|
y = y.reshape(x.shape[:2]) |
|
return np.ascontiguousarray(y) |
|
|
|
|
|
class BaseWithDepth(Dataset): |
|
DEFAULT_DEPTH_ROOT="data/imagenet_depth" |
|
|
|
def __init__(self, config=None, size=None, random_crop=False, |
|
crop_size=None, root=None): |
|
self.config = config |
|
self.base_dset = self.get_base_dset() |
|
self.preprocessor = get_preprocessor( |
|
size=size, |
|
crop_size=crop_size, |
|
random_crop=random_crop, |
|
additional_targets={"depth": "image"}) |
|
self.crop_size = crop_size |
|
if self.crop_size is not None: |
|
self.rescaler = albumentations.Compose( |
|
[albumentations.SmallestMaxSize(max_size = self.crop_size)], |
|
additional_targets={"depth": "image"}) |
|
if root is not None: |
|
self.DEFAULT_DEPTH_ROOT = root |
|
|
|
def __len__(self): |
|
return len(self.base_dset) |
|
|
|
def preprocess_depth(self, path): |
|
rgba = np.array(Image.open(path)) |
|
depth = rgba_to_depth(rgba) |
|
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) |
|
depth = 2.0*depth-1.0 |
|
return depth |
|
|
|
def __getitem__(self, i): |
|
e = self.base_dset[i] |
|
e["depth"] = self.preprocess_depth(self.get_depth_path(e)) |
|
|
|
h,w,c = e["image"].shape |
|
if self.crop_size and min(h,w) < self.crop_size: |
|
|
|
out = self.rescaler(image=e["image"], depth=e["depth"]) |
|
e["image"] = out["image"] |
|
e["depth"] = out["depth"] |
|
transformed = self.preprocessor(image=e["image"], depth=e["depth"]) |
|
e["image"] = transformed["image"] |
|
e["depth"] = transformed["depth"] |
|
return e |
|
|
|
|
|
class ImageNetTrainWithDepth(BaseWithDepth): |
|
|
|
def __init__(self, random_crop=True, sub_indices=None, **kwargs): |
|
self.sub_indices = sub_indices |
|
super().__init__(random_crop=random_crop, **kwargs) |
|
|
|
def get_base_dset(self): |
|
if self.sub_indices is None: |
|
return ImageNetTrain() |
|
else: |
|
return ImageNetTrain({"sub_indices": self.sub_indices}) |
|
|
|
def get_depth_path(self, e): |
|
fid = os.path.splitext(e["relpath"])[0]+".png" |
|
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid) |
|
return fid |
|
|
|
|
|
class ImageNetValidationWithDepth(BaseWithDepth): |
|
def __init__(self, sub_indices=None, **kwargs): |
|
self.sub_indices = sub_indices |
|
super().__init__(**kwargs) |
|
|
|
def get_base_dset(self): |
|
if self.sub_indices is None: |
|
return ImageNetValidation() |
|
else: |
|
return ImageNetValidation({"sub_indices": self.sub_indices}) |
|
|
|
def get_depth_path(self, e): |
|
fid = os.path.splitext(e["relpath"])[0]+".png" |
|
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid) |
|
return fid |
|
|
|
|
|
class RINTrainWithDepth(ImageNetTrainWithDepth): |
|
def __init__(self, config=None, size=None, random_crop=True, crop_size=None): |
|
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" |
|
super().__init__(config=config, size=size, random_crop=random_crop, |
|
sub_indices=sub_indices, crop_size=crop_size) |
|
|
|
|
|
class RINValidationWithDepth(ImageNetValidationWithDepth): |
|
def __init__(self, config=None, size=None, random_crop=False, crop_size=None): |
|
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" |
|
super().__init__(config=config, size=size, random_crop=random_crop, |
|
sub_indices=sub_indices, crop_size=crop_size) |
|
|
|
|
|
class DRINExamples(Dataset): |
|
def __init__(self): |
|
self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"}) |
|
with open("data/drin_examples.txt", "r") as f: |
|
relpaths = f.read().splitlines() |
|
self.image_paths = [os.path.join("data/drin_images", |
|
relpath) for relpath in relpaths] |
|
self.depth_paths = [os.path.join("data/drin_depth", |
|
relpath.replace(".JPEG", ".png")) for relpath in relpaths] |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def preprocess_image(self, image_path): |
|
image = Image.open(image_path) |
|
if not image.mode == "RGB": |
|
image = image.convert("RGB") |
|
image = np.array(image).astype(np.uint8) |
|
image = self.preprocessor(image=image)["image"] |
|
image = (image/127.5 - 1.0).astype(np.float32) |
|
return image |
|
|
|
def preprocess_depth(self, path): |
|
rgba = np.array(Image.open(path)) |
|
depth = rgba_to_depth(rgba) |
|
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) |
|
depth = 2.0*depth-1.0 |
|
return depth |
|
|
|
def __getitem__(self, i): |
|
e = dict() |
|
e["image"] = self.preprocess_image(self.image_paths[i]) |
|
e["depth"] = self.preprocess_depth(self.depth_paths[i]) |
|
transformed = self.preprocessor(image=e["image"], depth=e["depth"]) |
|
e["image"] = transformed["image"] |
|
e["depth"] = transformed["depth"] |
|
return e |
|
|
|
|
|
def imscale(x, factor, keepshapes=False, keepmode="bicubic"): |
|
if factor is None or factor==1: |
|
return x |
|
|
|
dtype = x.dtype |
|
assert dtype in [np.float32, np.float64] |
|
assert x.min() >= -1 |
|
assert x.max() <= 1 |
|
|
|
keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, |
|
"bicubic": Image.BICUBIC}[keepmode] |
|
|
|
lr = (x+1.0)*127.5 |
|
lr = lr.clip(0,255).astype(np.uint8) |
|
lr = Image.fromarray(lr) |
|
|
|
h, w, _ = x.shape |
|
nh = h//factor |
|
nw = w//factor |
|
assert nh > 0 and nw > 0, (nh, nw) |
|
|
|
lr = lr.resize((nw,nh), Image.BICUBIC) |
|
if keepshapes: |
|
lr = lr.resize((w,h), keepmode) |
|
lr = np.array(lr)/127.5-1.0 |
|
lr = lr.astype(dtype) |
|
|
|
return lr |
|
|
|
|
|
class ImageNetScale(Dataset): |
|
def __init__(self, size=None, crop_size=None, random_crop=False, |
|
up_factor=None, hr_factor=None, keep_mode="bicubic"): |
|
self.base = self.get_base() |
|
|
|
self.size = size |
|
self.crop_size = crop_size if crop_size is not None else self.size |
|
self.random_crop = random_crop |
|
self.up_factor = up_factor |
|
self.hr_factor = hr_factor |
|
self.keep_mode = keep_mode |
|
|
|
transforms = list() |
|
|
|
if self.size is not None and self.size > 0: |
|
rescaler = albumentations.SmallestMaxSize(max_size = self.size) |
|
self.rescaler = rescaler |
|
transforms.append(rescaler) |
|
|
|
if self.crop_size is not None and self.crop_size > 0: |
|
if len(transforms) == 0: |
|
self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size) |
|
|
|
if not self.random_crop: |
|
cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size) |
|
else: |
|
cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size) |
|
transforms.append(cropper) |
|
|
|
if len(transforms) > 0: |
|
if self.up_factor is not None: |
|
additional_targets = {"lr": "image"} |
|
else: |
|
additional_targets = None |
|
self.preprocessor = albumentations.Compose(transforms, |
|
additional_targets=additional_targets) |
|
else: |
|
self.preprocessor = lambda **kwargs: kwargs |
|
|
|
def __len__(self): |
|
return len(self.base) |
|
|
|
def __getitem__(self, i): |
|
example = self.base[i] |
|
image = example["image"] |
|
|
|
image = imscale(image, self.hr_factor, keepshapes=False) |
|
h,w,c = image.shape |
|
if self.crop_size and min(h,w) < self.crop_size: |
|
|
|
image = self.rescaler(image=image)["image"] |
|
if self.up_factor is None: |
|
image = self.preprocessor(image=image)["image"] |
|
example["image"] = image |
|
else: |
|
lr = imscale(image, self.up_factor, keepshapes=True, |
|
keepmode=self.keep_mode) |
|
|
|
out = self.preprocessor(image=image, lr=lr) |
|
example["image"] = out["image"] |
|
example["lr"] = out["lr"] |
|
|
|
return example |
|
|
|
class ImageNetScaleTrain(ImageNetScale): |
|
def __init__(self, random_crop=True, **kwargs): |
|
super().__init__(random_crop=random_crop, **kwargs) |
|
|
|
def get_base(self): |
|
return ImageNetTrain() |
|
|
|
class ImageNetScaleValidation(ImageNetScale): |
|
def get_base(self): |
|
return ImageNetValidation() |
|
|
|
|
|
from skimage.feature import canny |
|
from skimage.color import rgb2gray |
|
|
|
|
|
class ImageNetEdges(ImageNetScale): |
|
def __init__(self, up_factor=1, **kwargs): |
|
super().__init__(up_factor=1, **kwargs) |
|
|
|
def __getitem__(self, i): |
|
example = self.base[i] |
|
image = example["image"] |
|
h,w,c = image.shape |
|
if self.crop_size and min(h,w) < self.crop_size: |
|
|
|
image = self.rescaler(image=image)["image"] |
|
|
|
lr = canny(rgb2gray(image), sigma=2) |
|
lr = lr.astype(np.float32) |
|
lr = lr[:,:,None][:,:,[0,0,0]] |
|
|
|
out = self.preprocessor(image=image, lr=lr) |
|
example["image"] = out["image"] |
|
example["lr"] = out["lr"] |
|
|
|
return example |
|
|
|
|
|
class ImageNetEdgesTrain(ImageNetEdges): |
|
def __init__(self, random_crop=True, **kwargs): |
|
super().__init__(random_crop=random_crop, **kwargs) |
|
|
|
def get_base(self): |
|
return ImageNetTrain() |
|
|
|
class ImageNetEdgesValidation(ImageNetEdges): |
|
def get_base(self): |
|
return ImageNetValidation() |
|
|