Spaces:
Runtime error
Runtime error
| from os.path import expanduser | |
| import torch | |
| import json | |
| import torchvision | |
| from general_utils import get_from_repository | |
| from general_utils import log | |
| from torchvision import transforms | |
| PASCAL_VOC_CLASSES_ZS = [['cattle.n.01', 'motorcycle.n.01'], ['aeroplane.n.01', 'sofa.n.01'], | |
| ['cat.n.01', 'television.n.03'], ['train.n.01', 'bottle.n.01'], | |
| ['chair.n.01', 'pot_plant.n.01']] | |
| class PascalZeroShot(object): | |
| def __init__(self, split, n_unseen, image_size=224) -> None: | |
| super().__init__() | |
| import sys | |
| sys.path.append('third_party/JoEm') | |
| from third_party.JoEm.data_loader.dataset import VOCSegmentation | |
| from third_party.JoEm.data_loader import get_seen_idx, get_unseen_idx, VOC | |
| self.pascal_classes = VOC | |
| self.image_size = image_size | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| ]) | |
| if split == 'train': | |
| self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), | |
| split=split, transform=True, transform_args=dict(base_size=312, crop_size=312), | |
| ignore_bg=False, ignore_unseen=False, remv_unseen_img=True) | |
| elif split == 'val': | |
| self.voc = VOCSegmentation(get_unseen_idx(n_unseen), get_seen_idx(n_unseen), | |
| split=split, transform=False, | |
| ignore_bg=False, ignore_unseen=False) | |
| self.unseen_idx = get_unseen_idx(n_unseen) | |
| def __len__(self): | |
| return len(self.voc) | |
| def __getitem__(self, i): | |
| sample = self.voc[i] | |
| label = sample['label'].long() | |
| all_labels = [l for l in torch.where(torch.bincount(label.flatten())>0)[0].numpy().tolist() if l != 255] | |
| class_indices = [l for l in all_labels] | |
| class_names = [self.pascal_classes[l] for l in all_labels] | |
| image = self.transform(sample['image']) | |
| label = transforms.Resize((self.image_size, self.image_size), | |
| interpolation=torchvision.transforms.InterpolationMode.NEAREST)(label.unsqueeze(0))[0] | |
| return (image,), (label, ) | |