|
import os |
|
import os.path |
|
from typing import Any, Callable, cast, Dict, List, Optional, Tuple |
|
from typing import Union |
|
|
|
from PIL import Image |
|
import pandas as pd |
|
from torchvision.datasets import VisionDataset |
|
import torch |
|
|
|
|
|
def pil_loader(path: str) -> Image.Image: |
|
|
|
with open(path, "rb") as f: |
|
img = Image.open(f) |
|
return img.convert("RGB") |
|
|
|
class BinaryWaterbirds(VisionDataset): |
|
def __init__( |
|
self, |
|
root: str, |
|
split: str, |
|
loader: Callable[[str], Any] = pil_loader, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
) -> None: |
|
super().__init__(root, transform=transform, target_transform=target_transform) |
|
|
|
self.loader = loader |
|
csv = pd.read_csv(os.path.join(root, 'metadata.csv')) |
|
split = {'test': 2, 'valid': 1, 'train': 0}[split] |
|
csv = csv[csv['split'] == split] |
|
self.samples = [(os.path.join(root, csv.iloc[i]['img_filename']), csv.iloc[i]['y']) for i in range(len(csv))] |
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
""" |
|
Args: |
|
index (int): Index |
|
Returns: |
|
tuple: (sample, target) where target is class_index of the target class. |
|
""" |
|
path, target = self.samples[index] |
|
sample = self.loader(path) |
|
if self.transform is not None: |
|
sample = self.transform(sample) |
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
|
|
return sample, target |
|
|
|
def __len__(self) -> int: |
|
return len(self.samples) |
|
|