Spaces:
Runtime error
Runtime error
| # --------------------------------------------------------------- | |
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # This file has been modified from a file in the torchvision library | |
| # which was released under the BSD 3-Clause License. | |
| # | |
| # Source: | |
| # https://github.com/pytorch/vision/blob/ea6b879e90459006e71a164dc76b7e2cc3bff9d9/torchvision/datasets/lsun.py | |
| # | |
| # The license for the original version of this file can be | |
| # found in this directory (LICENSE_torchvision). The modifications | |
| # to this file are subject to the same BSD 3-Clause License. | |
| # --------------------------------------------------------------- | |
| from torchvision.datasets.vision import VisionDataset | |
| from PIL import Image | |
| import os | |
| import os.path | |
| import io | |
| import string | |
| from collections.abc import Iterable | |
| import pickle | |
| from torchvision.datasets.utils import verify_str_arg, iterable_to_str | |
| class LSUNClass(VisionDataset): | |
| def __init__(self, root, transform=None, target_transform=None): | |
| import lmdb | |
| super(LSUNClass, self).__init__(root, transform=transform, | |
| target_transform=target_transform) | |
| self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, | |
| readahead=False, meminit=False) | |
| with self.env.begin(write=False) as txn: | |
| self.length = txn.stat()['entries'] | |
| # cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters) | |
| # av begin | |
| # We only modified the location of cache_file. | |
| cache_file = os.path.join(self.root, '_cache_') | |
| # av end | |
| if os.path.isfile(cache_file): | |
| self.keys = pickle.load(open(cache_file, "rb")) | |
| else: | |
| with self.env.begin(write=False) as txn: | |
| self.keys = [key for key, _ in txn.cursor()] | |
| pickle.dump(self.keys, open(cache_file, "wb")) | |
| def __getitem__(self, index): | |
| img, target = None, -1 | |
| env = self.env | |
| with env.begin(write=False) as txn: | |
| imgbuf = txn.get(self.keys[index]) | |
| buf = io.BytesIO() | |
| buf.write(imgbuf) | |
| buf.seek(0) | |
| img = Image.open(buf).convert('RGB') | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| return img, target | |
| def __len__(self): | |
| return self.length | |
| class LSUN(VisionDataset): | |
| """ | |
| `LSUN <https://www.yf.io/p/lsun>`_ dataset. | |
| Args: | |
| root (string): Root directory for the database files. | |
| classes (string or list): One of {'train', 'val', 'test'} or a list of | |
| categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. | |
| transform (callable, optional): A function/transform that takes in an PIL image | |
| and returns a transformed version. E.g, ``transforms.RandomCrop`` | |
| target_transform (callable, optional): A function/transform that takes in the | |
| target and transforms it. | |
| """ | |
| def __init__(self, root, classes='train', transform=None, target_transform=None): | |
| super(LSUN, self).__init__(root, transform=transform, | |
| target_transform=target_transform) | |
| self.classes = self._verify_classes(classes) | |
| # for each class, create an LSUNClassDataset | |
| self.dbs = [] | |
| for c in self.classes: | |
| self.dbs.append(LSUNClass( | |
| root=root + '/' + c + '_lmdb', | |
| transform=transform)) | |
| self.indices = [] | |
| count = 0 | |
| for db in self.dbs: | |
| count += len(db) | |
| self.indices.append(count) | |
| self.length = count | |
| def _verify_classes(self, classes): | |
| categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', | |
| 'conference_room', 'dining_room', 'kitchen', | |
| 'living_room', 'restaurant', 'tower', 'cat'] | |
| dset_opts = ['train', 'val', 'test'] | |
| try: | |
| verify_str_arg(classes, "classes", dset_opts) | |
| if classes == 'test': | |
| classes = [classes] | |
| else: | |
| classes = [c + '_' + classes for c in categories] | |
| except ValueError: | |
| if not isinstance(classes, Iterable): | |
| msg = ("Expected type str or Iterable for argument classes, " | |
| "but got type {}.") | |
| raise ValueError(msg.format(type(classes))) | |
| classes = list(classes) | |
| msg_fmtstr = ("Expected type str for elements in argument classes, " | |
| "but got type {}.") | |
| for c in classes: | |
| verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c))) | |
| c_short = c.split('_') | |
| category, dset_opt = '_'.join(c_short[:-1]), c_short[-1] | |
| msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." | |
| msg = msg_fmtstr.format(category, "LSUN class", | |
| iterable_to_str(categories)) | |
| verify_str_arg(category, valid_values=categories, custom_msg=msg) | |
| msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) | |
| verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) | |
| return classes | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: Tuple (image, target) where target is the index of the target category. | |
| """ | |
| target = 0 | |
| sub = 0 | |
| for ind in self.indices: | |
| if index < ind: | |
| break | |
| target += 1 | |
| sub = ind | |
| db = self.dbs[target] | |
| index = index - sub | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| img, _ = db[index] | |
| return img, target | |
| def __len__(self): | |
| return self.length | |
| def extra_repr(self): | |
| return "Classes: {classes}".format(**self.__dict__) | |