File size: 1,725 Bytes
7acde1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union
from PIL import Image
import pandas as pd
from torchvision.datasets import VisionDataset
import torch
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")
class BinaryWaterbirds(VisionDataset):
def __init__(
self,
root: str,
split: str,
loader: Callable[[str], Any] = pil_loader,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.loader = loader
csv = pd.read_csv(os.path.join(root, 'metadata.csv'))
split = {'test': 2, 'valid': 1, 'train': 0}[split]
csv = csv[csv['split'] == split]
self.samples = [(os.path.join(root, csv.iloc[i]['img_filename']), csv.iloc[i]['y']) for i in range(len(csv))]
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
|