Spaces:
Runtime error
Runtime error
| import os | |
| import matplotlib.pyplot as plt | |
| from pandas.core.common import flatten | |
| import copy | |
| import numpy as np | |
| import random | |
| import torch | |
| from torch import nn | |
| from torch import optim | |
| import torch.nn.functional as F | |
| from torchvision import datasets, transforms, models | |
| from torch.utils.data import Dataset, DataLoader | |
| import torch.nn as nn | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import cv2 | |
| import glob | |
| from tqdm import tqdm | |
| import random | |
| class MotorbikeDataset(torch.utils.data.Dataset): | |
| def __init__(self, image_paths, transform=None): | |
| self.root = image_paths | |
| self.image_paths = os.listdir(image_paths) | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image_filepath = self.image_paths[idx] | |
| image = cv2.imread(os.path.join(self.root,image_filepath)) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| label = int('t' in image_filepath) | |
| if self.transform is not None: | |
| image = self.transform(image=image)["image"] | |
| return image, label | |
| class MotorbikeDataset_CV(torch.utils.data.Dataset): | |
| def __init__(self, root, train_transforms, val_transforms, trainval_ratio=0.8) -> None: | |
| self.root = root | |
| self.train_transforms = train_transforms | |
| self.val_transforms = val_transforms | |
| self.trainval_ratio = trainval_ratio | |
| self.train_split, self.val_split = self.gen_split() | |
| def __len__(self): | |
| return len(self.root) | |
| def gen_split(self): | |
| img_list = os.listdir(self.root) | |
| n_list = [img for img in img_list if img.startswith('n_')] | |
| t_list = [img for img in img_list if img.startswith('t_')] | |
| n_train = random.choices(n_list, k=int(len(n_list)*self.trainval_ratio)) | |
| t_train = random.choices(t_list, k=int(len(t_list)*self.trainval_ratio)) | |
| n_val = [img for img in n_list if img not in n_train] | |
| t_val = [img for img in t_list if img not in t_train] | |
| train_split = n_train + t_train | |
| val_split = n_val + t_val | |
| return train_split, val_split | |
| def get_split(self): | |
| train_dataset = Dataset_from_list(self.root, self.train_split, self.train_transforms) | |
| val_dataset = Dataset_from_list(self.root, self.val_split, self.val_transforms) | |
| return train_dataset, val_dataset | |
| class Dataset_from_list(torch.utils.data.Dataset): | |
| def __init__(self, root, img_list, transform) -> None: | |
| self.root = root | |
| self.img_list = img_list | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.img_list) | |
| def __getitem__(self, idx): | |
| image = cv2.imread(os.path.join(self.root, self.img_list[idx])) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| label = int(self.img_list[idx].startswith('t_')) | |
| if self.transform is not None: | |
| image = self.transform(image=image)["image"] | |
| return image, label | |