Spaces:
Runtime error
Runtime error
| # --------------------------------------------------------------- | |
| # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # This work is licensed under the NVIDIA Source Code License | |
| # for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file. | |
| # --------------------------------------------------------------- | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.datasets as dset | |
| import torchvision.transforms as transforms | |
| class StackedMNIST(dset.MNIST): | |
| def __init__(self, root, train=True, transform=None, target_transform=None, | |
| download=False): | |
| super(StackedMNIST, self).__init__(root=root, train=train, transform=transform, | |
| target_transform=target_transform, download=download) | |
| index1 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))]) | |
| index2 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))]) | |
| index3 = np.hstack([np.random.permutation(len(self.data)), np.random.permutation(len(self.data))]) | |
| self.num_images = 2 * len(self.data) | |
| self.index = [] | |
| for i in range(self.num_images): | |
| self.index.append((index1[i], index2[i], index3[i])) | |
| def __len__(self): | |
| return self.num_images | |
| def __getitem__(self, index): | |
| img = np.zeros((28, 28, 3), dtype=np.uint8) | |
| target = 0 | |
| for i in range(3): | |
| img_, target_ = self.data[self.index[index][i]], int(self.targets[self.index[index][i]]) | |
| img[:, :, i] = img_ | |
| target += target_ * 10 ** (2 - i) | |
| img = Image.fromarray(img, mode="RGB") | |
| if self.transform is not None: | |
| img = self.transform(img) | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| return img, target | |
| def _data_transforms_stacked_mnist(): | |
| """Get data transforms for cifar10.""" | |
| train_transform = transforms.Compose([ | |
| transforms.Pad(padding=2), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) | |
| ]) | |
| valid_transform = transforms.Compose([ | |
| transforms.Pad(padding=2), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) | |
| ]) | |
| return train_transform, valid_transform | |