# Copyright (c) OpenMMLab. All rights reserved. import copy import numpy as np import torch from mmdet.utils.util_mixins import NiceRepr class GeneralData(NiceRepr): """A general data structure of OpenMMlab. A data structure that stores the meta information, the annotations of the images or the model predictions, which can be used in communication between components. The attributes in `GeneralData` are divided into two parts, the `meta_info_fields` and the `data_fields` respectively. - `meta_info_fields`: Usually contains the information about the image such as filename, image_shape, pad_shape, etc. All attributes in it are immutable once set, but the user can add new meta information with `set_meta_info` function, all information can be accessed with methods `meta_info_keys`, `meta_info_values`, `meta_info_items`. - `data_fields`: Annotations or model predictions are stored. The attributes can be accessed or modified by dict-like or object-like operations, such as `.` , `[]`, `in`, `del`, `pop(str)` `get(str)`, `keys()`, `values()`, `items()`. Users can also apply tensor-like methods to all obj:`torch.Tensor` in the `data_fileds`, such as `.cuda()`, `.cpu()`, `.numpy()`, `device`, `.to()` `.detach()`, `.numpy()` Args: meta_info (dict, optional): A dict contains the meta information of single image. such as `img_shape`, `scale_factor`, etc. Default: None. data (dict, optional): A dict contains annotations of single image or model predictions. Default: None. Examples: >>> from mmdet.core import GeneralData >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) >>> instance_data = GeneralData(meta_info=img_meta) >>> img_shape in instance_data True >>> instance_data.det_labels = torch.LongTensor([0, 1, 2, 3]) >>> instance_data["det_scores"] = torch.Tensor([0.01, 0.1, 0.2, 0.3]) >>> print(results) >>> instance_data.det_scores tensor([0.0100, 0.1000, 0.2000, 0.3000]) >>> instance_data.det_labels tensor([0, 1, 2, 3]) >>> instance_data['det_labels'] tensor([0, 1, 2, 3]) >>> 'det_labels' in instance_data True >>> instance_data.img_shape (800, 1196, 3) >>> 'det_scores' in instance_data True >>> del instance_data.det_scores >>> 'det_scores' in instance_data False >>> det_labels = instance_data.pop('det_labels', None) >>> det_labels tensor([0, 1, 2, 3]) >>> 'det_labels' in instance_data >>> False """ def __init__(self, meta_info=None, data=None): self._meta_info_fields = set() self._data_fields = set() if meta_info is not None: self.set_meta_info(meta_info=meta_info) if data is not None: self.set_data(data) def set_meta_info(self, meta_info): """Add meta information. Args: meta_info (dict): A dict contains the meta information of image. such as `img_shape`, `scale_factor`, etc. Default: None. """ assert isinstance(meta_info, dict), f'meta should be a `dict` but get {meta_info}' meta = copy.deepcopy(meta_info) for k, v in meta.items(): # should be consistent with original meta_info if k in self._meta_info_fields: ori_value = getattr(self, k) if isinstance(ori_value, (torch.Tensor, np.ndarray)): if (ori_value == v).all(): continue else: raise KeyError( f'img_meta_info {k} has been set as ' f'{getattr(self, k)} before, which is immutable ') elif ori_value == v: continue else: raise KeyError( f'img_meta_info {k} has been set as ' f'{getattr(self, k)} before, which is immutable ') else: self._meta_info_fields.add(k) self.__dict__[k] = v def set_data(self, data): """Update a dict to `data_fields`. Args: data (dict): A dict contains annotations of image or model predictions. Default: None. """ assert isinstance(data, dict), f'meta should be a `dict` but get {data}' for k, v in data.items(): self.__setattr__(k, v) def new(self, meta_info=None, data=None): """Return a new results with same image meta information. Args: meta_info (dict, optional): A dict contains the meta information of image. such as `img_shape`, `scale_factor`, etc. Default: None. data (dict, optional): A dict contains annotations of image or model predictions. Default: None. """ new_data = self.__class__() new_data.set_meta_info(dict(self.meta_info_items())) if meta_info is not None: new_data.set_meta_info(meta_info) if data is not None: new_data.set_data(data) return new_data def keys(self): """ Returns: list: Contains all keys in data_fields. """ return [key for key in self._data_fields] def meta_info_keys(self): """ Returns: list: Contains all keys in meta_info_fields. """ return [key for key in self._meta_info_fields] def values(self): """ Returns: list: Contains all values in data_fields. """ return [getattr(self, k) for k in self.keys()] def meta_info_values(self): """ Returns: list: Contains all values in meta_info_fields. """ return [getattr(self, k) for k in self.meta_info_keys()] def items(self): for k in self.keys(): yield (k, getattr(self, k)) def meta_info_items(self): for k in self.meta_info_keys(): yield (k, getattr(self, k)) def __setattr__(self, name, val): if name in ('_meta_info_fields', '_data_fields'): if not hasattr(self, name): super().__setattr__(name, val) else: raise AttributeError( f'{name} has been used as a ' f'private attribute, which is immutable. ') else: if name in self._meta_info_fields: raise AttributeError(f'`{name}` is used in meta information,' f'which is immutable') self._data_fields.add(name) super().__setattr__(name, val) def __delattr__(self, item): if item in ('_meta_info_fields', '_data_fields'): raise AttributeError(f'{item} has been used as a ' f'private attribute, which is immutable. ') if item in self._meta_info_fields: raise KeyError(f'{item} is used in meta information, ' f'which is immutable.') super().__delattr__(item) if item in self._data_fields: self._data_fields.remove(item) # dict-like methods __setitem__ = __setattr__ __delitem__ = __delattr__ def __getitem__(self, name): return getattr(self, name) def get(self, *args): assert len(args) < 3, '`get` get more than 2 arguments' return self.__dict__.get(*args) def pop(self, *args): assert len(args) < 3, '`pop` get more than 2 arguments' name = args[0] if name in self._meta_info_fields: raise KeyError(f'{name} is a key in meta information, ' f'which is immutable') if args[0] in self._data_fields: self._data_fields.remove(args[0]) return self.__dict__.pop(*args) # with default value elif len(args) == 2: return args[1] else: raise KeyError(f'{args[0]}') def __contains__(self, item): return item in self._data_fields or \ item in self._meta_info_fields # Tensor-like methods def to(self, *args, **kwargs): """Apply same name function to all tensors in data_fields.""" new_data = self.new() for k, v in self.items(): if hasattr(v, 'to'): v = v.to(*args, **kwargs) new_data[k] = v return new_data # Tensor-like methods def cpu(self): """Apply same name function to all tensors in data_fields.""" new_data = self.new() for k, v in self.items(): if isinstance(v, torch.Tensor): v = v.cpu() new_data[k] = v return new_data # Tensor-like methods def npu(self): """Apply same name function to all tensors in data_fields.""" new_data = self.new() for k, v in self.items(): if isinstance(v, torch.Tensor): v = v.npu() new_data[k] = v return new_data # Tensor-like methods def mlu(self): """Apply same name function to all tensors in data_fields.""" new_data = self.new() for k, v in self.items(): if isinstance(v, torch.Tensor): v = v.mlu() new_data[k] = v return new_data # Tensor-like methods def cuda(self): """Apply same name function to all tensors in data_fields.""" new_data = self.new() for k, v in self.items(): if isinstance(v, torch.Tensor): v = v.cuda() new_data[k] = v return new_data # Tensor-like methods def detach(self): """Apply same name function to all tensors in data_fields.""" new_data = self.new() for k, v in self.items(): if isinstance(v, torch.Tensor): v = v.detach() new_data[k] = v return new_data # Tensor-like methods def numpy(self): """Apply same name function to all tensors in data_fields.""" new_data = self.new() for k, v in self.items(): if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() new_data[k] = v return new_data def __nice__(self): repr = '\n \n META INFORMATION \n' for k, v in self.meta_info_items(): repr += f'{k}: {v} \n' repr += '\n DATA FIELDS \n' for k, v in self.items(): if isinstance(v, (torch.Tensor, np.ndarray)): repr += f'shape of {k}: {v.shape} \n' else: repr += f'{k}: {v} \n' return repr + '\n'