Spaces:
Sleeping
Sleeping
File size: 8,819 Bytes
0135475 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
# Ultralytics YOLO π, GPL-3.0 license
import glob
import math
import os
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT
from .utils import HELP_URL, IMG_FORMATS
class BaseDataset(Dataset):
"""Base Dataset.
Args:
img_path (str): image path.
pipeline (dict): a dict of image transforms.
label_path (str): label path, this can also be an ann_file or other custom label path.
"""
def __init__(self,
img_path,
imgsz=640,
cache=False,
augment=True,
hyp=None,
prefix='',
rect=False,
batch_size=None,
stride=32,
pad=0.5,
single_cls=False,
classes=None):
super().__init__()
self.img_path = img_path
self.imgsz = imgsz
self.augment = augment
self.single_cls = single_cls
self.prefix = prefix
self.im_files = self.get_img_files(self.img_path)
self.labels = self.get_labels()
self.update_labels(include_class=classes) # single_cls and include_class
self.ni = len(self.labels)
# rect stuff
self.rect = rect
self.batch_size = batch_size
self.stride = stride
self.pad = pad
if self.rect:
assert self.batch_size is not None
self.set_rectangle()
# cache stuff
self.ims = [None] * self.ni
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
if cache:
self.cache_images(cache)
# transforms
self.transforms = self.build_transforms(hyp=hyp)
def get_img_files(self, img_path):
"""Read image files."""
try:
f = [] # image files
for p in img_path if isinstance(img_path, list) else [img_path]:
p = Path(p) # os-agnostic
if p.is_dir(): # dir
f += glob.glob(str(p / '**' / '*.*'), recursive=True)
# f = list(p.rglob('*.*')) # pathlib
elif p.is_file(): # file
with open(p) as t:
t = t.read().strip().splitlines()
parent = str(p.parent) + os.sep
f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else:
raise FileNotFoundError(f'{self.prefix}{p} does not exist')
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert im_files, f'{self.prefix}No images found'
except Exception as e:
raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
return im_files
def update_labels(self, include_class: Optional[list]):
"""include_class, filter labels to include only these classes (optional)"""
include_class_array = np.array(include_class).reshape(1, -1)
for i in range(len(self.labels)):
if include_class is not None:
cls = self.labels[i]['cls']
bboxes = self.labels[i]['bboxes']
segments = self.labels[i]['segments']
j = (cls == include_class_array).any(1)
self.labels[i]['cls'] = cls[j]
self.labels[i]['bboxes'] = bboxes[j]
if segments:
self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
if self.single_cls:
self.labels[i]['cls'][:, 0] = 0
def load_image(self, i):
# Loads 1 image from dataset index 'i', returns (im, resized hw)
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f'Image Not Found {f}')
h0, w0 = im.shape[:2] # orig hw
r = self.imgsz / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
def cache_images(self, cache):
# cache images to memory or disk
gb = 0 # Gigabytes of cached images
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(fcn, range(self.ni))
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
for i, x in pbar:
if cache == 'disk':
gb += self.npy_files[i].stat().st_size
else: # 'ram'
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.ims[i].nbytes
pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})'
pbar.close()
def cache_images_to_disk(self, i):
# Saves an image as an *.npy file for faster loading
f = self.npy_files[i]
if not f.exists():
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
def set_rectangle(self):
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches
s = np.array([x.pop('shape') for x in self.labels]) # hw
ar = s[:, 0] / s[:, 1] # aspect ratio
irect = ar.argsort()
self.im_files = [self.im_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
ar = ar[irect]
# Set training image shapes
shapes = [[1, 1]] * nb
for i in range(nb):
ari = ar[bi == i]
mini, maxi = ari.min(), ari.max()
if maxi < 1:
shapes[i] = [maxi, 1]
elif mini > 1:
shapes[i] = [1, 1 / mini]
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
self.batch = bi # batch index of image
def __getitem__(self, index):
return self.transforms(self.get_label_info(index))
def get_label_info(self, index):
label = self.labels[index].copy()
label.pop('shape', None) # shape is for rect, remove it
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
label['ratio_pad'] = (
label['resized_shape'][0] / label['ori_shape'][0],
label['resized_shape'][1] / label['ori_shape'][1],
) # for evaluation
if self.rect:
label['rect_shape'] = self.batch_shapes[self.batch[index]]
label = self.update_labels_info(label)
return label
def __len__(self):
return len(self.labels)
def update_labels_info(self, label):
"""custom your label format here"""
return label
def build_transforms(self, hyp=None):
"""Users can custom augmentations here
like:
if self.augment:
# training transforms
return Compose([])
else:
# val transforms
return Compose([])
"""
raise NotImplementedError
def get_labels(self):
"""Users can custom their own format here.
Make sure your output is a list with each element like below:
dict(
im_file=im_file,
shape=shape, # format: (height, width)
cls=cls,
bboxes=bboxes, # xywh
segments=segments, # xy
keypoints=keypoints, # xy
normalized=True, # or False
bbox_format="xyxy", # or xywh, ltwh
)
"""
raise NotImplementedError
|