Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import xml.etree.ElementTree as ET | |
import mmcv | |
import numpy as np | |
from PIL import Image | |
from .builder import DATASETS | |
from .custom import CustomDataset | |
class XMLDataset(CustomDataset): | |
"""XML dataset for detection. | |
Args: | |
min_size (int | float, optional): The minimum size of bounding | |
boxes in the images. If the size of a bounding box is less than | |
``min_size``, it would be add to ignored field. | |
img_subdir (str): Subdir where images are stored. Default: JPEGImages. | |
ann_subdir (str): Subdir where annotations are. Default: Annotations. | |
""" | |
def __init__(self, | |
min_size=None, | |
img_subdir='JPEGImages', | |
ann_subdir='Annotations', | |
**kwargs): | |
assert self.CLASSES or kwargs.get( | |
'classes', None), 'CLASSES in `XMLDataset` can not be None.' | |
self.img_subdir = img_subdir | |
self.ann_subdir = ann_subdir | |
super(XMLDataset, self).__init__(**kwargs) | |
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} | |
self.min_size = min_size | |
def load_annotations(self, ann_file): | |
"""Load annotation from XML style ann_file. | |
Args: | |
ann_file (str): Path of XML file. | |
Returns: | |
list[dict]: Annotation info from XML file. | |
""" | |
data_infos = [] | |
img_ids = mmcv.list_from_file(ann_file) | |
for img_id in img_ids: | |
filename = osp.join(self.img_subdir, f'{img_id}.jpg') | |
xml_path = osp.join(self.img_prefix, self.ann_subdir, | |
f'{img_id}.xml') | |
tree = ET.parse(xml_path) | |
root = tree.getroot() | |
size = root.find('size') | |
if size is not None: | |
width = int(size.find('width').text) | |
height = int(size.find('height').text) | |
else: | |
img_path = osp.join(self.img_prefix, filename) | |
img = Image.open(img_path) | |
width, height = img.size | |
data_infos.append( | |
dict(id=img_id, filename=filename, width=width, height=height)) | |
return data_infos | |
def _filter_imgs(self, min_size=32): | |
"""Filter images too small or without annotation.""" | |
valid_inds = [] | |
for i, img_info in enumerate(self.data_infos): | |
if min(img_info['width'], img_info['height']) < min_size: | |
continue | |
if self.filter_empty_gt: | |
img_id = img_info['id'] | |
xml_path = osp.join(self.img_prefix, self.ann_subdir, | |
f'{img_id}.xml') | |
tree = ET.parse(xml_path) | |
root = tree.getroot() | |
for obj in root.findall('object'): | |
name = obj.find('name').text | |
if name in self.CLASSES: | |
valid_inds.append(i) | |
break | |
else: | |
valid_inds.append(i) | |
return valid_inds | |
def get_ann_info(self, idx): | |
"""Get annotation from XML file by index. | |
Args: | |
idx (int): Index of data. | |
Returns: | |
dict: Annotation info of specified index. | |
""" | |
img_id = self.data_infos[idx]['id'] | |
xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml') | |
tree = ET.parse(xml_path) | |
root = tree.getroot() | |
bboxes = [] | |
labels = [] | |
bboxes_ignore = [] | |
labels_ignore = [] | |
for obj in root.findall('object'): | |
name = obj.find('name').text | |
if name not in self.CLASSES: | |
continue | |
label = self.cat2label[name] | |
difficult = obj.find('difficult') | |
difficult = 0 if difficult is None else int(difficult.text) | |
bnd_box = obj.find('bndbox') | |
# TODO: check whether it is necessary to use int | |
# Coordinates may be float type | |
bbox = [ | |
int(float(bnd_box.find('xmin').text)), | |
int(float(bnd_box.find('ymin').text)), | |
int(float(bnd_box.find('xmax').text)), | |
int(float(bnd_box.find('ymax').text)) | |
] | |
ignore = False | |
if self.min_size: | |
assert not self.test_mode | |
w = bbox[2] - bbox[0] | |
h = bbox[3] - bbox[1] | |
if w < self.min_size or h < self.min_size: | |
ignore = True | |
if difficult or ignore: | |
bboxes_ignore.append(bbox) | |
labels_ignore.append(label) | |
else: | |
bboxes.append(bbox) | |
labels.append(label) | |
if not bboxes: | |
bboxes = np.zeros((0, 4)) | |
labels = np.zeros((0, )) | |
else: | |
bboxes = np.array(bboxes, ndmin=2) - 1 | |
labels = np.array(labels) | |
if not bboxes_ignore: | |
bboxes_ignore = np.zeros((0, 4)) | |
labels_ignore = np.zeros((0, )) | |
else: | |
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 | |
labels_ignore = np.array(labels_ignore) | |
ann = dict( | |
bboxes=bboxes.astype(np.float32), | |
labels=labels.astype(np.int64), | |
bboxes_ignore=bboxes_ignore.astype(np.float32), | |
labels_ignore=labels_ignore.astype(np.int64)) | |
return ann | |
def get_cat_ids(self, idx): | |
"""Get category ids in XML file by index. | |
Args: | |
idx (int): Index of data. | |
Returns: | |
list[int]: All categories in the image of specified index. | |
""" | |
cat_ids = [] | |
img_id = self.data_infos[idx]['id'] | |
xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml') | |
tree = ET.parse(xml_path) | |
root = tree.getroot() | |
for obj in root.findall('object'): | |
name = obj.find('name').text | |
if name not in self.CLASSES: | |
continue | |
label = self.cat2label[name] | |
cat_ids.append(label) | |
return cat_ids | |