Spaces:
Sleeping
Sleeping
# Ultralytics YOLO π, GPL-3.0 license | |
import glob | |
import math | |
import os | |
from multiprocessing.pool import ThreadPool | |
from pathlib import Path | |
from typing import Optional | |
import cv2 | |
import numpy as np | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT | |
from .utils import HELP_URL, IMG_FORMATS | |
class BaseDataset(Dataset): | |
"""Base Dataset. | |
Args: | |
img_path (str): image path. | |
pipeline (dict): a dict of image transforms. | |
label_path (str): label path, this can also be an ann_file or other custom label path. | |
""" | |
def __init__(self, | |
img_path, | |
imgsz=640, | |
cache=False, | |
augment=True, | |
hyp=None, | |
prefix='', | |
rect=False, | |
batch_size=None, | |
stride=32, | |
pad=0.5, | |
single_cls=False, | |
classes=None): | |
super().__init__() | |
self.img_path = img_path | |
self.imgsz = imgsz | |
self.augment = augment | |
self.single_cls = single_cls | |
self.prefix = prefix | |
self.im_files = self.get_img_files(self.img_path) | |
self.labels = self.get_labels() | |
self.update_labels(include_class=classes) # single_cls and include_class | |
self.ni = len(self.labels) | |
# rect stuff | |
self.rect = rect | |
self.batch_size = batch_size | |
self.stride = stride | |
self.pad = pad | |
if self.rect: | |
assert self.batch_size is not None | |
self.set_rectangle() | |
# cache stuff | |
self.ims = [None] * self.ni | |
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] | |
if cache: | |
self.cache_images(cache) | |
# transforms | |
self.transforms = self.build_transforms(hyp=hyp) | |
def get_img_files(self, img_path): | |
"""Read image files.""" | |
try: | |
f = [] # image files | |
for p in img_path if isinstance(img_path, list) else [img_path]: | |
p = Path(p) # os-agnostic | |
if p.is_dir(): # dir | |
f += glob.glob(str(p / '**' / '*.*'), recursive=True) | |
# f = list(p.rglob('*.*')) # pathlib | |
elif p.is_file(): # file | |
with open(p) as t: | |
t = t.read().strip().splitlines() | |
parent = str(p.parent) + os.sep | |
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path | |
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) | |
else: | |
raise FileNotFoundError(f'{self.prefix}{p} does not exist') | |
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS) | |
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib | |
assert im_files, f'{self.prefix}No images found' | |
except Exception as e: | |
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e | |
return im_files | |
def update_labels(self, include_class: Optional[list]): | |
"""include_class, filter labels to include only these classes (optional)""" | |
include_class_array = np.array(include_class).reshape(1, -1) | |
for i in range(len(self.labels)): | |
if include_class is not None: | |
cls = self.labels[i]['cls'] | |
bboxes = self.labels[i]['bboxes'] | |
segments = self.labels[i]['segments'] | |
j = (cls == include_class_array).any(1) | |
self.labels[i]['cls'] = cls[j] | |
self.labels[i]['bboxes'] = bboxes[j] | |
if segments: | |
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx] | |
if self.single_cls: | |
self.labels[i]['cls'][:, 0] = 0 | |
def load_image(self, i): | |
# Loads 1 image from dataset index 'i', returns (im, resized hw) | |
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] | |
if im is None: # not cached in RAM | |
if fn.exists(): # load npy | |
im = np.load(fn) | |
else: # read image | |
im = cv2.imread(f) # BGR | |
if im is None: | |
raise FileNotFoundError(f'Image Not Found {f}') | |
h0, w0 = im.shape[:2] # orig hw | |
r = self.imgsz / max(h0, w0) # ratio | |
if r != 1: # if sizes are not equal | |
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA | |
im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp) | |
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized | |
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized | |
def cache_images(self, cache): | |
# cache images to memory or disk | |
gb = 0 # Gigabytes of cached images | |
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni | |
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image | |
with ThreadPool(NUM_THREADS) as pool: | |
results = pool.imap(fcn, range(self.ni)) | |
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0) | |
for i, x in pbar: | |
if cache == 'disk': | |
gb += self.npy_files[i].stat().st_size | |
else: # 'ram' | |
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) | |
gb += self.ims[i].nbytes | |
pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})' | |
pbar.close() | |
def cache_images_to_disk(self, i): | |
# Saves an image as an *.npy file for faster loading | |
f = self.npy_files[i] | |
if not f.exists(): | |
np.save(f.as_posix(), cv2.imread(self.im_files[i])) | |
def set_rectangle(self): | |
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index | |
nb = bi[-1] + 1 # number of batches | |
s = np.array([x.pop('shape') for x in self.labels]) # hw | |
ar = s[:, 0] / s[:, 1] # aspect ratio | |
irect = ar.argsort() | |
self.im_files = [self.im_files[i] for i in irect] | |
self.labels = [self.labels[i] for i in irect] | |
ar = ar[irect] | |
# Set training image shapes | |
shapes = [[1, 1]] * nb | |
for i in range(nb): | |
ari = ar[bi == i] | |
mini, maxi = ari.min(), ari.max() | |
if maxi < 1: | |
shapes[i] = [maxi, 1] | |
elif mini > 1: | |
shapes[i] = [1, 1 / mini] | |
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride | |
self.batch = bi # batch index of image | |
def __getitem__(self, index): | |
return self.transforms(self.get_label_info(index)) | |
def get_label_info(self, index): | |
label = self.labels[index].copy() | |
label.pop('shape', None) # shape is for rect, remove it | |
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index) | |
label['ratio_pad'] = ( | |
label['resized_shape'][0] / label['ori_shape'][0], | |
label['resized_shape'][1] / label['ori_shape'][1], | |
) # for evaluation | |
if self.rect: | |
label['rect_shape'] = self.batch_shapes[self.batch[index]] | |
label = self.update_labels_info(label) | |
return label | |
def __len__(self): | |
return len(self.labels) | |
def update_labels_info(self, label): | |
"""custom your label format here""" | |
return label | |
def build_transforms(self, hyp=None): | |
"""Users can custom augmentations here | |
like: | |
if self.augment: | |
# training transforms | |
return Compose([]) | |
else: | |
# val transforms | |
return Compose([]) | |
""" | |
raise NotImplementedError | |
def get_labels(self): | |
"""Users can custom their own format here. | |
Make sure your output is a list with each element like below: | |
dict( | |
im_file=im_file, | |
shape=shape, # format: (height, width) | |
cls=cls, | |
bboxes=bboxes, # xywh | |
segments=segments, # xy | |
keypoints=keypoints, # xy | |
normalized=True, # or False | |
bbox_format="xyxy", # or xywh, ltwh | |
) | |
""" | |
raise NotImplementedError | |