danurahul's picture
Add application file
d4bc11a
import os
import numpy as np
import cv2
import albumentations
from PIL import Image
from torch.utils.data import Dataset
from taming.data.sflckr import SegmentationBase # for examples included in repo
class Examples(SegmentationBase):
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
super().__init__(data_csv="data/ade20k_examples.txt",
data_root="data/ade20k_images",
segmentation_root="data/ade20k_segmentations",
size=size, random_crop=random_crop,
interpolation=interpolation,
n_labels=151, shift_segmentation=False)
# With semantic map and scene label
class ADE20kBase(Dataset):
def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
self.split = self.get_split()
self.n_labels = 151 # unknown + 150
self.data_csv = {"train": "data/ade20k_train.txt",
"validation": "data/ade20k_test.txt"}[self.split]
self.data_root = "data/ade20k_root"
with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
self.scene_categories = f.read().splitlines()
self.scene_categories = dict(line.split() for line in self.scene_categories)
with open(self.data_csv, "r") as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, "images", l)
for l in self.image_paths],
"relative_segmentation_path_": [l.replace(".jpg", ".png")
for l in self.image_paths],
"segmentation_path_": [os.path.join(self.data_root, "annotations",
l.replace(".jpg", ".png"))
for l in self.image_paths],
"scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
for l in self.image_paths],
}
size = None if size is not None and size<=0 else size
self.size = size
if crop_size is None:
self.crop_size = size if size is not None else None
else:
self.crop_size = crop_size
if self.size is not None:
self.interpolation = interpolation
self.interpolation = {
"nearest": cv2.INTER_NEAREST,
"bilinear": cv2.INTER_LINEAR,
"bicubic": cv2.INTER_CUBIC,
"area": cv2.INTER_AREA,
"lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=self.interpolation)
self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=cv2.INTER_NEAREST)
if crop_size is not None:
self.center_crop = not random_crop
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
else:
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
self.preprocessor = self.cropper
def __len__(self):
return self._length
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
if self.size is not None:
image = self.image_rescaler(image=image)["image"]
segmentation = Image.open(example["segmentation_path_"])
segmentation = np.array(segmentation).astype(np.uint8)
if self.size is not None:
segmentation = self.segmentation_rescaler(image=segmentation)["image"]
if self.size is not None:
processed = self.preprocessor(image=image, mask=segmentation)
else:
processed = {"image": image, "mask": segmentation}
example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
segmentation = processed["mask"]
onehot = np.eye(self.n_labels)[segmentation]
example["segmentation"] = onehot
return example
class ADE20kTrain(ADE20kBase):
# default to random_crop=True
def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
super().__init__(config=config, size=size, random_crop=random_crop,
interpolation=interpolation, crop_size=crop_size)
def get_split(self):
return "train"
class ADE20kValidation(ADE20kBase):
def get_split(self):
return "validation"
if __name__ == "__main__":
dset = ADE20kValidation()
ex = dset[0]
for k in ["image", "scene_category", "segmentation"]:
print(type(ex[k]))
try:
print(ex[k].shape)
except:
print(ex[k])