File size: 2,032 Bytes
7754b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
import h5py
__all__ = ['ImagenetResults']
class Imagenet_Segmentation(data.Dataset):
CLASSES = 2
def __init__(self,
path,
transform=None,
target_transform=None):
self.path = path
self.transform = transform
self.target_transform = target_transform
self.h5py = None
tmp = h5py.File(path, 'r')
self.data_length = len(tmp['/value/img'])
tmp.close()
del tmp
def __getitem__(self, index):
if self.h5py is None:
self.h5py = h5py.File(self.path, 'r')
img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
img = Image.fromarray(img).convert('RGB')
target = Image.fromarray(target)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = np.array(self.target_transform(target)).astype('int32')
target = torch.from_numpy(target).long()
return img, target
def __len__(self):
return self.data_length
class ImagenetResults(data.Dataset):
def __init__(self, path):
super(ImagenetResults, self).__init__()
self.path = os.path.join(path, 'results.hdf5')
self.data = None
print('Reading dataset length...')
with h5py.File(self.path, 'r') as f:
self.data_length = len(f['/image'])
def __len__(self):
return self.data_length
def __getitem__(self, item):
if self.data is None:
self.data = h5py.File(self.path, 'r')
image = torch.tensor(self.data['image'][item])
vis = torch.tensor(self.data['vis'][item])
target = torch.tensor(self.data['target'][item]).long()
return image, vis, target
|