Spaces:
Runtime error
Runtime error
File size: 6,672 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from mmcv.cnn import VGG
from mmcv.runner.hooks import HOOKS, Hook
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import (LoadAnnotations, LoadImageFromFile,
LoadPanopticAnnotations)
from mmdet.models.dense_heads import GARPNHead, RPNHead
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
def replace_ImageToTensor(pipelines):
"""Replace the ImageToTensor transform in a data pipeline to
DefaultFormatBundle, which is normally useful in batch inference.
Args:
pipelines (list[dict]): Data pipeline configs.
Returns:
list: The new pipeline list with all ImageToTensor replaced by
DefaultFormatBundle.
Examples:
>>> pipelines = [
... dict(type='LoadImageFromFile'),
... dict(
... type='MultiScaleFlipAug',
... img_scale=(1333, 800),
... flip=False,
... transforms=[
... dict(type='Resize', keep_ratio=True),
... dict(type='RandomFlip'),
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
... dict(type='Pad', size_divisor=32),
... dict(type='ImageToTensor', keys=['img']),
... dict(type='Collect', keys=['img']),
... ])
... ]
>>> expected_pipelines = [
... dict(type='LoadImageFromFile'),
... dict(
... type='MultiScaleFlipAug',
... img_scale=(1333, 800),
... flip=False,
... transforms=[
... dict(type='Resize', keep_ratio=True),
... dict(type='RandomFlip'),
... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
... dict(type='Pad', size_divisor=32),
... dict(type='DefaultFormatBundle'),
... dict(type='Collect', keys=['img']),
... ])
... ]
>>> assert expected_pipelines == replace_ImageToTensor(pipelines)
"""
pipelines = copy.deepcopy(pipelines)
for i, pipeline in enumerate(pipelines):
if pipeline['type'] == 'MultiScaleFlipAug':
assert 'transforms' in pipeline
pipeline['transforms'] = replace_ImageToTensor(
pipeline['transforms'])
elif pipeline['type'] == 'ImageToTensor':
warnings.warn(
'"ImageToTensor" pipeline is replaced by '
'"DefaultFormatBundle" for batch inference. It is '
'recommended to manually replace it in the test '
'data pipeline in your config file.', UserWarning)
pipelines[i] = {'type': 'DefaultFormatBundle'}
return pipelines
def get_loading_pipeline(pipeline):
"""Only keep loading image and annotations related configuration.
Args:
pipeline (list[dict]): Data pipeline configs.
Returns:
list[dict]: The new pipeline list with only keep
loading image and annotations related configuration.
Examples:
>>> pipelines = [
... dict(type='LoadImageFromFile'),
... dict(type='LoadAnnotations', with_bbox=True),
... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
... dict(type='RandomFlip', flip_ratio=0.5),
... dict(type='Normalize', **img_norm_cfg),
... dict(type='Pad', size_divisor=32),
... dict(type='DefaultFormatBundle'),
... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
... ]
>>> expected_pipelines = [
... dict(type='LoadImageFromFile'),
... dict(type='LoadAnnotations', with_bbox=True)
... ]
>>> assert expected_pipelines ==\
... get_loading_pipeline(pipelines)
"""
loading_pipeline_cfg = []
for cfg in pipeline:
obj_cls = PIPELINES.get(cfg['type'])
# TODO:use more elegant way to distinguish loading modules
if obj_cls is not None and obj_cls in (LoadImageFromFile,
LoadAnnotations,
LoadPanopticAnnotations):
loading_pipeline_cfg.append(cfg)
assert len(loading_pipeline_cfg) == 2, \
'The data pipeline in your config file must include ' \
'loading image and annotations related pipeline.'
return loading_pipeline_cfg
@HOOKS.register_module()
class NumClassCheckHook(Hook):
def _check_head(self, runner):
"""Check whether the `num_classes` in head matches the length of
`CLASSES` in `dataset`.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
model = runner.model
dataset = runner.data_loader.dataset
if dataset.CLASSES is None:
runner.logger.warning(
f'Please set `CLASSES` '
f'in the {dataset.__class__.__name__} and'
f'check if it is consistent with the `num_classes` '
f'of head')
else:
assert type(dataset.CLASSES) is not str, \
(f'`CLASSES` in {dataset.__class__.__name__}'
f'should be a tuple of str.'
f'Add comma if number of classes is 1 as '
f'CLASSES = ({dataset.CLASSES},)')
for name, module in model.named_modules():
if hasattr(module, 'num_classes') and not isinstance(
module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)):
assert module.num_classes == len(dataset.CLASSES), \
(f'The `num_classes` ({module.num_classes}) in '
f'{module.__class__.__name__} of '
f'{model.__class__.__name__} does not matches '
f'the length of `CLASSES` '
f'{len(dataset.CLASSES)}) in '
f'{dataset.__class__.__name__}')
def before_train_epoch(self, runner):
"""Check whether the training dataset is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner)
def before_val_epoch(self, runner):
"""Check whether the dataset in val epoch is compatible with head.
Args:
runner (obj:`EpochBasedRunner`): Epoch based Runner.
"""
self._check_head(runner)
|