Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2019-present NAVER Corp. | |
| # CC BY-NC-SA 3.0 | |
| # Available only for non-commercial use | |
| import os, pdb | |
| import numpy as np | |
| from PIL import Image | |
| from .dataset import Dataset | |
| from .pair_dataset import PairDataset, StillPairDataset | |
| class AachenImages (Dataset): | |
| """ Loads all images from the Aachen Day-Night dataset | |
| """ | |
| def __init__(self, select='db day night', root='data/aachen'): | |
| Dataset.__init__(self) | |
| self.root = root | |
| self.img_dir = 'images_upright' | |
| self.select = set(select.split()) | |
| assert self.select, 'Nothing was selected' | |
| self.imgs = [] | |
| root = os.path.join(root, self.img_dir) | |
| for dirpath, _, filenames in os.walk(root): | |
| r = dirpath[len(root)+1:] | |
| if not(self.select & set(r.split('/'))): continue | |
| self.imgs += [os.path.join(r,f) for f in filenames if f.endswith('.jpg')] | |
| self.nimg = len(self.imgs) | |
| assert self.nimg, 'Empty Aachen dataset' | |
| def get_key(self, idx): | |
| return self.imgs[idx] | |
| class AachenImages_DB (AachenImages): | |
| """ Only database (db) images. | |
| """ | |
| def __init__(self, **kw): | |
| AachenImages.__init__(self, select='db', **kw) | |
| self.db_image_idxs = {self.get_tag(i) : i for i,f in enumerate(self.imgs)} | |
| def get_tag(self, idx): | |
| # returns image tag == img number (name) | |
| return os.path.split( self.imgs[idx][:-4] )[1] | |
| class AachenPairs_StyleTransferDayNight (AachenImages_DB, StillPairDataset): | |
| """ synthetic day-night pairs of images | |
| (night images obtained using autoamtic style transfer from web night images) | |
| """ | |
| def __init__(self, root='data/aachen/style_transfer', **kw): | |
| StillPairDataset.__init__(self) | |
| AachenImages_DB.__init__(self, **kw) | |
| old_root = os.path.join(self.root, self.img_dir) | |
| self.root = os.path.commonprefix((old_root, root)) | |
| self.img_dir = '' | |
| newpath = lambda folder, f: os.path.join(folder, f)[len(self.root):] | |
| self.imgs = [newpath(old_root, f) for f in self.imgs] | |
| self.image_pairs = [] | |
| for fname in os.listdir(root): | |
| tag = fname.split('.jpg.st_')[0] | |
| self.image_pairs.append((self.db_image_idxs[tag], len(self.imgs))) | |
| self.imgs.append(newpath(root, fname)) | |
| self.nimg = len(self.imgs) | |
| self.npairs = len(self.image_pairs) | |
| assert self.nimg and self.npairs | |
| class AachenPairs_OpticalFlow (AachenImages_DB, PairDataset): | |
| """ Image pairs from Aachen db with optical flow. | |
| """ | |
| def __init__(self, root='data/aachen/optical_flow', **kw): | |
| PairDataset.__init__(self) | |
| AachenImages_DB.__init__(self, **kw) | |
| self.root_flow = root | |
| # find out the subsest of valid pairs from the list of flow files | |
| flows = {f for f in os.listdir(os.path.join(root, 'flow')) if f.endswith('.png')} | |
| masks = {f for f in os.listdir(os.path.join(root, 'mask')) if f.endswith('.png')} | |
| assert flows == masks, 'Missing flow or mask pairs' | |
| make_pair = lambda f: tuple(self.db_image_idxs[v] for v in f[:-4].split('_')) | |
| self.image_pairs = [make_pair(f) for f in flows] | |
| self.npairs = len(self.image_pairs) | |
| assert self.nimg and self.npairs | |
| def get_mask_filename(self, pair_idx): | |
| tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx]) | |
| return os.path.join(self.root_flow, 'mask', f'{tag_a}_{tag_b}.png') | |
| def get_mask(self, pair_idx): | |
| return np.asarray(Image.open(self.get_mask_filename(pair_idx))) | |
| def get_flow_filename(self, pair_idx): | |
| tag_a, tag_b = map(self.get_tag, self.image_pairs[pair_idx]) | |
| return os.path.join(self.root_flow, 'flow', f'{tag_a}_{tag_b}.png') | |
| def get_flow(self, pair_idx): | |
| fname = self.get_flow_filename(pair_idx) | |
| try: | |
| return self._png2flow(fname) | |
| except IOError: | |
| flow = open(fname[:-4], 'rb') | |
| help = np.fromfile(flow, np.float32, 1) | |
| assert help == 202021.25 | |
| W, H = np.fromfile(flow, np.int32, 2) | |
| flow = np.fromfile(flow, np.float32).reshape((H, W, 2)) | |
| return self._flow2png(flow, fname) | |
| def get_pair(self, idx, output=()): | |
| if isinstance(output, str): | |
| output = output.split() | |
| img1, img2 = map(self.get_image, self.image_pairs[idx]) | |
| meta = {} | |
| if 'flow' in output or 'aflow' in output: | |
| flow = self.get_flow(idx) | |
| assert flow.shape[:2] == img1.size[::-1] | |
| meta['flow'] = flow | |
| H, W = flow.shape[:2] | |
| meta['aflow'] = flow + np.mgrid[:H,:W][::-1].transpose(1,2,0) | |
| if 'mask' in output: | |
| mask = self.get_mask(idx) | |
| assert mask.shape[:2] == img1.size[::-1] | |
| meta['mask'] = mask | |
| return img1, img2, meta | |
| if __name__ == '__main__': | |
| print(aachen_db_images) | |
| print(aachen_style_transfer_pairs) | |
| print(aachen_flow_pairs) | |
| pdb.set_trace() | |