diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..fcbb4c577d32f25aac39e5310c3bbe14cb8fd704
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,125 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/en/_build/
+docs/zh_cn/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+data/
+data
+.vscode
+.idea
+.DS_Store
+
+# custom
+*.pkl
+*.pkl.json
+*.log.json
+docs/modelzoo_statistics.md
+mmdet/.mim
+work_dirs/
+ckpt/
+
+# Pytorch
+*.pth
+*.py~
+*.sh~
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..1bfc23e48f92245b229cdd57c77e79bc10a1cc27
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2018-2023 OpenMMLab. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2018-2023 OpenMMLab.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..af1ccad55022ef03dd2eb8000ea3931c792117ad
--- /dev/null
+++ b/app.py
@@ -0,0 +1,133 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from collections import OrderedDict
+
+import torch
+from mmcv import Config
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
+
+from mmdet.apis import init_detector, inference_detector
+from mmdet.datasets import (CocoDataset)
+from mmdet.utils import (compat_cfg, replace_cfg_vals, setup_multi_processes,
+ update_data_root)
+
+import gradio as gr
+
+config_dict = OrderedDict([('swin-l-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py'),
+ ('swin-l-hdetr_sam-vit-l', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'),
+ ('swin-l-hdetr_sam-vit-h', 'projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py'),
+ ('focalnet-l-dino_sam-vit-b', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py'),
+ ('focalnet-l-dino_sam-vit-l', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py'),
+ (
+ 'focalnet-l-dino_sam-vit-h', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py')])
+
+
+def inference(img, config):
+ if img is None:
+ return None
+ config = config_dict[config]
+ cfg = Config.fromfile(config)
+
+ # replace the ${key} with the value of cfg.key
+ cfg = replace_cfg_vals(cfg)
+
+ # update data root according to MMDET_DATASETS
+ update_data_root(cfg)
+
+ cfg = compat_cfg(cfg)
+
+ # set multi-process settings
+ setup_multi_processes(cfg)
+
+ # import modules from plguin/xx, registry will be updated
+ if hasattr(cfg, 'plugin'):
+ if cfg.plugin:
+ import importlib
+ if hasattr(cfg, 'plugin_dir'):
+ plugin_dir = cfg.plugin_dir
+ _module_dir = os.path.dirname(plugin_dir)
+ _module_dir = _module_dir.split('/')
+ _module_path = _module_dir[0]
+
+ for m in _module_dir[1:]:
+ _module_path = _module_path + '.' + m
+ print(_module_path)
+ plg_lib = importlib.import_module(_module_path)
+ else:
+ # import dir is the dirpath for the config file
+ _module_dir = os.path.dirname(config)
+ _module_dir = _module_dir.split('/')
+ _module_path = _module_dir[0]
+ for m in _module_dir[1:]:
+ _module_path = _module_path + '.' + m
+ # print(_module_path)
+ plg_lib = importlib.import_module(_module_path)
+
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+ if IS_CUDA_AVAILABLE or IS_MLU_AVAILABLE:
+ device = "cuda"
+ else:
+ device = "cpu"
+ model = init_detector(cfg, None, device=device)
+ model.CLASSES = CocoDataset.CLASSES
+
+ results = inference_detector(model, img)
+ visualize = model.show_result(
+ img,
+ results,
+ bbox_color=CocoDataset.PALETTE,
+ text_color=CocoDataset.PALETTE,
+ mask_color=CocoDataset.PALETTE,
+ show=False,
+ out_file=None,
+ score_thr=0.3
+ )
+ del model
+ return visualize
+
+
+description = """
+#
Prompt Segment Anything (zero-shot instance segmentation demo)
+Github link: [Link](https://github.com/RockeyCoss/Prompt-Segment-Anything)
+You can select the model you want to use from the "Model" dropdown menu and click "Submit" to segment the image you uploaded to the "Input Image" box.
+"""
+
+
+def main():
+ with gr.Blocks() as demo:
+ gr.Markdown(description)
+ with gr.Column():
+ with gr.Row():
+ with gr.Column():
+ input_img = gr.Image(type="numpy", label="Input Image")
+ model_type = gr.Dropdown(choices=list(config_dict.keys()),
+ value=list(config_dict.keys())[0],
+ label='Model',
+ multiselect=False)
+ with gr.Row():
+ clear_btn = gr.Button(value="Clear")
+ submit_btn = gr.Button(value="Submit")
+ output_img = gr.Image(type="numpy", label="Output")
+ gr.Examples(
+ examples=[["./assets/img1.jpg", "swin-l-hdetr_sam-vit-b"],
+ ["./assets/img2.jpg", "swin-l-hdetr_sam-vit-l"],
+ ["./assets/img3.jpg", "swin-l-hdetr_sam-vit-l"],
+ ["./assets/img4.jpg", "focalnet-l-dino_sam-vit-b"]],
+ inputs=[input_img, model_type],
+ outputs=output_img,
+ fn=inference
+ )
+
+ submit_btn.click(inference,
+ inputs=[input_img, model_type],
+ outputs=output_img)
+ clear_btn.click(lambda: [None, None], None, [input_img, output_img], queue=False)
+
+ demo.queue()
+ demo.launch(share=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/assets/img1.jpg b/assets/img1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..677a010839d09b0d03fa4639d4e481b1ccc4a375
Binary files /dev/null and b/assets/img1.jpg differ
diff --git a/assets/img2.jpg b/assets/img2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a23272707413a3e14c85d45ca479b7b583ee763f
Binary files /dev/null and b/assets/img2.jpg differ
diff --git a/assets/img3.jpg b/assets/img3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8eefe6f996c12d54f38f171f38ce08c059ab568f
Binary files /dev/null and b/assets/img3.jpg differ
diff --git a/assets/img4.jpg b/assets/img4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1863c219c821d529b85b3667823c6485790202c5
Binary files /dev/null and b/assets/img4.jpg differ
diff --git a/flagged/Input/tmpaytsmk0e.jpg b/flagged/Input/tmpaytsmk0e.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ec3a13b8b110a77b983fc624f041185f37a7835a
Binary files /dev/null and b/flagged/Input/tmpaytsmk0e.jpg differ
diff --git a/flagged/Output/tmpgs59m7u_.png b/flagged/Output/tmpgs59m7u_.png
new file mode 100644
index 0000000000000000000000000000000000000000..b38f0da4e3c09d8afbcc53ba3e6faf31c133328d
Binary files /dev/null and b/flagged/Output/tmpgs59m7u_.png differ
diff --git a/flagged/log.csv b/flagged/log.csv
new file mode 100644
index 0000000000000000000000000000000000000000..1e659b7e144acea6dbe81c82ed29c6cd40d590fd
--- /dev/null
+++ b/flagged/log.csv
@@ -0,0 +1,2 @@
+Input,Output,flag,username,timestamp
+C:\Users\13502\Documents\msra\prompt_segment_anything_demo\flagged\Input\tmpaytsmk0e.jpg,C:\Users\13502\Documents\msra\prompt_segment_anything_demo\flagged\Output\tmpgs59m7u_.png,,,2023-04-10 20:52:40.908980
diff --git a/mmdet/__init__.py b/mmdet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4df16af56d316e5eb6eff42053173f3e8a074d19
--- /dev/null
+++ b/mmdet/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+
+from .version import __version__, short_version
+
+
+def digit_version(version_str):
+ digit_version = []
+ for x in version_str.split('.'):
+ if x.isdigit():
+ digit_version.append(int(x))
+ elif x.find('rc') != -1:
+ patch_version = x.split('rc')
+ digit_version.append(int(patch_version[0]) - 1)
+ digit_version.append(int(patch_version[1]))
+ return digit_version
+
+
+mmcv_minimum_version = '1.3.17'
+mmcv_maximum_version = '1.8.0'
+mmcv_version = digit_version(mmcv.__version__)
+
+
+assert (mmcv_version >= digit_version(mmcv_minimum_version)
+ and mmcv_version <= digit_version(mmcv_maximum_version)), \
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
+ f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
+
+__all__ = ['__version__', 'short_version']
diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a865e942afd03ddc60ffedbabf9716e769f5bcfe
--- /dev/null
+++ b/mmdet/apis/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .inference import (async_inference_detector, inference_detector,
+ init_detector, show_result_pyplot)
+from .test import multi_gpu_test, single_gpu_test
+from .train import (get_root_logger, init_random_seed, set_random_seed,
+ train_detector)
+
+__all__ = [
+ 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
+ 'async_inference_detector', 'inference_detector', 'show_result_pyplot',
+ 'multi_gpu_test', 'single_gpu_test', 'init_random_seed'
+]
diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2ad4313fce53e9a601e94775a2a70cf5c7f2be7
--- /dev/null
+++ b/mmdet/apis/inference.py
@@ -0,0 +1,258 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from pathlib import Path
+
+import mmcv
+import numpy as np
+import torch
+from mmcv.ops import RoIPool
+from mmcv.parallel import collate, scatter
+from mmcv.runner import load_checkpoint
+
+from mmdet.core import get_classes
+from mmdet.datasets import replace_ImageToTensor
+from mmdet.datasets.pipelines import Compose
+from mmdet.models import build_detector
+
+
+def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
+ """Initialize a detector from config file.
+
+ Args:
+ config (str, :obj:`Path`, or :obj:`mmcv.Config`): Config file path,
+ :obj:`Path`, or the config object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+ cfg_options (dict): Options to override some settings in the used
+ config.
+
+ Returns:
+ nn.Module: The constructed detector.
+ """
+ if isinstance(config, (str, Path)):
+ config = mmcv.Config.fromfile(config)
+ elif not isinstance(config, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ f'but got {type(config)}')
+ if cfg_options is not None:
+ config.merge_from_dict(cfg_options)
+ if 'pretrained' in config.model:
+ config.model.pretrained = None
+ elif (config.model.get('backbone', None) is not None
+ and 'init_cfg' in config.model.backbone):
+ config.model.backbone.init_cfg = None
+ config.model.train_cfg = None
+ model = build_detector(config.model, test_cfg=config.get('test_cfg'))
+ if checkpoint is not None:
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
+ if 'CLASSES' in checkpoint.get('meta', {}):
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ else:
+ warnings.simplefilter('once')
+ warnings.warn('Class names are not saved in the checkpoint\'s '
+ 'meta data, use COCO classes by default.')
+ model.CLASSES = get_classes('coco')
+ model.cfg = config # save the config in the model for convenience
+ model.to(device)
+ model.eval()
+
+ if device == 'npu':
+ from mmcv.device.npu import NPUDataParallel
+ model = NPUDataParallel(model)
+ model.cfg = config
+
+ return model
+
+
+class LoadImage:
+ """Deprecated.
+
+ A simple pipeline to load image.
+ """
+
+ def __call__(self, results):
+ """Call function to load images into results.
+
+ Args:
+ results (dict): A result dict contains the file name
+ of the image to be read.
+ Returns:
+ dict: ``results`` will be returned containing loaded image.
+ """
+ warnings.simplefilter('once')
+ warnings.warn('`LoadImage` is deprecated and will be removed in '
+ 'future releases. You may use `LoadImageFromWebcam` '
+ 'from `mmdet.datasets.pipelines.` instead.')
+ if isinstance(results['img'], str):
+ results['filename'] = results['img']
+ results['ori_filename'] = results['img']
+ else:
+ results['filename'] = None
+ results['ori_filename'] = None
+ img = mmcv.imread(results['img'])
+ results['img'] = img
+ results['img_fields'] = ['img']
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ return results
+
+
+def inference_detector(model, imgs):
+ """Inference image(s) with the detector.
+
+ Args:
+ model (nn.Module): The loaded detector.
+ imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
+ Either image files or loaded images.
+
+ Returns:
+ If imgs is a list or tuple, the same length list type results
+ will be returned, otherwise return the detection results directly.
+ """
+ ori_img = imgs
+ if isinstance(imgs, (list, tuple)):
+ is_batch = True
+ else:
+ imgs = [imgs]
+ is_batch = False
+
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+
+ if isinstance(imgs[0], np.ndarray):
+ cfg = cfg.copy()
+ # set loading pipeline type
+ cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
+
+ cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
+ test_pipeline = Compose(cfg.data.test.pipeline)
+
+ datas = []
+ for img in imgs:
+ # prepare data
+ if isinstance(img, np.ndarray):
+ # directly add img
+ data = dict(img=img)
+ else:
+ # add information into dict
+ data = dict(img_info=dict(filename=img), img_prefix=None)
+ # build the data pipeline
+ data = test_pipeline(data)
+ datas.append(data)
+
+ data = collate(datas, samples_per_gpu=len(imgs))
+ # just get the actual data from DataContainer
+ data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
+ data['img'] = [img.data[0] for img in data['img']]
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ for m in model.modules():
+ assert not isinstance(
+ m, RoIPool
+ ), 'CPU inference with RoIPool is not supported currently.'
+
+ # forward the model
+ with torch.no_grad():
+ results = model(return_loss=False, rescale=True, **data, ori_img=ori_img)
+
+ if not is_batch:
+ return results[0]
+ else:
+ return results
+
+
+async def async_inference_detector(model, imgs):
+ """Async inference image(s) with the detector.
+
+ Args:
+ model (nn.Module): The loaded detector.
+ img (str | ndarray): Either image files or loaded images.
+
+ Returns:
+ Awaitable detection results.
+ """
+ if not isinstance(imgs, (list, tuple)):
+ imgs = [imgs]
+
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+
+ if isinstance(imgs[0], np.ndarray):
+ cfg = cfg.copy()
+ # set loading pipeline type
+ cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
+
+ cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
+ test_pipeline = Compose(cfg.data.test.pipeline)
+
+ datas = []
+ for img in imgs:
+ # prepare data
+ if isinstance(img, np.ndarray):
+ # directly add img
+ data = dict(img=img)
+ else:
+ # add information into dict
+ data = dict(img_info=dict(filename=img), img_prefix=None)
+ # build the data pipeline
+ data = test_pipeline(data)
+ datas.append(data)
+
+ data = collate(datas, samples_per_gpu=len(imgs))
+ # just get the actual data from DataContainer
+ data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
+ data['img'] = [img.data[0] for img in data['img']]
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ for m in model.modules():
+ assert not isinstance(
+ m, RoIPool
+ ), 'CPU inference with RoIPool is not supported currently.'
+
+ # We don't restore `torch.is_grad_enabled()` value during concurrent
+ # inference since execution can overlap
+ torch.set_grad_enabled(False)
+ results = await model.aforward_test(rescale=True, **data)
+ return results
+
+
+def show_result_pyplot(model,
+ img,
+ result,
+ score_thr=0.3,
+ title='result',
+ wait_time=0,
+ palette=None,
+ out_file=None):
+ """Visualize the detection results on the image.
+
+ Args:
+ model (nn.Module): The loaded detector.
+ img (str or np.ndarray): Image filename or loaded image.
+ result (tuple[list] or list): The detection result, can be either
+ (bbox, segm) or just bbox.
+ score_thr (float): The threshold to visualize the bboxes and masks.
+ title (str): Title of the pyplot figure.
+ wait_time (float): Value of waitKey param. Default: 0.
+ palette (str or tuple(int) or :obj:`Color`): Color.
+ The tuple of color should be in BGR order.
+ out_file (str or None): The path to write the image.
+ Default: None.
+ """
+ if hasattr(model, 'module'):
+ model = model.module
+ model.show_result(
+ img,
+ result,
+ score_thr=score_thr,
+ show=True,
+ wait_time=wait_time,
+ win_name=title,
+ bbox_color=palette,
+ text_color=(200, 200, 200),
+ mask_color=palette,
+ out_file=out_file)
diff --git a/mmdet/apis/test.py b/mmdet/apis/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..973d3623d6ef9917ae7316214ebcef6f7e4a75e8
--- /dev/null
+++ b/mmdet/apis/test.py
@@ -0,0 +1,209 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+
+import mmcv
+import torch
+import torch.distributed as dist
+from mmcv.image import tensor2imgs
+from mmcv.runner import get_dist_info
+
+from mmdet.core import encode_mask_results
+
+
+def single_gpu_test(model,
+ data_loader,
+ show=False,
+ out_dir=None,
+ show_score_thr=0.3):
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ PALETTE = getattr(dataset, 'PALETTE', None)
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+
+ batch_size = len(result)
+ if show or out_dir:
+ if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
+ img_tensor = data['img'][0]
+ else:
+ img_tensor = data['img'][0].data[0]
+ img_metas = data['img_metas'][0].data[0]
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+ assert len(imgs) == len(img_metas)
+
+ for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
+ h, w, _ = img_meta['img_shape']
+ img_show = img[:h, :w, :]
+
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+ if out_dir:
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
+ else:
+ out_file = None
+
+ model.module.show_result(
+ img_show,
+ result[i],
+ bbox_color=PALETTE,
+ text_color=PALETTE,
+ mask_color=PALETTE,
+ show=show,
+ out_file=out_file,
+ score_thr=show_score_thr)
+
+ # encode mask results
+ if isinstance(result[0], tuple):
+ result = [(bbox_results, encode_mask_results(mask_results))
+ for bbox_results, mask_results in result]
+ # This logic is only used in panoptic segmentation test.
+ elif isinstance(result[0], dict) and 'ins_results' in result[0]:
+ for j in range(len(result)):
+ bbox_results, mask_results = result[j]['ins_results']
+ result[j]['ins_results'] = (bbox_results,
+ encode_mask_results(mask_results))
+
+ results.extend(result)
+
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
+ it encodes results to gpu tensors and use gpu communication for results
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
+ and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+ # encode mask results
+ if isinstance(result[0], tuple):
+ result = [(bbox_results, encode_mask_results(mask_results))
+ for bbox_results, mask_results in result]
+ # This logic is only used in panoptic segmentation test.
+ elif isinstance(result[0], dict) and 'ins_results' in result[0]:
+ for j in range(len(result)):
+ bbox_results, mask_results = result[j]['ins_results']
+ result[j]['ins_results'] = (
+ bbox_results, encode_mask_results(mask_results))
+
+ results.extend(result)
+
+ if rank == 0:
+ batch_size = len(result)
+ for _ in range(batch_size * world_size):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ mmcv.mkdir_or_exist('.dist_test')
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
+ part_list.append(mmcv.load(part_file))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_list.append(
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..f51f862a053e95079c1c0978c17dbdeb9f8eeea4
--- /dev/null
+++ b/mmdet/apis/train.py
@@ -0,0 +1,246 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import random
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
+ Fp16OptimizerHook, OptimizerHook, build_runner,
+ get_dist_info)
+
+from mmdet.core import DistEvalHook, EvalHook, build_optimizer
+from mmdet.datasets import (build_dataloader, build_dataset,
+ replace_ImageToTensor)
+from mmdet.utils import (build_ddp, build_dp, compat_cfg,
+ find_latest_checkpoint, get_root_logger)
+
+
+def init_random_seed(seed=None, device='cuda'):
+ """Initialize random seed.
+
+ If the seed is not set, the seed will be automatically randomized,
+ and then broadcast to all processes to prevent some potential bugs.
+
+ Args:
+ seed (int, Optional): The seed. Default to None.
+ device (str): The device where the seed will be put on.
+ Default to 'cuda'.
+
+ Returns:
+ int: Seed to be used.
+ """
+ if seed is not None:
+ return seed
+
+ # Make sure all ranks share the same random seed to prevent
+ # some potential bugs. Please refer to
+ # https://github.com/open-mmlab/mmdetection/issues/6339
+ rank, world_size = get_dist_info()
+ seed = np.random.randint(2**31)
+ if world_size == 1:
+ return seed
+
+ if rank == 0:
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+ else:
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
+ dist.broadcast(random_num, src=0)
+ return random_num.item()
+
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def auto_scale_lr(cfg, distributed, logger):
+ """Automatically scaling LR according to GPU number and sample per GPU.
+
+ Args:
+ cfg (config): Training config.
+ distributed (bool): Using distributed or not.
+ logger (logging.Logger): Logger.
+ """
+ # Get flag from config
+ if ('auto_scale_lr' not in cfg) or \
+ (not cfg.auto_scale_lr.get('enable', False)):
+ logger.info('Automatic scaling of learning rate (LR)'
+ ' has been disabled.')
+ return
+
+ # Get base batch size from config
+ base_batch_size = cfg.auto_scale_lr.get('base_batch_size', None)
+ if base_batch_size is None:
+ return
+
+ # Get gpu number
+ if distributed:
+ _, world_size = get_dist_info()
+ num_gpus = len(range(world_size))
+ else:
+ num_gpus = len(cfg.gpu_ids)
+
+ # calculate the batch size
+ samples_per_gpu = cfg.data.train_dataloader.samples_per_gpu
+ batch_size = num_gpus * samples_per_gpu
+ logger.info(f'Training with {num_gpus} GPU(s) with {samples_per_gpu} '
+ f'samples per GPU. The total batch size is {batch_size}.')
+
+ if batch_size != base_batch_size:
+ # scale LR with
+ # [linear scaling rule](https://arxiv.org/abs/1706.02677)
+ scaled_lr = (batch_size / base_batch_size) * cfg.optimizer.lr
+ logger.info('LR has been automatically scaled '
+ f'from {cfg.optimizer.lr} to {scaled_lr}')
+ cfg.optimizer.lr = scaled_lr
+ else:
+ logger.info('The batch size match the '
+ f'base batch size: {base_batch_size}, '
+ f'will not scaling the LR ({cfg.optimizer.lr}).')
+
+
+def train_detector(model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None):
+
+ cfg = compat_cfg(cfg)
+ logger = get_root_logger(log_level=cfg.log_level)
+
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+
+ runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[
+ 'type']
+
+ train_dataloader_default_args = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ # `num_gpus` will be ignored if distributed
+ num_gpus=len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.seed,
+ runner_type=runner_type,
+ persistent_workers=False)
+
+ train_loader_cfg = {
+ **train_dataloader_default_args,
+ **cfg.data.get('train_dataloader', {})
+ }
+
+ data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = build_ddp(
+ model,
+ cfg.device,
+ device_ids=[int(os.environ['LOCAL_RANK'])],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
+
+ # build optimizer
+ auto_scale_lr(cfg, distributed, logger)
+ optimizer = build_optimizer(model, cfg.optimizer)
+
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta))
+
+ # an ugly workaround to make .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ # fp16 setting
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is None and cfg.get('device', None) == 'npu':
+ fp16_cfg = dict(loss_scale='dynamic')
+ if fp16_cfg is not None:
+ optimizer_config = Fp16OptimizerHook(
+ **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
+ elif distributed and 'type' not in cfg.optimizer_config:
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
+ else:
+ optimizer_config = cfg.optimizer_config
+
+ # register hooks
+ runner.register_training_hooks(
+ cfg.lr_config,
+ optimizer_config,
+ cfg.checkpoint_config,
+ cfg.log_config,
+ cfg.get('momentum_config', None),
+ custom_hooks_config=cfg.get('custom_hooks', None))
+
+ if distributed:
+ if isinstance(runner, EpochBasedRunner):
+ runner.register_hook(DistSamplerSeedHook())
+
+ # register eval hooks
+ if validate:
+ val_dataloader_default_args = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=2,
+ dist=distributed,
+ shuffle=False,
+ persistent_workers=False)
+
+ val_dataloader_args = {
+ **val_dataloader_default_args,
+ **cfg.data.get('val_dataloader', {})
+ }
+ # Support batch_size > 1 in validation
+
+ if val_dataloader_args['samples_per_gpu'] > 1:
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
+ cfg.data.val.pipeline = replace_ImageToTensor(
+ cfg.data.val.pipeline)
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+
+ val_dataloader = build_dataloader(val_dataset, **val_dataloader_args)
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+ eval_hook = DistEvalHook if distributed else EvalHook
+ # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
+ # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
+ runner.register_hook(
+ eval_hook(val_dataloader, **eval_cfg), priority='LOW')
+
+ resume_from = None
+ if cfg.resume_from is None and cfg.get('auto_resume'):
+ resume_from = find_latest_checkpoint(cfg.work_dir)
+ if resume_from is not None:
+ cfg.resume_from = resume_from
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
diff --git a/mmdet/core/__init__.py b/mmdet/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a6203879840c80c7f89b348f02e4d45b33e5de4
--- /dev/null
+++ b/mmdet/core/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .anchor import * # noqa: F401, F403
+from .bbox import * # noqa: F401, F403
+from .data_structures import * # noqa: F401, F403
+from .evaluation import * # noqa: F401, F403
+from .hook import * # noqa: F401, F403
+from .mask import * # noqa: F401, F403
+from .optimizers import * # noqa: F401, F403
+from .post_processing import * # noqa: F401, F403
+from .utils import * # noqa: F401, F403
diff --git a/mmdet/core/anchor/__init__.py b/mmdet/core/anchor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcc7e4af36fd12a7c9de6ffe07f77aafad5731ba
--- /dev/null
+++ b/mmdet/core/anchor/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
+ YOLOAnchorGenerator)
+from .builder import (ANCHOR_GENERATORS, PRIOR_GENERATORS,
+ build_anchor_generator, build_prior_generator)
+from .point_generator import MlvlPointGenerator, PointGenerator
+from .utils import anchor_inside_flags, calc_region, images_to_levels
+
+__all__ = [
+ 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
+ 'PointGenerator', 'images_to_levels', 'calc_region',
+ 'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator',
+ 'build_prior_generator', 'PRIOR_GENERATORS', 'MlvlPointGenerator'
+]
diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..20886fbda65dbf0737565ec6dba59e9fc7bb73ff
--- /dev/null
+++ b/mmdet/core/anchor/anchor_generator.py
@@ -0,0 +1,866 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import mmcv
+import numpy as np
+import torch
+from torch.nn.modules.utils import _pair
+
+from .builder import PRIOR_GENERATORS
+
+
+@PRIOR_GENERATORS.register_module()
+class AnchorGenerator:
+ """Standard anchor generator for 2D anchor-based detectors.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels in order (w, h).
+ ratios (list[float]): The list of ratios between the height and width
+ of anchors in a single level.
+ scales (list[int] | None): Anchor scales for anchors in a single level.
+ It cannot be set at the same time if `octave_base_scale` and
+ `scales_per_octave` are set.
+ base_sizes (list[int] | None): The basic sizes
+ of anchors in multiple levels.
+ If None is given, strides will be used as base_sizes.
+ (If strides are non square, the shortest stride is taken.)
+ scale_major (bool): Whether to multiply scales first when generating
+ base anchors. If true, the anchors in the same row will have the
+ same scales. By default it is True in V2.0
+ octave_base_scale (int): The base scale of octave.
+ scales_per_octave (int): Number of scales for each octave.
+ `octave_base_scale` and `scales_per_octave` are usually used in
+ retinanet and the `scales` should be None when they are set.
+ centers (list[tuple[float, float]] | None): The centers of the anchor
+ relative to the feature grid center in multiple feature levels.
+ By default it is set to be None and not used. If a list of tuple of
+ float is given, they will be used to shift the centers of anchors.
+ center_offset (float): The offset of center in proportion to anchors'
+ width and height. By default it is 0 in V2.0.
+
+ Examples:
+ >>> from mmdet.core import AnchorGenerator
+ >>> self = AnchorGenerator([16], [1.], [1.], [9])
+ >>> all_anchors = self.grid_priors([(2, 2)], device='cpu')
+ >>> print(all_anchors)
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
+ [11.5000, -4.5000, 20.5000, 4.5000],
+ [-4.5000, 11.5000, 4.5000, 20.5000],
+ [11.5000, 11.5000, 20.5000, 20.5000]])]
+ >>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
+ >>> all_anchors = self.grid_priors([(2, 2), (1, 1)], device='cpu')
+ >>> print(all_anchors)
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
+ [11.5000, -4.5000, 20.5000, 4.5000],
+ [-4.5000, 11.5000, 4.5000, 20.5000],
+ [11.5000, 11.5000, 20.5000, 20.5000]]), \
+ tensor([[-9., -9., 9., 9.]])]
+ """
+
+ def __init__(self,
+ strides,
+ ratios,
+ scales=None,
+ base_sizes=None,
+ scale_major=True,
+ octave_base_scale=None,
+ scales_per_octave=None,
+ centers=None,
+ center_offset=0.):
+ # check center and center_offset
+ if center_offset != 0:
+ assert centers is None, 'center cannot be set when center_offset' \
+ f'!=0, {centers} is given.'
+ if not (0 <= center_offset <= 1):
+ raise ValueError('center_offset should be in range [0, 1], '
+ f'{center_offset} is given.')
+ if centers is not None:
+ assert len(centers) == len(strides), \
+ 'The number of strides should be the same as centers, got ' \
+ f'{strides} and {centers}'
+
+ # calculate base sizes of anchors
+ self.strides = [_pair(stride) for stride in strides]
+ self.base_sizes = [min(stride) for stride in self.strides
+ ] if base_sizes is None else base_sizes
+ assert len(self.base_sizes) == len(self.strides), \
+ 'The number of strides should be the same as base sizes, got ' \
+ f'{self.strides} and {self.base_sizes}'
+
+ # calculate scales of anchors
+ assert ((octave_base_scale is not None
+ and scales_per_octave is not None) ^ (scales is not None)), \
+ 'scales and octave_base_scale with scales_per_octave cannot' \
+ ' be set at the same time'
+ if scales is not None:
+ self.scales = torch.Tensor(scales)
+ elif octave_base_scale is not None and scales_per_octave is not None:
+ octave_scales = np.array(
+ [2**(i / scales_per_octave) for i in range(scales_per_octave)])
+ scales = octave_scales * octave_base_scale
+ self.scales = torch.Tensor(scales)
+ else:
+ raise ValueError('Either scales or octave_base_scale with '
+ 'scales_per_octave should be set')
+
+ self.octave_base_scale = octave_base_scale
+ self.scales_per_octave = scales_per_octave
+ self.ratios = torch.Tensor(ratios)
+ self.scale_major = scale_major
+ self.centers = centers
+ self.center_offset = center_offset
+ self.base_anchors = self.gen_base_anchors()
+
+ @property
+ def num_base_anchors(self):
+ """list[int]: total number of base anchors in a feature grid"""
+ return self.num_base_priors
+
+ @property
+ def num_base_priors(self):
+ """list[int]: The number of priors (anchors) at a point
+ on the feature grid"""
+ return [base_anchors.size(0) for base_anchors in self.base_anchors]
+
+ @property
+ def num_levels(self):
+ """int: number of feature levels that the generator will be applied"""
+ return len(self.strides)
+
+ def gen_base_anchors(self):
+ """Generate base anchors.
+
+ Returns:
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
+ feature levels.
+ """
+ multi_level_base_anchors = []
+ for i, base_size in enumerate(self.base_sizes):
+ center = None
+ if self.centers is not None:
+ center = self.centers[i]
+ multi_level_base_anchors.append(
+ self.gen_single_level_base_anchors(
+ base_size,
+ scales=self.scales,
+ ratios=self.ratios,
+ center=center))
+ return multi_level_base_anchors
+
+ def gen_single_level_base_anchors(self,
+ base_size,
+ scales,
+ ratios,
+ center=None):
+ """Generate base anchors of a single level.
+
+ Args:
+ base_size (int | float): Basic size of an anchor.
+ scales (torch.Tensor): Scales of the anchor.
+ ratios (torch.Tensor): The ratio between between the height
+ and width of anchors in a single level.
+ center (tuple[float], optional): The center of the base anchor
+ related to a single feature grid. Defaults to None.
+
+ Returns:
+ torch.Tensor: Anchors in a single-level feature maps.
+ """
+ w = base_size
+ h = base_size
+ if center is None:
+ x_center = self.center_offset * w
+ y_center = self.center_offset * h
+ else:
+ x_center, y_center = center
+
+ h_ratios = torch.sqrt(ratios)
+ w_ratios = 1 / h_ratios
+ if self.scale_major:
+ ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
+ hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
+ else:
+ ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
+ hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
+
+ # use float anchor and the anchor's center is aligned with the
+ # pixel center
+ base_anchors = [
+ x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
+ y_center + 0.5 * hs
+ ]
+ base_anchors = torch.stack(base_anchors, dim=-1)
+
+ return base_anchors
+
+ def _meshgrid(self, x, y, row_major=True):
+ """Generate mesh grid of x and y.
+
+ Args:
+ x (torch.Tensor): Grids of x dimension.
+ y (torch.Tensor): Grids of y dimension.
+ row_major (bool, optional): Whether to return y grids first.
+ Defaults to True.
+
+ Returns:
+ tuple[torch.Tensor]: The mesh grids of x and y.
+ """
+ # use shape instead of len to keep tracing while exporting to onnx
+ xx = x.repeat(y.shape[0])
+ yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
+ if row_major:
+ return xx, yy
+ else:
+ return yy, xx
+
+ def grid_priors(self, featmap_sizes, dtype=torch.float32, device='cuda'):
+ """Generate grid anchors in multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple]): List of feature map sizes in
+ multiple feature levels.
+ dtype (:obj:`torch.dtype`): Dtype of priors.
+ Default: torch.float32.
+ device (str): The device where the anchors will be put on.
+
+ Return:
+ list[torch.Tensor]: Anchors in multiple feature levels. \
+ The sizes of each tensor should be [N, 4], where \
+ N = width * height * num_base_anchors, width and height \
+ are the sizes of the corresponding feature level, \
+ num_base_anchors is the number of anchors for that level.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_anchors = []
+ for i in range(self.num_levels):
+ anchors = self.single_level_grid_priors(
+ featmap_sizes[i], level_idx=i, dtype=dtype, device=device)
+ multi_level_anchors.append(anchors)
+ return multi_level_anchors
+
+ def single_level_grid_priors(self,
+ featmap_size,
+ level_idx,
+ dtype=torch.float32,
+ device='cuda'):
+ """Generate grid anchors of a single level.
+
+ Note:
+ This function is usually called by method ``self.grid_priors``.
+
+ Args:
+ featmap_size (tuple[int]): Size of the feature maps.
+ level_idx (int): The index of corresponding feature map level.
+ dtype (obj:`torch.dtype`): Date type of points.Defaults to
+ ``torch.float32``.
+ device (str, optional): The device the tensor will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ torch.Tensor: Anchors in the overall feature maps.
+ """
+
+ base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
+ feat_h, feat_w = featmap_size
+ stride_w, stride_h = self.strides[level_idx]
+ # First create Range with the default dtype, than convert to
+ # target `dtype` for onnx exporting.
+ shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
+ shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h
+
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
+ # first feat_w elements correspond to the first row of shifts
+ # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
+ # shifted anchors (K, A, 4), reshape to (K*A, 4)
+
+ all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
+ all_anchors = all_anchors.view(-1, 4)
+ # first A rows correspond to A anchors of (0, 0) in feature map,
+ # then (0, 1), (0, 2), ...
+ return all_anchors
+
+ def sparse_priors(self,
+ prior_idxs,
+ featmap_size,
+ level_idx,
+ dtype=torch.float32,
+ device='cuda'):
+ """Generate sparse anchors according to the ``prior_idxs``.
+
+ Args:
+ prior_idxs (Tensor): The index of corresponding anchors
+ in the feature map.
+ featmap_size (tuple[int]): feature map size arrange as (h, w).
+ level_idx (int): The level index of corresponding feature
+ map.
+ dtype (obj:`torch.dtype`): Date type of points.Defaults to
+ ``torch.float32``.
+ device (obj:`torch.device`): The device where the points is
+ located.
+ Returns:
+ Tensor: Anchor with shape (N, 4), N should be equal to
+ the length of ``prior_idxs``.
+ """
+
+ height, width = featmap_size
+ num_base_anchors = self.num_base_anchors[level_idx]
+ base_anchor_id = prior_idxs % num_base_anchors
+ x = (prior_idxs //
+ num_base_anchors) % width * self.strides[level_idx][0]
+ y = (prior_idxs // width //
+ num_base_anchors) % height * self.strides[level_idx][1]
+ priors = torch.stack([x, y, x, y], 1).to(dtype).to(device) + \
+ self.base_anchors[level_idx][base_anchor_id, :].to(device)
+
+ return priors
+
+ def grid_anchors(self, featmap_sizes, device='cuda'):
+ """Generate grid anchors in multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple]): List of feature map sizes in
+ multiple feature levels.
+ device (str): Device where the anchors will be put on.
+
+ Return:
+ list[torch.Tensor]: Anchors in multiple feature levels. \
+ The sizes of each tensor should be [N, 4], where \
+ N = width * height * num_base_anchors, width and height \
+ are the sizes of the corresponding feature level, \
+ num_base_anchors is the number of anchors for that level.
+ """
+ warnings.warn('``grid_anchors`` would be deprecated soon. '
+ 'Please use ``grid_priors`` ')
+
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_anchors = []
+ for i in range(self.num_levels):
+ anchors = self.single_level_grid_anchors(
+ self.base_anchors[i].to(device),
+ featmap_sizes[i],
+ self.strides[i],
+ device=device)
+ multi_level_anchors.append(anchors)
+ return multi_level_anchors
+
+ def single_level_grid_anchors(self,
+ base_anchors,
+ featmap_size,
+ stride=(16, 16),
+ device='cuda'):
+ """Generate grid anchors of a single level.
+
+ Note:
+ This function is usually called by method ``self.grid_anchors``.
+
+ Args:
+ base_anchors (torch.Tensor): The base anchors of a feature grid.
+ featmap_size (tuple[int]): Size of the feature maps.
+ stride (tuple[int], optional): Stride of the feature map in order
+ (w, h). Defaults to (16, 16).
+ device (str, optional): Device the tensor will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ torch.Tensor: Anchors in the overall feature maps.
+ """
+
+ warnings.warn(
+ '``single_level_grid_anchors`` would be deprecated soon. '
+ 'Please use ``single_level_grid_priors`` ')
+
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ feat_h, feat_w = featmap_size
+ shift_x = torch.arange(0, feat_w, device=device) * stride[0]
+ shift_y = torch.arange(0, feat_h, device=device) * stride[1]
+
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
+ shifts = shifts.type_as(base_anchors)
+ # first feat_w elements correspond to the first row of shifts
+ # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
+ # shifted anchors (K, A, 4), reshape to (K*A, 4)
+
+ all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
+ all_anchors = all_anchors.view(-1, 4)
+ # first A rows correspond to A anchors of (0, 0) in feature map,
+ # then (0, 1), (0, 2), ...
+ return all_anchors
+
+ def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
+ """Generate valid flags of anchors in multiple feature levels.
+
+ Args:
+ featmap_sizes (list(tuple)): List of feature map sizes in
+ multiple feature levels.
+ pad_shape (tuple): The padded shape of the image.
+ device (str): Device where the anchors will be put on.
+
+ Return:
+ list(torch.Tensor): Valid flags of anchors in multiple levels.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_flags = []
+ for i in range(self.num_levels):
+ anchor_stride = self.strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ h, w = pad_shape[:2]
+ valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h)
+ valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w)
+ flags = self.single_level_valid_flags((feat_h, feat_w),
+ (valid_feat_h, valid_feat_w),
+ self.num_base_anchors[i],
+ device=device)
+ multi_level_flags.append(flags)
+ return multi_level_flags
+
+ def single_level_valid_flags(self,
+ featmap_size,
+ valid_size,
+ num_base_anchors,
+ device='cuda'):
+ """Generate the valid flags of anchor in a single feature map.
+
+ Args:
+ featmap_size (tuple[int]): The size of feature maps, arrange
+ as (h, w).
+ valid_size (tuple[int]): The valid size of the feature maps.
+ num_base_anchors (int): The number of base anchors.
+ device (str, optional): Device where the flags will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ torch.Tensor: The valid flags of each anchor in a single level \
+ feature map.
+ """
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ valid = valid[:, None].expand(valid.size(0),
+ num_base_anchors).contiguous().view(-1)
+ return valid
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ indent_str = ' '
+ repr_str = self.__class__.__name__ + '(\n'
+ repr_str += f'{indent_str}strides={self.strides},\n'
+ repr_str += f'{indent_str}ratios={self.ratios},\n'
+ repr_str += f'{indent_str}scales={self.scales},\n'
+ repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
+ repr_str += f'{indent_str}scale_major={self.scale_major},\n'
+ repr_str += f'{indent_str}octave_base_scale='
+ repr_str += f'{self.octave_base_scale},\n'
+ repr_str += f'{indent_str}scales_per_octave='
+ repr_str += f'{self.scales_per_octave},\n'
+ repr_str += f'{indent_str}num_levels={self.num_levels}\n'
+ repr_str += f'{indent_str}centers={self.centers},\n'
+ repr_str += f'{indent_str}center_offset={self.center_offset})'
+ return repr_str
+
+
+@PRIOR_GENERATORS.register_module()
+class SSDAnchorGenerator(AnchorGenerator):
+ """Anchor generator for SSD.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels.
+ ratios (list[float]): The list of ratios between the height and width
+ of anchors in a single level.
+ min_sizes (list[float]): The list of minimum anchor sizes on each
+ level.
+ max_sizes (list[float]): The list of maximum anchor sizes on each
+ level.
+ basesize_ratio_range (tuple(float)): Ratio range of anchors. Being
+ used when not setting min_sizes and max_sizes.
+ input_size (int): Size of feature map, 300 for SSD300, 512 for
+ SSD512. Being used when not setting min_sizes and max_sizes.
+ scale_major (bool): Whether to multiply scales first when generating
+ base anchors. If true, the anchors in the same row will have the
+ same scales. It is always set to be False in SSD.
+ """
+
+ def __init__(self,
+ strides,
+ ratios,
+ min_sizes=None,
+ max_sizes=None,
+ basesize_ratio_range=(0.15, 0.9),
+ input_size=300,
+ scale_major=True):
+ assert len(strides) == len(ratios)
+ assert not (min_sizes is None) ^ (max_sizes is None)
+ self.strides = [_pair(stride) for stride in strides]
+ self.centers = [(stride[0] / 2., stride[1] / 2.)
+ for stride in self.strides]
+
+ if min_sizes is None and max_sizes is None:
+ # use hard code to generate SSD anchors
+ self.input_size = input_size
+ assert mmcv.is_tuple_of(basesize_ratio_range, float)
+ self.basesize_ratio_range = basesize_ratio_range
+ # calculate anchor ratios and sizes
+ min_ratio, max_ratio = basesize_ratio_range
+ min_ratio = int(min_ratio * 100)
+ max_ratio = int(max_ratio * 100)
+ step = int(np.floor(max_ratio - min_ratio) / (self.num_levels - 2))
+ min_sizes = []
+ max_sizes = []
+ for ratio in range(int(min_ratio), int(max_ratio) + 1, step):
+ min_sizes.append(int(self.input_size * ratio / 100))
+ max_sizes.append(int(self.input_size * (ratio + step) / 100))
+ if self.input_size == 300:
+ if basesize_ratio_range[0] == 0.15: # SSD300 COCO
+ min_sizes.insert(0, int(self.input_size * 7 / 100))
+ max_sizes.insert(0, int(self.input_size * 15 / 100))
+ elif basesize_ratio_range[0] == 0.2: # SSD300 VOC
+ min_sizes.insert(0, int(self.input_size * 10 / 100))
+ max_sizes.insert(0, int(self.input_size * 20 / 100))
+ else:
+ raise ValueError(
+ 'basesize_ratio_range[0] should be either 0.15'
+ 'or 0.2 when input_size is 300, got '
+ f'{basesize_ratio_range[0]}.')
+ elif self.input_size == 512:
+ if basesize_ratio_range[0] == 0.1: # SSD512 COCO
+ min_sizes.insert(0, int(self.input_size * 4 / 100))
+ max_sizes.insert(0, int(self.input_size * 10 / 100))
+ elif basesize_ratio_range[0] == 0.15: # SSD512 VOC
+ min_sizes.insert(0, int(self.input_size * 7 / 100))
+ max_sizes.insert(0, int(self.input_size * 15 / 100))
+ else:
+ raise ValueError(
+ 'When not setting min_sizes and max_sizes,'
+ 'basesize_ratio_range[0] should be either 0.1'
+ 'or 0.15 when input_size is 512, got'
+ f' {basesize_ratio_range[0]}.')
+ else:
+ raise ValueError(
+ 'Only support 300 or 512 in SSDAnchorGenerator when '
+ 'not setting min_sizes and max_sizes, '
+ f'got {self.input_size}.')
+
+ assert len(min_sizes) == len(max_sizes) == len(strides)
+
+ anchor_ratios = []
+ anchor_scales = []
+ for k in range(len(self.strides)):
+ scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
+ anchor_ratio = [1.]
+ for r in ratios[k]:
+ anchor_ratio += [1 / r, r] # 4 or 6 ratio
+ anchor_ratios.append(torch.Tensor(anchor_ratio))
+ anchor_scales.append(torch.Tensor(scales))
+
+ self.base_sizes = min_sizes
+ self.scales = anchor_scales
+ self.ratios = anchor_ratios
+ self.scale_major = scale_major
+ self.center_offset = 0
+ self.base_anchors = self.gen_base_anchors()
+
+ def gen_base_anchors(self):
+ """Generate base anchors.
+
+ Returns:
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
+ feature levels.
+ """
+ multi_level_base_anchors = []
+ for i, base_size in enumerate(self.base_sizes):
+ base_anchors = self.gen_single_level_base_anchors(
+ base_size,
+ scales=self.scales[i],
+ ratios=self.ratios[i],
+ center=self.centers[i])
+ indices = list(range(len(self.ratios[i])))
+ indices.insert(1, len(indices))
+ base_anchors = torch.index_select(base_anchors, 0,
+ torch.LongTensor(indices))
+ multi_level_base_anchors.append(base_anchors)
+ return multi_level_base_anchors
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ indent_str = ' '
+ repr_str = self.__class__.__name__ + '(\n'
+ repr_str += f'{indent_str}strides={self.strides},\n'
+ repr_str += f'{indent_str}scales={self.scales},\n'
+ repr_str += f'{indent_str}scale_major={self.scale_major},\n'
+ repr_str += f'{indent_str}input_size={self.input_size},\n'
+ repr_str += f'{indent_str}scales={self.scales},\n'
+ repr_str += f'{indent_str}ratios={self.ratios},\n'
+ repr_str += f'{indent_str}num_levels={self.num_levels},\n'
+ repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
+ repr_str += f'{indent_str}basesize_ratio_range='
+ repr_str += f'{self.basesize_ratio_range})'
+ return repr_str
+
+
+@PRIOR_GENERATORS.register_module()
+class LegacyAnchorGenerator(AnchorGenerator):
+ """Legacy anchor generator used in MMDetection V1.x.
+
+ Note:
+ Difference to the V2.0 anchor generator:
+
+ 1. The center offset of V1.x anchors are set to be 0.5 rather than 0.
+ 2. The width/height are minused by 1 when calculating the anchors' \
+ centers and corners to meet the V1.x coordinate system.
+ 3. The anchors' corners are quantized.
+
+ Args:
+ strides (list[int] | list[tuple[int]]): Strides of anchors
+ in multiple feature levels.
+ ratios (list[float]): The list of ratios between the height and width
+ of anchors in a single level.
+ scales (list[int] | None): Anchor scales for anchors in a single level.
+ It cannot be set at the same time if `octave_base_scale` and
+ `scales_per_octave` are set.
+ base_sizes (list[int]): The basic sizes of anchors in multiple levels.
+ If None is given, strides will be used to generate base_sizes.
+ scale_major (bool): Whether to multiply scales first when generating
+ base anchors. If true, the anchors in the same row will have the
+ same scales. By default it is True in V2.0
+ octave_base_scale (int): The base scale of octave.
+ scales_per_octave (int): Number of scales for each octave.
+ `octave_base_scale` and `scales_per_octave` are usually used in
+ retinanet and the `scales` should be None when they are set.
+ centers (list[tuple[float, float]] | None): The centers of the anchor
+ relative to the feature grid center in multiple feature levels.
+ By default it is set to be None and not used. It a list of float
+ is given, this list will be used to shift the centers of anchors.
+ center_offset (float): The offset of center in proportion to anchors'
+ width and height. By default it is 0.5 in V2.0 but it should be 0.5
+ in v1.x models.
+
+ Examples:
+ >>> from mmdet.core import LegacyAnchorGenerator
+ >>> self = LegacyAnchorGenerator(
+ >>> [16], [1.], [1.], [9], center_offset=0.5)
+ >>> all_anchors = self.grid_anchors(((2, 2),), device='cpu')
+ >>> print(all_anchors)
+ [tensor([[ 0., 0., 8., 8.],
+ [16., 0., 24., 8.],
+ [ 0., 16., 8., 24.],
+ [16., 16., 24., 24.]])]
+ """
+
+ def gen_single_level_base_anchors(self,
+ base_size,
+ scales,
+ ratios,
+ center=None):
+ """Generate base anchors of a single level.
+
+ Note:
+ The width/height of anchors are minused by 1 when calculating \
+ the centers and corners to meet the V1.x coordinate system.
+
+ Args:
+ base_size (int | float): Basic size of an anchor.
+ scales (torch.Tensor): Scales of the anchor.
+ ratios (torch.Tensor): The ratio between between the height.
+ and width of anchors in a single level.
+ center (tuple[float], optional): The center of the base anchor
+ related to a single feature grid. Defaults to None.
+
+ Returns:
+ torch.Tensor: Anchors in a single-level feature map.
+ """
+ w = base_size
+ h = base_size
+ if center is None:
+ x_center = self.center_offset * (w - 1)
+ y_center = self.center_offset * (h - 1)
+ else:
+ x_center, y_center = center
+
+ h_ratios = torch.sqrt(ratios)
+ w_ratios = 1 / h_ratios
+ if self.scale_major:
+ ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
+ hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
+ else:
+ ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
+ hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
+
+ # use float anchor and the anchor's center is aligned with the
+ # pixel center
+ base_anchors = [
+ x_center - 0.5 * (ws - 1), y_center - 0.5 * (hs - 1),
+ x_center + 0.5 * (ws - 1), y_center + 0.5 * (hs - 1)
+ ]
+ base_anchors = torch.stack(base_anchors, dim=-1).round()
+
+ return base_anchors
+
+
+@PRIOR_GENERATORS.register_module()
+class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator):
+ """Legacy anchor generator used in MMDetection V1.x.
+
+ The difference between `LegacySSDAnchorGenerator` and `SSDAnchorGenerator`
+ can be found in `LegacyAnchorGenerator`.
+ """
+
+ def __init__(self,
+ strides,
+ ratios,
+ basesize_ratio_range,
+ input_size=300,
+ scale_major=True):
+ super(LegacySSDAnchorGenerator, self).__init__(
+ strides=strides,
+ ratios=ratios,
+ basesize_ratio_range=basesize_ratio_range,
+ input_size=input_size,
+ scale_major=scale_major)
+ self.centers = [((stride - 1) / 2., (stride - 1) / 2.)
+ for stride in strides]
+ self.base_anchors = self.gen_base_anchors()
+
+
+@PRIOR_GENERATORS.register_module()
+class YOLOAnchorGenerator(AnchorGenerator):
+ """Anchor generator for YOLO.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels.
+ base_sizes (list[list[tuple[int, int]]]): The basic sizes
+ of anchors in multiple levels.
+ """
+
+ def __init__(self, strides, base_sizes):
+ self.strides = [_pair(stride) for stride in strides]
+ self.centers = [(stride[0] / 2., stride[1] / 2.)
+ for stride in self.strides]
+ self.base_sizes = []
+ num_anchor_per_level = len(base_sizes[0])
+ for base_sizes_per_level in base_sizes:
+ assert num_anchor_per_level == len(base_sizes_per_level)
+ self.base_sizes.append(
+ [_pair(base_size) for base_size in base_sizes_per_level])
+ self.base_anchors = self.gen_base_anchors()
+
+ @property
+ def num_levels(self):
+ """int: number of feature levels that the generator will be applied"""
+ return len(self.base_sizes)
+
+ def gen_base_anchors(self):
+ """Generate base anchors.
+
+ Returns:
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
+ feature levels.
+ """
+ multi_level_base_anchors = []
+ for i, base_sizes_per_level in enumerate(self.base_sizes):
+ center = None
+ if self.centers is not None:
+ center = self.centers[i]
+ multi_level_base_anchors.append(
+ self.gen_single_level_base_anchors(base_sizes_per_level,
+ center))
+ return multi_level_base_anchors
+
+ def gen_single_level_base_anchors(self, base_sizes_per_level, center=None):
+ """Generate base anchors of a single level.
+
+ Args:
+ base_sizes_per_level (list[tuple[int, int]]): Basic sizes of
+ anchors.
+ center (tuple[float], optional): The center of the base anchor
+ related to a single feature grid. Defaults to None.
+
+ Returns:
+ torch.Tensor: Anchors in a single-level feature maps.
+ """
+ x_center, y_center = center
+ base_anchors = []
+ for base_size in base_sizes_per_level:
+ w, h = base_size
+
+ # use float anchor and the anchor's center is aligned with the
+ # pixel center
+ base_anchor = torch.Tensor([
+ x_center - 0.5 * w, y_center - 0.5 * h, x_center + 0.5 * w,
+ y_center + 0.5 * h
+ ])
+ base_anchors.append(base_anchor)
+ base_anchors = torch.stack(base_anchors, dim=0)
+
+ return base_anchors
+
+ def responsible_flags(self, featmap_sizes, gt_bboxes, device='cuda'):
+ """Generate responsible anchor flags of grid cells in multiple scales.
+
+ Args:
+ featmap_sizes (list(tuple)): List of feature map sizes in multiple
+ feature levels.
+ gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
+ device (str): Device where the anchors will be put on.
+
+ Return:
+ list(torch.Tensor): responsible flags of anchors in multiple level
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_responsible_flags = []
+ for i in range(self.num_levels):
+ anchor_stride = self.strides[i]
+ flags = self.single_level_responsible_flags(
+ featmap_sizes[i],
+ gt_bboxes,
+ anchor_stride,
+ self.num_base_anchors[i],
+ device=device)
+ multi_level_responsible_flags.append(flags)
+ return multi_level_responsible_flags
+
+ def single_level_responsible_flags(self,
+ featmap_size,
+ gt_bboxes,
+ stride,
+ num_base_anchors,
+ device='cuda'):
+ """Generate the responsible flags of anchor in a single feature map.
+
+ Args:
+ featmap_size (tuple[int]): The size of feature maps.
+ gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
+ stride (tuple(int)): stride of current level
+ num_base_anchors (int): The number of base anchors.
+ device (str, optional): Device where the flags will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ torch.Tensor: The valid flags of each anchor in a single level \
+ feature map.
+ """
+ feat_h, feat_w = featmap_size
+ gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
+ gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
+ gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long()
+ gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long()
+
+ # row major indexing
+ gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x
+
+ responsible_grid = torch.zeros(
+ feat_h * feat_w, dtype=torch.uint8, device=device)
+ responsible_grid[gt_bboxes_grid_idx] = 1
+
+ responsible_grid = responsible_grid[:, None].expand(
+ responsible_grid.size(0), num_base_anchors).contiguous().view(-1)
+ return responsible_grid
diff --git a/mmdet/core/anchor/builder.py b/mmdet/core/anchor/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddb25ad37937bcf227832e37469a0e31cae77826
--- /dev/null
+++ b/mmdet/core/anchor/builder.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+from mmcv.utils import Registry, build_from_cfg
+
+PRIOR_GENERATORS = Registry('Generator for anchors and points')
+
+ANCHOR_GENERATORS = PRIOR_GENERATORS
+
+
+def build_prior_generator(cfg, default_args=None):
+ return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)
+
+
+def build_anchor_generator(cfg, default_args=None):
+ warnings.warn(
+ '``build_anchor_generator`` would be deprecated soon, please use '
+ '``build_prior_generator`` ')
+ return build_prior_generator(cfg, default_args=default_args)
diff --git a/mmdet/core/anchor/point_generator.py b/mmdet/core/anchor/point_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc9c3887dd7c1d3afe30b705f16162d1d03c9b5d
--- /dev/null
+++ b/mmdet/core/anchor/point_generator.py
@@ -0,0 +1,263 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+from torch.nn.modules.utils import _pair
+
+from .builder import PRIOR_GENERATORS
+
+
+@PRIOR_GENERATORS.register_module()
+class PointGenerator:
+
+ def _meshgrid(self, x, y, row_major=True):
+ xx = x.repeat(len(y))
+ yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
+ if row_major:
+ return xx, yy
+ else:
+ return yy, xx
+
+ def grid_points(self, featmap_size, stride=16, device='cuda'):
+ feat_h, feat_w = featmap_size
+ shift_x = torch.arange(0., feat_w, device=device) * stride
+ shift_y = torch.arange(0., feat_h, device=device) * stride
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ stride = shift_x.new_full((shift_xx.shape[0], ), stride)
+ shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
+ all_points = shifts.to(device)
+ return all_points
+
+ def valid_flags(self, featmap_size, valid_size, device='cuda'):
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ return valid
+
+
+@PRIOR_GENERATORS.register_module()
+class MlvlPointGenerator:
+ """Standard points generator for multi-level (Mlvl) feature maps in 2D
+ points-based detectors.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels in order (w, h).
+ offset (float): The offset of points, the value is normalized with
+ corresponding stride. Defaults to 0.5.
+ """
+
+ def __init__(self, strides, offset=0.5):
+ self.strides = [_pair(stride) for stride in strides]
+ self.offset = offset
+
+ @property
+ def num_levels(self):
+ """int: number of feature levels that the generator will be applied"""
+ return len(self.strides)
+
+ @property
+ def num_base_priors(self):
+ """list[int]: The number of priors (points) at a point
+ on the feature grid"""
+ return [1 for _ in range(len(self.strides))]
+
+ def _meshgrid(self, x, y, row_major=True):
+ yy, xx = torch.meshgrid(y, x)
+ if row_major:
+ # warning .flatten() would cause error in ONNX exporting
+ # have to use reshape here
+ return xx.reshape(-1), yy.reshape(-1)
+
+ else:
+ return yy.reshape(-1), xx.reshape(-1)
+
+ def grid_priors(self,
+ featmap_sizes,
+ dtype=torch.float32,
+ device='cuda',
+ with_stride=False):
+ """Generate grid points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple]): List of feature map sizes in
+ multiple feature levels, each size arrange as
+ as (h, w).
+ dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
+ device (str): The device where the anchors will be put on.
+ with_stride (bool): Whether to concatenate the stride to
+ the last dimension of points.
+
+ Return:
+ list[torch.Tensor]: Points of multiple feature levels.
+ The sizes of each tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_priors = []
+ for i in range(self.num_levels):
+ priors = self.single_level_grid_priors(
+ featmap_sizes[i],
+ level_idx=i,
+ dtype=dtype,
+ device=device,
+ with_stride=with_stride)
+ multi_level_priors.append(priors)
+ return multi_level_priors
+
+ def single_level_grid_priors(self,
+ featmap_size,
+ level_idx,
+ dtype=torch.float32,
+ device='cuda',
+ with_stride=False):
+ """Generate grid Points of a single level.
+
+ Note:
+ This function is usually called by method ``self.grid_priors``.
+
+ Args:
+ featmap_size (tuple[int]): Size of the feature maps, arrange as
+ (h, w).
+ level_idx (int): The index of corresponding feature map level.
+ dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
+ device (str, optional): The device the tensor will be put on.
+ Defaults to 'cuda'.
+ with_stride (bool): Concatenate the stride to the last dimension
+ of points.
+
+ Return:
+ Tensor: Points of single feature levels.
+ The shape of tensor should be (N, 2) when with stride is
+ ``False``, where N = width * height, width and height
+ are the sizes of the corresponding feature level,
+ and the last dimension 2 represent (coord_x, coord_y),
+ otherwise the shape should be (N, 4),
+ and the last dimension 4 represent
+ (coord_x, coord_y, stride_w, stride_h).
+ """
+ feat_h, feat_w = featmap_size
+ stride_w, stride_h = self.strides[level_idx]
+ shift_x = (torch.arange(0, feat_w, device=device) +
+ self.offset) * stride_w
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_x = shift_x.to(dtype)
+
+ shift_y = (torch.arange(0, feat_h, device=device) +
+ self.offset) * stride_h
+ # keep featmap_size as Tensor instead of int, so that we
+ # can convert to ONNX correctly
+ shift_y = shift_y.to(dtype)
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ if not with_stride:
+ shifts = torch.stack([shift_xx, shift_yy], dim=-1)
+ else:
+ # use `shape[0]` instead of `len(shift_xx)` for ONNX export
+ stride_w = shift_xx.new_full((shift_xx.shape[0], ),
+ stride_w).to(dtype)
+ stride_h = shift_xx.new_full((shift_yy.shape[0], ),
+ stride_h).to(dtype)
+ shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h],
+ dim=-1)
+ all_points = shifts.to(device)
+ return all_points
+
+ def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
+ """Generate valid flags of points of multiple feature levels.
+
+ Args:
+ featmap_sizes (list(tuple)): List of feature map sizes in
+ multiple feature levels, each size arrange as
+ as (h, w).
+ pad_shape (tuple(int)): The padded shape of the image,
+ arrange as (h, w).
+ device (str): The device where the anchors will be put on.
+
+ Return:
+ list(torch.Tensor): Valid flags of points of multiple levels.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_flags = []
+ for i in range(self.num_levels):
+ point_stride = self.strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ h, w = pad_shape[:2]
+ valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
+ valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
+ flags = self.single_level_valid_flags((feat_h, feat_w),
+ (valid_feat_h, valid_feat_w),
+ device=device)
+ multi_level_flags.append(flags)
+ return multi_level_flags
+
+ def single_level_valid_flags(self,
+ featmap_size,
+ valid_size,
+ device='cuda'):
+ """Generate the valid flags of points of a single feature map.
+
+ Args:
+ featmap_size (tuple[int]): The size of feature maps, arrange as
+ as (h, w).
+ valid_size (tuple[int]): The valid size of the feature maps.
+ The size arrange as as (h, w).
+ device (str, optional): The device where the flags will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ torch.Tensor: The valid flags of each points in a single level \
+ feature map.
+ """
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ return valid
+
+ def sparse_priors(self,
+ prior_idxs,
+ featmap_size,
+ level_idx,
+ dtype=torch.float32,
+ device='cuda'):
+ """Generate sparse points according to the ``prior_idxs``.
+
+ Args:
+ prior_idxs (Tensor): The index of corresponding anchors
+ in the feature map.
+ featmap_size (tuple[int]): feature map size arrange as (w, h).
+ level_idx (int): The level index of corresponding feature
+ map.
+ dtype (obj:`torch.dtype`): Date type of points. Defaults to
+ ``torch.float32``.
+ device (obj:`torch.device`): The device where the points is
+ located.
+ Returns:
+ Tensor: Anchor with shape (N, 2), N should be equal to
+ the length of ``prior_idxs``. And last dimension
+ 2 represent (coord_x, coord_y).
+ """
+ height, width = featmap_size
+ x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
+ y = ((prior_idxs // width) % height +
+ self.offset) * self.strides[level_idx][1]
+ prioris = torch.stack([x, y], 1).to(dtype)
+ prioris = prioris.to(device)
+ return prioris
diff --git a/mmdet/core/anchor/utils.py b/mmdet/core/anchor/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2f202476ca4413efbca191150719d68777e2be3
--- /dev/null
+++ b/mmdet/core/anchor/utils.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+def images_to_levels(target, num_levels):
+ """Convert targets by image to targets by feature level.
+
+ [target_img0, target_img1] -> [target_level0, target_level1, ...]
+ """
+ target = torch.stack(target, 0)
+ level_targets = []
+ start = 0
+ for n in num_levels:
+ end = start + n
+ # level_targets.append(target[:, start:end].squeeze(0))
+ level_targets.append(target[:, start:end])
+ start = end
+ return level_targets
+
+
+def anchor_inside_flags(flat_anchors,
+ valid_flags,
+ img_shape,
+ allowed_border=0):
+ """Check whether the anchors are inside the border.
+
+ Args:
+ flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
+ valid_flags (torch.Tensor): An existing valid flags of anchors.
+ img_shape (tuple(int)): Shape of current image.
+ allowed_border (int, optional): The border to allow the valid anchor.
+ Defaults to 0.
+
+ Returns:
+ torch.Tensor: Flags indicating whether the anchors are inside a \
+ valid range.
+ """
+ img_h, img_w = img_shape[:2]
+ if allowed_border >= 0:
+ inside_flags = valid_flags & \
+ (flat_anchors[:, 0] >= -allowed_border) & \
+ (flat_anchors[:, 1] >= -allowed_border) & \
+ (flat_anchors[:, 2] < img_w + allowed_border) & \
+ (flat_anchors[:, 3] < img_h + allowed_border)
+ else:
+ inside_flags = valid_flags
+ return inside_flags
+
+
+def calc_region(bbox, ratio, featmap_size=None):
+ """Calculate a proportional bbox region.
+
+ The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
+
+ Args:
+ bbox (Tensor): Bboxes to calculate regions, shape (n, 4).
+ ratio (float): Ratio of the output region.
+ featmap_size (tuple): Feature map size used for clipping the boundary.
+
+ Returns:
+ tuple: x1, y1, x2, y2
+ """
+ x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
+ y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
+ x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
+ y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
+ if featmap_size is not None:
+ x1 = x1.clamp(min=0, max=featmap_size[1])
+ y1 = y1.clamp(min=0, max=featmap_size[0])
+ x2 = x2.clamp(min=0, max=featmap_size[1])
+ y2 = y2.clamp(min=0, max=featmap_size[0])
+ return (x1, y1, x2, y2)
diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..371eba198e9fad1b0c3697d6c9f250c930f844d7
--- /dev/null
+++ b/mmdet/core/bbox/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner,
+ MaxIoUAssigner, RegionAssigner)
+from .builder import build_assigner, build_bbox_coder, build_sampler
+from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, DistancePointBBoxCoder,
+ PseudoBBoxCoder, TBLRBBoxCoder)
+from .iou_calculators import BboxOverlaps2D, bbox_overlaps
+from .samplers import (BaseSampler, CombinedSampler,
+ InstanceBalancedPosSampler, IoUBalancedNegSampler,
+ OHEMSampler, PseudoSampler, RandomSampler,
+ SamplingResult, ScoreHLRSampler)
+from .transforms import (bbox2distance, bbox2result, bbox2roi,
+ bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
+ bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
+ distance2bbox, find_inside_bboxes, roi2bbox)
+
+__all__ = [
+ 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
+ 'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
+ 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
+ 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner',
+ 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
+ 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
+ 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
+ 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'DistancePointBBoxCoder',
+ 'CenterRegionAssigner', 'bbox_rescale', 'bbox_cxcywh_to_xyxy',
+ 'bbox_xyxy_to_cxcywh', 'RegionAssigner', 'find_inside_bboxes'
+]
diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6480a783be1afca2e7d414c24c44b20744db779
--- /dev/null
+++ b/mmdet/core/bbox/assigners/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .approx_max_iou_assigner import ApproxMaxIoUAssigner
+from .ascend_assign_result import AscendAssignResult
+from .ascend_max_iou_assigner import AscendMaxIoUAssigner
+from .assign_result import AssignResult
+from .atss_assigner import ATSSAssigner
+from .base_assigner import BaseAssigner
+from .center_region_assigner import CenterRegionAssigner
+from .grid_assigner import GridAssigner
+from .hungarian_assigner import HungarianAssigner
+from .mask_hungarian_assigner import MaskHungarianAssigner
+from .max_iou_assigner import MaxIoUAssigner
+from .point_assigner import PointAssigner
+from .region_assigner import RegionAssigner
+from .sim_ota_assigner import SimOTAAssigner
+from .task_aligned_assigner import TaskAlignedAssigner
+from .uniform_assigner import UniformAssigner
+
+__all__ = [
+ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
+ 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
+ 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner',
+ 'TaskAlignedAssigner', 'MaskHungarianAssigner', 'AscendAssignResult',
+ 'AscendMaxIoUAssigner'
+]
diff --git a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..304d09c3fba3def3fb0320eaba67d3b967cf5f11
--- /dev/null
+++ b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
@@ -0,0 +1,146 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .max_iou_assigner import MaxIoUAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class ApproxMaxIoUAssigner(MaxIoUAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with an integer indicating the ground truth
+ index. (semi-positive index: gt label (0-based), -1: background)
+
+ - -1: negative sample, no assigned gt
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
+
+ Args:
+ pos_iou_thr (float): IoU threshold for positive bboxes.
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
+ positive bbox. Positive samples can have smaller IoU than
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
+ highest overlap with some gt to that gt.
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
+ `gt_bboxes_ignore` is specified). Negative values mean not
+ ignoring any bboxes.
+ ignore_wrt_candidates (bool): Whether to compute the iof between
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
+ match_low_quality (bool): Whether to allow quality matches. This is
+ usually allowed for RPN and single stage detectors, but not allowed
+ in the second stage.
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
+ assign. When the number of gt is above this threshold, will assign
+ on CPU device. Negative values mean not assign on CPU.
+ """
+
+ def __init__(self,
+ pos_iou_thr,
+ neg_iou_thr,
+ min_pos_iou=.0,
+ gt_max_assign_all=True,
+ ignore_iof_thr=-1,
+ ignore_wrt_candidates=True,
+ match_low_quality=True,
+ gpu_assign_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.pos_iou_thr = pos_iou_thr
+ self.neg_iou_thr = neg_iou_thr
+ self.min_pos_iou = min_pos_iou
+ self.gt_max_assign_all = gt_max_assign_all
+ self.ignore_iof_thr = ignore_iof_thr
+ self.ignore_wrt_candidates = ignore_wrt_candidates
+ self.gpu_assign_thr = gpu_assign_thr
+ self.match_low_quality = match_low_quality
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def assign(self,
+ approxs,
+ squares,
+ approxs_per_octave,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ gt_labels=None):
+ """Assign gt to approxs.
+
+ This method assign a gt bbox to each group of approxs (bboxes),
+ each group of approxs is represent by a base approx (bbox) and
+ will be assigned with -1, or a semi-positive number.
+ background_label (-1) means negative sample,
+ semi-positive number is the index (0-based) of assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. assign every bbox to background_label (-1)
+ 2. use the max IoU of each group of approxs to assign
+ 2. assign proposals whose iou with all gts < neg_iou_thr to background
+ 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
+ assign it to that bbox
+ 4. for each gt bbox, assign its nearest proposals (may be more than
+ one) to itself
+
+ Args:
+ approxs (Tensor): Bounding boxes to be assigned,
+ shape(approxs_per_octave*n, 4).
+ squares (Tensor): Base Bounding boxes to be assigned,
+ shape(n, 4).
+ approxs_per_octave (int): number of approxs per octave
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ num_squares = squares.size(0)
+ num_gts = gt_bboxes.size(0)
+
+ if num_squares == 0 or num_gts == 0:
+ # No predictions and/or truth, return empty assignment
+ overlaps = approxs.new(num_gts, num_squares)
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+ return assign_result
+
+ # re-organize anchors by approxs_per_octave x num_squares
+ approxs = torch.transpose(
+ approxs.view(num_squares, approxs_per_octave, 4), 0,
+ 1).contiguous().view(-1, 4)
+ assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
+ num_gts > self.gpu_assign_thr) else False
+ # compute overlap and assign gt on CPU when number of GT is large
+ if assign_on_cpu:
+ device = approxs.device
+ approxs = approxs.cpu()
+ gt_bboxes = gt_bboxes.cpu()
+ if gt_bboxes_ignore is not None:
+ gt_bboxes_ignore = gt_bboxes_ignore.cpu()
+ if gt_labels is not None:
+ gt_labels = gt_labels.cpu()
+ all_overlaps = self.iou_calculator(approxs, gt_bboxes)
+
+ overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares,
+ num_gts).max(dim=0)
+ overlaps = torch.transpose(overlaps, 0, 1)
+
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
+ and gt_bboxes_ignore.numel() > 0 and squares.numel() > 0):
+ if self.ignore_wrt_candidates:
+ ignore_overlaps = self.iou_calculator(
+ squares, gt_bboxes_ignore, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+ else:
+ ignore_overlaps = self.iou_calculator(
+ gt_bboxes_ignore, squares, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
+ overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
+
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+ if assign_on_cpu:
+ assign_result.gt_inds = assign_result.gt_inds.to(device)
+ assign_result.max_overlaps = assign_result.max_overlaps.to(device)
+ if assign_result.labels is not None:
+ assign_result.labels = assign_result.labels.to(device)
+ return assign_result
diff --git a/mmdet/core/bbox/assigners/ascend_assign_result.py b/mmdet/core/bbox/assigners/ascend_assign_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..03d33c2b59a5e7c8dec8157c43d89987e57ec6c1
--- /dev/null
+++ b/mmdet/core/bbox/assigners/ascend_assign_result.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmdet.utils import util_mixins
+
+
+class AscendAssignResult(util_mixins.NiceRepr):
+ """Stores ascend assignments between predicted and truth boxes.
+
+ Arguments:
+ batch_num_gts (list[int]): the number of truth boxes considered.
+ batch_pos_mask (IntTensor): Positive samples mask in all images.
+ batch_neg_mask (IntTensor): Negative samples mask in all images.
+ batch_max_overlaps (FloatTensor): The max overlaps of all bboxes
+ and ground truth boxes.
+ batch_anchor_gt_indes(None | LongTensor): The assigned truth
+ box index of all anchors.
+ batch_anchor_gt_labels(None | LongTensor): The gt labels
+ of all anchors
+ """
+
+ def __init__(self,
+ batch_num_gts,
+ batch_pos_mask,
+ batch_neg_mask,
+ batch_max_overlaps,
+ batch_anchor_gt_indes=None,
+ batch_anchor_gt_labels=None):
+ self.batch_num_gts = batch_num_gts
+ self.batch_pos_mask = batch_pos_mask
+ self.batch_neg_mask = batch_neg_mask
+ self.batch_max_overlaps = batch_max_overlaps
+ self.batch_anchor_gt_indes = batch_anchor_gt_indes
+ self.batch_anchor_gt_labels = batch_anchor_gt_labels
+ # Interface for possible user-defined properties
+ self._extra_properties = {}
diff --git a/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8f528aead66d8c5e60d40915d8317e867af163c
--- /dev/null
+++ b/mmdet/core/bbox/assigners/ascend_max_iou_assigner.py
@@ -0,0 +1,178 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ....utils import masked_fill
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .ascend_assign_result import AscendAssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class AscendMaxIoUAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `-1`, or a semi-positive integer
+ indicating the ground truth index.
+
+ - -1: negative sample, no assigned gt
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
+
+ Args:
+ pos_iou_thr (float): IoU threshold for positive bboxes.
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
+ positive bbox. Positive samples can have smaller IoU than
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+ `min_pos_iou` is set to avoid assigning bboxes that have extremely
+ small iou with GT as positive samples. It brings about 0.3 mAP
+ improvements in 1x schedule but does not affect the performance of
+ 3x schedule. More comparisons can be found in
+ `PR #7464 `_.
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
+ highest overlap with some gt to that gt.
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
+ `gt_bboxes_ignore` is specified). Negative values mean not
+ ignoring any bboxes.
+ ignore_wrt_candidates (bool): Whether to compute the iof between
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
+ match_low_quality (bool): Whether to allow low quality matches. This is
+ usually allowed for RPN and single stage detectors, but not allowed
+ in the second stage. Details are demonstrated in Step 4.
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
+ assign. When the number of gt is above this threshold, will assign
+ on CPU device. Negative values mean not assign on CPU.
+ """
+
+ def __init__(self,
+ pos_iou_thr,
+ neg_iou_thr,
+ min_pos_iou=.0,
+ gt_max_assign_all=True,
+ ignore_iof_thr=-1,
+ ignore_wrt_candidates=True,
+ match_low_quality=True,
+ gpu_assign_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.pos_iou_thr = pos_iou_thr
+ self.neg_iou_thr = neg_iou_thr
+ self.min_pos_iou = min_pos_iou
+ self.gt_max_assign_all = gt_max_assign_all
+ self.ignore_iof_thr = ignore_iof_thr
+ self.ignore_wrt_candidates = ignore_wrt_candidates
+ self.gpu_assign_thr = gpu_assign_thr
+ self.match_low_quality = match_low_quality
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def assign(self,
+ batch_bboxes,
+ batch_gt_bboxes,
+ batch_gt_bboxes_ignore=None,
+ batch_gt_labels=None,
+ batch_bboxes_ignore_mask=None,
+ batch_num_gts=None):
+ """Assign gt to bboxes.
+
+ Args:
+ batch_bboxes (Tensor): Bounding boxes to be assigned,
+ shape(b, n, 4).
+ batch_gt_bboxes (Tensor): Ground truth boxes,
+ shape (b, k, 4).
+ batch_gt_bboxes_ignore (Tensor, optional): Ground truth
+ bboxes that are labelled as `ignored`,
+ e.g., crowd boxes in COCO.
+ batch_gt_labels (Tensor, optional): Label of gt_bboxes,
+ shape (b, k, ).
+ batch_bboxes_ignore_mask: (b, n)
+ batch_num_gts:(b, )
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ batch_overlaps = self.iou_calculator(batch_gt_bboxes, batch_bboxes)
+ batch_overlaps = masked_fill(
+ batch_overlaps,
+ batch_bboxes_ignore_mask.unsqueeze(1).float(),
+ -1,
+ neg=True)
+ if self.ignore_iof_thr > 0 and batch_gt_bboxes_ignore is not None:
+ if self.ignore_wrt_candidates:
+ batch_ignore_overlaps = self.iou_calculator(
+ batch_bboxes, batch_gt_bboxes_ignore, mode='iof')
+ batch_ignore_overlaps = masked_fill(batch_ignore_overlaps,
+ batch_bboxes_ignore_mask,
+ -1)
+ batch_ignore_max_overlaps, _ = batch_ignore_overlaps.max(dim=2)
+ else:
+ batch_ignore_overlaps = self.iou_calculator(
+ batch_gt_bboxes_ignore, batch_bboxes, mode='iof')
+ batch_ignore_overlaps = masked_fill(batch_ignore_overlaps,
+ batch_bboxes_ignore_mask,
+ -1)
+ batch_ignore_max_overlaps, _ = \
+ batch_ignore_overlaps.max(dim=1)
+ batch_ignore_mask = \
+ batch_ignore_max_overlaps > self.ignore_iof_thr
+ batch_overlaps = masked_fill(batch_overlaps, batch_ignore_mask, -1)
+ batch_assign_result = self.batch_assign_wrt_overlaps(
+ batch_overlaps, batch_gt_labels, batch_num_gts)
+ return batch_assign_result
+
+ def batch_assign_wrt_overlaps(self,
+ batch_overlaps,
+ batch_gt_labels=None,
+ batch_num_gts=None):
+ num_images, num_gts, num_bboxes = batch_overlaps.size()
+ batch_max_overlaps, batch_argmax_overlaps = batch_overlaps.max(dim=1)
+ if isinstance(self.neg_iou_thr, float):
+ batch_neg_mask = \
+ ((batch_max_overlaps >= 0)
+ & (batch_max_overlaps < self.neg_iou_thr)).int()
+ elif isinstance(self.neg_iou_thr, tuple):
+ assert len(self.neg_iou_thr) == 2
+ batch_neg_mask = \
+ ((batch_max_overlaps >= self.neg_iou_thr[0])
+ & (batch_max_overlaps < self.neg_iou_thr[1])).int()
+ else:
+ batch_neg_mask = torch.zeros(
+ batch_max_overlaps.size(),
+ dtype=torch.int,
+ device=batch_max_overlaps.device)
+ batch_pos_mask = (batch_max_overlaps >= self.pos_iou_thr).int()
+ if self.match_low_quality:
+ batch_gt_max_overlaps, batch_gt_argmax_overlaps = \
+ batch_overlaps.max(dim=2)
+ batch_index_bool = (batch_gt_max_overlaps >= self.min_pos_iou) & \
+ (batch_gt_max_overlaps > 0)
+ if self.gt_max_assign_all:
+ pos_inds_low_quality = \
+ (batch_overlaps == batch_gt_max_overlaps.unsqueeze(2)) & \
+ batch_index_bool.unsqueeze(2)
+ for i in range(num_gts):
+ pos_inds_low_quality_gt = pos_inds_low_quality[:, i, :]
+ batch_argmax_overlaps[pos_inds_low_quality_gt] = i
+ batch_pos_mask[pos_inds_low_quality_gt] = 1
+ else:
+ index_temp = torch.arange(
+ 0, num_gts, device=batch_max_overlaps.device)
+ for index_image in range(num_images):
+ gt_argmax_overlaps = batch_gt_argmax_overlaps[index_image]
+ index_bool = batch_index_bool[index_image]
+ pos_inds_low_quality = gt_argmax_overlaps[index_bool]
+ batch_argmax_overlaps[index_image][pos_inds_low_quality] \
+ = index_temp[index_bool]
+ batch_pos_mask[index_image][pos_inds_low_quality] = 1
+ batch_neg_mask = batch_neg_mask * (1 - batch_pos_mask)
+ if batch_gt_labels is not None:
+ batch_anchor_gt_labels = torch.zeros((num_images, num_bboxes),
+ dtype=batch_gt_labels.dtype,
+ device=batch_gt_labels.device)
+ for index_image in range(num_images):
+ batch_anchor_gt_labels[index_image] = torch.index_select(
+ batch_gt_labels[index_image], 0,
+ batch_argmax_overlaps[index_image])
+ else:
+ batch_anchor_gt_labels = None
+ return AscendAssignResult(batch_num_gts, batch_pos_mask,
+ batch_neg_mask, batch_max_overlaps,
+ batch_argmax_overlaps,
+ batch_anchor_gt_labels)
diff --git a/mmdet/core/bbox/assigners/assign_result.py b/mmdet/core/bbox/assigners/assign_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..488010b5d903d0f51ada89a472d6843de1412116
--- /dev/null
+++ b/mmdet/core/bbox/assigners/assign_result.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.utils import util_mixins
+
+
+class AssignResult(util_mixins.NiceRepr):
+ """Stores assignments between predicted and truth boxes.
+
+ Attributes:
+ num_gts (int): the number of truth boxes considered when computing this
+ assignment
+
+ gt_inds (LongTensor): for each predicted box indicates the 1-based
+ index of the assigned truth box. 0 means unassigned and -1 means
+ ignore.
+
+ max_overlaps (FloatTensor): the iou between the predicted box and its
+ assigned truth box.
+
+ labels (None | LongTensor): If specified, for each predicted box
+ indicates the category label of the assigned truth box.
+
+ Example:
+ >>> # An assign result between 4 predicted boxes and 9 true boxes
+ >>> # where only two boxes were assigned.
+ >>> num_gts = 9
+ >>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
+ >>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
+ >>> labels = torch.LongTensor([0, 3, 4, 0])
+ >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
+
+ >>> # Force addition of gt labels (when adding gt as proposals)
+ >>> new_labels = torch.LongTensor([3, 4, 5])
+ >>> self.add_gt_(new_labels)
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
+
+ """
+
+ def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
+ self.num_gts = num_gts
+ self.gt_inds = gt_inds
+ self.max_overlaps = max_overlaps
+ self.labels = labels
+ # Interface for possible user-defined properties
+ self._extra_properties = {}
+
+ @property
+ def num_preds(self):
+ """int: the number of predictions in this assignment"""
+ return len(self.gt_inds)
+
+ def set_extra_property(self, key, value):
+ """Set user-defined new property."""
+ assert key not in self.info
+ self._extra_properties[key] = value
+
+ def get_extra_property(self, key):
+ """Get user-defined property."""
+ return self._extra_properties.get(key, None)
+
+ @property
+ def info(self):
+ """dict: a dictionary of info about the object"""
+ basic_info = {
+ 'num_gts': self.num_gts,
+ 'num_preds': self.num_preds,
+ 'gt_inds': self.gt_inds,
+ 'max_overlaps': self.max_overlaps,
+ 'labels': self.labels,
+ }
+ basic_info.update(self._extra_properties)
+ return basic_info
+
+ def __nice__(self):
+ """str: a "nice" summary string describing this assign result"""
+ parts = []
+ parts.append(f'num_gts={self.num_gts!r}')
+ if self.gt_inds is None:
+ parts.append(f'gt_inds={self.gt_inds!r}')
+ else:
+ parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}')
+ if self.max_overlaps is None:
+ parts.append(f'max_overlaps={self.max_overlaps!r}')
+ else:
+ parts.append('max_overlaps.shape='
+ f'{tuple(self.max_overlaps.shape)!r}')
+ if self.labels is None:
+ parts.append(f'labels={self.labels!r}')
+ else:
+ parts.append(f'labels.shape={tuple(self.labels.shape)!r}')
+ return ', '.join(parts)
+
+ @classmethod
+ def random(cls, **kwargs):
+ """Create random AssignResult for tests or debugging.
+
+ Args:
+ num_preds: number of predicted boxes
+ num_gts: number of true boxes
+ p_ignore (float): probability of a predicted box assigned to an
+ ignored truth
+ p_assigned (float): probability of a predicted box not being
+ assigned
+ p_use_label (float | bool): with labels or not
+ rng (None | int | numpy.random.RandomState): seed or state
+
+ Returns:
+ :obj:`AssignResult`: Randomly generated assign results.
+
+ Example:
+ >>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
+ >>> self = AssignResult.random()
+ >>> print(self.info)
+ """
+ from mmdet.core.bbox import demodata
+ rng = demodata.ensure_rng(kwargs.get('rng', None))
+
+ num_gts = kwargs.get('num_gts', None)
+ num_preds = kwargs.get('num_preds', None)
+ p_ignore = kwargs.get('p_ignore', 0.3)
+ p_assigned = kwargs.get('p_assigned', 0.7)
+ p_use_label = kwargs.get('p_use_label', 0.5)
+ num_classes = kwargs.get('p_use_label', 3)
+
+ if num_gts is None:
+ num_gts = rng.randint(0, 8)
+ if num_preds is None:
+ num_preds = rng.randint(0, 16)
+
+ if num_gts == 0:
+ max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
+ if p_use_label is True or p_use_label < rng.rand():
+ labels = torch.zeros(num_preds, dtype=torch.int64)
+ else:
+ labels = None
+ else:
+ import numpy as np
+
+ # Create an overlap for each predicted box
+ max_overlaps = torch.from_numpy(rng.rand(num_preds))
+
+ # Construct gt_inds for each predicted box
+ is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
+ # maximum number of assignments constraints
+ n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
+
+ assigned_idxs = np.where(is_assigned)[0]
+ rng.shuffle(assigned_idxs)
+ assigned_idxs = assigned_idxs[0:n_assigned]
+ assigned_idxs.sort()
+
+ is_assigned[:] = 0
+ is_assigned[assigned_idxs] = True
+
+ is_ignore = torch.from_numpy(
+ rng.rand(num_preds) < p_ignore) & is_assigned
+
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
+
+ true_idxs = np.arange(num_gts)
+ rng.shuffle(true_idxs)
+ true_idxs = torch.from_numpy(true_idxs)
+ gt_inds[is_assigned] = true_idxs[:n_assigned].long()
+
+ gt_inds = torch.from_numpy(
+ rng.randint(1, num_gts + 1, size=num_preds))
+ gt_inds[is_ignore] = -1
+ gt_inds[~is_assigned] = 0
+ max_overlaps[~is_assigned] = 0
+
+ if p_use_label is True or p_use_label < rng.rand():
+ if num_classes == 0:
+ labels = torch.zeros(num_preds, dtype=torch.int64)
+ else:
+ labels = torch.from_numpy(
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ rng.randint(0, num_classes, size=num_preds))
+ labels[~is_assigned] = 0
+ else:
+ labels = None
+
+ self = cls(num_gts, gt_inds, max_overlaps, labels)
+ return self
+
+ def add_gt_(self, gt_labels):
+ """Add ground truth as assigned results.
+
+ Args:
+ gt_labels (torch.Tensor): Labels of gt boxes
+ """
+ self_inds = torch.arange(
+ 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
+ self.gt_inds = torch.cat([self_inds, self.gt_inds])
+
+ self.max_overlaps = torch.cat(
+ [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
+
+ if self.labels is not None:
+ self.labels = torch.cat([gt_labels, self.labels])
diff --git a/mmdet/core/bbox/assigners/atss_assigner.py b/mmdet/core/bbox/assigners/atss_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c8281e50b38df5a663ef183ff75e8cf7b0b195
--- /dev/null
+++ b/mmdet/core/bbox/assigners/atss_assigner.py
@@ -0,0 +1,234 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class ATSSAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `0` or a positive integer
+ indicating the ground truth index.
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ If ``alpha`` is not None, it means that the dynamic cost
+ ATSSAssigner is adopted, which is currently only used in the DDOD.
+
+ Args:
+ topk (float): number of bbox selected in each level
+ """
+
+ def __init__(self,
+ topk,
+ alpha=None,
+ iou_calculator=dict(type='BboxOverlaps2D'),
+ ignore_iof_thr=-1):
+ self.topk = topk
+ self.alpha = alpha
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+ self.ignore_iof_thr = ignore_iof_thr
+
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Args:
+ topk (int): number of bbox selected in each level.
+ alpha (float): param of cost rate for each proposal only in DDOD.
+ Default None.
+ iou_calculator (dict): builder of IoU calculator.
+ Default dict(type='BboxOverlaps2D').
+ ignore_iof_thr (int): whether ignore max overlaps or not.
+ Default -1 (1 or -1).
+ """
+
+ # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
+ def assign(self,
+ bboxes,
+ num_level_bboxes,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ gt_labels=None,
+ cls_scores=None,
+ bbox_preds=None):
+ """Assign gt to bboxes.
+
+ The assignment is done in following steps
+
+ 1. compute iou between all bbox (bbox of all pyramid levels) and gt
+ 2. compute center distance between all bbox and gt
+ 3. on each pyramid level, for each gt, select k bbox whose center
+ are closest to the gt center, so we total select k*l bbox as
+ candidates for each gt
+ 4. get corresponding iou for the these candidates, and compute the
+ mean and std, set mean + std as the iou threshold
+ 5. select these candidates whose iou are greater than or equal to
+ the threshold as positive
+ 6. limit the positive sample's center in gt
+
+ If ``alpha`` is not None, and ``cls_scores`` and `bbox_preds`
+ are not None, the overlaps calculation in the first step
+ will also include dynamic cost, which is currently only used in
+ the DDOD.
+
+ Args:
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+ num_level_bboxes (List): num of bboxes in each level
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO. Default None.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes. Default None.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4. Default None.
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ INF = 100000000
+ bboxes = bboxes[:, :4]
+ num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
+
+ message = 'Invalid alpha parameter because cls_scores or ' \
+ 'bbox_preds are None. If you want to use the ' \
+ 'cost-based ATSSAssigner, please set cls_scores, ' \
+ 'bbox_preds and self.alpha at the same time. '
+
+ if self.alpha is None:
+ # ATSSAssigner
+ overlaps = self.iou_calculator(bboxes, gt_bboxes)
+ if cls_scores is not None or bbox_preds is not None:
+ warnings.warn(message)
+ else:
+ # Dynamic cost ATSSAssigner in DDOD
+ assert cls_scores is not None and bbox_preds is not None, message
+
+ # compute cls cost for bbox and GT
+ cls_cost = torch.sigmoid(cls_scores[:, gt_labels])
+
+ # compute iou between all bbox and gt
+ overlaps = self.iou_calculator(bbox_preds, gt_bboxes)
+
+ # make sure that we are in element-wise multiplication
+ assert cls_cost.shape == overlaps.shape
+
+ # overlaps is actually a cost matrix
+ overlaps = cls_cost**(1 - self.alpha) * overlaps**self.alpha
+
+ # assign 0 by default
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
+ 0,
+ dtype=torch.long)
+
+ if num_gt == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
+ if num_gt == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+ # compute center distance between all bbox and gt
+ gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
+ gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
+ gt_points = torch.stack((gt_cx, gt_cy), dim=1)
+
+ bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
+ bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
+ bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1)
+
+ distances = (bboxes_points[:, None, :] -
+ gt_points[None, :, :]).pow(2).sum(-1).sqrt()
+
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
+ and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
+ ignore_overlaps = self.iou_calculator(
+ bboxes, gt_bboxes_ignore, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+ ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
+ distances[ignore_idxs, :] = INF
+ assigned_gt_inds[ignore_idxs] = -1
+
+ # Selecting candidates based on the center distance
+ candidate_idxs = []
+ start_idx = 0
+ for level, bboxes_per_level in enumerate(num_level_bboxes):
+ # on each pyramid level, for each gt,
+ # select k bbox whose center are closest to the gt center
+ end_idx = start_idx + bboxes_per_level
+ distances_per_level = distances[start_idx:end_idx, :]
+ selectable_k = min(self.topk, bboxes_per_level)
+
+ _, topk_idxs_per_level = distances_per_level.topk(
+ selectable_k, dim=0, largest=False)
+ candidate_idxs.append(topk_idxs_per_level + start_idx)
+ start_idx = end_idx
+ candidate_idxs = torch.cat(candidate_idxs, dim=0)
+
+ # get corresponding iou for the these candidates, and compute the
+ # mean and std, set mean + std as the iou threshold
+ candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
+ overlaps_mean_per_gt = candidate_overlaps.mean(0)
+ overlaps_std_per_gt = candidate_overlaps.std(0)
+ overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
+
+ is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
+
+ # limit the positive sample's center in gt
+ for gt_idx in range(num_gt):
+ candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
+ ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
+ num_gt, num_bboxes).contiguous().view(-1)
+ ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
+ num_gt, num_bboxes).contiguous().view(-1)
+ candidate_idxs = candidate_idxs.view(-1)
+
+ # calculate the left, top, right, bottom distance between positive
+ # bbox center and gt side
+ l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
+ t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
+ r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
+ b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
+ is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
+
+ is_pos = is_pos & is_in_gts
+
+ # if an anchor box is assigned to multiple gts,
+ # the one with the highest IoU will be selected.
+ overlaps_inf = torch.full_like(overlaps,
+ -INF).t().contiguous().view(-1)
+ index = candidate_idxs.view(-1)[is_pos.view(-1)]
+ overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
+ overlaps_inf = overlaps_inf.view(num_gt, -1).t()
+
+ max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
+ assigned_gt_inds[
+ max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+ return AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/base_assigner.py b/mmdet/core/bbox/assigners/base_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c2d597a5b12275a8941a5d87c56f05dbc955071
--- /dev/null
+++ b/mmdet/core/bbox/assigners/base_assigner.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseAssigner(metaclass=ABCMeta):
+ """Base assigner that assigns boxes to ground truth boxes."""
+
+ @abstractmethod
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign boxes to either a ground truth boxes or a negative boxes."""
diff --git a/mmdet/core/bbox/assigners/center_region_assigner.py b/mmdet/core/bbox/assigners/center_region_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..86e78597d8efa3313d126cc4707d9c6ef1d16e85
--- /dev/null
+++ b/mmdet/core/bbox/assigners/center_region_assigner.py
@@ -0,0 +1,336 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+def scale_boxes(bboxes, scale):
+ """Expand an array of boxes by a given scale.
+
+ Args:
+ bboxes (Tensor): Shape (m, 4)
+ scale (float): The scale factor of bboxes
+
+ Returns:
+ (Tensor): Shape (m, 4). Scaled bboxes
+ """
+ assert bboxes.size(1) == 4
+ w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
+ h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
+ x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
+ y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
+
+ w_half *= scale
+ h_half *= scale
+
+ boxes_scaled = torch.zeros_like(bboxes)
+ boxes_scaled[:, 0] = x_c - w_half
+ boxes_scaled[:, 2] = x_c + w_half
+ boxes_scaled[:, 1] = y_c - h_half
+ boxes_scaled[:, 3] = y_c + h_half
+ return boxes_scaled
+
+
+def is_located_in(points, bboxes):
+ """Are points located in bboxes.
+
+ Args:
+ points (Tensor): Points, shape: (m, 2).
+ bboxes (Tensor): Bounding boxes, shape: (n, 4).
+
+ Return:
+ Tensor: Flags indicating if points are located in bboxes, shape: (m, n).
+ """
+ assert points.size(1) == 2
+ assert bboxes.size(1) == 4
+ return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \
+ (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \
+ (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \
+ (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0))
+
+
+def bboxes_area(bboxes):
+ """Compute the area of an array of bboxes.
+
+ Args:
+ bboxes (Tensor): The coordinates ox bboxes. Shape: (m, 4)
+
+ Returns:
+ Tensor: Area of the bboxes. Shape: (m, )
+ """
+ assert bboxes.size(1) == 4
+ w = (bboxes[:, 2] - bboxes[:, 0])
+ h = (bboxes[:, 3] - bboxes[:, 1])
+ areas = w * h
+ return areas
+
+
+@BBOX_ASSIGNERS.register_module()
+class CenterRegionAssigner(BaseAssigner):
+ """Assign pixels at the center region of a bbox as positive.
+
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
+ indicating the ground truth index.
+ - -1: negative samples
+ - semi-positive numbers: positive sample, index (0-based) of assigned gt
+
+ Args:
+ pos_scale (float): Threshold within which pixels are
+ labelled as positive.
+ neg_scale (float): Threshold above which pixels are
+ labelled as positive.
+ min_pos_iof (float): Minimum iof of a pixel with a gt to be
+ labelled as positive. Default: 1e-2
+ ignore_gt_scale (float): Threshold within which the pixels
+ are ignored when the gt is labelled as shadowed. Default: 0.5
+ foreground_dominate (bool): If True, the bbox will be assigned as
+ positive when a gt's kernel region overlaps with another's shadowed
+ (ignored) region, otherwise it is set as ignored. Default to False.
+ """
+
+ def __init__(self,
+ pos_scale,
+ neg_scale,
+ min_pos_iof=1e-2,
+ ignore_gt_scale=0.5,
+ foreground_dominate=False,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.pos_scale = pos_scale
+ self.neg_scale = neg_scale
+ self.min_pos_iof = min_pos_iof
+ self.ignore_gt_scale = ignore_gt_scale
+ self.foreground_dominate = foreground_dominate
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def get_gt_priorities(self, gt_bboxes):
+ """Get gt priorities according to their areas.
+
+ Smaller gt has higher priority.
+
+ Args:
+ gt_bboxes (Tensor): Ground truth boxes, shape (k, 4).
+
+ Returns:
+ Tensor: The priority of gts so that gts with larger priority is \
+ more likely to be assigned. Shape (k, )
+ """
+ gt_areas = bboxes_area(gt_bboxes)
+ # Rank all gt bbox areas. Smaller objects has larger priority
+ _, sort_idx = gt_areas.sort(descending=True)
+ sort_idx = sort_idx.argsort()
+ return sort_idx
+
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign gt to bboxes.
+
+ This method assigns gts to every bbox (proposal/anchor), each bbox \
+ will be assigned with -1, or a semi-positive number. -1 means \
+ negative sample, semi-positive number is the index (0-based) of \
+ assigned gt.
+
+ Args:
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (tensor, optional): Label of gt_bboxes, shape (num_gts,).
+
+ Returns:
+ :obj:`AssignResult`: The assigned result. Note that \
+ shadowed_labels of shape (N, 2) is also added as an \
+ `assign_result` attribute. `shadowed_labels` is a tensor \
+ composed of N pairs of anchor_ind, class_label], where N \
+ is the number of anchors that lie in the outer region of a \
+ gt, anchor_ind is the shadowed anchor index and class_label \
+ is the shadowed class label.
+
+ Example:
+ >>> self = CenterRegionAssigner(0.2, 0.2)
+ >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
+ >>> gt_bboxes = torch.Tensor([[0, 0, 10, 10]])
+ >>> assign_result = self.assign(bboxes, gt_bboxes)
+ >>> expected_gt_inds = torch.LongTensor([1, 0])
+ >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
+ """
+ # There are in total 5 steps in the pixel assignment
+ # 1. Find core (the center region, say inner 0.2)
+ # and shadow (the relatively ourter part, say inner 0.2-0.5)
+ # regions of every gt.
+ # 2. Find all prior bboxes that lie in gt_core and gt_shadow regions
+ # 3. Assign prior bboxes in gt_core with a one-hot id of the gt in
+ # the image.
+ # 3.1. For overlapping objects, the prior bboxes in gt_core is
+ # assigned with the object with smallest area
+ # 4. Assign prior bboxes with class label according to its gt id.
+ # 4.1. Assign -1 to prior bboxes lying in shadowed gts
+ # 4.2. Assign positive prior boxes with the corresponding label
+ # 5. Find pixels lying in the shadow of an object and assign them with
+ # background label, but set the loss weight of its corresponding
+ # gt to zero.
+ assert bboxes.size(1) == 4, 'bboxes must have size of 4'
+ # 1. Find core positive and shadow region of every gt
+ gt_core = scale_boxes(gt_bboxes, self.pos_scale)
+ gt_shadow = scale_boxes(gt_bboxes, self.neg_scale)
+
+ # 2. Find prior bboxes that lie in gt_core and gt_shadow regions
+ bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2]) / 2
+ # The center points lie within the gt boxes
+ is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes)
+ # Only calculate bbox and gt_core IoF. This enables small prior bboxes
+ # to match large gts
+ bbox_and_gt_core_overlaps = self.iou_calculator(
+ bboxes, gt_core, mode='iof')
+ # The center point of effective priors should be within the gt box
+ is_bbox_in_gt_core = is_bbox_in_gt & (
+ bbox_and_gt_core_overlaps > self.min_pos_iof) # shape (n, k)
+
+ is_bbox_in_gt_shadow = (
+ self.iou_calculator(bboxes, gt_shadow, mode='iof') >
+ self.min_pos_iof)
+ # Rule out center effective positive pixels
+ is_bbox_in_gt_shadow &= (~is_bbox_in_gt_core)
+
+ num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
+ if num_gts == 0 or num_bboxes == 0:
+ # If no gts exist, assign all pixels to negative
+ assigned_gt_ids = \
+ is_bbox_in_gt_core.new_zeros((num_bboxes,),
+ dtype=torch.long)
+ pixels_in_gt_shadow = assigned_gt_ids.new_empty((0, 2))
+ else:
+ # Step 3: assign a one-hot gt id to each pixel, and smaller objects
+ # have high priority to assign the pixel.
+ sort_idx = self.get_gt_priorities(gt_bboxes)
+ assigned_gt_ids, pixels_in_gt_shadow = \
+ self.assign_one_hot_gt_indices(is_bbox_in_gt_core,
+ is_bbox_in_gt_shadow,
+ gt_priority=sort_idx)
+
+ if gt_bboxes_ignore is not None and gt_bboxes_ignore.numel() > 0:
+ # No ground truth or boxes, return empty assignment
+ gt_bboxes_ignore = scale_boxes(
+ gt_bboxes_ignore, scale=self.ignore_gt_scale)
+ is_bbox_in_ignored_gts = is_located_in(bbox_centers,
+ gt_bboxes_ignore)
+ is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1)
+ assigned_gt_ids[is_bbox_in_ignored_gts] = -1
+
+ # 4. Assign prior bboxes with class label according to its gt id.
+ assigned_labels = None
+ shadowed_pixel_labels = None
+ if gt_labels is not None:
+ # Default assigned label is the background (-1)
+ assigned_labels = assigned_gt_ids.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_ids > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[assigned_gt_ids[pos_inds]
+ - 1]
+ # 5. Find pixels lying in the shadow of an object
+ shadowed_pixel_labels = pixels_in_gt_shadow.clone()
+ if pixels_in_gt_shadow.numel() > 0:
+ pixel_idx, gt_idx =\
+ pixels_in_gt_shadow[:, 0], pixels_in_gt_shadow[:, 1]
+ assert (assigned_gt_ids[pixel_idx] != gt_idx).all(), \
+ 'Some pixels are dually assigned to ignore and gt!'
+ shadowed_pixel_labels[:, 1] = gt_labels[gt_idx - 1]
+ override = (
+ assigned_labels[pixel_idx] == shadowed_pixel_labels[:, 1])
+ if self.foreground_dominate:
+ # When a pixel is both positive and shadowed, set it as pos
+ shadowed_pixel_labels = shadowed_pixel_labels[~override]
+ else:
+ # When a pixel is both pos and shadowed, set it as shadowed
+ assigned_labels[pixel_idx[override]] = -1
+ assigned_gt_ids[pixel_idx[override]] = 0
+
+ assign_result = AssignResult(
+ num_gts, assigned_gt_ids, None, labels=assigned_labels)
+ # Add shadowed_labels as assign_result property. Shape: (num_shadow, 2)
+ assign_result.set_extra_property('shadowed_labels',
+ shadowed_pixel_labels)
+ return assign_result
+
+ def assign_one_hot_gt_indices(self,
+ is_bbox_in_gt_core,
+ is_bbox_in_gt_shadow,
+ gt_priority=None):
+ """Assign only one gt index to each prior box.
+
+ Gts with large gt_priority are more likely to be assigned.
+
+ Args:
+ is_bbox_in_gt_core (Tensor): Bool tensor indicating the bbox center
+ is in the core area of a gt (e.g. 0-0.2).
+ Shape: (num_prior, num_gt).
+ is_bbox_in_gt_shadow (Tensor): Bool tensor indicating the bbox
+ center is in the shadowed area of a gt (e.g. 0.2-0.5).
+ Shape: (num_prior, num_gt).
+ gt_priority (Tensor): Priorities of gts. The gt with a higher
+ priority is more likely to be assigned to the bbox when the bbox
+ match with multiple gts. Shape: (num_gt, ).
+
+ Returns:
+ tuple: Returns (assigned_gt_inds, shadowed_gt_inds).
+
+ - assigned_gt_inds: The assigned gt index of each prior bbox \
+ (i.e. index from 1 to num_gts). Shape: (num_prior, ).
+ - shadowed_gt_inds: shadowed gt indices. It is a tensor of \
+ shape (num_ignore, 2) with first column being the \
+ shadowed prior bbox indices and the second column the \
+ shadowed gt indices (1-based).
+ """
+ num_bboxes, num_gts = is_bbox_in_gt_core.shape
+
+ if gt_priority is None:
+ gt_priority = torch.arange(
+ num_gts, device=is_bbox_in_gt_core.device)
+ assert gt_priority.size(0) == num_gts
+ # The bigger gt_priority, the more preferable to be assigned
+ # The assigned inds are by default 0 (background)
+ assigned_gt_inds = is_bbox_in_gt_core.new_zeros((num_bboxes, ),
+ dtype=torch.long)
+ # Shadowed bboxes are assigned to be background. But the corresponding
+ # label is ignored during loss calculation, which is done through
+ # shadowed_gt_inds
+ shadowed_gt_inds = torch.nonzero(is_bbox_in_gt_shadow, as_tuple=False)
+ if is_bbox_in_gt_core.sum() == 0: # No gt match
+ shadowed_gt_inds[:, 1] += 1 # 1-based. For consistency issue
+ return assigned_gt_inds, shadowed_gt_inds
+
+ # The priority of each prior box and gt pair. If one prior box is
+ # matched bo multiple gts. Only the pair with the highest priority
+ # is saved
+ pair_priority = is_bbox_in_gt_core.new_full((num_bboxes, num_gts),
+ -1,
+ dtype=torch.long)
+
+ # Each bbox could match with multiple gts.
+ # The following codes deal with this situation
+ # Matched bboxes (to any gt). Shape: (num_pos_anchor, )
+ inds_of_match = torch.any(is_bbox_in_gt_core, dim=1)
+ # The matched gt index of each positive bbox. Length >= num_pos_anchor
+ # , since one bbox could match multiple gts
+ matched_bbox_gt_inds = torch.nonzero(
+ is_bbox_in_gt_core, as_tuple=False)[:, 1]
+ # Assign priority to each bbox-gt pair.
+ pair_priority[is_bbox_in_gt_core] = gt_priority[matched_bbox_gt_inds]
+ _, argmax_priority = pair_priority[inds_of_match].max(dim=1)
+ assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based
+ # Zero-out the assigned anchor box to filter the shadowed gt indices
+ is_bbox_in_gt_core[inds_of_match, argmax_priority] = 0
+ # Concat the shadowed indices due to overlapping with that out side of
+ # effective scale. shape: (total_num_ignore, 2)
+ shadowed_gt_inds = torch.cat(
+ (shadowed_gt_inds, torch.nonzero(
+ is_bbox_in_gt_core, as_tuple=False)),
+ dim=0)
+ # `is_bbox_in_gt_core` should be changed back to keep arguments intact.
+ is_bbox_in_gt_core[inds_of_match, argmax_priority] = 1
+ # 1-based shadowed gt indices, to be consistent with `assigned_gt_inds`
+ if shadowed_gt_inds.numel() > 0:
+ shadowed_gt_inds[:, 1] += 1
+ return assigned_gt_inds, shadowed_gt_inds
diff --git a/mmdet/core/bbox/assigners/grid_assigner.py b/mmdet/core/bbox/assigners/grid_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0c814e782ebc79600cae4ca4e66b4ebaf47c81e
--- /dev/null
+++ b/mmdet/core/bbox/assigners/grid_assigner.py
@@ -0,0 +1,156 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class GridAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
+ indicating the ground truth index.
+
+ - -1: don't care
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ pos_iou_thr (float): IoU threshold for positive bboxes.
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
+ positive bbox. Positive samples can have smaller IoU than
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
+ highest overlap with some gt to that gt.
+ """
+
+ def __init__(self,
+ pos_iou_thr,
+ neg_iou_thr,
+ min_pos_iou=.0,
+ gt_max_assign_all=True,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.pos_iou_thr = pos_iou_thr
+ self.neg_iou_thr = neg_iou_thr
+ self.min_pos_iou = min_pos_iou
+ self.gt_max_assign_all = gt_max_assign_all
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def assign(self, bboxes, box_responsible_flags, gt_bboxes, gt_labels=None):
+ """Assign gt to bboxes. The process is very much like the max iou
+ assigner, except that positive samples are constrained within the cell
+ that the gt boxes fell in.
+
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
+ will be assigned with -1, 0, or a positive number. -1 means don't care,
+ 0 means negative sample, positive number is the index (1-based) of
+ assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. assign every bbox to -1
+ 2. assign proposals whose iou with all gts <= neg_iou_thr to 0
+ 3. for each bbox within a cell, if the iou with its nearest gt >
+ pos_iou_thr and the center of that gt falls inside the cell,
+ assign it to that bbox
+ 4. for each gt bbox, assign its nearest proposals within the cell the
+ gt bbox falls in to itself.
+
+ Args:
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+ box_responsible_flags (Tensor): flag to indicate whether box is
+ responsible for prediction, shape(n, )
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
+
+ # compute iou between all gt and bboxes
+ overlaps = self.iou_calculator(gt_bboxes, bboxes)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
+ if num_gts == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ max_overlaps,
+ labels=assigned_labels)
+
+ # 2. assign negative: below
+ # for each anchor, which gt best overlaps with it
+ # for each anchor, the max iou of all gts
+ # shape of max_overlaps == argmax_overlaps == num_bboxes
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+ if isinstance(self.neg_iou_thr, float):
+ assigned_gt_inds[(max_overlaps >= 0)
+ & (max_overlaps <= self.neg_iou_thr)] = 0
+ elif isinstance(self.neg_iou_thr, (tuple, list)):
+ assert len(self.neg_iou_thr) == 2
+ assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0])
+ & (max_overlaps <= self.neg_iou_thr[1])] = 0
+
+ # 3. assign positive: falls into responsible cell and above
+ # positive IOU threshold, the order matters.
+ # the prior condition of comparison is to filter out all
+ # unrelated anchors, i.e. not box_responsible_flags
+ overlaps[:, ~box_responsible_flags.type(torch.bool)] = -1.
+
+ # calculate max_overlaps again, but this time we only consider IOUs
+ # for anchors responsible for prediction
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+ # for each gt, which anchor best overlaps with it
+ # for each gt, the max iou of all proposals
+ # shape of gt_max_overlaps == gt_argmax_overlaps == num_gts
+ gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
+
+ pos_inds = (max_overlaps >
+ self.pos_iou_thr) & box_responsible_flags.type(torch.bool)
+ assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
+
+ # 4. assign positive to max overlapped anchors within responsible cell
+ for i in range(num_gts):
+ if gt_max_overlaps[i] > self.min_pos_iou:
+ if self.gt_max_assign_all:
+ max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \
+ box_responsible_flags.type(torch.bool)
+ assigned_gt_inds[max_iou_inds] = i + 1
+ elif box_responsible_flags[gt_argmax_overlaps[i]]:
+ assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
+
+ # assign labels of positive anchors
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/hungarian_assigner.py b/mmdet/core/bbox/assigners/hungarian_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..435612ada124824962b5176ed4d7d2f804c704b2
--- /dev/null
+++ b/mmdet/core/bbox/assigners/hungarian_assigner.py
@@ -0,0 +1,139 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from scipy.optimize import linear_sum_assignment
+
+from ..builder import BBOX_ASSIGNERS
+from ..match_costs import build_match_cost
+from ..transforms import bbox_cxcywh_to_xyxy
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class HungarianAssigner(BaseAssigner):
+ """Computes one-to-one matching between predictions and ground truth.
+
+ This class computes an assignment between the targets and the predictions
+ based on the costs. The costs are weighted sum of three components:
+ classification cost, regression L1 cost and regression iou cost. The
+ targets don't include the no_object, so generally there are more
+ predictions than targets. After the one-to-one matching, the un-matched
+ are treated as backgrounds. Thus each query prediction will be assigned
+ with `0` or a positive integer indicating the ground truth index:
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ cls_weight (int | float, optional): The scale factor for classification
+ cost. Default 1.0.
+ bbox_weight (int | float, optional): The scale factor for regression
+ L1 cost. Default 1.0.
+ iou_weight (int | float, optional): The scale factor for regression
+ iou cost. Default 1.0.
+ iou_calculator (dict | optional): The config for the iou calculation.
+ Default type `BboxOverlaps2D`.
+ iou_mode (str | optional): "iou" (intersection over union), "iof"
+ (intersection over foreground), or "giou" (generalized
+ intersection over union). Default "giou".
+ """
+
+ def __init__(self,
+ cls_cost=dict(type='ClassificationCost', weight=1.),
+ reg_cost=dict(type='BBoxL1Cost', weight=1.0),
+ iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0)):
+ self.cls_cost = build_match_cost(cls_cost)
+ self.reg_cost = build_match_cost(reg_cost)
+ self.iou_cost = build_match_cost(iou_cost)
+
+ def assign(self,
+ bbox_pred,
+ cls_pred,
+ gt_bboxes,
+ gt_labels,
+ img_meta,
+ gt_bboxes_ignore=None,
+ eps=1e-7):
+ """Computes one-to-one matching based on the weighted costs.
+
+ This method assign each query prediction to a ground truth or
+ background. The `assigned_gt_inds` with -1 means don't care,
+ 0 means negative sample, and positive number is the index (1-based)
+ of assigned gt.
+ The assignment is done in the following steps, the order matters.
+
+ 1. assign every prediction to -1
+ 2. compute the weighted costs
+ 3. do Hungarian matching on CPU based on the costs
+ 4. assign all to 0 (background) first, then for each matched pair
+ between predictions and gts, treat this prediction as foreground
+ and assign the corresponding gt index (plus 1) to it.
+
+ Args:
+ bbox_pred (Tensor): Predicted boxes with normalized coordinates
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
+ [num_query, 4].
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ gt_bboxes (Tensor): Ground truth boxes with unnormalized
+ coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+ img_meta (dict): Meta information for current image.
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`. Default None.
+ eps (int | float, optional): A value added to the denominator for
+ numerical stability. Default 1e-7.
+
+ Returns:
+ :obj:`AssignResult`: The assigned result.
+ """
+ assert gt_bboxes_ignore is None, \
+ 'Only case when gt_bboxes_ignore is None is supported.'
+ num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ assigned_labels = bbox_pred.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ if num_gts == 0:
+ # No ground truth, assign all to background
+ assigned_gt_inds[:] = 0
+ return AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
+ img_h, img_w, _ = img_meta['img_shape']
+ factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
+ img_h]).unsqueeze(0)
+
+ # 2. compute the weighted costs
+ # classification and bboxcost.
+ cls_cost = self.cls_cost(cls_pred, gt_labels)
+ # regression L1 cost
+ normalize_gt_bboxes = gt_bboxes / factor
+ reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes)
+ # regression iou cost, defaultly giou is used in official DETR.
+ bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
+ iou_cost = self.iou_cost(bboxes, gt_bboxes)
+ # weighted sum of above three costs
+ cost = cls_cost + reg_cost + iou_cost
+
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu()
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(
+ bbox_pred.device)
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(
+ bbox_pred.device)
+
+ # 4. assign backgrounds and foregrounds
+ # assign all indices to backgrounds first
+ assigned_gt_inds[:] = 0
+ # assign foregrounds based on matching results
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
+ assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
+ return AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/mask_hungarian_assigner.py b/mmdet/core/bbox/assigners/mask_hungarian_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..d83def1d09557abc83875a8923fce7f00bb0c8e5
--- /dev/null
+++ b/mmdet/core/bbox/assigners/mask_hungarian_assigner.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from scipy.optimize import linear_sum_assignment
+
+from mmdet.core.bbox.builder import BBOX_ASSIGNERS
+from mmdet.core.bbox.match_costs.builder import build_match_cost
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class MaskHungarianAssigner(BaseAssigner):
+ """Computes one-to-one matching between predictions and ground truth for
+ mask.
+
+ This class computes an assignment between the targets and the predictions
+ based on the costs. The costs are weighted sum of three components:
+ classification cost, mask focal cost and mask dice cost. The
+ targets don't include the no_object, so generally there are more
+ predictions than targets. After the one-to-one matching, the un-matched
+ are treated as backgrounds. Thus each query prediction will be assigned
+ with `0` or a positive integer indicating the ground truth index:
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ cls_cost (:obj:`mmcv.ConfigDict` | dict): Classification cost config.
+ mask_cost (:obj:`mmcv.ConfigDict` | dict): Mask cost config.
+ dice_cost (:obj:`mmcv.ConfigDict` | dict): Dice cost config.
+ """
+
+ def __init__(self,
+ cls_cost=dict(type='ClassificationCost', weight=1.0),
+ mask_cost=dict(
+ type='FocalLossCost', weight=1.0, binary_input=True),
+ dice_cost=dict(type='DiceCost', weight=1.0)):
+ self.cls_cost = build_match_cost(cls_cost)
+ self.mask_cost = build_match_cost(mask_cost)
+ self.dice_cost = build_match_cost(dice_cost)
+
+ def assign(self,
+ cls_pred,
+ mask_pred,
+ gt_labels,
+ gt_mask,
+ img_meta,
+ gt_bboxes_ignore=None,
+ eps=1e-7):
+ """Computes one-to-one matching based on the weighted costs.
+
+ Args:
+ cls_pred (Tensor | None): Class prediction in shape
+ (num_query, cls_out_channels).
+ mask_pred (Tensor): Mask prediction in shape (num_query, H, W).
+ gt_labels (Tensor): Label of 'gt_mask'in shape = (num_gt, ).
+ gt_mask (Tensor): Ground truth mask in shape = (num_gt, H, W).
+ img_meta (dict): Meta information for current image.
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`. Default None.
+ eps (int | float, optional): A value added to the denominator for
+ numerical stability. Default 1e-7.
+
+ Returns:
+ :obj:`AssignResult`: The assigned result.
+ """
+ assert gt_bboxes_ignore is None, \
+ 'Only case when gt_bboxes_ignore is None is supported.'
+ # K-Net sometimes passes cls_pred=None to this assigner.
+ # So we should use the shape of mask_pred
+ num_gt, num_query = gt_labels.shape[0], mask_pred.shape[0]
+
+ # 1. assign -1 by default
+ assigned_gt_inds = mask_pred.new_full((num_query, ),
+ -1,
+ dtype=torch.long)
+ assigned_labels = mask_pred.new_full((num_query, ),
+ -1,
+ dtype=torch.long)
+ if num_gt == 0 or num_query == 0:
+ # No ground truth or boxes, return empty assignment
+ if num_gt == 0:
+ # No ground truth, assign all to background
+ assigned_gt_inds[:] = 0
+ return AssignResult(
+ num_gt, assigned_gt_inds, None, labels=assigned_labels)
+
+ # 2. compute the weighted costs
+ # classification and maskcost.
+ if self.cls_cost.weight != 0 and cls_pred is not None:
+ cls_cost = self.cls_cost(cls_pred, gt_labels)
+ else:
+ cls_cost = 0
+
+ if self.mask_cost.weight != 0:
+ # mask_pred shape = [num_query, h, w]
+ # gt_mask shape = [num_gt, h, w]
+ # mask_cost shape = [num_query, num_gt]
+ mask_cost = self.mask_cost(mask_pred, gt_mask)
+ else:
+ mask_cost = 0
+
+ if self.dice_cost.weight != 0:
+ dice_cost = self.dice_cost(mask_pred, gt_mask)
+ else:
+ dice_cost = 0
+ cost = cls_cost + mask_cost + dice_cost
+
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu()
+
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(
+ mask_pred.device)
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(
+ mask_pred.device)
+
+ # 4. assign backgrounds and foregrounds
+ # assign all indices to backgrounds first
+ assigned_gt_inds[:] = 0
+ # assign foregrounds based on matching results
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
+ assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
+ return AssignResult(
+ num_gt, assigned_gt_inds, None, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..676421f7653f37e936c7152ed64bebe80564d147
--- /dev/null
+++ b/mmdet/core/bbox/assigners/max_iou_assigner.py
@@ -0,0 +1,218 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class MaxIoUAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `-1`, or a semi-positive integer
+ indicating the ground truth index.
+
+ - -1: negative sample, no assigned gt
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
+
+ Args:
+ pos_iou_thr (float): IoU threshold for positive bboxes.
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
+ positive bbox. Positive samples can have smaller IoU than
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+ `min_pos_iou` is set to avoid assigning bboxes that have extremely
+ small iou with GT as positive samples. It brings about 0.3 mAP
+ improvements in 1x schedule but does not affect the performance of
+ 3x schedule. More comparisons can be found in
+ `PR #7464 `_.
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
+ highest overlap with some gt to that gt.
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
+ `gt_bboxes_ignore` is specified). Negative values mean not
+ ignoring any bboxes.
+ ignore_wrt_candidates (bool): Whether to compute the iof between
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
+ match_low_quality (bool): Whether to allow low quality matches. This is
+ usually allowed for RPN and single stage detectors, but not allowed
+ in the second stage. Details are demonstrated in Step 4.
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
+ assign. When the number of gt is above this threshold, will assign
+ on CPU device. Negative values mean not assign on CPU.
+ """
+
+ def __init__(self,
+ pos_iou_thr,
+ neg_iou_thr,
+ min_pos_iou=.0,
+ gt_max_assign_all=True,
+ ignore_iof_thr=-1,
+ ignore_wrt_candidates=True,
+ match_low_quality=True,
+ gpu_assign_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.pos_iou_thr = pos_iou_thr
+ self.neg_iou_thr = neg_iou_thr
+ self.min_pos_iou = min_pos_iou
+ self.gt_max_assign_all = gt_max_assign_all
+ self.ignore_iof_thr = ignore_iof_thr
+ self.ignore_wrt_candidates = ignore_wrt_candidates
+ self.gpu_assign_thr = gpu_assign_thr
+ self.match_low_quality = match_low_quality
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign gt to bboxes.
+
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
+ will be assigned with -1, or a semi-positive number. -1 means negative
+ sample, semi-positive number is the index (0-based) of assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. assign every bbox to the background
+ 2. assign proposals whose iou with all gts < neg_iou_thr to 0
+ 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
+ assign it to that bbox
+ 4. for each gt bbox, assign its nearest proposals (may be more than
+ one) to itself
+
+ Args:
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+
+ Example:
+ >>> self = MaxIoUAssigner(0.5, 0.5)
+ >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
+ >>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]])
+ >>> assign_result = self.assign(bboxes, gt_bboxes)
+ >>> expected_gt_inds = torch.LongTensor([1, 0])
+ >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
+ """
+ assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
+ gt_bboxes.shape[0] > self.gpu_assign_thr) else False
+ # compute overlap and assign gt on CPU when number of GT is large
+ if assign_on_cpu:
+ device = bboxes.device
+ bboxes = bboxes.cpu()
+ gt_bboxes = gt_bboxes.cpu()
+ if gt_bboxes_ignore is not None:
+ gt_bboxes_ignore = gt_bboxes_ignore.cpu()
+ if gt_labels is not None:
+ gt_labels = gt_labels.cpu()
+
+ overlaps = self.iou_calculator(gt_bboxes, bboxes)
+
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
+ and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
+ if self.ignore_wrt_candidates:
+ ignore_overlaps = self.iou_calculator(
+ bboxes, gt_bboxes_ignore, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+ else:
+ ignore_overlaps = self.iou_calculator(
+ gt_bboxes_ignore, bboxes, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
+ overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
+
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+ if assign_on_cpu:
+ assign_result.gt_inds = assign_result.gt_inds.to(device)
+ assign_result.max_overlaps = assign_result.max_overlaps.to(device)
+ if assign_result.labels is not None:
+ assign_result.labels = assign_result.labels.to(device)
+ return assign_result
+
+ def assign_wrt_overlaps(self, overlaps, gt_labels=None):
+ """Assign w.r.t. the overlaps of bboxes with gts.
+
+ Args:
+ overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes,
+ shape(k, n).
+ gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
+ if num_gts == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ max_overlaps,
+ labels=assigned_labels)
+
+ # for each anchor, which gt best overlaps with it
+ # for each anchor, the max iou of all gts
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+ # for each gt, which anchor best overlaps with it
+ # for each gt, the max iou of all proposals
+ gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
+
+ # 2. assign negative: below
+ # the negative inds are set to be 0
+ if isinstance(self.neg_iou_thr, float):
+ assigned_gt_inds[(max_overlaps >= 0)
+ & (max_overlaps < self.neg_iou_thr)] = 0
+ elif isinstance(self.neg_iou_thr, tuple):
+ assert len(self.neg_iou_thr) == 2
+ assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
+ & (max_overlaps < self.neg_iou_thr[1])] = 0
+
+ # 3. assign positive: above positive IoU threshold
+ pos_inds = max_overlaps >= self.pos_iou_thr
+ assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
+
+ if self.match_low_quality:
+ # Low-quality matching will overwrite the assigned_gt_inds assigned
+ # in Step 3. Thus, the assigned gt might not be the best one for
+ # prediction.
+ # For example, if bbox A has 0.9 and 0.8 iou with GT bbox 1 & 2,
+ # bbox 1 will be assigned as the best target for bbox A in step 3.
+ # However, if GT bbox 2's gt_argmax_overlaps = A, bbox A's
+ # assigned_gt_inds will be overwritten to be bbox 2.
+ # This might be the reason that it is not used in ROI Heads.
+ for i in range(num_gts):
+ if gt_max_overlaps[i] >= self.min_pos_iou:
+ if self.gt_max_assign_all:
+ max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
+ assigned_gt_inds[max_iou_inds] = i + 1
+ else:
+ assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/point_assigner.py b/mmdet/core/bbox/assigners/point_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0dc2246320bd271af644992a4309077bc537076
--- /dev/null
+++ b/mmdet/core/bbox/assigners/point_assigner.py
@@ -0,0 +1,134 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class PointAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each point.
+
+ Each proposals will be assigned with `0`, or a positive integer
+ indicating the ground truth index.
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+ """
+
+ def __init__(self, scale=4, pos_num=3):
+ self.scale = scale
+ self.pos_num = pos_num
+
+ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign gt to points.
+
+ This method assign a gt bbox to every points set, each points set
+ will be assigned with the background_label (-1), or a label number.
+ -1 is background, and semi-positive number is the index (0-based) of
+ assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. assign every points to the background_label (-1)
+ 2. A point is assigned to some gt bbox if
+ (i) the point is within the k closest points to the gt bbox
+ (ii) the distance between this point and the gt is smaller than
+ other gt bboxes
+
+ Args:
+ points (Tensor): points to be assigned, shape(n, 3) while last
+ dimension stands for (x, y, stride).
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ NOTE: currently unused.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ num_points = points.shape[0]
+ num_gts = gt_bboxes.shape[0]
+
+ if num_gts == 0 or num_points == 0:
+ # If no truth assign everything to the background
+ assigned_gt_inds = points.new_full((num_points, ),
+ 0,
+ dtype=torch.long)
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = points.new_full((num_points, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
+
+ points_xy = points[:, :2]
+ points_stride = points[:, 2]
+ points_lvl = torch.log2(
+ points_stride).int() # [3...,4...,5...,6...,7...]
+ lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
+
+ # assign gt box
+ gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
+ gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6)
+ scale = self.scale
+ gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) +
+ torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int()
+ gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max)
+
+ # stores the assigned gt index of each point
+ assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long)
+ # stores the assigned gt dist (to this point) of each point
+ assigned_gt_dist = points.new_full((num_points, ), float('inf'))
+ points_range = torch.arange(points.shape[0])
+
+ for idx in range(num_gts):
+ gt_lvl = gt_bboxes_lvl[idx]
+ # get the index of points in this level
+ lvl_idx = gt_lvl == points_lvl
+ points_index = points_range[lvl_idx]
+ # get the points in this level
+ lvl_points = points_xy[lvl_idx, :]
+ # get the center point of gt
+ gt_point = gt_bboxes_xy[[idx], :]
+ # get width and height of gt
+ gt_wh = gt_bboxes_wh[[idx], :]
+ # compute the distance between gt center and
+ # all points in this level
+ points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1)
+ # find the nearest k points to gt center in this level
+ min_dist, min_dist_index = torch.topk(
+ points_gt_dist, self.pos_num, largest=False)
+ # the index of nearest k points to gt center in this level
+ min_dist_points_index = points_index[min_dist_index]
+ # The less_than_recorded_index stores the index
+ # of min_dist that is less then the assigned_gt_dist. Where
+ # assigned_gt_dist stores the dist from previous assigned gt
+ # (if exist) to each point.
+ less_than_recorded_index = min_dist < assigned_gt_dist[
+ min_dist_points_index]
+ # The min_dist_points_index stores the index of points satisfy:
+ # (1) it is k nearest to current gt center in this level.
+ # (2) it is closer to current gt center than other gt center.
+ min_dist_points_index = min_dist_points_index[
+ less_than_recorded_index]
+ # assign the result
+ assigned_gt_inds[min_dist_points_index] = idx + 1
+ assigned_gt_dist[min_dist_points_index] = min_dist[
+ less_than_recorded_index]
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_points, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/region_assigner.py b/mmdet/core/bbox/assigners/region_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1833b89418820562333c7abfc2acea57deba4893
--- /dev/null
+++ b/mmdet/core/bbox/assigners/region_assigner.py
@@ -0,0 +1,222 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.core import anchor_inside_flags
+from ..builder import BBOX_ASSIGNERS
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+def calc_region(bbox, ratio, stride, featmap_size=None):
+ """Calculate region of the box defined by the ratio, the ratio is from the
+ center of the box to every edge."""
+ # project bbox on the feature
+ f_bbox = bbox / stride
+ x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2])
+ y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3])
+ x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2])
+ y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3])
+ if featmap_size is not None:
+ x1 = x1.clamp(min=0, max=featmap_size[1])
+ y1 = y1.clamp(min=0, max=featmap_size[0])
+ x2 = x2.clamp(min=0, max=featmap_size[1])
+ y2 = y2.clamp(min=0, max=featmap_size[0])
+ return (x1, y1, x2, y2)
+
+
+def anchor_ctr_inside_region_flags(anchors, stride, region):
+ """Get the flag indicate whether anchor centers are inside regions."""
+ x1, y1, x2, y2 = region
+ f_anchors = anchors / stride
+ x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5
+ y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5
+ flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2)
+ return flags
+
+
+@BBOX_ASSIGNERS.register_module()
+class RegionAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
+ indicating the ground truth index.
+
+ - -1: don't care
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ center_ratio: ratio of the region in the center of the bbox to
+ define positive sample.
+ ignore_ratio: ratio of the region to define ignore samples.
+ """
+
+ def __init__(self, center_ratio=0.2, ignore_ratio=0.5):
+ self.center_ratio = center_ratio
+ self.ignore_ratio = ignore_ratio
+
+ def assign(self,
+ mlvl_anchors,
+ mlvl_valid_flags,
+ gt_bboxes,
+ img_meta,
+ featmap_sizes,
+ anchor_scale,
+ anchor_strides,
+ gt_bboxes_ignore=None,
+ gt_labels=None,
+ allowed_border=0):
+ """Assign gt to anchors.
+
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
+ will be assigned with -1, 0, or a positive number. -1 means don't care,
+ 0 means negative sample, positive number is the index (1-based) of
+ assigned gt.
+
+ The assignment is done in following steps, and the order matters.
+
+ 1. Assign every anchor to 0 (negative)
+ 2. (For each gt_bboxes) Compute ignore flags based on ignore_region
+ then assign -1 to anchors w.r.t. ignore flags
+ 3. (For each gt_bboxes) Compute pos flags based on center_region then
+ assign gt_bboxes to anchors w.r.t. pos flags
+ 4. (For each gt_bboxes) Compute ignore flags based on adjacent anchor
+ level then assign -1 to anchors w.r.t. ignore flags
+ 5. Assign anchor outside of image to -1
+
+ Args:
+ mlvl_anchors (list[Tensor]): Multi level anchors.
+ mlvl_valid_flags (list[Tensor]): Multi level valid flags.
+ gt_bboxes (Tensor): Ground truth bboxes of image
+ img_meta (dict): Meta info of image.
+ featmap_sizes (list[Tensor]): Feature mapsize each level
+ anchor_scale (int): Scale of the anchor.
+ anchor_strides (list[int]): Stride of the anchor.
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+ allowed_border (int, optional): The border to allow the valid
+ anchor. Defaults to 0.
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ if gt_bboxes_ignore is not None:
+ raise NotImplementedError
+
+ num_gts = gt_bboxes.shape[0]
+ num_bboxes = sum(x.shape[0] for x in mlvl_anchors)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = gt_bboxes.new_zeros((num_bboxes, ))
+ assigned_gt_inds = gt_bboxes.new_zeros((num_bboxes, ),
+ dtype=torch.long)
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = gt_bboxes.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ max_overlaps,
+ labels=assigned_labels)
+
+ num_lvls = len(mlvl_anchors)
+ r1 = (1 - self.center_ratio) / 2
+ r2 = (1 - self.ignore_ratio) / 2
+
+ scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
+ (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
+ min_anchor_size = scale.new_full(
+ (1, ), float(anchor_scale * anchor_strides[0]))
+ target_lvls = torch.floor(
+ torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
+ target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
+
+ # 1. assign 0 (negative) by default
+ mlvl_assigned_gt_inds = []
+ mlvl_ignore_flags = []
+ for lvl in range(num_lvls):
+ h, w = featmap_sizes[lvl]
+ assert h * w == mlvl_anchors[lvl].shape[0]
+ assigned_gt_inds = gt_bboxes.new_full((h * w, ),
+ 0,
+ dtype=torch.long)
+ ignore_flags = torch.zeros_like(assigned_gt_inds)
+ mlvl_assigned_gt_inds.append(assigned_gt_inds)
+ mlvl_ignore_flags.append(ignore_flags)
+
+ for gt_id in range(num_gts):
+ lvl = target_lvls[gt_id].item()
+ featmap_size = featmap_sizes[lvl]
+ stride = anchor_strides[lvl]
+ anchors = mlvl_anchors[lvl]
+ gt_bbox = gt_bboxes[gt_id, :4]
+
+ # Compute regions
+ ignore_region = calc_region(gt_bbox, r2, stride, featmap_size)
+ ctr_region = calc_region(gt_bbox, r1, stride, featmap_size)
+
+ # 2. Assign -1 to ignore flags
+ ignore_flags = anchor_ctr_inside_region_flags(
+ anchors, stride, ignore_region)
+ mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
+
+ # 3. Assign gt_bboxes to pos flags
+ pos_flags = anchor_ctr_inside_region_flags(anchors, stride,
+ ctr_region)
+ mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1
+
+ # 4. Assign -1 to ignore adjacent lvl
+ if lvl > 0:
+ d_lvl = lvl - 1
+ d_anchors = mlvl_anchors[d_lvl]
+ d_featmap_size = featmap_sizes[d_lvl]
+ d_stride = anchor_strides[d_lvl]
+ d_ignore_region = calc_region(gt_bbox, r2, d_stride,
+ d_featmap_size)
+ ignore_flags = anchor_ctr_inside_region_flags(
+ d_anchors, d_stride, d_ignore_region)
+ mlvl_ignore_flags[d_lvl][ignore_flags] = 1
+ if lvl < num_lvls - 1:
+ u_lvl = lvl + 1
+ u_anchors = mlvl_anchors[u_lvl]
+ u_featmap_size = featmap_sizes[u_lvl]
+ u_stride = anchor_strides[u_lvl]
+ u_ignore_region = calc_region(gt_bbox, r2, u_stride,
+ u_featmap_size)
+ ignore_flags = anchor_ctr_inside_region_flags(
+ u_anchors, u_stride, u_ignore_region)
+ mlvl_ignore_flags[u_lvl][ignore_flags] = 1
+
+ # 4. (cont.) Assign -1 to ignore adjacent lvl
+ for lvl in range(num_lvls):
+ ignore_flags = mlvl_ignore_flags[lvl]
+ mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
+
+ # 5. Assign -1 to anchor outside of image
+ flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds)
+ flat_anchors = torch.cat(mlvl_anchors)
+ flat_valid_flags = torch.cat(mlvl_valid_flags)
+ assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] ==
+ flat_valid_flags.shape[0])
+ inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags,
+ img_meta['img_shape'],
+ allowed_border)
+ outside_flags = ~inside_flags
+ flat_assigned_gt_inds[outside_flags] = -1
+
+ if gt_labels is not None:
+ assigned_labels = torch.zeros_like(flat_assigned_gt_inds)
+ pos_flags = assigned_gt_inds > 0
+ assigned_labels[pos_flags] = gt_labels[
+ flat_assigned_gt_inds[pos_flags] - 1]
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, flat_assigned_gt_inds, None, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/sim_ota_assigner.py b/mmdet/core/bbox/assigners/sim_ota_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..58bfef433bad3f1c43df0b950ad92a0619db7641
--- /dev/null
+++ b/mmdet/core/bbox/assigners/sim_ota_assigner.py
@@ -0,0 +1,257 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn.functional as F
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import bbox_overlaps
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class SimOTAAssigner(BaseAssigner):
+ """Computes matching between predictions and ground truth.
+
+ Args:
+ center_radius (int | float, optional): Ground truth center size
+ to judge whether a prior is in center. Default 2.5.
+ candidate_topk (int, optional): The candidate top-k which used to
+ get top-k ious to calculate dynamic-k. Default 10.
+ iou_weight (int | float, optional): The scale factor for regression
+ iou cost. Default 3.0.
+ cls_weight (int | float, optional): The scale factor for classification
+ cost. Default 1.0.
+ """
+
+ def __init__(self,
+ center_radius=2.5,
+ candidate_topk=10,
+ iou_weight=3.0,
+ cls_weight=1.0):
+ self.center_radius = center_radius
+ self.candidate_topk = candidate_topk
+ self.iou_weight = iou_weight
+ self.cls_weight = cls_weight
+
+ def assign(self,
+ pred_scores,
+ priors,
+ decoded_bboxes,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ eps=1e-7):
+ """Assign gt to priors using SimOTA. It will switch to CPU mode when
+ GPU is out of memory.
+ Args:
+ pred_scores (Tensor): Classification scores of one image,
+ a 2D-Tensor with shape [num_priors, num_classes]
+ priors (Tensor): All priors of one image, a 2D-Tensor with shape
+ [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
+ decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
+ [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
+ gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
+ with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (Tensor): Ground truth labels of one image, a Tensor
+ with shape [num_gts].
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ eps (float): A value added to the denominator for numerical
+ stability. Default 1e-7.
+ Returns:
+ assign_result (obj:`AssignResult`): The assigned result.
+ """
+ try:
+ assign_result = self._assign(pred_scores, priors, decoded_bboxes,
+ gt_bboxes, gt_labels,
+ gt_bboxes_ignore, eps)
+ return assign_result
+ except RuntimeError:
+ origin_device = pred_scores.device
+ warnings.warn('OOM RuntimeError is raised due to the huge memory '
+ 'cost during label assignment. CPU mode is applied '
+ 'in this batch. If you want to avoid this issue, '
+ 'try to reduce the batch size or image size.')
+ torch.cuda.empty_cache()
+
+ pred_scores = pred_scores.cpu()
+ priors = priors.cpu()
+ decoded_bboxes = decoded_bboxes.cpu()
+ gt_bboxes = gt_bboxes.cpu().float()
+ gt_labels = gt_labels.cpu()
+
+ assign_result = self._assign(pred_scores, priors, decoded_bboxes,
+ gt_bboxes, gt_labels,
+ gt_bboxes_ignore, eps)
+ assign_result.gt_inds = assign_result.gt_inds.to(origin_device)
+ assign_result.max_overlaps = assign_result.max_overlaps.to(
+ origin_device)
+ assign_result.labels = assign_result.labels.to(origin_device)
+
+ return assign_result
+
+ def _assign(self,
+ pred_scores,
+ priors,
+ decoded_bboxes,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ eps=1e-7):
+ """Assign gt to priors using SimOTA.
+ Args:
+ pred_scores (Tensor): Classification scores of one image,
+ a 2D-Tensor with shape [num_priors, num_classes]
+ priors (Tensor): All priors of one image, a 2D-Tensor with shape
+ [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
+ decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
+ [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
+ gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
+ with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (Tensor): Ground truth labels of one image, a Tensor
+ with shape [num_gts].
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ eps (float): A value added to the denominator for numerical
+ stability. Default 1e-7.
+ Returns:
+ :obj:`AssignResult`: The assigned result.
+ """
+ INF = 100000.0
+ num_gt = gt_bboxes.size(0)
+ num_bboxes = decoded_bboxes.size(0)
+
+ # assign 0 by default
+ assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ),
+ 0,
+ dtype=torch.long)
+ valid_mask, is_in_boxes_and_center = self.get_in_gt_and_in_center_info(
+ priors, gt_bboxes)
+ valid_decoded_bbox = decoded_bboxes[valid_mask]
+ valid_pred_scores = pred_scores[valid_mask]
+ num_valid = valid_decoded_bbox.size(0)
+
+ if num_gt == 0 or num_bboxes == 0 or num_valid == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = decoded_bboxes.new_zeros((num_bboxes, ))
+ if num_gt == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = decoded_bboxes.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+ pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes)
+ iou_cost = -torch.log(pairwise_ious + eps)
+
+ gt_onehot_label = (
+ F.one_hot(gt_labels.to(torch.int64),
+ pred_scores.shape[-1]).float().unsqueeze(0).repeat(
+ num_valid, 1, 1))
+
+ valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
+ cls_cost = (
+ F.binary_cross_entropy(
+ valid_pred_scores.to(dtype=torch.float32).sqrt_(),
+ gt_onehot_label,
+ reduction='none',
+ ).sum(-1).to(dtype=valid_pred_scores.dtype))
+
+ cost_matrix = (
+ cls_cost * self.cls_weight + iou_cost * self.iou_weight +
+ (~is_in_boxes_and_center) * INF)
+
+ matched_pred_ious, matched_gt_inds = \
+ self.dynamic_k_matching(
+ cost_matrix, pairwise_ious, num_gt, valid_mask)
+
+ # convert to AssignResult format
+ assigned_gt_inds[valid_mask] = matched_gt_inds + 1
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
+ max_overlaps = assigned_gt_inds.new_full((num_bboxes, ),
+ -INF,
+ dtype=torch.float32)
+ max_overlaps[valid_mask] = matched_pred_ious
+ return AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+ def get_in_gt_and_in_center_info(self, priors, gt_bboxes):
+ num_gt = gt_bboxes.size(0)
+
+ repeated_x = priors[:, 0].unsqueeze(1).repeat(1, num_gt)
+ repeated_y = priors[:, 1].unsqueeze(1).repeat(1, num_gt)
+ repeated_stride_x = priors[:, 2].unsqueeze(1).repeat(1, num_gt)
+ repeated_stride_y = priors[:, 3].unsqueeze(1).repeat(1, num_gt)
+
+ # is prior centers in gt bboxes, shape: [n_prior, n_gt]
+ l_ = repeated_x - gt_bboxes[:, 0]
+ t_ = repeated_y - gt_bboxes[:, 1]
+ r_ = gt_bboxes[:, 2] - repeated_x
+ b_ = gt_bboxes[:, 3] - repeated_y
+
+ deltas = torch.stack([l_, t_, r_, b_], dim=1)
+ is_in_gts = deltas.min(dim=1).values > 0
+ is_in_gts_all = is_in_gts.sum(dim=1) > 0
+
+ # is prior centers in gt centers
+ gt_cxs = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
+ gt_cys = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
+ ct_box_l = gt_cxs - self.center_radius * repeated_stride_x
+ ct_box_t = gt_cys - self.center_radius * repeated_stride_y
+ ct_box_r = gt_cxs + self.center_radius * repeated_stride_x
+ ct_box_b = gt_cys + self.center_radius * repeated_stride_y
+
+ cl_ = repeated_x - ct_box_l
+ ct_ = repeated_y - ct_box_t
+ cr_ = ct_box_r - repeated_x
+ cb_ = ct_box_b - repeated_y
+
+ ct_deltas = torch.stack([cl_, ct_, cr_, cb_], dim=1)
+ is_in_cts = ct_deltas.min(dim=1).values > 0
+ is_in_cts_all = is_in_cts.sum(dim=1) > 0
+
+ # in boxes or in centers, shape: [num_priors]
+ is_in_gts_or_centers = is_in_gts_all | is_in_cts_all
+
+ # both in boxes and centers, shape: [num_fg, num_gt]
+ is_in_boxes_and_centers = (
+ is_in_gts[is_in_gts_or_centers, :]
+ & is_in_cts[is_in_gts_or_centers, :])
+ return is_in_gts_or_centers, is_in_boxes_and_centers
+
+ def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
+ matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+ # select candidate topk ious for dynamic-k calculation
+ candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
+ topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
+ # calculate dynamic k for each gt
+ dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
+ for gt_idx in range(num_gt):
+ _, pos_idx = torch.topk(
+ cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
+ matching_matrix[:, gt_idx][pos_idx] = 1
+
+ del topk_ious, dynamic_ks, pos_idx
+
+ prior_match_gt_mask = matching_matrix.sum(1) > 1
+ if prior_match_gt_mask.sum() > 0:
+ cost_min, cost_argmin = torch.min(
+ cost[prior_match_gt_mask, :], dim=1)
+ matching_matrix[prior_match_gt_mask, :] *= 0
+ matching_matrix[prior_match_gt_mask, cost_argmin] = 1
+ # get foreground mask inside box and center prior
+ fg_mask_inboxes = matching_matrix.sum(1) > 0
+ valid_mask[valid_mask.clone()] = fg_mask_inboxes
+
+ matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
+ matched_pred_ious = (matching_matrix *
+ pairwise_ious).sum(1)[fg_mask_inboxes]
+ return matched_pred_ious, matched_gt_inds
diff --git a/mmdet/core/bbox/assigners/task_aligned_assigner.py b/mmdet/core/bbox/assigners/task_aligned_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1872de4a780ab1e7c6b4632e576f8e0644743ca2
--- /dev/null
+++ b/mmdet/core/bbox/assigners/task_aligned_assigner.py
@@ -0,0 +1,151 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+INF = 100000000
+
+
+@BBOX_ASSIGNERS.register_module()
+class TaskAlignedAssigner(BaseAssigner):
+ """Task aligned assigner used in the paper:
+ `TOOD: Task-aligned One-stage Object Detection.
+ `_.
+
+ Assign a corresponding gt bbox or background to each predicted bbox.
+ Each bbox will be assigned with `0` or a positive integer
+ indicating the ground truth index.
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ topk (int): number of bbox selected in each level
+ iou_calculator (dict): Config dict for iou calculator.
+ Default: dict(type='BboxOverlaps2D')
+ """
+
+ def __init__(self, topk, iou_calculator=dict(type='BboxOverlaps2D')):
+ assert topk >= 1
+ self.topk = topk
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def assign(self,
+ pred_scores,
+ decode_bboxes,
+ anchors,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ gt_labels=None,
+ alpha=1,
+ beta=6):
+ """Assign gt to bboxes.
+
+ The assignment is done in following steps
+
+ 1. compute alignment metric between all bbox (bbox of all pyramid
+ levels) and gt
+ 2. select top-k bbox as candidates for each gt
+ 3. limit the positive sample's center in gt (because the anchor-free
+ detector only can predict positive distance)
+
+
+ Args:
+ pred_scores (Tensor): predicted class probability,
+ shape(n, num_classes)
+ decode_bboxes (Tensor): predicted bounding boxes, shape(n, 4)
+ anchors (Tensor): pre-defined anchors, shape(n, 4).
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`TaskAlignedAssignResult`: The assign result.
+ """
+ anchors = anchors[:, :4]
+ num_gt, num_bboxes = gt_bboxes.size(0), anchors.size(0)
+ # compute alignment metric between all bbox and gt
+ overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach()
+ bbox_scores = pred_scores[:, gt_labels].detach()
+ # assign 0 by default
+ assigned_gt_inds = anchors.new_full((num_bboxes, ),
+ 0,
+ dtype=torch.long)
+ assign_metrics = anchors.new_zeros((num_bboxes, ))
+
+ if num_gt == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = anchors.new_zeros((num_bboxes, ))
+ if num_gt == 0:
+ # No gt boxes, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = anchors.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ assign_result = AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+ assign_result.assign_metrics = assign_metrics
+ return assign_result
+
+ # select top-k bboxes as candidates for each gt
+ alignment_metrics = bbox_scores**alpha * overlaps**beta
+ topk = min(self.topk, alignment_metrics.size(0))
+ _, candidate_idxs = alignment_metrics.topk(topk, dim=0, largest=True)
+ candidate_metrics = alignment_metrics[candidate_idxs,
+ torch.arange(num_gt)]
+ is_pos = candidate_metrics > 0
+
+ # limit the positive sample's center in gt
+ anchors_cx = (anchors[:, 0] + anchors[:, 2]) / 2.0
+ anchors_cy = (anchors[:, 1] + anchors[:, 3]) / 2.0
+ for gt_idx in range(num_gt):
+ candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
+ ep_anchors_cx = anchors_cx.view(1, -1).expand(
+ num_gt, num_bboxes).contiguous().view(-1)
+ ep_anchors_cy = anchors_cy.view(1, -1).expand(
+ num_gt, num_bboxes).contiguous().view(-1)
+ candidate_idxs = candidate_idxs.view(-1)
+
+ # calculate the left, top, right, bottom distance between positive
+ # bbox center and gt side
+ l_ = ep_anchors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
+ t_ = ep_anchors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
+ r_ = gt_bboxes[:, 2] - ep_anchors_cx[candidate_idxs].view(-1, num_gt)
+ b_ = gt_bboxes[:, 3] - ep_anchors_cy[candidate_idxs].view(-1, num_gt)
+ is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
+ is_pos = is_pos & is_in_gts
+
+ # if an anchor box is assigned to multiple gts,
+ # the one with the highest iou will be selected.
+ overlaps_inf = torch.full_like(overlaps,
+ -INF).t().contiguous().view(-1)
+ index = candidate_idxs.view(-1)[is_pos.view(-1)]
+ overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
+ overlaps_inf = overlaps_inf.view(num_gt, -1).t()
+
+ max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
+ assigned_gt_inds[
+ max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
+ assign_metrics[max_overlaps != -INF] = alignment_metrics[
+ max_overlaps != -INF, argmax_overlaps[max_overlaps != -INF]]
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+ assign_result = AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+ assign_result.assign_metrics = assign_metrics
+ return assign_result
diff --git a/mmdet/core/bbox/assigners/uniform_assigner.py b/mmdet/core/bbox/assigners/uniform_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..70294fc45f32b2611c6c1521de14f57e4ec446f0
--- /dev/null
+++ b/mmdet/core/bbox/assigners/uniform_assigner.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from ..transforms import bbox_xyxy_to_cxcywh
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class UniformAssigner(BaseAssigner):
+ """Uniform Matching between the anchors and gt boxes, which can achieve
+ balance in positive anchors, and gt_bboxes_ignore was not considered for
+ now.
+
+ Args:
+ pos_ignore_thr (float): the threshold to ignore positive anchors
+ neg_ignore_thr (float): the threshold to ignore negative anchors
+ match_times(int): Number of positive anchors for each gt box.
+ Default 4.
+ iou_calculator (dict): iou_calculator config
+ """
+
+ def __init__(self,
+ pos_ignore_thr,
+ neg_ignore_thr,
+ match_times=4,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.match_times = match_times
+ self.pos_ignore_thr = pos_ignore_thr
+ self.neg_ignore_thr = neg_ignore_thr
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def assign(self,
+ bbox_pred,
+ anchor,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ gt_labels=None):
+ num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
+ 0,
+ dtype=torch.long)
+ assigned_labels = bbox_pred.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ if num_gts == 0:
+ # No ground truth, assign all to background
+ assigned_gt_inds[:] = 0
+ assign_result = AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
+ assign_result.set_extra_property(
+ 'pos_idx', bbox_pred.new_empty(0, dtype=torch.bool))
+ assign_result.set_extra_property('pos_predicted_boxes',
+ bbox_pred.new_empty((0, 4)))
+ assign_result.set_extra_property('target_boxes',
+ bbox_pred.new_empty((0, 4)))
+ return assign_result
+
+ # 2. Compute the L1 cost between boxes
+ # Note that we use anchors and predict boxes both
+ cost_bbox = torch.cdist(
+ bbox_xyxy_to_cxcywh(bbox_pred),
+ bbox_xyxy_to_cxcywh(gt_bboxes),
+ p=1)
+ cost_bbox_anchors = torch.cdist(
+ bbox_xyxy_to_cxcywh(anchor), bbox_xyxy_to_cxcywh(gt_bboxes), p=1)
+
+ # We found that topk function has different results in cpu and
+ # cuda mode. In order to ensure consistency with the source code,
+ # we also use cpu mode.
+ # TODO: Check whether the performance of cpu and cuda are the same.
+ C = cost_bbox.cpu()
+ C1 = cost_bbox_anchors.cpu()
+
+ # self.match_times x n
+ index = torch.topk(
+ C, # c=b,n,x c[i]=n,x
+ k=self.match_times,
+ dim=0,
+ largest=False)[1]
+
+ # self.match_times x n
+ index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1]
+ # (self.match_times*2) x n
+ indexes = torch.cat((index, index1),
+ dim=1).reshape(-1).to(bbox_pred.device)
+
+ pred_overlaps = self.iou_calculator(bbox_pred, gt_bboxes)
+ anchor_overlaps = self.iou_calculator(anchor, gt_bboxes)
+ pred_max_overlaps, _ = pred_overlaps.max(dim=1)
+ anchor_max_overlaps, _ = anchor_overlaps.max(dim=0)
+
+ # 3. Compute the ignore indexes use gt_bboxes and predict boxes
+ ignore_idx = pred_max_overlaps > self.neg_ignore_thr
+ assigned_gt_inds[ignore_idx] = -1
+
+ # 4. Compute the ignore indexes of positive sample use anchors
+ # and predict boxes
+ pos_gt_index = torch.arange(
+ 0, C1.size(1),
+ device=bbox_pred.device).repeat(self.match_times * 2)
+ pos_ious = anchor_overlaps[indexes, pos_gt_index]
+ pos_ignore_idx = pos_ious < self.pos_ignore_thr
+
+ pos_gt_index_with_ignore = pos_gt_index + 1
+ pos_gt_index_with_ignore[pos_ignore_idx] = -1
+ assigned_gt_inds[indexes] = pos_gt_index_with_ignore
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+
+ assign_result = AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ anchor_max_overlaps,
+ labels=assigned_labels)
+ assign_result.set_extra_property('pos_idx', ~pos_ignore_idx)
+ assign_result.set_extra_property('pos_predicted_boxes',
+ bbox_pred[indexes])
+ assign_result.set_extra_property('target_boxes',
+ gt_bboxes[pos_gt_index])
+ return assign_result
diff --git a/mmdet/core/bbox/builder.py b/mmdet/core/bbox/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cfa055b5df8cb73d84580ea1f23b82f5393ca8e
--- /dev/null
+++ b/mmdet/core/bbox/builder.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import Registry, build_from_cfg
+
+BBOX_ASSIGNERS = Registry('bbox_assigner')
+BBOX_SAMPLERS = Registry('bbox_sampler')
+BBOX_CODERS = Registry('bbox_coder')
+
+
+def build_assigner(cfg, **default_args):
+ """Builder of box assigner."""
+ return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args)
+
+
+def build_sampler(cfg, **default_args):
+ """Builder of box sampler."""
+ return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
+
+
+def build_bbox_coder(cfg, **default_args):
+ """Builder of box coder."""
+ return build_from_cfg(cfg, BBOX_CODERS, default_args)
diff --git a/mmdet/core/bbox/coder/__init__.py b/mmdet/core/bbox/coder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e12fd64e12b5e76a014da9bd724f1b6f50b488c4
--- /dev/null
+++ b/mmdet/core/bbox/coder/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_bbox_coder import BaseBBoxCoder
+from .bucketing_bbox_coder import BucketingBBoxCoder
+from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
+from .distance_point_bbox_coder import DistancePointBBoxCoder
+from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder
+from .pseudo_bbox_coder import PseudoBBoxCoder
+from .tblr_bbox_coder import TBLRBBoxCoder
+from .yolo_bbox_coder import YOLOBBoxCoder
+
+__all__ = [
+ 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder',
+ 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder',
+ 'BucketingBBoxCoder', 'DistancePointBBoxCoder'
+]
diff --git a/mmdet/core/bbox/coder/base_bbox_coder.py b/mmdet/core/bbox/coder/base_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7ed041a456e59282c1bf72eaec76bc2c0d1b990
--- /dev/null
+++ b/mmdet/core/bbox/coder/base_bbox_coder.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseBBoxCoder(metaclass=ABCMeta):
+ """Base bounding box coder."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ @abstractmethod
+ def encode(self, bboxes, gt_bboxes):
+ """Encode deltas between bboxes and ground truth boxes."""
+
+ @abstractmethod
+ def decode(self, bboxes, bboxes_pred):
+ """Decode the predicted bboxes according to prediction and base
+ boxes."""
diff --git a/mmdet/core/bbox/coder/bucketing_bbox_coder.py b/mmdet/core/bbox/coder/bucketing_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4be0ada04b410017035443fdfc15d898ed9a0e4b
--- /dev/null
+++ b/mmdet/core/bbox/coder/bucketing_bbox_coder.py
@@ -0,0 +1,351 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ..builder import BBOX_CODERS
+from ..transforms import bbox_rescale
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class BucketingBBoxCoder(BaseBBoxCoder):
+ """Bucketing BBox Coder for Side-Aware Boundary Localization (SABL).
+
+ Boundary Localization with Bucketing and Bucketing Guided Rescoring
+ are implemented here.
+
+ Please refer to https://arxiv.org/abs/1912.04260 for more details.
+
+ Args:
+ num_buckets (int): Number of buckets.
+ scale_factor (int): Scale factor of proposals to generate buckets.
+ offset_topk (int): Topk buckets are used to generate
+ bucket fine regression targets. Defaults to 2.
+ offset_upperbound (float): Offset upperbound to generate
+ bucket fine regression targets.
+ To avoid too large offset displacements. Defaults to 1.0.
+ cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
+ Defaults to True.
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+ """
+
+ def __init__(self,
+ num_buckets,
+ scale_factor,
+ offset_topk=2,
+ offset_upperbound=1.0,
+ cls_ignore_neighbor=True,
+ clip_border=True):
+ super(BucketingBBoxCoder, self).__init__()
+ self.num_buckets = num_buckets
+ self.scale_factor = scale_factor
+ self.offset_topk = offset_topk
+ self.offset_upperbound = offset_upperbound
+ self.cls_ignore_neighbor = cls_ignore_neighbor
+ self.clip_border = clip_border
+
+ def encode(self, bboxes, gt_bboxes):
+ """Get bucketing estimation and fine regression targets during
+ training.
+
+ Args:
+ bboxes (torch.Tensor): source boxes, e.g., object proposals.
+ gt_bboxes (torch.Tensor): target of the transformation, e.g.,
+ ground truth boxes.
+
+ Returns:
+ encoded_bboxes(tuple[Tensor]): bucketing estimation
+ and fine regression targets and weights
+ """
+
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets,
+ self.scale_factor, self.offset_topk,
+ self.offset_upperbound,
+ self.cls_ignore_neighbor)
+ return encoded_bboxes
+
+ def decode(self, bboxes, pred_bboxes, max_shape=None):
+ """Apply transformation `pred_bboxes` to `boxes`.
+ Args:
+ boxes (torch.Tensor): Basic boxes.
+ pred_bboxes (torch.Tensor): Predictions for bucketing estimation
+ and fine regression
+ max_shape (tuple[int], optional): Maximum shape of boxes.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+ assert len(pred_bboxes) == 2
+ cls_preds, offset_preds = pred_bboxes
+ assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size(
+ 0) == bboxes.size(0)
+ decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds,
+ self.num_buckets, self.scale_factor,
+ max_shape, self.clip_border)
+
+ return decoded_bboxes
+
+
+@mmcv.jit(coderize=True)
+def generat_buckets(proposals, num_buckets, scale_factor=1.0):
+ """Generate buckets w.r.t bucket number and scale factor of proposals.
+
+ Args:
+ proposals (Tensor): Shape (n, 4)
+ num_buckets (int): Number of buckets.
+ scale_factor (float): Scale factor to rescale proposals.
+
+ Returns:
+ tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets,
+ t_buckets, d_buckets)
+
+ - bucket_w: Width of buckets on x-axis. Shape (n, ).
+ - bucket_h: Height of buckets on y-axis. Shape (n, ).
+ - l_buckets: Left buckets. Shape (n, ceil(side_num/2)).
+ - r_buckets: Right buckets. Shape (n, ceil(side_num/2)).
+ - t_buckets: Top buckets. Shape (n, ceil(side_num/2)).
+ - d_buckets: Down buckets. Shape (n, ceil(side_num/2)).
+ """
+ proposals = bbox_rescale(proposals, scale_factor)
+
+ # number of buckets in each side
+ side_num = int(np.ceil(num_buckets / 2.0))
+ pw = proposals[..., 2] - proposals[..., 0]
+ ph = proposals[..., 3] - proposals[..., 1]
+ px1 = proposals[..., 0]
+ py1 = proposals[..., 1]
+ px2 = proposals[..., 2]
+ py2 = proposals[..., 3]
+
+ bucket_w = pw / num_buckets
+ bucket_h = ph / num_buckets
+
+ # left buckets
+ l_buckets = px1[:, None] + (0.5 + torch.arange(
+ 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
+ # right buckets
+ r_buckets = px2[:, None] - (0.5 + torch.arange(
+ 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
+ # top buckets
+ t_buckets = py1[:, None] + (0.5 + torch.arange(
+ 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
+ # down buckets
+ d_buckets = py2[:, None] - (0.5 + torch.arange(
+ 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
+ return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets
+
+
+@mmcv.jit(coderize=True)
+def bbox2bucket(proposals,
+ gt,
+ num_buckets,
+ scale_factor,
+ offset_topk=2,
+ offset_upperbound=1.0,
+ cls_ignore_neighbor=True):
+ """Generate buckets estimation and fine regression targets.
+
+ Args:
+ proposals (Tensor): Shape (n, 4)
+ gt (Tensor): Shape (n, 4)
+ num_buckets (int): Number of buckets.
+ scale_factor (float): Scale factor to rescale proposals.
+ offset_topk (int): Topk buckets are used to generate
+ bucket fine regression targets. Defaults to 2.
+ offset_upperbound (float): Offset allowance to generate
+ bucket fine regression targets.
+ To avoid too large offset displacements. Defaults to 1.0.
+ cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
+ Defaults to True.
+
+ Returns:
+ tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights).
+
+ - offsets: Fine regression targets. \
+ Shape (n, num_buckets*2).
+ - offsets_weights: Fine regression weights. \
+ Shape (n, num_buckets*2).
+ - bucket_labels: Bucketing estimation labels. \
+ Shape (n, num_buckets*2).
+ - cls_weights: Bucketing estimation weights. \
+ Shape (n, num_buckets*2).
+ """
+ assert proposals.size() == gt.size()
+
+ # generate buckets
+ proposals = proposals.float()
+ gt = gt.float()
+ (bucket_w, bucket_h, l_buckets, r_buckets, t_buckets,
+ d_buckets) = generat_buckets(proposals, num_buckets, scale_factor)
+
+ gx1 = gt[..., 0]
+ gy1 = gt[..., 1]
+ gx2 = gt[..., 2]
+ gy2 = gt[..., 3]
+
+ # generate offset targets and weights
+ # offsets from buckets to gts
+ l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None]
+ r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None]
+ t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None]
+ d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None]
+
+ # select top-k nearest buckets
+ l_topk, l_label = l_offsets.abs().topk(
+ offset_topk, dim=1, largest=False, sorted=True)
+ r_topk, r_label = r_offsets.abs().topk(
+ offset_topk, dim=1, largest=False, sorted=True)
+ t_topk, t_label = t_offsets.abs().topk(
+ offset_topk, dim=1, largest=False, sorted=True)
+ d_topk, d_label = d_offsets.abs().topk(
+ offset_topk, dim=1, largest=False, sorted=True)
+
+ offset_l_weights = l_offsets.new_zeros(l_offsets.size())
+ offset_r_weights = r_offsets.new_zeros(r_offsets.size())
+ offset_t_weights = t_offsets.new_zeros(t_offsets.size())
+ offset_d_weights = d_offsets.new_zeros(d_offsets.size())
+ inds = torch.arange(0, proposals.size(0)).to(proposals).long()
+
+ # generate offset weights of top-k nearest buckets
+ for k in range(offset_topk):
+ if k >= 1:
+ offset_l_weights[inds, l_label[:,
+ k]] = (l_topk[:, k] <
+ offset_upperbound).float()
+ offset_r_weights[inds, r_label[:,
+ k]] = (r_topk[:, k] <
+ offset_upperbound).float()
+ offset_t_weights[inds, t_label[:,
+ k]] = (t_topk[:, k] <
+ offset_upperbound).float()
+ offset_d_weights[inds, d_label[:,
+ k]] = (d_topk[:, k] <
+ offset_upperbound).float()
+ else:
+ offset_l_weights[inds, l_label[:, k]] = 1.0
+ offset_r_weights[inds, r_label[:, k]] = 1.0
+ offset_t_weights[inds, t_label[:, k]] = 1.0
+ offset_d_weights[inds, d_label[:, k]] = 1.0
+
+ offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1)
+ offsets_weights = torch.cat([
+ offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights
+ ],
+ dim=-1)
+
+ # generate bucket labels and weight
+ side_num = int(np.ceil(num_buckets / 2.0))
+ labels = torch.stack(
+ [l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1)
+
+ batch_size = labels.size(0)
+ bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size,
+ -1).float()
+ bucket_cls_l_weights = (l_offsets.abs() < 1).float()
+ bucket_cls_r_weights = (r_offsets.abs() < 1).float()
+ bucket_cls_t_weights = (t_offsets.abs() < 1).float()
+ bucket_cls_d_weights = (d_offsets.abs() < 1).float()
+ bucket_cls_weights = torch.cat([
+ bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights,
+ bucket_cls_d_weights
+ ],
+ dim=-1)
+ # ignore second nearest buckets for cls if necessary
+ if cls_ignore_neighbor:
+ bucket_cls_weights = (~((bucket_cls_weights == 1) &
+ (bucket_labels == 0))).float()
+ else:
+ bucket_cls_weights[:] = 1.0
+ return offsets, offsets_weights, bucket_labels, bucket_cls_weights
+
+
+@mmcv.jit(coderize=True)
+def bucket2bbox(proposals,
+ cls_preds,
+ offset_preds,
+ num_buckets,
+ scale_factor=1.0,
+ max_shape=None,
+ clip_border=True):
+ """Apply bucketing estimation (cls preds) and fine regression (offset
+ preds) to generate det bboxes.
+
+ Args:
+ proposals (Tensor): Boxes to be transformed. Shape (n, 4)
+ cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2).
+ offset_preds (Tensor): fine regression. Shape (n, num_buckets*2).
+ num_buckets (int): Number of buckets.
+ scale_factor (float): Scale factor to rescale proposals.
+ max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+
+ Returns:
+ tuple[Tensor]: (bboxes, loc_confidence).
+
+ - bboxes: predicted bboxes. Shape (n, 4)
+ - loc_confidence: localization confidence of predicted bboxes.
+ Shape (n,).
+ """
+
+ side_num = int(np.ceil(num_buckets / 2.0))
+ cls_preds = cls_preds.view(-1, side_num)
+ offset_preds = offset_preds.view(-1, side_num)
+
+ scores = F.softmax(cls_preds, dim=1)
+ score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True)
+
+ rescaled_proposals = bbox_rescale(proposals, scale_factor)
+
+ pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0]
+ ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1]
+ px1 = rescaled_proposals[..., 0]
+ py1 = rescaled_proposals[..., 1]
+ px2 = rescaled_proposals[..., 2]
+ py2 = rescaled_proposals[..., 3]
+
+ bucket_w = pw / num_buckets
+ bucket_h = ph / num_buckets
+
+ score_inds_l = score_label[0::4, 0]
+ score_inds_r = score_label[1::4, 0]
+ score_inds_t = score_label[2::4, 0]
+ score_inds_d = score_label[3::4, 0]
+ l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w
+ r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w
+ t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h
+ d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h
+
+ offsets = offset_preds.view(-1, 4, side_num)
+ inds = torch.arange(proposals.size(0)).to(proposals).long()
+ l_offsets = offsets[:, 0, :][inds, score_inds_l]
+ r_offsets = offsets[:, 1, :][inds, score_inds_r]
+ t_offsets = offsets[:, 2, :][inds, score_inds_t]
+ d_offsets = offsets[:, 3, :][inds, score_inds_d]
+
+ x1 = l_buckets - l_offsets * bucket_w
+ x2 = r_buckets - r_offsets * bucket_w
+ y1 = t_buckets - t_offsets * bucket_h
+ y2 = d_buckets - d_offsets * bucket_h
+
+ if clip_border and max_shape is not None:
+ x1 = x1.clamp(min=0, max=max_shape[1] - 1)
+ y1 = y1.clamp(min=0, max=max_shape[0] - 1)
+ x2 = x2.clamp(min=0, max=max_shape[1] - 1)
+ y2 = y2.clamp(min=0, max=max_shape[0] - 1)
+ bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]],
+ dim=-1)
+
+ # bucketing guided rescoring
+ loc_confidence = score_topk[:, 0]
+ top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1
+ loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float()
+ loc_confidence = loc_confidence.view(-1, 4).mean(dim=1)
+
+ return bboxes, loc_confidence
diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7f1c62fa7bde9280f9edcb4926cd77bfdd3a0b4
--- /dev/null
+++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
@@ -0,0 +1,392 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import mmcv
+import numpy as np
+import torch
+
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class DeltaXYWHBBoxCoder(BaseBBoxCoder):
+ """Delta XYWH BBox coder.
+
+ Following the practice in `R-CNN `_,
+ this coder encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh) and
+ decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2).
+
+ Args:
+ target_means (Sequence[float]): Denormalizing means of target for
+ delta coordinates
+ target_stds (Sequence[float]): Denormalizing standard deviation of
+ target for delta coordinates
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+ add_ctr_clamp (bool): Whether to add center clamp, when added, the
+ predicted box is clamped is its center is too far away from
+ the original anchor's center. Only used by YOLOF. Default False.
+ ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
+ Default 32.
+ """
+
+ def __init__(self,
+ target_means=(0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1.),
+ clip_border=True,
+ add_ctr_clamp=False,
+ ctr_clamp=32):
+ super(BaseBBoxCoder, self).__init__()
+ self.means = target_means
+ self.stds = target_stds
+ self.clip_border = clip_border
+ self.add_ctr_clamp = add_ctr_clamp
+ self.ctr_clamp = ctr_clamp
+
+ def encode(self, bboxes, gt_bboxes):
+ """Get box regression transformation deltas that can be used to
+ transform the ``bboxes`` into the ``gt_bboxes``.
+
+ Args:
+ bboxes (torch.Tensor): Source boxes, e.g., object proposals.
+ gt_bboxes (torch.Tensor): Target of the transformation, e.g.,
+ ground-truth boxes.
+
+ Returns:
+ torch.Tensor: Box transformation deltas
+ """
+
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds)
+ return encoded_bboxes
+
+ def decode(self,
+ bboxes,
+ pred_bboxes,
+ max_shape=None,
+ wh_ratio_clip=16 / 1000):
+ """Apply transformation `pred_bboxes` to `boxes`.
+
+ Args:
+ bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
+ pred_bboxes (Tensor): Encoded offsets with respect to each roi.
+ Has shape (B, N, num_classes * 4) or (B, N, 4) or
+ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
+ when rois is a grid of anchors.Offset encoding follows [1]_.
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+ wh_ratio_clip (float, optional): The allowed ratio between
+ width and height.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+
+ assert pred_bboxes.size(0) == bboxes.size(0)
+ if pred_bboxes.ndim == 3:
+ assert pred_bboxes.size(1) == bboxes.size(1)
+
+ if pred_bboxes.ndim == 2 and not torch.onnx.is_in_onnx_export():
+ # single image decode
+ decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means,
+ self.stds, max_shape, wh_ratio_clip,
+ self.clip_border, self.add_ctr_clamp,
+ self.ctr_clamp)
+ else:
+ if pred_bboxes.ndim == 3 and not torch.onnx.is_in_onnx_export():
+ warnings.warn(
+ 'DeprecationWarning: onnx_delta2bbox is deprecated '
+ 'in the case of batch decoding and non-ONNX, '
+ 'please use “delta2bbox” instead. In order to improve '
+ 'the decoding speed, the batch function will no '
+ 'longer be supported. ')
+ decoded_bboxes = onnx_delta2bbox(bboxes, pred_bboxes, self.means,
+ self.stds, max_shape,
+ wh_ratio_clip, self.clip_border,
+ self.add_ctr_clamp,
+ self.ctr_clamp)
+
+ return decoded_bboxes
+
+
+@mmcv.jit(coderize=True)
+def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
+ """Compute deltas of proposals w.r.t. gt.
+
+ We usually compute the deltas of x, y, w, h of proposals w.r.t ground
+ truth bboxes to get regression target.
+ This is the inverse function of :func:`delta2bbox`.
+
+ Args:
+ proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
+ gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
+ means (Sequence[float]): Denormalizing means for delta coordinates
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates
+
+ Returns:
+ Tensor: deltas with shape (N, 4), where columns represent dx, dy,
+ dw, dh.
+ """
+ assert proposals.size() == gt.size()
+
+ proposals = proposals.float()
+ gt = gt.float()
+ px = (proposals[..., 0] + proposals[..., 2]) * 0.5
+ py = (proposals[..., 1] + proposals[..., 3]) * 0.5
+ pw = proposals[..., 2] - proposals[..., 0]
+ ph = proposals[..., 3] - proposals[..., 1]
+
+ gx = (gt[..., 0] + gt[..., 2]) * 0.5
+ gy = (gt[..., 1] + gt[..., 3]) * 0.5
+ gw = gt[..., 2] - gt[..., 0]
+ gh = gt[..., 3] - gt[..., 1]
+
+ dx = (gx - px) / pw
+ dy = (gy - py) / ph
+ dw = torch.log(gw / pw)
+ dh = torch.log(gh / ph)
+ deltas = torch.stack([dx, dy, dw, dh], dim=-1)
+
+ means = deltas.new_tensor(means).unsqueeze(0)
+ stds = deltas.new_tensor(stds).unsqueeze(0)
+ deltas = deltas.sub_(means).div_(stds)
+
+ return deltas
+
+
+@mmcv.jit(coderize=True)
+def delta2bbox(rois,
+ deltas,
+ means=(0., 0., 0., 0.),
+ stds=(1., 1., 1., 1.),
+ max_shape=None,
+ wh_ratio_clip=16 / 1000,
+ clip_border=True,
+ add_ctr_clamp=False,
+ ctr_clamp=32):
+ """Apply deltas to shift/scale base boxes.
+
+ Typically the rois are anchor or proposed bounding boxes and the deltas are
+ network outputs used to shift/scale those boxes.
+ This is the inverse function of :func:`bbox2delta`.
+
+ Args:
+ rois (Tensor): Boxes to be transformed. Has shape (N, 4).
+ deltas (Tensor): Encoded offsets relative to each roi.
+ Has shape (N, num_classes * 4) or (N, 4). Note
+ N = num_base_anchors * W * H, when rois is a grid of
+ anchors. Offset encoding follows [1]_.
+ means (Sequence[float]): Denormalizing means for delta coordinates.
+ Default (0., 0., 0., 0.).
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates. Default (1., 1., 1., 1.).
+ max_shape (tuple[int, int]): Maximum bounds for boxes, specifies
+ (H, W). Default None.
+ wh_ratio_clip (float): Maximum aspect ratio for boxes. Default
+ 16 / 1000.
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Default True.
+ add_ctr_clamp (bool): Whether to add center clamp. When set to True,
+ the center of the prediction bounding box will be clamped to
+ avoid being too far away from the center of the anchor.
+ Only used by YOLOF. Default False.
+ ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
+ Default 32.
+
+ Returns:
+ Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4
+ represent tl_x, tl_y, br_x, br_y.
+
+ References:
+ .. [1] https://arxiv.org/abs/1311.2524
+
+ Example:
+ >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 5., 5., 5., 5.]])
+ >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
+ >>> [ 1., 1., 1., 1.],
+ >>> [ 0., 0., 2., -1.],
+ >>> [ 0.7, -1.9, -0.5, 0.3]])
+ >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
+ tensor([[0.0000, 0.0000, 1.0000, 1.0000],
+ [0.1409, 0.1409, 2.8591, 2.8591],
+ [0.0000, 0.3161, 4.1945, 0.6839],
+ [5.0000, 5.0000, 5.0000, 5.0000]])
+ """
+ num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4
+ if num_bboxes == 0:
+ return deltas
+
+ deltas = deltas.reshape(-1, 4)
+
+ means = deltas.new_tensor(means).view(1, -1)
+ stds = deltas.new_tensor(stds).view(1, -1)
+ denorm_deltas = deltas * stds + means
+
+ dxy = denorm_deltas[:, :2]
+ dwh = denorm_deltas[:, 2:]
+
+ # Compute width/height of each roi
+ rois_ = rois.repeat(1, num_classes).reshape(-1, 4)
+ pxy = ((rois_[:, :2] + rois_[:, 2:]) * 0.5)
+ pwh = (rois_[:, 2:] - rois_[:, :2])
+
+ dxy_wh = pwh * dxy
+
+ max_ratio = np.abs(np.log(wh_ratio_clip))
+ if add_ctr_clamp:
+ dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
+ dwh = torch.clamp(dwh, max=max_ratio)
+ else:
+ dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
+
+ gxy = pxy + dxy_wh
+ gwh = pwh * dwh.exp()
+ x1y1 = gxy - (gwh * 0.5)
+ x2y2 = gxy + (gwh * 0.5)
+ bboxes = torch.cat([x1y1, x2y2], dim=-1)
+ if clip_border and max_shape is not None:
+ bboxes[..., 0::2].clamp_(min=0, max=max_shape[1])
+ bboxes[..., 1::2].clamp_(min=0, max=max_shape[0])
+ bboxes = bboxes.reshape(num_bboxes, -1)
+ return bboxes
+
+
+def onnx_delta2bbox(rois,
+ deltas,
+ means=(0., 0., 0., 0.),
+ stds=(1., 1., 1., 1.),
+ max_shape=None,
+ wh_ratio_clip=16 / 1000,
+ clip_border=True,
+ add_ctr_clamp=False,
+ ctr_clamp=32):
+ """Apply deltas to shift/scale base boxes.
+
+ Typically the rois are anchor or proposed bounding boxes and the deltas are
+ network outputs used to shift/scale those boxes.
+ This is the inverse function of :func:`bbox2delta`.
+
+ Args:
+ rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
+ deltas (Tensor): Encoded offsets with respect to each roi.
+ Has shape (B, N, num_classes * 4) or (B, N, 4) or
+ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
+ when rois is a grid of anchors.Offset encoding follows [1]_.
+ means (Sequence[float]): Denormalizing means for delta coordinates.
+ Default (0., 0., 0., 0.).
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates. Default (1., 1., 1., 1.).
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If rois shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B. Default None.
+ wh_ratio_clip (float): Maximum aspect ratio for boxes.
+ Default 16 / 1000.
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Default True.
+ add_ctr_clamp (bool): Whether to add center clamp, when added, the
+ predicted box is clamped is its center is too far away from
+ the original anchor's center. Only used by YOLOF. Default False.
+ ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF.
+ Default 32.
+
+ Returns:
+ Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or
+ (N, num_classes * 4) or (N, 4), where 4 represent
+ tl_x, tl_y, br_x, br_y.
+
+ References:
+ .. [1] https://arxiv.org/abs/1311.2524
+
+ Example:
+ >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 5., 5., 5., 5.]])
+ >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
+ >>> [ 1., 1., 1., 1.],
+ >>> [ 0., 0., 2., -1.],
+ >>> [ 0.7, -1.9, -0.5, 0.3]])
+ >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
+ tensor([[0.0000, 0.0000, 1.0000, 1.0000],
+ [0.1409, 0.1409, 2.8591, 2.8591],
+ [0.0000, 0.3161, 4.1945, 0.6839],
+ [5.0000, 5.0000, 5.0000, 5.0000]])
+ """
+ means = deltas.new_tensor(means).view(1,
+ -1).repeat(1,
+ deltas.size(-1) // 4)
+ stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
+ denorm_deltas = deltas * stds + means
+ dx = denorm_deltas[..., 0::4]
+ dy = denorm_deltas[..., 1::4]
+ dw = denorm_deltas[..., 2::4]
+ dh = denorm_deltas[..., 3::4]
+
+ x1, y1 = rois[..., 0], rois[..., 1]
+ x2, y2 = rois[..., 2], rois[..., 3]
+ # Compute center of each roi
+ px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
+ py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
+ # Compute width/height of each roi
+ pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
+ ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
+
+ dx_width = pw * dx
+ dy_height = ph * dy
+
+ max_ratio = np.abs(np.log(wh_ratio_clip))
+ if add_ctr_clamp:
+ dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp)
+ dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp)
+ dw = torch.clamp(dw, max=max_ratio)
+ dh = torch.clamp(dh, max=max_ratio)
+ else:
+ dw = dw.clamp(min=-max_ratio, max=max_ratio)
+ dh = dh.clamp(min=-max_ratio, max=max_ratio)
+ # Use exp(network energy) to enlarge/shrink each roi
+ gw = pw * dw.exp()
+ gh = ph * dh.exp()
+ # Use network energy to shift the center of each roi
+ gx = px + dx_width
+ gy = py + dy_height
+ # Convert center-xy/width/height to top-left, bottom-right
+ x1 = gx - gw * 0.5
+ y1 = gy - gh * 0.5
+ x2 = gx + gw * 0.5
+ y2 = gy + gh * 0.5
+
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
+
+ if clip_border and max_shape is not None:
+ # clip bboxes with dynamic `min` and `max` for onnx
+ if torch.onnx.is_in_onnx_export():
+ from mmdet.core.export import dynamic_clip_for_onnx
+ x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape)
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
+ return bboxes
+ if not isinstance(max_shape, torch.Tensor):
+ max_shape = x1.new_tensor(max_shape)
+ max_shape = max_shape[..., :2].type_as(x1)
+ if max_shape.ndim == 2:
+ assert bboxes.ndim == 3
+ assert max_shape.size(0) == bboxes.size(0)
+
+ min_xy = x1.new_tensor(0)
+ max_xy = torch.cat(
+ [max_shape] * (deltas.size(-1) // 2),
+ dim=-1).flip(-1).unsqueeze(-2)
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+ return bboxes
diff --git a/mmdet/core/bbox/coder/distance_point_bbox_coder.py b/mmdet/core/bbox/coder/distance_point_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f308a8419c8ec1c483784599deaf04beae6aa7e
--- /dev/null
+++ b/mmdet/core/bbox/coder/distance_point_bbox_coder.py
@@ -0,0 +1,63 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import BBOX_CODERS
+from ..transforms import bbox2distance, distance2bbox
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class DistancePointBBoxCoder(BaseBBoxCoder):
+ """Distance Point BBox coder.
+
+ This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
+ right) and decode it back to the original.
+
+ Args:
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+ """
+
+ def __init__(self, clip_border=True):
+ super(BaseBBoxCoder, self).__init__()
+ self.clip_border = clip_border
+
+ def encode(self, points, gt_bboxes, max_dis=None, eps=0.1):
+ """Encode bounding box to distances.
+
+ Args:
+ points (Tensor): Shape (N, 2), The format is [x, y].
+ gt_bboxes (Tensor): Shape (N, 4), The format is "xyxy"
+ max_dis (float): Upper bound of the distance. Default None.
+ eps (float): a small value to ensure target < max_dis, instead <=.
+ Default 0.1.
+
+ Returns:
+ Tensor: Box transformation deltas. The shape is (N, 4).
+ """
+ assert points.size(0) == gt_bboxes.size(0)
+ assert points.size(-1) == 2
+ assert gt_bboxes.size(-1) == 4
+ return bbox2distance(points, gt_bboxes, max_dis, eps)
+
+ def decode(self, points, pred_bboxes, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (B, N, 2) or (N, 2).
+ pred_bboxes (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom). Shape (B, N, 4)
+ or (N, 4)
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If priors shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]],
+ and the length of max_shape should also be B.
+ Default None.
+ Returns:
+ Tensor: Boxes with shape (N, 4) or (B, N, 4)
+ """
+ assert points.size(0) == pred_bboxes.size(0)
+ assert points.size(-1) == 2
+ assert pred_bboxes.size(-1) == 4
+ if self.clip_border is False:
+ max_shape = None
+ return distance2bbox(points, pred_bboxes, max_shape)
diff --git a/mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fa348b2d1868342a16c13b7a93a2d7d01007bd4
--- /dev/null
+++ b/mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py
@@ -0,0 +1,216 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import numpy as np
+import torch
+
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class LegacyDeltaXYWHBBoxCoder(BaseBBoxCoder):
+ """Legacy Delta XYWH BBox coder used in MMDet V1.x.
+
+ Following the practice in R-CNN [1]_, this coder encodes bbox (x1, y1, x2,
+ y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh)
+ back to original bbox (x1, y1, x2, y2).
+
+ Note:
+ The main difference between :class`LegacyDeltaXYWHBBoxCoder` and
+ :class:`DeltaXYWHBBoxCoder` is whether ``+ 1`` is used during width and
+ height calculation. We suggest to only use this coder when testing with
+ MMDet V1.x models.
+
+ References:
+ .. [1] https://arxiv.org/abs/1311.2524
+
+ Args:
+ target_means (Sequence[float]): denormalizing means of target for
+ delta coordinates
+ target_stds (Sequence[float]): denormalizing standard deviation of
+ target for delta coordinates
+ """
+
+ def __init__(self,
+ target_means=(0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1.)):
+ super(BaseBBoxCoder, self).__init__()
+ self.means = target_means
+ self.stds = target_stds
+
+ def encode(self, bboxes, gt_bboxes):
+ """Get box regression transformation deltas that can be used to
+ transform the ``bboxes`` into the ``gt_bboxes``.
+
+ Args:
+ bboxes (torch.Tensor): source boxes, e.g., object proposals.
+ gt_bboxes (torch.Tensor): target of the transformation, e.g.,
+ ground-truth boxes.
+
+ Returns:
+ torch.Tensor: Box transformation deltas
+ """
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ encoded_bboxes = legacy_bbox2delta(bboxes, gt_bboxes, self.means,
+ self.stds)
+ return encoded_bboxes
+
+ def decode(self,
+ bboxes,
+ pred_bboxes,
+ max_shape=None,
+ wh_ratio_clip=16 / 1000):
+ """Apply transformation `pred_bboxes` to `boxes`.
+
+ Args:
+ boxes (torch.Tensor): Basic boxes.
+ pred_bboxes (torch.Tensor): Encoded boxes with shape
+ max_shape (tuple[int], optional): Maximum shape of boxes.
+ Defaults to None.
+ wh_ratio_clip (float, optional): The allowed ratio between
+ width and height.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+ assert pred_bboxes.size(0) == bboxes.size(0)
+ decoded_bboxes = legacy_delta2bbox(bboxes, pred_bboxes, self.means,
+ self.stds, max_shape, wh_ratio_clip)
+
+ return decoded_bboxes
+
+
+@mmcv.jit(coderize=True)
+def legacy_bbox2delta(proposals,
+ gt,
+ means=(0., 0., 0., 0.),
+ stds=(1., 1., 1., 1.)):
+ """Compute deltas of proposals w.r.t. gt in the MMDet V1.x manner.
+
+ We usually compute the deltas of x, y, w, h of proposals w.r.t ground
+ truth bboxes to get regression target.
+ This is the inverse function of `delta2bbox()`
+
+ Args:
+ proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
+ gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
+ means (Sequence[float]): Denormalizing means for delta coordinates
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates
+
+ Returns:
+ Tensor: deltas with shape (N, 4), where columns represent dx, dy,
+ dw, dh.
+ """
+ assert proposals.size() == gt.size()
+
+ proposals = proposals.float()
+ gt = gt.float()
+ px = (proposals[..., 0] + proposals[..., 2]) * 0.5
+ py = (proposals[..., 1] + proposals[..., 3]) * 0.5
+ pw = proposals[..., 2] - proposals[..., 0] + 1.0
+ ph = proposals[..., 3] - proposals[..., 1] + 1.0
+
+ gx = (gt[..., 0] + gt[..., 2]) * 0.5
+ gy = (gt[..., 1] + gt[..., 3]) * 0.5
+ gw = gt[..., 2] - gt[..., 0] + 1.0
+ gh = gt[..., 3] - gt[..., 1] + 1.0
+
+ dx = (gx - px) / pw
+ dy = (gy - py) / ph
+ dw = torch.log(gw / pw)
+ dh = torch.log(gh / ph)
+ deltas = torch.stack([dx, dy, dw, dh], dim=-1)
+
+ means = deltas.new_tensor(means).unsqueeze(0)
+ stds = deltas.new_tensor(stds).unsqueeze(0)
+ deltas = deltas.sub_(means).div_(stds)
+
+ return deltas
+
+
+@mmcv.jit(coderize=True)
+def legacy_delta2bbox(rois,
+ deltas,
+ means=(0., 0., 0., 0.),
+ stds=(1., 1., 1., 1.),
+ max_shape=None,
+ wh_ratio_clip=16 / 1000):
+ """Apply deltas to shift/scale base boxes in the MMDet V1.x manner.
+
+ Typically the rois are anchor or proposed bounding boxes and the deltas are
+ network outputs used to shift/scale those boxes.
+ This is the inverse function of `bbox2delta()`
+
+ Args:
+ rois (Tensor): Boxes to be transformed. Has shape (N, 4)
+ deltas (Tensor): Encoded offsets with respect to each roi.
+ Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when
+ rois is a grid of anchors. Offset encoding follows [1]_.
+ means (Sequence[float]): Denormalizing means for delta coordinates
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates
+ max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
+ wh_ratio_clip (float): Maximum aspect ratio for boxes.
+
+ Returns:
+ Tensor: Boxes with shape (N, 4), where columns represent
+ tl_x, tl_y, br_x, br_y.
+
+ References:
+ .. [1] https://arxiv.org/abs/1311.2524
+
+ Example:
+ >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 5., 5., 5., 5.]])
+ >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
+ >>> [ 1., 1., 1., 1.],
+ >>> [ 0., 0., 2., -1.],
+ >>> [ 0.7, -1.9, -0.5, 0.3]])
+ >>> legacy_delta2bbox(rois, deltas, max_shape=(32, 32))
+ tensor([[0.0000, 0.0000, 1.5000, 1.5000],
+ [0.0000, 0.0000, 5.2183, 5.2183],
+ [0.0000, 0.1321, 7.8891, 0.8679],
+ [5.3967, 2.4251, 6.0033, 3.7749]])
+ """
+ means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
+ stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
+ denorm_deltas = deltas * stds + means
+ dx = denorm_deltas[:, 0::4]
+ dy = denorm_deltas[:, 1::4]
+ dw = denorm_deltas[:, 2::4]
+ dh = denorm_deltas[:, 3::4]
+ max_ratio = np.abs(np.log(wh_ratio_clip))
+ dw = dw.clamp(min=-max_ratio, max=max_ratio)
+ dh = dh.clamp(min=-max_ratio, max=max_ratio)
+ # Compute center of each roi
+ px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
+ py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
+ # Compute width/height of each roi
+ pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw)
+ ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh)
+ # Use exp(network energy) to enlarge/shrink each roi
+ gw = pw * dw.exp()
+ gh = ph * dh.exp()
+ # Use network energy to shift the center of each roi
+ gx = px + pw * dx
+ gy = py + ph * dy
+ # Convert center-xy/width/height to top-left, bottom-right
+
+ # The true legacy box coder should +- 0.5 here.
+ # However, current implementation improves the performance when testing
+ # the models trained in MMDetection 1.X (~0.5 bbox AP, 0.2 mask AP)
+ x1 = gx - gw * 0.5
+ y1 = gy - gh * 0.5
+ x2 = gx + gw * 0.5
+ y2 = gy + gh * 0.5
+ if max_shape is not None:
+ x1 = x1.clamp(min=0, max=max_shape[1] - 1)
+ y1 = y1.clamp(min=0, max=max_shape[0] - 1)
+ x2 = x2.clamp(min=0, max=max_shape[1] - 1)
+ y2 = y2.clamp(min=0, max=max_shape[0] - 1)
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
+ return bboxes
diff --git a/mmdet/core/bbox/coder/pseudo_bbox_coder.py b/mmdet/core/bbox/coder/pseudo_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe71f369cf18dc06ce2a81c9d23f32d7e9d93449
--- /dev/null
+++ b/mmdet/core/bbox/coder/pseudo_bbox_coder.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class PseudoBBoxCoder(BaseBBoxCoder):
+ """Pseudo bounding box coder."""
+
+ def __init__(self, **kwargs):
+ super(BaseBBoxCoder, self).__init__(**kwargs)
+
+ def encode(self, bboxes, gt_bboxes):
+ """torch.Tensor: return the given ``bboxes``"""
+ return gt_bboxes
+
+ def decode(self, bboxes, pred_bboxes):
+ """torch.Tensor: return the given ``pred_bboxes``"""
+ return pred_bboxes
diff --git a/mmdet/core/bbox/coder/tblr_bbox_coder.py b/mmdet/core/bbox/coder/tblr_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb4206636f5b3704b465c5507d1f25492f11cf5c
--- /dev/null
+++ b/mmdet/core/bbox/coder/tblr_bbox_coder.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch
+
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class TBLRBBoxCoder(BaseBBoxCoder):
+ """TBLR BBox coder.
+
+ Following the practice in `FSAF `_,
+ this coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
+ right) and decode it back to the original.
+
+ Args:
+ normalizer (list | float): Normalization factor to be
+ divided with when coding the coordinates. If it is a list, it should
+ have length of 4 indicating normalization factor in tblr dims.
+ Otherwise it is a unified float factor for all dims. Default: 4.0
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+ """
+
+ def __init__(self, normalizer=4.0, clip_border=True):
+ super(BaseBBoxCoder, self).__init__()
+ self.normalizer = normalizer
+ self.clip_border = clip_border
+
+ def encode(self, bboxes, gt_bboxes):
+ """Get box regression transformation deltas that can be used to
+ transform the ``bboxes`` into the ``gt_bboxes`` in the (top, left,
+ bottom, right) order.
+
+ Args:
+ bboxes (torch.Tensor): source boxes, e.g., object proposals.
+ gt_bboxes (torch.Tensor): target of the transformation, e.g.,
+ ground truth boxes.
+
+ Returns:
+ torch.Tensor: Box transformation deltas
+ """
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ encoded_bboxes = bboxes2tblr(
+ bboxes, gt_bboxes, normalizer=self.normalizer)
+ return encoded_bboxes
+
+ def decode(self, bboxes, pred_bboxes, max_shape=None):
+ """Apply transformation `pred_bboxes` to `boxes`.
+
+ Args:
+ bboxes (torch.Tensor): Basic boxes.Shape (B, N, 4) or (N, 4)
+ pred_bboxes (torch.Tensor): Encoded boxes with shape
+ (B, N, 4) or (N, 4)
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+ decoded_bboxes = tblr2bboxes(
+ bboxes,
+ pred_bboxes,
+ normalizer=self.normalizer,
+ max_shape=max_shape,
+ clip_border=self.clip_border)
+
+ return decoded_bboxes
+
+
+@mmcv.jit(coderize=True)
+def bboxes2tblr(priors, gts, normalizer=4.0, normalize_by_wh=True):
+ """Encode ground truth boxes to tblr coordinate.
+
+ It first convert the gt coordinate to tblr format,
+ (top, bottom, left, right), relative to prior box centers.
+ The tblr coordinate may be normalized by the side length of prior bboxes
+ if `normalize_by_wh` is specified as True, and it is then normalized by
+ the `normalizer` factor.
+
+ Args:
+ priors (Tensor): Prior boxes in point form
+ Shape: (num_proposals,4).
+ gts (Tensor): Coords of ground truth for each prior in point-form
+ Shape: (num_proposals, 4).
+ normalizer (Sequence[float] | float): normalization parameter of
+ encoded boxes. If it is a list, it has to have length = 4.
+ Default: 4.0
+ normalize_by_wh (bool): Whether to normalize tblr coordinate by the
+ side length (wh) of prior bboxes.
+
+ Return:
+ encoded boxes (Tensor), Shape: (num_proposals, 4)
+ """
+
+ # dist b/t match center and prior's center
+ if not isinstance(normalizer, float):
+ normalizer = torch.tensor(normalizer, device=priors.device)
+ assert len(normalizer) == 4, 'Normalizer must have length = 4'
+ assert priors.size(0) == gts.size(0)
+ prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2
+ xmin, ymin, xmax, ymax = gts.split(1, dim=1)
+ top = prior_centers[:, 1].unsqueeze(1) - ymin
+ bottom = ymax - prior_centers[:, 1].unsqueeze(1)
+ left = prior_centers[:, 0].unsqueeze(1) - xmin
+ right = xmax - prior_centers[:, 0].unsqueeze(1)
+ loc = torch.cat((top, bottom, left, right), dim=1)
+ if normalize_by_wh:
+ # Normalize tblr by anchor width and height
+ wh = priors[:, 2:4] - priors[:, 0:2]
+ w, h = torch.split(wh, 1, dim=1)
+ loc[:, :2] /= h # tb is normalized by h
+ loc[:, 2:] /= w # lr is normalized by w
+ # Normalize tblr by the given normalization factor
+ return loc / normalizer
+
+
+@mmcv.jit(coderize=True)
+def tblr2bboxes(priors,
+ tblr,
+ normalizer=4.0,
+ normalize_by_wh=True,
+ max_shape=None,
+ clip_border=True):
+ """Decode tblr outputs to prediction boxes.
+
+ The process includes 3 steps: 1) De-normalize tblr coordinates by
+ multiplying it with `normalizer`; 2) De-normalize tblr coordinates by the
+ prior bbox width and height if `normalize_by_wh` is `True`; 3) Convert
+ tblr (top, bottom, left, right) pair relative to the center of priors back
+ to (xmin, ymin, xmax, ymax) coordinate.
+
+ Args:
+ priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
+ Shape: (N,4) or (B, N, 4).
+ tblr (Tensor): Coords of network output in tblr form
+ Shape: (N, 4) or (B, N, 4).
+ normalizer (Sequence[float] | float): Normalization parameter of
+ encoded boxes. By list, it represents the normalization factors at
+ tblr dims. By float, it is the unified normalization factor at all
+ dims. Default: 4.0
+ normalize_by_wh (bool): Whether the tblr coordinates have been
+ normalized by the side length (wh) of prior bboxes.
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If priors shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+
+ Return:
+ encoded boxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
+ """
+ if not isinstance(normalizer, float):
+ normalizer = torch.tensor(normalizer, device=priors.device)
+ assert len(normalizer) == 4, 'Normalizer must have length = 4'
+ assert priors.size(0) == tblr.size(0)
+ if priors.ndim == 3:
+ assert priors.size(1) == tblr.size(1)
+
+ loc_decode = tblr * normalizer
+ prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
+ if normalize_by_wh:
+ wh = priors[..., 2:4] - priors[..., 0:2]
+ w, h = torch.split(wh, 1, dim=-1)
+ # Inplace operation with slice would failed for exporting to ONNX
+ th = h * loc_decode[..., :2] # tb
+ tw = w * loc_decode[..., 2:] # lr
+ loc_decode = torch.cat([th, tw], dim=-1)
+ # Cannot be exported using onnx when loc_decode.split(1, dim=-1)
+ top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1)
+ xmin = prior_centers[..., 0].unsqueeze(-1) - left
+ xmax = prior_centers[..., 0].unsqueeze(-1) + right
+ ymin = prior_centers[..., 1].unsqueeze(-1) - top
+ ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
+
+ bboxes = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
+
+ if clip_border and max_shape is not None:
+ # clip bboxes with dynamic `min` and `max` for onnx
+ if torch.onnx.is_in_onnx_export():
+ from mmdet.core.export import dynamic_clip_for_onnx
+ xmin, ymin, xmax, ymax = dynamic_clip_for_onnx(
+ xmin, ymin, xmax, ymax, max_shape)
+ bboxes = torch.cat([xmin, ymin, xmax, ymax], dim=-1)
+ return bboxes
+ if not isinstance(max_shape, torch.Tensor):
+ max_shape = priors.new_tensor(max_shape)
+ max_shape = max_shape[..., :2].type_as(priors)
+ if max_shape.ndim == 2:
+ assert bboxes.ndim == 3
+ assert max_shape.size(0) == bboxes.size(0)
+
+ min_xy = priors.new_tensor(0)
+ max_xy = torch.cat([max_shape, max_shape],
+ dim=-1).flip(-1).unsqueeze(-2)
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+ return bboxes
diff --git a/mmdet/core/bbox/coder/yolo_bbox_coder.py b/mmdet/core/bbox/coder/yolo_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2852eca7541769cc2dff872665bc1d54a5b81b5a
--- /dev/null
+++ b/mmdet/core/bbox/coder/yolo_bbox_coder.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch
+
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class YOLOBBoxCoder(BaseBBoxCoder):
+ """YOLO BBox coder.
+
+ Following `YOLO `_, this coder divide
+ image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh).
+ cx, cy in [0., 1.], denotes relative center position w.r.t the center of
+ bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`.
+
+ Args:
+ eps (float): Min value of cx, cy when encoding.
+ """
+
+ def __init__(self, eps=1e-6):
+ super(BaseBBoxCoder, self).__init__()
+ self.eps = eps
+
+ @mmcv.jit(coderize=True)
+ def encode(self, bboxes, gt_bboxes, stride):
+ """Get box regression transformation deltas that can be used to
+ transform the ``bboxes`` into the ``gt_bboxes``.
+
+ Args:
+ bboxes (torch.Tensor): Source boxes, e.g., anchors.
+ gt_bboxes (torch.Tensor): Target of the transformation, e.g.,
+ ground-truth boxes.
+ stride (torch.Tensor | int): Stride of bboxes.
+
+ Returns:
+ torch.Tensor: Box transformation deltas
+ """
+
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ x_center_gt = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) * 0.5
+ y_center_gt = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) * 0.5
+ w_gt = gt_bboxes[..., 2] - gt_bboxes[..., 0]
+ h_gt = gt_bboxes[..., 3] - gt_bboxes[..., 1]
+ x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5
+ y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5
+ w = bboxes[..., 2] - bboxes[..., 0]
+ h = bboxes[..., 3] - bboxes[..., 1]
+ w_target = torch.log((w_gt / w).clamp(min=self.eps))
+ h_target = torch.log((h_gt / h).clamp(min=self.eps))
+ x_center_target = ((x_center_gt - x_center) / stride + 0.5).clamp(
+ self.eps, 1 - self.eps)
+ y_center_target = ((y_center_gt - y_center) / stride + 0.5).clamp(
+ self.eps, 1 - self.eps)
+ encoded_bboxes = torch.stack(
+ [x_center_target, y_center_target, w_target, h_target], dim=-1)
+ return encoded_bboxes
+
+ @mmcv.jit(coderize=True)
+ def decode(self, bboxes, pred_bboxes, stride):
+ """Apply transformation `pred_bboxes` to `boxes`.
+
+ Args:
+ boxes (torch.Tensor): Basic boxes, e.g. anchors.
+ pred_bboxes (torch.Tensor): Encoded boxes with shape
+ stride (torch.Tensor | int): Strides of bboxes.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+ assert pred_bboxes.size(-1) == bboxes.size(-1) == 4
+ xy_centers = (bboxes[..., :2] + bboxes[..., 2:]) * 0.5 + (
+ pred_bboxes[..., :2] - 0.5) * stride
+ whs = (bboxes[..., 2:] -
+ bboxes[..., :2]) * 0.5 * pred_bboxes[..., 2:].exp()
+ decoded_bboxes = torch.stack(
+ (xy_centers[..., 0] - whs[..., 0], xy_centers[..., 1] -
+ whs[..., 1], xy_centers[..., 0] + whs[..., 0],
+ xy_centers[..., 1] + whs[..., 1]),
+ dim=-1)
+ return decoded_bboxes
diff --git a/mmdet/core/bbox/demodata.py b/mmdet/core/bbox/demodata.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb24b34b640d3f333c1ec568f96ec795b560ab86
--- /dev/null
+++ b/mmdet/core/bbox/demodata.py
@@ -0,0 +1,42 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from mmdet.utils.util_random import ensure_rng
+
+
+def random_boxes(num=1, scale=1, rng=None):
+ """Simple version of ``kwimage.Boxes.random``
+
+ Returns:
+ Tensor: shape (n, 4) in x1, y1, x2, y2 format.
+
+ References:
+ https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
+
+ Example:
+ >>> num = 3
+ >>> scale = 512
+ >>> rng = 0
+ >>> boxes = random_boxes(num, scale, rng)
+ >>> print(boxes)
+ tensor([[280.9925, 278.9802, 308.6148, 366.1769],
+ [216.9113, 330.6978, 224.0446, 456.5878],
+ [405.3632, 196.3221, 493.3953, 270.7942]])
+ """
+ rng = ensure_rng(rng)
+
+ tlbr = rng.rand(num, 4).astype(np.float32)
+
+ tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
+ tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
+ br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
+ br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
+
+ tlbr[:, 0] = tl_x * scale
+ tlbr[:, 1] = tl_y * scale
+ tlbr[:, 2] = br_x * scale
+ tlbr[:, 3] = br_y * scale
+
+ boxes = torch.from_numpy(tlbr)
+ return boxes
diff --git a/mmdet/core/bbox/iou_calculators/__init__.py b/mmdet/core/bbox/iou_calculators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..04ba925b448d8e4c99ac1434a7d7b909ace1d65f
--- /dev/null
+++ b/mmdet/core/bbox/iou_calculators/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import build_iou_calculator
+from .iou2d_calculator import BboxOverlaps2D, bbox_overlaps
+
+__all__ = ['build_iou_calculator', 'BboxOverlaps2D', 'bbox_overlaps']
diff --git a/mmdet/core/bbox/iou_calculators/builder.py b/mmdet/core/bbox/iou_calculators/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..378ee269f3616d40e6687ad1a2d27ad5234e1784
--- /dev/null
+++ b/mmdet/core/bbox/iou_calculators/builder.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import Registry, build_from_cfg
+
+IOU_CALCULATORS = Registry('IoU calculator')
+
+
+def build_iou_calculator(cfg, default_args=None):
+ """Builder of IoU calculator."""
+ return build_from_cfg(cfg, IOU_CALCULATORS, default_args)
diff --git a/mmdet/core/bbox/iou_calculators/iou2d_calculator.py b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py
new file mode 100644
index 0000000000000000000000000000000000000000..b71a5557ea129aaf72e39305524236e4419c3327
--- /dev/null
+++ b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py
@@ -0,0 +1,260 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from .builder import IOU_CALCULATORS
+
+
+def cast_tensor_type(x, scale=1., dtype=None):
+ if dtype == 'fp16':
+ # scale is for preventing overflows
+ x = (x / scale).half()
+ return x
+
+
+def fp16_clamp(x, min=None, max=None):
+ if not x.is_cuda and x.dtype == torch.float16:
+ # clamp for cpu float16, tensor fp16 has no clamp implementation
+ return x.float().clamp(min, max).half()
+
+ return x.clamp(min, max)
+
+
+@IOU_CALCULATORS.register_module()
+class BboxOverlaps2D:
+ """2D Overlaps (e.g. IoUs, GIoUs) Calculator."""
+
+ def __init__(self, scale=1., dtype=None):
+ self.scale = scale
+ self.dtype = dtype
+
+ def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False):
+ """Calculate IoU between 2D bboxes.
+
+ Args:
+ bboxes1 (Tensor): bboxes have shape (m, 4) in
+ format, or shape (m, 5) in format.
+ bboxes2 (Tensor): bboxes have shape (n, 4) in
+ format, shape (n, 5) in format, or be
+ empty.
+ mode (str): "iou" (intersection over union), "iof" (intersection
+ over foreground), or "giou" (generalized intersection over
+ union).
+ is_aligned (bool, optional): If True, then m and n must be equal.
+ Default False.
+
+ Returns:
+ Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
+ """
+ assert bboxes1.size(-1) in [0, 4, 5]
+ assert bboxes2.size(-1) in [0, 4, 5]
+ if bboxes2.size(-1) == 5:
+ bboxes2 = bboxes2[..., :4]
+ if bboxes1.size(-1) == 5:
+ bboxes1 = bboxes1[..., :4]
+
+ if self.dtype == 'fp16':
+ # change tensor type to save cpu and cuda memory and keep speed
+ bboxes1 = cast_tensor_type(bboxes1, self.scale, self.dtype)
+ bboxes2 = cast_tensor_type(bboxes2, self.scale, self.dtype)
+ overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
+ if not overlaps.is_cuda and overlaps.dtype == torch.float16:
+ # resume cpu float32
+ overlaps = overlaps.float()
+ return overlaps
+
+ return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
+
+ def __repr__(self):
+ """str: a string describing the module"""
+ repr_str = self.__class__.__name__ + f'(' \
+ f'scale={self.scale}, dtype={self.dtype})'
+ return repr_str
+
+
+def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
+ """Calculate overlap between two set of bboxes.
+
+ FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889
+ Note:
+ Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou',
+ there are some new generated variable when calculating IOU
+ using bbox_overlaps function:
+
+ 1) is_aligned is False
+ area1: M x 1
+ area2: N x 1
+ lt: M x N x 2
+ rb: M x N x 2
+ wh: M x N x 2
+ overlap: M x N x 1
+ union: M x N x 1
+ ious: M x N x 1
+
+ Total memory:
+ S = (9 x N x M + N + M) * 4 Byte,
+
+ When using FP16, we can reduce:
+ R = (9 x N x M + N + M) * 4 / 2 Byte
+ R large than (N + M) * 4 * 2 is always true when N and M >= 1.
+ Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2,
+ N + 1 < 3 * N, when N or M is 1.
+
+ Given M = 40 (ground truth), N = 400000 (three anchor boxes
+ in per grid, FPN, R-CNNs),
+ R = 275 MB (one times)
+
+ A special case (dense detection), M = 512 (ground truth),
+ R = 3516 MB = 3.43 GB
+
+ When the batch size is B, reduce:
+ B x R
+
+ Therefore, CUDA memory runs out frequently.
+
+ Experiments on GeForce RTX 2080Ti (11019 MiB):
+
+ | dtype | M | N | Use | Real | Ideal |
+ |:----:|:----:|:----:|:----:|:----:|:----:|
+ | FP32 | 512 | 400000 | 8020 MiB | -- | -- |
+ | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB |
+ | FP32 | 40 | 400000 | 1540 MiB | -- | -- |
+ | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB |
+
+ 2) is_aligned is True
+ area1: N x 1
+ area2: N x 1
+ lt: N x 2
+ rb: N x 2
+ wh: N x 2
+ overlap: N x 1
+ union: N x 1
+ ious: N x 1
+
+ Total memory:
+ S = 11 x N * 4 Byte
+
+ When using FP16, we can reduce:
+ R = 11 x N * 4 / 2 Byte
+
+ So do the 'giou' (large than 'iou').
+
+ Time-wise, FP16 is generally faster than FP32.
+
+ When gpu_assign_thr is not -1, it takes more time on cpu
+ but not reduce memory.
+ There, we can reduce half the memory and keep the speed.
+
+ If ``is_aligned`` is ``False``, then calculate the overlaps between each
+ bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
+ pair of bboxes1 and bboxes2.
+
+ Args:
+ bboxes1 (Tensor): shape (B, m, 4) in format or empty.
+ bboxes2 (Tensor): shape (B, n, 4) in format or empty.
+ B indicates the batch dim, in shape (B1, B2, ..., Bn).
+ If ``is_aligned`` is ``True``, then m and n must be equal.
+ mode (str): "iou" (intersection over union), "iof" (intersection over
+ foreground) or "giou" (generalized intersection over union).
+ Default "iou".
+ is_aligned (bool, optional): If True, then m and n must be equal.
+ Default False.
+ eps (float, optional): A value added to the denominator for numerical
+ stability. Default 1e-6.
+
+ Returns:
+ Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
+
+ Example:
+ >>> bboxes1 = torch.FloatTensor([
+ >>> [0, 0, 10, 10],
+ >>> [10, 10, 20, 20],
+ >>> [32, 32, 38, 42],
+ >>> ])
+ >>> bboxes2 = torch.FloatTensor([
+ >>> [0, 0, 10, 20],
+ >>> [0, 10, 10, 19],
+ >>> [10, 10, 20, 20],
+ >>> ])
+ >>> overlaps = bbox_overlaps(bboxes1, bboxes2)
+ >>> assert overlaps.shape == (3, 3)
+ >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True)
+ >>> assert overlaps.shape == (3, )
+
+ Example:
+ >>> empty = torch.empty(0, 4)
+ >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]])
+ >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
+ >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
+ >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
+ """
+
+ assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}'
+ # Either the boxes are empty or the length of boxes' last dimension is 4
+ assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
+ assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
+
+ # Batch dim must be the same
+ # Batch dim: (B1, B2, ... Bn)
+ assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
+ batch_shape = bboxes1.shape[:-2]
+
+ rows = bboxes1.size(-2)
+ cols = bboxes2.size(-2)
+ if is_aligned:
+ assert rows == cols
+
+ if rows * cols == 0:
+ if is_aligned:
+ return bboxes1.new(batch_shape + (rows, ))
+ else:
+ return bboxes1.new(batch_shape + (rows, cols))
+
+ area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
+ bboxes1[..., 3] - bboxes1[..., 1])
+ area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
+ bboxes2[..., 3] - bboxes2[..., 1])
+
+ if is_aligned:
+ lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2]
+ rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2]
+
+ wh = fp16_clamp(rb - lt, min=0)
+ overlap = wh[..., 0] * wh[..., 1]
+
+ if mode in ['iou', 'giou']:
+ union = area1 + area2 - overlap
+ else:
+ union = area1
+ if mode == 'giou':
+ enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2])
+ enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:])
+ else:
+ lt = torch.max(bboxes1[..., :, None, :2],
+ bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
+ rb = torch.min(bboxes1[..., :, None, 2:],
+ bboxes2[..., None, :, 2:]) # [B, rows, cols, 2]
+
+ wh = fp16_clamp(rb - lt, min=0)
+ overlap = wh[..., 0] * wh[..., 1]
+
+ if mode in ['iou', 'giou']:
+ union = area1[..., None] + area2[..., None, :] - overlap
+ else:
+ union = area1[..., None]
+ if mode == 'giou':
+ enclosed_lt = torch.min(bboxes1[..., :, None, :2],
+ bboxes2[..., None, :, :2])
+ enclosed_rb = torch.max(bboxes1[..., :, None, 2:],
+ bboxes2[..., None, :, 2:])
+
+ eps = union.new_tensor([eps])
+ union = torch.max(union, eps)
+ ious = overlap / union
+ if mode in ['iou', 'iof']:
+ return ious
+ # calculate gious
+ enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0)
+ enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
+ enclose_area = torch.max(enclose_area, eps)
+ gious = ious - (enclose_area - union) / enclose_area
+ return gious
diff --git a/mmdet/core/bbox/match_costs/__init__.py b/mmdet/core/bbox/match_costs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b636795082cf7b731e3125f7ae36b51e4bfb5a3
--- /dev/null
+++ b/mmdet/core/bbox/match_costs/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import build_match_cost
+from .match_cost import (BBoxL1Cost, ClassificationCost, CrossEntropyLossCost,
+ DiceCost, FocalLossCost, IoUCost)
+
+__all__ = [
+ 'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost',
+ 'FocalLossCost', 'DiceCost', 'CrossEntropyLossCost'
+]
diff --git a/mmdet/core/bbox/match_costs/builder.py b/mmdet/core/bbox/match_costs/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea086adff23c5adbc35d448d5a93daf1a04bdc53
--- /dev/null
+++ b/mmdet/core/bbox/match_costs/builder.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import Registry, build_from_cfg
+
+MATCH_COST = Registry('Match Cost')
+
+
+def build_match_cost(cfg, default_args=None):
+ """Builder of IoU calculator."""
+ return build_from_cfg(cfg, MATCH_COST, default_args)
diff --git a/mmdet/core/bbox/match_costs/match_cost.py b/mmdet/core/bbox/match_costs/match_cost.py
new file mode 100644
index 0000000000000000000000000000000000000000..4342b024588663b602d7dc1b82a1e708cc8aea91
--- /dev/null
+++ b/mmdet/core/bbox/match_costs/match_cost.py
@@ -0,0 +1,359 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn.functional as F
+
+from mmdet.core.bbox.iou_calculators import bbox_overlaps
+from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
+from .builder import MATCH_COST
+
+
+@MATCH_COST.register_module()
+class BBoxL1Cost:
+ """BBoxL1Cost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+ box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost
+ >>> import torch
+ >>> self = BBoxL1Cost()
+ >>> bbox_pred = torch.rand(1, 4)
+ >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(bbox_pred, gt_bboxes, factor)
+ tensor([[1.6172, 1.6422]])
+ """
+
+ def __init__(self, weight=1., box_format='xyxy'):
+ self.weight = weight
+ assert box_format in ['xyxy', 'xywh']
+ self.box_format = box_format
+
+ def __call__(self, bbox_pred, gt_bboxes):
+ """
+ Args:
+ bbox_pred (Tensor): Predicted boxes with normalized coordinates
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
+ (num_query, 4).
+ gt_bboxes (Tensor): Ground truth boxes with normalized
+ coordinates (x1, y1, x2, y2). Shape (num_gt, 4).
+
+ Returns:
+ torch.Tensor: bbox_cost value with weight
+ """
+ if self.box_format == 'xywh':
+ gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
+ elif self.box_format == 'xyxy':
+ bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
+ bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
+ return bbox_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class FocalLossCost:
+ """FocalLossCost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+ alpha (int | float, optional): focal_loss alpha
+ gamma (int | float, optional): focal_loss gamma
+ eps (float, optional): default 1e-12
+ binary_input (bool, optional): Whether the input is binary,
+ default False.
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost
+ >>> import torch
+ >>> self = FocalLossCost()
+ >>> cls_pred = torch.rand(4, 3)
+ >>> gt_labels = torch.tensor([0, 1, 2])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(cls_pred, gt_labels)
+ tensor([[-0.3236, -0.3364, -0.2699],
+ [-0.3439, -0.3209, -0.4807],
+ [-0.4099, -0.3795, -0.2929],
+ [-0.1950, -0.1207, -0.2626]])
+ """
+
+ def __init__(self,
+ weight=1.,
+ alpha=0.25,
+ gamma=2,
+ eps=1e-12,
+ binary_input=False):
+ self.weight = weight
+ self.alpha = alpha
+ self.gamma = gamma
+ self.eps = eps
+ self.binary_input = binary_input
+
+ def _focal_loss_cost(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits, shape
+ (num_query, num_class).
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+
+ Returns:
+ torch.Tensor: cls_cost value with weight
+ """
+ cls_pred = cls_pred.sigmoid()
+ neg_cost = -(1 - cls_pred + self.eps).log() * (
+ 1 - self.alpha) * cls_pred.pow(self.gamma)
+ pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
+ 1 - cls_pred).pow(self.gamma)
+
+ cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
+ return cls_cost * self.weight
+
+ def _mask_focal_loss_cost(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classfication logits
+ in shape (num_query, d1, ..., dn), dtype=torch.float32.
+ gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn),
+ dtype=torch.long. Labels should be binary.
+
+ Returns:
+ Tensor: Focal cost matrix with weight in shape\
+ (num_query, num_gt).
+ """
+ cls_pred = cls_pred.flatten(1)
+ gt_labels = gt_labels.flatten(1).float()
+ n = cls_pred.shape[1]
+ cls_pred = cls_pred.sigmoid()
+ neg_cost = -(1 - cls_pred + self.eps).log() * (
+ 1 - self.alpha) * cls_pred.pow(self.gamma)
+ pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
+ 1 - cls_pred).pow(self.gamma)
+
+ cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \
+ torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels))
+ return cls_cost / n * self.weight
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classfication logits.
+ gt_labels (Tensor)): Labels.
+
+ Returns:
+ Tensor: Focal cost matrix with weight in shape\
+ (num_query, num_gt).
+ """
+ if self.binary_input:
+ return self._mask_focal_loss_cost(cls_pred, gt_labels)
+ else:
+ return self._focal_loss_cost(cls_pred, gt_labels)
+
+
+@MATCH_COST.register_module()
+class ClassificationCost:
+ """ClsSoftmaxCost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import \
+ ... ClassificationCost
+ >>> import torch
+ >>> self = ClassificationCost()
+ >>> cls_pred = torch.rand(4, 3)
+ >>> gt_labels = torch.tensor([0, 1, 2])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(cls_pred, gt_labels)
+ tensor([[-0.3430, -0.3525, -0.3045],
+ [-0.3077, -0.2931, -0.3992],
+ [-0.3664, -0.3455, -0.2881],
+ [-0.3343, -0.2701, -0.3956]])
+ """
+
+ def __init__(self, weight=1.):
+ self.weight = weight
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits, shape
+ (num_query, num_class).
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+
+ Returns:
+ torch.Tensor: cls_cost value with weight
+ """
+ # Following the official DETR repo, contrary to the loss that
+ # NLL is used, we approximate it in 1 - cls_score[gt_label].
+ # The 1 is a constant that doesn't change the matching,
+ # so it can be omitted.
+ cls_score = cls_pred.softmax(-1)
+ cls_cost = -cls_score[:, gt_labels]
+ return cls_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class IoUCost:
+ """IoUCost.
+
+ Args:
+ iou_mode (str, optional): iou mode such as 'iou' | 'giou'
+ weight (int | float, optional): loss weight
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import IoUCost
+ >>> import torch
+ >>> self = IoUCost()
+ >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]])
+ >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
+ >>> self(bboxes, gt_bboxes)
+ tensor([[-0.1250, 0.1667],
+ [ 0.1667, -0.5000]])
+ """
+
+ def __init__(self, iou_mode='giou', weight=1.):
+ self.weight = weight
+ self.iou_mode = iou_mode
+
+ def __call__(self, bboxes, gt_bboxes):
+ """
+ Args:
+ bboxes (Tensor): Predicted boxes with unnormalized coordinates
+ (x1, y1, x2, y2). Shape (num_query, 4).
+ gt_bboxes (Tensor): Ground truth boxes with unnormalized
+ coordinates (x1, y1, x2, y2). Shape (num_gt, 4).
+
+ Returns:
+ torch.Tensor: iou_cost value with weight
+ """
+ # overlaps: [num_bboxes, num_gt]
+ overlaps = bbox_overlaps(
+ bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False)
+ # The 1 is a constant that doesn't change the matching, so omitted.
+ iou_cost = -overlaps
+ return iou_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class DiceCost:
+ """Cost of mask assignments based on dice losses.
+
+ Args:
+ weight (int | float, optional): loss_weight. Defaults to 1.
+ pred_act (bool, optional): Whether to apply sigmoid to mask_pred.
+ Defaults to False.
+ eps (float, optional): default 1e-12.
+ naive_dice (bool, optional): If True, use the naive dice loss
+ in which the power of the number in the denominator is
+ the first power. If Flase, use the second power that
+ is adopted by K-Net and SOLO.
+ Defaults to True.
+ """
+
+ def __init__(self, weight=1., pred_act=False, eps=1e-3, naive_dice=True):
+ self.weight = weight
+ self.pred_act = pred_act
+ self.eps = eps
+ self.naive_dice = naive_dice
+
+ def binary_mask_dice_loss(self, mask_preds, gt_masks):
+ """
+ Args:
+ mask_preds (Tensor): Mask prediction in shape (num_query, *).
+ gt_masks (Tensor): Ground truth in shape (num_gt, *)
+ store 0 or 1, 0 for negative class and 1 for
+ positive class.
+
+ Returns:
+ Tensor: Dice cost matrix in shape (num_query, num_gt).
+ """
+ mask_preds = mask_preds.flatten(1)
+ gt_masks = gt_masks.flatten(1).float()
+ numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks)
+ if self.naive_dice:
+ denominator = mask_preds.sum(-1)[:, None] + \
+ gt_masks.sum(-1)[None, :]
+ else:
+ denominator = mask_preds.pow(2).sum(1)[:, None] + \
+ gt_masks.pow(2).sum(1)[None, :]
+ loss = 1 - (numerator + self.eps) / (denominator + self.eps)
+ return loss
+
+ def __call__(self, mask_preds, gt_masks):
+ """
+ Args:
+ mask_preds (Tensor): Mask prediction logits in shape (num_query, *)
+ gt_masks (Tensor): Ground truth in shape (num_gt, *)
+
+ Returns:
+ Tensor: Dice cost matrix with weight in shape (num_query, num_gt).
+ """
+ if self.pred_act:
+ mask_preds = mask_preds.sigmoid()
+ dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
+ return dice_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class CrossEntropyLossCost:
+ """CrossEntropyLossCost.
+
+ Args:
+ weight (int | float, optional): loss weight. Defaults to 1.
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to True.
+ Examples:
+ >>> from mmdet.core.bbox.match_costs import CrossEntropyLossCost
+ >>> import torch
+ >>> bce = CrossEntropyLossCost(use_sigmoid=True)
+ >>> cls_pred = torch.tensor([[7.6, 1.2], [-1.3, 10]])
+ >>> gt_labels = torch.tensor([[1, 1], [1, 0]])
+ >>> print(bce(cls_pred, gt_labels))
+ """
+
+ def __init__(self, weight=1., use_sigmoid=True):
+ assert use_sigmoid, 'use_sigmoid = False is not supported yet.'
+ self.weight = weight
+ self.use_sigmoid = use_sigmoid
+
+ def _binary_cross_entropy(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
+ (num_query, *).
+ gt_labels (Tensor): The learning label of prediction with
+ shape (num_gt, *).
+
+ Returns:
+ Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
+ """
+ cls_pred = cls_pred.flatten(1).float()
+ gt_labels = gt_labels.flatten(1).float()
+ n = cls_pred.shape[1]
+ pos = F.binary_cross_entropy_with_logits(
+ cls_pred, torch.ones_like(cls_pred), reduction='none')
+ neg = F.binary_cross_entropy_with_logits(
+ cls_pred, torch.zeros_like(cls_pred), reduction='none')
+ cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \
+ torch.einsum('nc,mc->nm', neg, 1 - gt_labels)
+ cls_cost = cls_cost / n
+
+ return cls_cost
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits.
+ gt_labels (Tensor): Labels.
+
+ Returns:
+ Tensor: Cross entropy cost matrix with weight in
+ shape (num_query, num_gt).
+ """
+ if self.use_sigmoid:
+ cls_cost = self._binary_cross_entropy(cls_pred, gt_labels)
+ else:
+ raise NotImplementedError
+
+ return cls_cost * self.weight
diff --git a/mmdet/core/bbox/samplers/__init__.py b/mmdet/core/bbox/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f58505b59dca744e489328a39fdabb02a893fb51
--- /dev/null
+++ b/mmdet/core/bbox/samplers/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_sampler import BaseSampler
+from .combined_sampler import CombinedSampler
+from .instance_balanced_pos_sampler import InstanceBalancedPosSampler
+from .iou_balanced_neg_sampler import IoUBalancedNegSampler
+from .mask_pseudo_sampler import MaskPseudoSampler
+from .mask_sampling_result import MaskSamplingResult
+from .ohem_sampler import OHEMSampler
+from .pseudo_sampler import PseudoSampler
+from .random_sampler import RandomSampler
+from .sampling_result import SamplingResult
+from .score_hlr_sampler import ScoreHLRSampler
+
+__all__ = [
+ 'BaseSampler', 'PseudoSampler', 'RandomSampler',
+ 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
+ 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'MaskPseudoSampler',
+ 'MaskSamplingResult'
+]
diff --git a/mmdet/core/bbox/samplers/base_sampler.py b/mmdet/core/bbox/samplers/base_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd15c7c643bdf52a39fd2f35e8d26a64de813b4b
--- /dev/null
+++ b/mmdet/core/bbox/samplers/base_sampler.py
@@ -0,0 +1,102 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+import torch
+
+from .sampling_result import SamplingResult
+
+
+class BaseSampler(metaclass=ABCMeta):
+ """Base class of samplers."""
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ **kwargs):
+ self.num = num
+ self.pos_fraction = pos_fraction
+ self.neg_pos_ub = neg_pos_ub
+ self.add_gt_as_proposals = add_gt_as_proposals
+ self.pos_sampler = self
+ self.neg_sampler = self
+
+ @abstractmethod
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Sample positive samples."""
+ pass
+
+ @abstractmethod
+ def _sample_neg(self, assign_result, num_expected, **kwargs):
+ """Sample negative samples."""
+ pass
+
+ def sample(self,
+ assign_result,
+ bboxes,
+ gt_bboxes,
+ gt_labels=None,
+ **kwargs):
+ """Sample positive and negative bboxes.
+
+ This is a simple implementation of bbox sampling given candidates,
+ assigning results and ground truth bboxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Bbox assigning results.
+ bboxes (Tensor): Boxes to be sampled from.
+ gt_bboxes (Tensor): Ground truth bboxes.
+ gt_labels (Tensor, optional): Class labels of ground truth bboxes.
+
+ Returns:
+ :obj:`SamplingResult`: Sampling result.
+
+ Example:
+ >>> from mmdet.core.bbox import RandomSampler
+ >>> from mmdet.core.bbox import AssignResult
+ >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
+ >>> rng = ensure_rng(None)
+ >>> assign_result = AssignResult.random(rng=rng)
+ >>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
+ >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
+ >>> gt_labels = None
+ >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
+ >>> add_gt_as_proposals=False)
+ >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+ """
+ if len(bboxes.shape) < 2:
+ bboxes = bboxes[None, :]
+
+ bboxes = bboxes[:, :4]
+
+ gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
+ if self.add_gt_as_proposals and len(gt_bboxes) > 0:
+ if gt_labels is None:
+ raise ValueError(
+ 'gt_labels must be given when add_gt_as_proposals is True')
+ bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
+ assign_result.add_gt_(gt_labels)
+ gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
+ gt_flags = torch.cat([gt_ones, gt_flags])
+
+ num_expected_pos = int(self.num * self.pos_fraction)
+ pos_inds = self.pos_sampler._sample_pos(
+ assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
+ # We found that sampled indices have duplicated items occasionally.
+ # (may be a bug of PyTorch)
+ pos_inds = pos_inds.unique()
+ num_sampled_pos = pos_inds.numel()
+ num_expected_neg = self.num - num_sampled_pos
+ if self.neg_pos_ub >= 0:
+ _pos = max(1, num_sampled_pos)
+ neg_upper_bound = int(self.neg_pos_ub * _pos)
+ if num_expected_neg > neg_upper_bound:
+ num_expected_neg = neg_upper_bound
+ neg_inds = self.neg_sampler._sample_neg(
+ assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
+ neg_inds = neg_inds.unique()
+
+ sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+ assign_result, gt_flags)
+ return sampling_result
diff --git a/mmdet/core/bbox/samplers/combined_sampler.py b/mmdet/core/bbox/samplers/combined_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f6d86ff26e1fbcecb31a671bf18a40e362feb57
--- /dev/null
+++ b/mmdet/core/bbox/samplers/combined_sampler.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import BBOX_SAMPLERS, build_sampler
+from .base_sampler import BaseSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class CombinedSampler(BaseSampler):
+ """A sampler that combines positive sampler and negative sampler."""
+
+ def __init__(self, pos_sampler, neg_sampler, **kwargs):
+ super(CombinedSampler, self).__init__(**kwargs)
+ self.pos_sampler = build_sampler(pos_sampler, **kwargs)
+ self.neg_sampler = build_sampler(neg_sampler, **kwargs)
+
+ def _sample_pos(self, **kwargs):
+ """Sample positive samples."""
+ raise NotImplementedError
+
+ def _sample_neg(self, **kwargs):
+ """Sample negative samples."""
+ raise NotImplementedError
diff --git a/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e0d9cc0e0a2dcd687d23c2f08c94fe4bf127d3a
--- /dev/null
+++ b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from .random_sampler import RandomSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class InstanceBalancedPosSampler(RandomSampler):
+ """Instance balanced sampler that samples equal number of positive samples
+ for each instance."""
+
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Sample positive boxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): The assigned results of boxes.
+ num_expected (int): The number of expected positive samples
+
+ Returns:
+ Tensor or ndarray: sampled indices.
+ """
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
+ if pos_inds.numel() != 0:
+ pos_inds = pos_inds.squeeze(1)
+ if pos_inds.numel() <= num_expected:
+ return pos_inds
+ else:
+ unique_gt_inds = assign_result.gt_inds[pos_inds].unique()
+ num_gts = len(unique_gt_inds)
+ num_per_gt = int(round(num_expected / float(num_gts)) + 1)
+ sampled_inds = []
+ for i in unique_gt_inds:
+ inds = torch.nonzero(
+ assign_result.gt_inds == i.item(), as_tuple=False)
+ if inds.numel() != 0:
+ inds = inds.squeeze(1)
+ else:
+ continue
+ if len(inds) > num_per_gt:
+ inds = self.random_choice(inds, num_per_gt)
+ sampled_inds.append(inds)
+ sampled_inds = torch.cat(sampled_inds)
+ if len(sampled_inds) < num_expected:
+ num_extra = num_expected - len(sampled_inds)
+ extra_inds = np.array(
+ list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
+ if len(extra_inds) > num_extra:
+ extra_inds = self.random_choice(extra_inds, num_extra)
+ extra_inds = torch.from_numpy(extra_inds).to(
+ assign_result.gt_inds.device).long()
+ sampled_inds = torch.cat([sampled_inds, extra_inds])
+ elif len(sampled_inds) > num_expected:
+ sampled_inds = self.random_choice(sampled_inds, num_expected)
+ return sampled_inds
diff --git a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..56e2874a47566b740899b0cdc3f311c02f83ad50
--- /dev/null
+++ b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
@@ -0,0 +1,158 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from .random_sampler import RandomSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class IoUBalancedNegSampler(RandomSampler):
+ """IoU Balanced Sampling.
+
+ arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
+
+ Sampling proposals according to their IoU. `floor_fraction` of needed RoIs
+ are sampled from proposals whose IoU are lower than `floor_thr` randomly.
+ The others are sampled from proposals whose IoU are higher than
+ `floor_thr`. These proposals are sampled from some bins evenly, which are
+ split by `num_bins` via IoU evenly.
+
+ Args:
+ num (int): number of proposals.
+ pos_fraction (float): fraction of positive proposals.
+ floor_thr (float): threshold (minimum) IoU for IoU balanced sampling,
+ set to -1 if all using IoU balanced sampling.
+ floor_fraction (float): sampling fraction of proposals under floor_thr.
+ num_bins (int): number of bins in IoU balanced sampling.
+ """
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ floor_thr=-1,
+ floor_fraction=0,
+ num_bins=3,
+ **kwargs):
+ super(IoUBalancedNegSampler, self).__init__(num, pos_fraction,
+ **kwargs)
+ assert floor_thr >= 0 or floor_thr == -1
+ assert 0 <= floor_fraction <= 1
+ assert num_bins >= 1
+
+ self.floor_thr = floor_thr
+ self.floor_fraction = floor_fraction
+ self.num_bins = num_bins
+
+ def sample_via_interval(self, max_overlaps, full_set, num_expected):
+ """Sample according to the iou interval.
+
+ Args:
+ max_overlaps (torch.Tensor): IoU between bounding boxes and ground
+ truth boxes.
+ full_set (set(int)): A full set of indices of boxes。
+ num_expected (int): Number of expected samples。
+
+ Returns:
+ np.ndarray: Indices of samples
+ """
+ max_iou = max_overlaps.max()
+ iou_interval = (max_iou - self.floor_thr) / self.num_bins
+ per_num_expected = int(num_expected / self.num_bins)
+
+ sampled_inds = []
+ for i in range(self.num_bins):
+ start_iou = self.floor_thr + i * iou_interval
+ end_iou = self.floor_thr + (i + 1) * iou_interval
+ tmp_set = set(
+ np.where(
+ np.logical_and(max_overlaps >= start_iou,
+ max_overlaps < end_iou))[0])
+ tmp_inds = list(tmp_set & full_set)
+ if len(tmp_inds) > per_num_expected:
+ tmp_sampled_set = self.random_choice(tmp_inds,
+ per_num_expected)
+ else:
+ tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
+ sampled_inds.append(tmp_sampled_set)
+
+ sampled_inds = np.concatenate(sampled_inds)
+ if len(sampled_inds) < num_expected:
+ num_extra = num_expected - len(sampled_inds)
+ extra_inds = np.array(list(full_set - set(sampled_inds)))
+ if len(extra_inds) > num_extra:
+ extra_inds = self.random_choice(extra_inds, num_extra)
+ sampled_inds = np.concatenate([sampled_inds, extra_inds])
+
+ return sampled_inds
+
+ def _sample_neg(self, assign_result, num_expected, **kwargs):
+ """Sample negative boxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): The assigned results of boxes.
+ num_expected (int): The number of expected negative samples
+
+ Returns:
+ Tensor or ndarray: sampled indices.
+ """
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
+ if neg_inds.numel() != 0:
+ neg_inds = neg_inds.squeeze(1)
+ if len(neg_inds) <= num_expected:
+ return neg_inds
+ else:
+ max_overlaps = assign_result.max_overlaps.cpu().numpy()
+ # balance sampling for negative samples
+ neg_set = set(neg_inds.cpu().numpy())
+
+ if self.floor_thr > 0:
+ floor_set = set(
+ np.where(
+ np.logical_and(max_overlaps >= 0,
+ max_overlaps < self.floor_thr))[0])
+ iou_sampling_set = set(
+ np.where(max_overlaps >= self.floor_thr)[0])
+ elif self.floor_thr == 0:
+ floor_set = set(np.where(max_overlaps == 0)[0])
+ iou_sampling_set = set(
+ np.where(max_overlaps > self.floor_thr)[0])
+ else:
+ floor_set = set()
+ iou_sampling_set = set(
+ np.where(max_overlaps > self.floor_thr)[0])
+ # for sampling interval calculation
+ self.floor_thr = 0
+
+ floor_neg_inds = list(floor_set & neg_set)
+ iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
+ num_expected_iou_sampling = int(num_expected *
+ (1 - self.floor_fraction))
+ if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
+ if self.num_bins >= 2:
+ iou_sampled_inds = self.sample_via_interval(
+ max_overlaps, set(iou_sampling_neg_inds),
+ num_expected_iou_sampling)
+ else:
+ iou_sampled_inds = self.random_choice(
+ iou_sampling_neg_inds, num_expected_iou_sampling)
+ else:
+ iou_sampled_inds = np.array(
+ iou_sampling_neg_inds, dtype=np.int)
+ num_expected_floor = num_expected - len(iou_sampled_inds)
+ if len(floor_neg_inds) > num_expected_floor:
+ sampled_floor_inds = self.random_choice(
+ floor_neg_inds, num_expected_floor)
+ else:
+ sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
+ sampled_inds = np.concatenate(
+ (sampled_floor_inds, iou_sampled_inds))
+ if len(sampled_inds) < num_expected:
+ num_extra = num_expected - len(sampled_inds)
+ extra_inds = np.array(list(neg_set - set(sampled_inds)))
+ if len(extra_inds) > num_extra:
+ extra_inds = self.random_choice(extra_inds, num_extra)
+ sampled_inds = np.concatenate((sampled_inds, extra_inds))
+ sampled_inds = torch.from_numpy(sampled_inds).long().to(
+ assign_result.gt_inds.device)
+ return sampled_inds
diff --git a/mmdet/core/bbox/samplers/mask_pseudo_sampler.py b/mmdet/core/bbox/samplers/mask_pseudo_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5f69658d02808fd67adf54d2acf5f7fc28d2e6e
--- /dev/null
+++ b/mmdet/core/bbox/samplers/mask_pseudo_sampler.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""copy from
+https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py."""
+
+import torch
+
+from mmdet.core.bbox.builder import BBOX_SAMPLERS
+from .base_sampler import BaseSampler
+from .mask_sampling_result import MaskSamplingResult
+
+
+@BBOX_SAMPLERS.register_module()
+class MaskPseudoSampler(BaseSampler):
+ """A pseudo sampler that does not do sampling actually."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def _sample_pos(self, **kwargs):
+ """Sample positive samples."""
+ raise NotImplementedError
+
+ def _sample_neg(self, **kwargs):
+ """Sample negative samples."""
+ raise NotImplementedError
+
+ def sample(self, assign_result, masks, gt_masks, **kwargs):
+ """Directly returns the positive and negative indices of samples.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Assigned results
+ masks (torch.Tensor): Bounding boxes
+ gt_masks (torch.Tensor): Ground truth boxes
+ Returns:
+ :obj:`SamplingResult`: sampler results
+ """
+ pos_inds = torch.nonzero(
+ assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
+ neg_inds = torch.nonzero(
+ assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
+ gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8)
+ sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks,
+ gt_masks, assign_result, gt_flags)
+ return sampling_result
diff --git a/mmdet/core/bbox/samplers/mask_sampling_result.py b/mmdet/core/bbox/samplers/mask_sampling_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d109432260089b8f494d0e5b78bab7280cc2e0d
--- /dev/null
+++ b/mmdet/core/bbox/samplers/mask_sampling_result.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""copy from
+https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py."""
+
+import torch
+
+from .sampling_result import SamplingResult
+
+
+class MaskSamplingResult(SamplingResult):
+ """Mask sampling result."""
+
+ def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result,
+ gt_flags):
+ self.pos_inds = pos_inds
+ self.neg_inds = neg_inds
+ self.pos_masks = masks[pos_inds]
+ self.neg_masks = masks[neg_inds]
+ self.pos_is_gt = gt_flags[pos_inds]
+
+ self.num_gts = gt_masks.shape[0]
+ self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
+
+ if gt_masks.numel() == 0:
+ # hack for index error case
+ assert self.pos_assigned_gt_inds.numel() == 0
+ self.pos_gt_masks = torch.empty_like(gt_masks)
+ else:
+ self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]
+
+ if assign_result.labels is not None:
+ self.pos_gt_labels = assign_result.labels[pos_inds]
+ else:
+ self.pos_gt_labels = None
+
+ @property
+ def masks(self):
+ """torch.Tensor: concatenated positive and negative boxes"""
+ return torch.cat([self.pos_masks, self.neg_masks])
+
+ def __nice__(self):
+ data = self.info.copy()
+ data['pos_masks'] = data.pop('pos_masks').shape
+ data['neg_masks'] = data.pop('neg_masks').shape
+ parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
+ body = ' ' + ',\n '.join(parts)
+ return '{\n' + body + '\n}'
+
+ @property
+ def info(self):
+ """Returns a dictionary of info about the object."""
+ return {
+ 'pos_inds': self.pos_inds,
+ 'neg_inds': self.neg_inds,
+ 'pos_masks': self.pos_masks,
+ 'neg_masks': self.neg_masks,
+ 'pos_is_gt': self.pos_is_gt,
+ 'num_gts': self.num_gts,
+ 'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
+ }
diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb066633809ff8d70240062c2dacd0e7283a1c5
--- /dev/null
+++ b/mmdet/core/bbox/samplers/ohem_sampler.py
@@ -0,0 +1,111 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from ..transforms import bbox2roi
+from .base_sampler import BaseSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class OHEMSampler(BaseSampler):
+ r"""Online Hard Example Mining Sampler described in `Training Region-based
+ Object Detectors with Online Hard Example Mining
+ `_.
+ """
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ context,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ loss_key='loss_cls',
+ **kwargs):
+ super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
+ add_gt_as_proposals)
+ self.context = context
+ if not hasattr(self.context, 'num_stages'):
+ self.bbox_head = self.context.bbox_head
+ else:
+ self.bbox_head = self.context.bbox_head[self.context.current_stage]
+
+ self.loss_key = loss_key
+
+ def hard_mining(self, inds, num_expected, bboxes, labels, feats):
+ with torch.no_grad():
+ rois = bbox2roi([bboxes])
+ if not hasattr(self.context, 'num_stages'):
+ bbox_results = self.context._bbox_forward(feats, rois)
+ else:
+ bbox_results = self.context._bbox_forward(
+ self.context.current_stage, feats, rois)
+ cls_score = bbox_results['cls_score']
+ loss = self.bbox_head.loss(
+ cls_score=cls_score,
+ bbox_pred=None,
+ rois=rois,
+ labels=labels,
+ label_weights=cls_score.new_ones(cls_score.size(0)),
+ bbox_targets=None,
+ bbox_weights=None,
+ reduction_override='none')[self.loss_key]
+ _, topk_loss_inds = loss.topk(num_expected)
+ return inds[topk_loss_inds]
+
+ def _sample_pos(self,
+ assign_result,
+ num_expected,
+ bboxes=None,
+ feats=None,
+ **kwargs):
+ """Sample positive boxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Assigned results
+ num_expected (int): Number of expected positive samples
+ bboxes (torch.Tensor, optional): Boxes. Defaults to None.
+ feats (list[torch.Tensor], optional): Multi-level features.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: Indices of positive samples
+ """
+ # Sample some hard positive samples
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
+ if pos_inds.numel() != 0:
+ pos_inds = pos_inds.squeeze(1)
+ if pos_inds.numel() <= num_expected:
+ return pos_inds
+ else:
+ return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
+ assign_result.labels[pos_inds], feats)
+
+ def _sample_neg(self,
+ assign_result,
+ num_expected,
+ bboxes=None,
+ feats=None,
+ **kwargs):
+ """Sample negative boxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Assigned results
+ num_expected (int): Number of expected negative samples
+ bboxes (torch.Tensor, optional): Boxes. Defaults to None.
+ feats (list[torch.Tensor], optional): Multi-level features.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: Indices of negative samples
+ """
+ # Sample some hard negative samples
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
+ if neg_inds.numel() != 0:
+ neg_inds = neg_inds.squeeze(1)
+ if len(neg_inds) <= num_expected:
+ return neg_inds
+ else:
+ neg_labels = assign_result.labels.new_empty(
+ neg_inds.size(0)).fill_(self.bbox_head.num_classes)
+ return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
+ neg_labels, feats)
diff --git a/mmdet/core/bbox/samplers/pseudo_sampler.py b/mmdet/core/bbox/samplers/pseudo_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5ce298ed01a327daa12167a20cb14b48c14d4e0
--- /dev/null
+++ b/mmdet/core/bbox/samplers/pseudo_sampler.py
@@ -0,0 +1,42 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from .base_sampler import BaseSampler
+from .sampling_result import SamplingResult
+
+
+@BBOX_SAMPLERS.register_module()
+class PseudoSampler(BaseSampler):
+ """A pseudo sampler that does not do sampling actually."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def _sample_pos(self, **kwargs):
+ """Sample positive samples."""
+ raise NotImplementedError
+
+ def _sample_neg(self, **kwargs):
+ """Sample negative samples."""
+ raise NotImplementedError
+
+ def sample(self, assign_result, bboxes, gt_bboxes, *args, **kwargs):
+ """Directly returns the positive and negative indices of samples.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Assigned results
+ bboxes (torch.Tensor): Bounding boxes
+ gt_bboxes (torch.Tensor): Ground truth boxes
+
+ Returns:
+ :obj:`SamplingResult`: sampler results
+ """
+ pos_inds = torch.nonzero(
+ assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
+ neg_inds = torch.nonzero(
+ assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
+ gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8)
+ sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+ assign_result, gt_flags)
+ return sampling_result
diff --git a/mmdet/core/bbox/samplers/random_sampler.py b/mmdet/core/bbox/samplers/random_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d3effcb7802df98aeff4282594d2b7464643621
--- /dev/null
+++ b/mmdet/core/bbox/samplers/random_sampler.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from .base_sampler import BaseSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class RandomSampler(BaseSampler):
+ """Random sampler.
+
+ Args:
+ num (int): Number of samples
+ pos_fraction (float): Fraction of positive samples
+ neg_pos_ub (int, optional): Upper bound number of negative and
+ positive samples. Defaults to -1.
+ add_gt_as_proposals (bool, optional): Whether to add ground truth
+ boxes as proposals. Defaults to True.
+ """
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ **kwargs):
+ from mmdet.core.bbox import demodata
+ super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
+ add_gt_as_proposals)
+ self.rng = demodata.ensure_rng(kwargs.get('rng', None))
+
+ def random_choice(self, gallery, num):
+ """Random select some elements from the gallery.
+
+ If `gallery` is a Tensor, the returned indices will be a Tensor;
+ If `gallery` is a ndarray or list, the returned indices will be a
+ ndarray.
+
+ Args:
+ gallery (Tensor | ndarray | list): indices pool.
+ num (int): expected sample num.
+
+ Returns:
+ Tensor or ndarray: sampled indices.
+ """
+ assert len(gallery) >= num
+
+ is_tensor = isinstance(gallery, torch.Tensor)
+ if not is_tensor:
+ if torch.cuda.is_available():
+ device = torch.cuda.current_device()
+ else:
+ device = 'cpu'
+ gallery = torch.tensor(gallery, dtype=torch.long, device=device)
+ # This is a temporary fix. We can revert the following code
+ # when PyTorch fixes the abnormal return of torch.randperm.
+ # See: https://github.com/open-mmlab/mmdetection/pull/5014
+ perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device)
+ rand_inds = gallery[perm]
+ if not is_tensor:
+ rand_inds = rand_inds.cpu().numpy()
+ return rand_inds
+
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Randomly sample some positive samples."""
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
+ if pos_inds.numel() != 0:
+ pos_inds = pos_inds.squeeze(1)
+ if pos_inds.numel() <= num_expected:
+ return pos_inds
+ else:
+ return self.random_choice(pos_inds, num_expected)
+
+ def _sample_neg(self, assign_result, num_expected, **kwargs):
+ """Randomly sample some negative samples."""
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
+ if neg_inds.numel() != 0:
+ neg_inds = neg_inds.squeeze(1)
+ if len(neg_inds) <= num_expected:
+ return neg_inds
+ else:
+ return self.random_choice(neg_inds, num_expected)
diff --git a/mmdet/core/bbox/samplers/sampling_result.py b/mmdet/core/bbox/samplers/sampling_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..11a02c5d95a4d633dfea26df7fb3e440494a8be7
--- /dev/null
+++ b/mmdet/core/bbox/samplers/sampling_result.py
@@ -0,0 +1,153 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.utils import util_mixins
+
+
+class SamplingResult(util_mixins.NiceRepr):
+ """Bbox sampling result.
+
+ Example:
+ >>> # xdoctest: +IGNORE_WANT
+ >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
+ >>> self = SamplingResult.random(rng=10)
+ >>> print(f'self = {self}')
+ self =
+ """
+
+ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
+ gt_flags):
+ self.pos_inds = pos_inds
+ self.neg_inds = neg_inds
+ self.pos_bboxes = bboxes[pos_inds]
+ self.neg_bboxes = bboxes[neg_inds]
+ self.pos_is_gt = gt_flags[pos_inds]
+
+ self.num_gts = gt_bboxes.shape[0]
+ self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
+
+ if gt_bboxes.numel() == 0:
+ # hack for index error case
+ assert self.pos_assigned_gt_inds.numel() == 0
+ self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
+ else:
+ if len(gt_bboxes.shape) < 2:
+ gt_bboxes = gt_bboxes.view(-1, 4)
+
+ self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :]
+
+ if assign_result.labels is not None:
+ self.pos_gt_labels = assign_result.labels[pos_inds]
+ else:
+ self.pos_gt_labels = None
+
+ @property
+ def bboxes(self):
+ """torch.Tensor: concatenated positive and negative boxes"""
+ return torch.cat([self.pos_bboxes, self.neg_bboxes])
+
+ def to(self, device):
+ """Change the device of the data inplace.
+
+ Example:
+ >>> self = SamplingResult.random()
+ >>> print(f'self = {self.to(None)}')
+ >>> # xdoctest: +REQUIRES(--gpu)
+ >>> print(f'self = {self.to(0)}')
+ """
+ _dict = self.__dict__
+ for key, value in _dict.items():
+ if isinstance(value, torch.Tensor):
+ _dict[key] = value.to(device)
+ return self
+
+ def __nice__(self):
+ data = self.info.copy()
+ data['pos_bboxes'] = data.pop('pos_bboxes').shape
+ data['neg_bboxes'] = data.pop('neg_bboxes').shape
+ parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
+ body = ' ' + ',\n '.join(parts)
+ return '{\n' + body + '\n}'
+
+ @property
+ def info(self):
+ """Returns a dictionary of info about the object."""
+ return {
+ 'pos_inds': self.pos_inds,
+ 'neg_inds': self.neg_inds,
+ 'pos_bboxes': self.pos_bboxes,
+ 'neg_bboxes': self.neg_bboxes,
+ 'pos_is_gt': self.pos_is_gt,
+ 'num_gts': self.num_gts,
+ 'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
+ }
+
+ @classmethod
+ def random(cls, rng=None, **kwargs):
+ """
+ Args:
+ rng (None | int | numpy.random.RandomState): seed or state.
+ kwargs (keyword arguments):
+ - num_preds: number of predicted boxes
+ - num_gts: number of true boxes
+ - p_ignore (float): probability of a predicted box assigned to \
+ an ignored truth.
+ - p_assigned (float): probability of a predicted box not being \
+ assigned.
+ - p_use_label (float | bool): with labels or not.
+
+ Returns:
+ :obj:`SamplingResult`: Randomly generated sampling result.
+
+ Example:
+ >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
+ >>> self = SamplingResult.random()
+ >>> print(self.__dict__)
+ """
+ from mmdet.core.bbox import demodata
+ from mmdet.core.bbox.assigners.assign_result import AssignResult
+ from mmdet.core.bbox.samplers.random_sampler import RandomSampler
+ rng = demodata.ensure_rng(rng)
+
+ # make probabilistic?
+ num = 32
+ pos_fraction = 0.5
+ neg_pos_ub = -1
+
+ assign_result = AssignResult.random(rng=rng, **kwargs)
+
+ # Note we could just compute an assignment
+ bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
+ gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
+
+ if rng.rand() > 0.2:
+ # sometimes algorithms squeeze their data, be robust to that
+ gt_bboxes = gt_bboxes.squeeze()
+ bboxes = bboxes.squeeze()
+
+ if assign_result.labels is None:
+ gt_labels = None
+ else:
+ gt_labels = None # todo
+
+ if gt_labels is None:
+ add_gt_as_proposals = False
+ else:
+ add_gt_as_proposals = True # make probabilistic?
+
+ sampler = RandomSampler(
+ num,
+ pos_fraction,
+ neg_pos_ub=neg_pos_ub,
+ add_gt_as_proposals=add_gt_as_proposals,
+ rng=rng)
+ self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+ return self
diff --git a/mmdet/core/bbox/samplers/score_hlr_sampler.py b/mmdet/core/bbox/samplers/score_hlr_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4be9b8cfefff7bd59242de1ab5b6a9e37fa7943
--- /dev/null
+++ b/mmdet/core/bbox/samplers/score_hlr_sampler.py
@@ -0,0 +1,265 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.ops import nms_match
+
+from ..builder import BBOX_SAMPLERS
+from ..transforms import bbox2roi
+from .base_sampler import BaseSampler
+from .sampling_result import SamplingResult
+
+
+@BBOX_SAMPLERS.register_module()
+class ScoreHLRSampler(BaseSampler):
+ r"""Importance-based Sample Reweighting (ISR_N), described in `Prime Sample
+ Attention in Object Detection `_.
+
+ Score hierarchical local rank (HLR) differentiates with RandomSampler in
+ negative part. It firstly computes Score-HLR in a two-step way,
+ then linearly maps score hlr to the loss weights.
+
+ Args:
+ num (int): Total number of sampled RoIs.
+ pos_fraction (float): Fraction of positive samples.
+ context (:class:`BaseRoIHead`): RoI head that the sampler belongs to.
+ neg_pos_ub (int): Upper bound of the ratio of num negative to num
+ positive, -1 means no upper bound.
+ add_gt_as_proposals (bool): Whether to add ground truth as proposals.
+ k (float): Power of the non-linear mapping.
+ bias (float): Shift of the non-linear mapping.
+ score_thr (float): Minimum score that a negative sample is to be
+ considered as valid bbox.
+ """
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ context,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ k=0.5,
+ bias=0,
+ score_thr=0.05,
+ iou_thr=0.5,
+ **kwargs):
+ super().__init__(num, pos_fraction, neg_pos_ub, add_gt_as_proposals)
+ self.k = k
+ self.bias = bias
+ self.score_thr = score_thr
+ self.iou_thr = iou_thr
+ self.context = context
+ # context of cascade detectors is a list, so distinguish them here.
+ if not hasattr(context, 'num_stages'):
+ self.bbox_roi_extractor = context.bbox_roi_extractor
+ self.bbox_head = context.bbox_head
+ self.with_shared_head = context.with_shared_head
+ if self.with_shared_head:
+ self.shared_head = context.shared_head
+ else:
+ self.bbox_roi_extractor = context.bbox_roi_extractor[
+ context.current_stage]
+ self.bbox_head = context.bbox_head[context.current_stage]
+
+ @staticmethod
+ def random_choice(gallery, num):
+ """Randomly select some elements from the gallery.
+
+ If `gallery` is a Tensor, the returned indices will be a Tensor;
+ If `gallery` is a ndarray or list, the returned indices will be a
+ ndarray.
+
+ Args:
+ gallery (Tensor | ndarray | list): indices pool.
+ num (int): expected sample num.
+
+ Returns:
+ Tensor or ndarray: sampled indices.
+ """
+ assert len(gallery) >= num
+
+ is_tensor = isinstance(gallery, torch.Tensor)
+ if not is_tensor:
+ if torch.cuda.is_available():
+ device = torch.cuda.current_device()
+ else:
+ device = 'cpu'
+ gallery = torch.tensor(gallery, dtype=torch.long, device=device)
+ perm = torch.randperm(gallery.numel(), device=gallery.device)[:num]
+ rand_inds = gallery[perm]
+ if not is_tensor:
+ rand_inds = rand_inds.cpu().numpy()
+ return rand_inds
+
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Randomly sample some positive samples."""
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0).flatten()
+ if pos_inds.numel() <= num_expected:
+ return pos_inds
+ else:
+ return self.random_choice(pos_inds, num_expected)
+
+ def _sample_neg(self,
+ assign_result,
+ num_expected,
+ bboxes,
+ feats=None,
+ img_meta=None,
+ **kwargs):
+ """Sample negative samples.
+
+ Score-HLR sampler is done in the following steps:
+ 1. Take the maximum positive score prediction of each negative samples
+ as s_i.
+ 2. Filter out negative samples whose s_i <= score_thr, the left samples
+ are called valid samples.
+ 3. Use NMS-Match to divide valid samples into different groups,
+ samples in the same group will greatly overlap with each other
+ 4. Rank the matched samples in two-steps to get Score-HLR.
+ (1) In the same group, rank samples with their scores.
+ (2) In the same score rank across different groups,
+ rank samples with their scores again.
+ 5. Linearly map Score-HLR to the final label weights.
+
+ Args:
+ assign_result (:obj:`AssignResult`): result of assigner.
+ num_expected (int): Expected number of samples.
+ bboxes (Tensor): bbox to be sampled.
+ feats (Tensor): Features come from FPN.
+ img_meta (dict): Meta information dictionary.
+ """
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0).flatten()
+ num_neg = neg_inds.size(0)
+ if num_neg == 0:
+ return neg_inds, None
+ with torch.no_grad():
+ neg_bboxes = bboxes[neg_inds]
+ neg_rois = bbox2roi([neg_bboxes])
+ bbox_result = self.context._bbox_forward(feats, neg_rois)
+ cls_score, bbox_pred = bbox_result['cls_score'], bbox_result[
+ 'bbox_pred']
+
+ ori_loss = self.bbox_head.loss(
+ cls_score=cls_score,
+ bbox_pred=None,
+ rois=None,
+ labels=neg_inds.new_full((num_neg, ),
+ self.bbox_head.num_classes),
+ label_weights=cls_score.new_ones(num_neg),
+ bbox_targets=None,
+ bbox_weights=None,
+ reduction_override='none')['loss_cls']
+
+ # filter out samples with the max score lower than score_thr
+ max_score, argmax_score = cls_score.softmax(-1)[:, :-1].max(-1)
+ valid_inds = (max_score > self.score_thr).nonzero().view(-1)
+ invalid_inds = (max_score <= self.score_thr).nonzero().view(-1)
+ num_valid = valid_inds.size(0)
+ num_invalid = invalid_inds.size(0)
+
+ num_expected = min(num_neg, num_expected)
+ num_hlr = min(num_valid, num_expected)
+ num_rand = num_expected - num_hlr
+ if num_valid > 0:
+ valid_rois = neg_rois[valid_inds]
+ valid_max_score = max_score[valid_inds]
+ valid_argmax_score = argmax_score[valid_inds]
+ valid_bbox_pred = bbox_pred[valid_inds]
+
+ # valid_bbox_pred shape: [num_valid, #num_classes, 4]
+ valid_bbox_pred = valid_bbox_pred.view(
+ valid_bbox_pred.size(0), -1, 4)
+ selected_bbox_pred = valid_bbox_pred[range(num_valid),
+ valid_argmax_score]
+ pred_bboxes = self.bbox_head.bbox_coder.decode(
+ valid_rois[:, 1:], selected_bbox_pred)
+ pred_bboxes_with_score = torch.cat(
+ [pred_bboxes, valid_max_score[:, None]], -1)
+ group = nms_match(pred_bboxes_with_score, self.iou_thr)
+
+ # imp: importance
+ imp = cls_score.new_zeros(num_valid)
+ for g in group:
+ g_score = valid_max_score[g]
+ # g_score has already sorted
+ rank = g_score.new_tensor(range(g_score.size(0)))
+ imp[g] = num_valid - rank + g_score
+ _, imp_rank_inds = imp.sort(descending=True)
+ _, imp_rank = imp_rank_inds.sort()
+ hlr_inds = imp_rank_inds[:num_expected]
+
+ if num_rand > 0:
+ rand_inds = torch.randperm(num_invalid)[:num_rand]
+ select_inds = torch.cat(
+ [valid_inds[hlr_inds], invalid_inds[rand_inds]])
+ else:
+ select_inds = valid_inds[hlr_inds]
+
+ neg_label_weights = cls_score.new_ones(num_expected)
+
+ up_bound = max(num_expected, num_valid)
+ imp_weights = (up_bound -
+ imp_rank[hlr_inds].float()) / up_bound
+ neg_label_weights[:num_hlr] = imp_weights
+ neg_label_weights[num_hlr:] = imp_weights.min()
+ neg_label_weights = (self.bias +
+ (1 - self.bias) * neg_label_weights).pow(
+ self.k)
+ ori_selected_loss = ori_loss[select_inds]
+ new_loss = ori_selected_loss * neg_label_weights
+ norm_ratio = ori_selected_loss.sum() / new_loss.sum()
+ neg_label_weights *= norm_ratio
+ else:
+ neg_label_weights = cls_score.new_ones(num_expected)
+ select_inds = torch.randperm(num_neg)[:num_expected]
+
+ return neg_inds[select_inds], neg_label_weights
+
+ def sample(self,
+ assign_result,
+ bboxes,
+ gt_bboxes,
+ gt_labels=None,
+ img_meta=None,
+ **kwargs):
+ """Sample positive and negative bboxes.
+
+ This is a simple implementation of bbox sampling given candidates,
+ assigning results and ground truth bboxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Bbox assigning results.
+ bboxes (Tensor): Boxes to be sampled from.
+ gt_bboxes (Tensor): Ground truth bboxes.
+ gt_labels (Tensor, optional): Class labels of ground truth bboxes.
+
+ Returns:
+ tuple[:obj:`SamplingResult`, Tensor]: Sampling result and negative
+ label weights.
+ """
+ bboxes = bboxes[:, :4]
+
+ gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
+ if self.add_gt_as_proposals:
+ bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
+ assign_result.add_gt_(gt_labels)
+ gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
+ gt_flags = torch.cat([gt_ones, gt_flags])
+
+ num_expected_pos = int(self.num * self.pos_fraction)
+ pos_inds = self.pos_sampler._sample_pos(
+ assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
+ num_sampled_pos = pos_inds.numel()
+ num_expected_neg = self.num - num_sampled_pos
+ if self.neg_pos_ub >= 0:
+ _pos = max(1, num_sampled_pos)
+ neg_upper_bound = int(self.neg_pos_ub * _pos)
+ if num_expected_neg > neg_upper_bound:
+ num_expected_neg = neg_upper_bound
+ neg_inds, neg_label_weights = self.neg_sampler._sample_neg(
+ assign_result,
+ num_expected_neg,
+ bboxes,
+ img_meta=img_meta,
+ **kwargs)
+
+ return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+ assign_result, gt_flags), neg_label_weights
diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d72076a5621c5b59c081a8a190b4c8d167c26a5
--- /dev/null
+++ b/mmdet/core/bbox/transforms.py
@@ -0,0 +1,270 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+
+def find_inside_bboxes(bboxes, img_h, img_w):
+ """Find bboxes as long as a part of bboxes is inside the image.
+
+ Args:
+ bboxes (Tensor): Shape (N, 4).
+ img_h (int): Image height.
+ img_w (int): Image width.
+
+ Returns:
+ Tensor: Index of the remaining bboxes.
+ """
+ inside_inds = (bboxes[:, 0] < img_w) & (bboxes[:, 2] > 0) \
+ & (bboxes[:, 1] < img_h) & (bboxes[:, 3] > 0)
+ return inside_inds
+
+
+def bbox_flip(bboxes, img_shape, direction='horizontal'):
+ """Flip bboxes horizontally or vertically.
+
+ Args:
+ bboxes (Tensor): Shape (..., 4*k)
+ img_shape (tuple): Image shape.
+ direction (str): Flip direction, options are "horizontal", "vertical",
+ "diagonal". Default: "horizontal"
+
+ Returns:
+ Tensor: Flipped bboxes.
+ """
+ assert bboxes.shape[-1] % 4 == 0
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ flipped = bboxes.clone()
+ if direction == 'horizontal':
+ flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
+ flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
+ elif direction == 'vertical':
+ flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
+ flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
+ else:
+ flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
+ flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
+ flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
+ flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
+ return flipped
+
+
+def bbox_mapping(bboxes,
+ img_shape,
+ scale_factor,
+ flip,
+ flip_direction='horizontal'):
+ """Map bboxes from the original image scale to testing scale."""
+ new_bboxes = bboxes * bboxes.new_tensor(scale_factor)
+ if flip:
+ new_bboxes = bbox_flip(new_bboxes, img_shape, flip_direction)
+ return new_bboxes
+
+
+def bbox_mapping_back(bboxes,
+ img_shape,
+ scale_factor,
+ flip,
+ flip_direction='horizontal'):
+ """Map bboxes from testing scale to original image scale."""
+ new_bboxes = bbox_flip(bboxes, img_shape,
+ flip_direction) if flip else bboxes
+ new_bboxes = new_bboxes.view(-1, 4) / new_bboxes.new_tensor(scale_factor)
+ return new_bboxes.view(bboxes.shape)
+
+
+def bbox2roi(bbox_list):
+ """Convert a list of bboxes to roi format.
+
+ Args:
+ bbox_list (list[Tensor]): a list of bboxes corresponding to a batch
+ of images.
+
+ Returns:
+ Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]
+ """
+ rois_list = []
+ for img_id, bboxes in enumerate(bbox_list):
+ if bboxes.size(0) > 0:
+ img_inds = bboxes.new_full((bboxes.size(0), 1), img_id)
+ rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1)
+ else:
+ rois = bboxes.new_zeros((0, 5))
+ rois_list.append(rois)
+ rois = torch.cat(rois_list, 0)
+ return rois
+
+
+def roi2bbox(rois):
+ """Convert rois to bounding box format.
+
+ Args:
+ rois (torch.Tensor): RoIs with the shape (n, 5) where the first
+ column indicates batch id of each RoI.
+
+ Returns:
+ list[torch.Tensor]: Converted boxes of corresponding rois.
+ """
+ bbox_list = []
+ img_ids = torch.unique(rois[:, 0].cpu(), sorted=True)
+ for img_id in img_ids:
+ inds = (rois[:, 0] == img_id.item())
+ bbox = rois[inds, 1:]
+ bbox_list.append(bbox)
+ return bbox_list
+
+
+def bbox2result(bboxes, labels, num_classes):
+ """Convert detection results to a list of numpy arrays.
+
+ Args:
+ bboxes (torch.Tensor | np.ndarray): shape (n, 5)
+ labels (torch.Tensor | np.ndarray): shape (n, )
+ num_classes (int): class number, including background class
+
+ Returns:
+ list(ndarray): bbox results of each class
+ """
+ if bboxes.shape[0] == 0:
+ return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
+ else:
+ if isinstance(bboxes, torch.Tensor):
+ bboxes = bboxes.detach().cpu().numpy()
+ labels = labels.detach().cpu().numpy()
+ return [bboxes[labels == i, :] for i in range(num_classes)]
+
+
+def distance2bbox(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (B, N, 2) or (N, 2).
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4)
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If priors shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+
+ Returns:
+ Tensor: Boxes with shape (N, 4) or (B, N, 4)
+ """
+
+ x1 = points[..., 0] - distance[..., 0]
+ y1 = points[..., 1] - distance[..., 1]
+ x2 = points[..., 0] + distance[..., 2]
+ y2 = points[..., 1] + distance[..., 3]
+
+ bboxes = torch.stack([x1, y1, x2, y2], -1)
+
+ if max_shape is not None:
+ if bboxes.dim() == 2 and not torch.onnx.is_in_onnx_export():
+ # speed up
+ bboxes[:, 0::2].clamp_(min=0, max=max_shape[1])
+ bboxes[:, 1::2].clamp_(min=0, max=max_shape[0])
+ return bboxes
+
+ # clip bboxes with dynamic `min` and `max` for onnx
+ if torch.onnx.is_in_onnx_export():
+ from mmdet.core.export import dynamic_clip_for_onnx
+ x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape)
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
+ return bboxes
+ if not isinstance(max_shape, torch.Tensor):
+ max_shape = x1.new_tensor(max_shape)
+ max_shape = max_shape[..., :2].type_as(x1)
+ if max_shape.ndim == 2:
+ assert bboxes.ndim == 3
+ assert max_shape.size(0) == bboxes.size(0)
+
+ min_xy = x1.new_tensor(0)
+ max_xy = torch.cat([max_shape, max_shape],
+ dim=-1).flip(-1).unsqueeze(-2)
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+ return bboxes
+
+
+def bbox2distance(points, bbox, max_dis=None, eps=0.1):
+ """Decode bounding box based on distances.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ bbox (Tensor): Shape (n, 4), "xyxy" format
+ max_dis (float): Upper bound of the distance.
+ eps (float): a small value to ensure target < max_dis, instead <=
+
+ Returns:
+ Tensor: Decoded distances.
+ """
+ left = points[:, 0] - bbox[:, 0]
+ top = points[:, 1] - bbox[:, 1]
+ right = bbox[:, 2] - points[:, 0]
+ bottom = bbox[:, 3] - points[:, 1]
+ if max_dis is not None:
+ left = left.clamp(min=0, max=max_dis - eps)
+ top = top.clamp(min=0, max=max_dis - eps)
+ right = right.clamp(min=0, max=max_dis - eps)
+ bottom = bottom.clamp(min=0, max=max_dis - eps)
+ return torch.stack([left, top, right, bottom], -1)
+
+
+def bbox_rescale(bboxes, scale_factor=1.0):
+ """Rescale bounding box w.r.t. scale_factor.
+
+ Args:
+ bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois
+ scale_factor (float): rescale factor
+
+ Returns:
+ Tensor: Rescaled bboxes.
+ """
+ if bboxes.size(1) == 5:
+ bboxes_ = bboxes[:, 1:]
+ inds_ = bboxes[:, 0]
+ else:
+ bboxes_ = bboxes
+ cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5
+ cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5
+ w = bboxes_[:, 2] - bboxes_[:, 0]
+ h = bboxes_[:, 3] - bboxes_[:, 1]
+ w = w * scale_factor
+ h = h * scale_factor
+ x1 = cx - 0.5 * w
+ x2 = cx + 0.5 * w
+ y1 = cy - 0.5 * h
+ y2 = cy + 0.5 * h
+ if bboxes.size(1) == 5:
+ rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1)
+ else:
+ rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
+ return rescaled_bboxes
+
+
+def bbox_cxcywh_to_xyxy(bbox):
+ """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2).
+
+ Args:
+ bbox (Tensor): Shape (n, 4) for bboxes.
+
+ Returns:
+ Tensor: Converted bboxes.
+ """
+ cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1)
+ bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)]
+ return torch.cat(bbox_new, dim=-1)
+
+
+def bbox_xyxy_to_cxcywh(bbox):
+ """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h).
+
+ Args:
+ bbox (Tensor): Shape (n, 4) for bboxes.
+
+ Returns:
+ Tensor: Converted bboxes.
+ """
+ x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1)
+ bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)]
+ return torch.cat(bbox_new, dim=-1)
diff --git a/mmdet/core/data_structures/__init__.py b/mmdet/core/data_structures/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..11ab96c565da484ad11533c3535e25abcc212c32
--- /dev/null
+++ b/mmdet/core/data_structures/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .general_data import GeneralData
+from .instance_data import InstanceData
+
+__all__ = ['GeneralData', 'InstanceData']
diff --git a/mmdet/core/data_structures/general_data.py b/mmdet/core/data_structures/general_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..978fdfd7460dda68bc1bfc81cdd9aef493d445b3
--- /dev/null
+++ b/mmdet/core/data_structures/general_data.py
@@ -0,0 +1,336 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import numpy as np
+import torch
+
+from mmdet.utils.util_mixins import NiceRepr
+
+
+class GeneralData(NiceRepr):
+ """A general data structure of OpenMMlab.
+
+ A data structure that stores the meta information,
+ the annotations of the images or the model predictions,
+ which can be used in communication between components.
+
+ The attributes in `GeneralData` are divided into two parts,
+ the `meta_info_fields` and the `data_fields` respectively.
+
+ - `meta_info_fields`: Usually contains the
+ information about the image such as filename,
+ image_shape, pad_shape, etc. All attributes in
+ it are immutable once set,
+ but the user can add new meta information with
+ `set_meta_info` function, all information can be accessed
+ with methods `meta_info_keys`, `meta_info_values`,
+ `meta_info_items`.
+
+ - `data_fields`: Annotations or model predictions are
+ stored. The attributes can be accessed or modified by
+ dict-like or object-like operations, such as
+ `.` , `[]`, `in`, `del`, `pop(str)` `get(str)`, `keys()`,
+ `values()`, `items()`. Users can also apply tensor-like methods
+ to all obj:`torch.Tensor` in the `data_fileds`,
+ such as `.cuda()`, `.cpu()`, `.numpy()`, `device`, `.to()`
+ `.detach()`, `.numpy()`
+
+ Args:
+ meta_info (dict, optional): A dict contains the meta information
+ of single image. such as `img_shape`, `scale_factor`, etc.
+ Default: None.
+ data (dict, optional): A dict contains annotations of single image or
+ model predictions. Default: None.
+
+ Examples:
+ >>> from mmdet.core import GeneralData
+ >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
+ >>> instance_data = GeneralData(meta_info=img_meta)
+ >>> img_shape in instance_data
+ True
+ >>> instance_data.det_labels = torch.LongTensor([0, 1, 2, 3])
+ >>> instance_data["det_scores"] = torch.Tensor([0.01, 0.1, 0.2, 0.3])
+ >>> print(results)
+
+ >>> instance_data.det_scores
+ tensor([0.0100, 0.1000, 0.2000, 0.3000])
+ >>> instance_data.det_labels
+ tensor([0, 1, 2, 3])
+ >>> instance_data['det_labels']
+ tensor([0, 1, 2, 3])
+ >>> 'det_labels' in instance_data
+ True
+ >>> instance_data.img_shape
+ (800, 1196, 3)
+ >>> 'det_scores' in instance_data
+ True
+ >>> del instance_data.det_scores
+ >>> 'det_scores' in instance_data
+ False
+ >>> det_labels = instance_data.pop('det_labels', None)
+ >>> det_labels
+ tensor([0, 1, 2, 3])
+ >>> 'det_labels' in instance_data
+ >>> False
+ """
+
+ def __init__(self, meta_info=None, data=None):
+
+ self._meta_info_fields = set()
+ self._data_fields = set()
+
+ if meta_info is not None:
+ self.set_meta_info(meta_info=meta_info)
+ if data is not None:
+ self.set_data(data)
+
+ def set_meta_info(self, meta_info):
+ """Add meta information.
+
+ Args:
+ meta_info (dict): A dict contains the meta information
+ of image. such as `img_shape`, `scale_factor`, etc.
+ Default: None.
+ """
+ assert isinstance(meta_info,
+ dict), f'meta should be a `dict` but get {meta_info}'
+ meta = copy.deepcopy(meta_info)
+ for k, v in meta.items():
+ # should be consistent with original meta_info
+ if k in self._meta_info_fields:
+ ori_value = getattr(self, k)
+ if isinstance(ori_value, (torch.Tensor, np.ndarray)):
+ if (ori_value == v).all():
+ continue
+ else:
+ raise KeyError(
+ f'img_meta_info {k} has been set as '
+ f'{getattr(self, k)} before, which is immutable ')
+ elif ori_value == v:
+ continue
+ else:
+ raise KeyError(
+ f'img_meta_info {k} has been set as '
+ f'{getattr(self, k)} before, which is immutable ')
+ else:
+ self._meta_info_fields.add(k)
+ self.__dict__[k] = v
+
+ def set_data(self, data):
+ """Update a dict to `data_fields`.
+
+ Args:
+ data (dict): A dict contains annotations of image or
+ model predictions. Default: None.
+ """
+ assert isinstance(data,
+ dict), f'meta should be a `dict` but get {data}'
+ for k, v in data.items():
+ self.__setattr__(k, v)
+
+ def new(self, meta_info=None, data=None):
+ """Return a new results with same image meta information.
+
+ Args:
+ meta_info (dict, optional): A dict contains the meta information
+ of image. such as `img_shape`, `scale_factor`, etc.
+ Default: None.
+ data (dict, optional): A dict contains annotations of image or
+ model predictions. Default: None.
+ """
+ new_data = self.__class__()
+ new_data.set_meta_info(dict(self.meta_info_items()))
+ if meta_info is not None:
+ new_data.set_meta_info(meta_info)
+ if data is not None:
+ new_data.set_data(data)
+ return new_data
+
+ def keys(self):
+ """
+ Returns:
+ list: Contains all keys in data_fields.
+ """
+ return [key for key in self._data_fields]
+
+ def meta_info_keys(self):
+ """
+ Returns:
+ list: Contains all keys in meta_info_fields.
+ """
+ return [key for key in self._meta_info_fields]
+
+ def values(self):
+ """
+ Returns:
+ list: Contains all values in data_fields.
+ """
+ return [getattr(self, k) for k in self.keys()]
+
+ def meta_info_values(self):
+ """
+ Returns:
+ list: Contains all values in meta_info_fields.
+ """
+ return [getattr(self, k) for k in self.meta_info_keys()]
+
+ def items(self):
+ for k in self.keys():
+ yield (k, getattr(self, k))
+
+ def meta_info_items(self):
+ for k in self.meta_info_keys():
+ yield (k, getattr(self, k))
+
+ def __setattr__(self, name, val):
+ if name in ('_meta_info_fields', '_data_fields'):
+ if not hasattr(self, name):
+ super().__setattr__(name, val)
+ else:
+ raise AttributeError(
+ f'{name} has been used as a '
+ f'private attribute, which is immutable. ')
+ else:
+ if name in self._meta_info_fields:
+ raise AttributeError(f'`{name}` is used in meta information,'
+ f'which is immutable')
+
+ self._data_fields.add(name)
+ super().__setattr__(name, val)
+
+ def __delattr__(self, item):
+
+ if item in ('_meta_info_fields', '_data_fields'):
+ raise AttributeError(f'{item} has been used as a '
+ f'private attribute, which is immutable. ')
+
+ if item in self._meta_info_fields:
+ raise KeyError(f'{item} is used in meta information, '
+ f'which is immutable.')
+ super().__delattr__(item)
+ if item in self._data_fields:
+ self._data_fields.remove(item)
+
+ # dict-like methods
+ __setitem__ = __setattr__
+ __delitem__ = __delattr__
+
+ def __getitem__(self, name):
+ return getattr(self, name)
+
+ def get(self, *args):
+ assert len(args) < 3, '`get` get more than 2 arguments'
+ return self.__dict__.get(*args)
+
+ def pop(self, *args):
+ assert len(args) < 3, '`pop` get more than 2 arguments'
+ name = args[0]
+ if name in self._meta_info_fields:
+ raise KeyError(f'{name} is a key in meta information, '
+ f'which is immutable')
+
+ if args[0] in self._data_fields:
+ self._data_fields.remove(args[0])
+ return self.__dict__.pop(*args)
+
+ # with default value
+ elif len(args) == 2:
+ return args[1]
+ else:
+ raise KeyError(f'{args[0]}')
+
+ def __contains__(self, item):
+ return item in self._data_fields or \
+ item in self._meta_info_fields
+
+ # Tensor-like methods
+ def to(self, *args, **kwargs):
+ """Apply same name function to all tensors in data_fields."""
+ new_data = self.new()
+ for k, v in self.items():
+ if hasattr(v, 'to'):
+ v = v.to(*args, **kwargs)
+ new_data[k] = v
+ return new_data
+
+ # Tensor-like methods
+ def cpu(self):
+ """Apply same name function to all tensors in data_fields."""
+ new_data = self.new()
+ for k, v in self.items():
+ if isinstance(v, torch.Tensor):
+ v = v.cpu()
+ new_data[k] = v
+ return new_data
+
+ # Tensor-like methods
+ def npu(self):
+ """Apply same name function to all tensors in data_fields."""
+ new_data = self.new()
+ for k, v in self.items():
+ if isinstance(v, torch.Tensor):
+ v = v.npu()
+ new_data[k] = v
+ return new_data
+
+ # Tensor-like methods
+ def mlu(self):
+ """Apply same name function to all tensors in data_fields."""
+ new_data = self.new()
+ for k, v in self.items():
+ if isinstance(v, torch.Tensor):
+ v = v.mlu()
+ new_data[k] = v
+ return new_data
+
+ # Tensor-like methods
+ def cuda(self):
+ """Apply same name function to all tensors in data_fields."""
+ new_data = self.new()
+ for k, v in self.items():
+ if isinstance(v, torch.Tensor):
+ v = v.cuda()
+ new_data[k] = v
+ return new_data
+
+ # Tensor-like methods
+ def detach(self):
+ """Apply same name function to all tensors in data_fields."""
+ new_data = self.new()
+ for k, v in self.items():
+ if isinstance(v, torch.Tensor):
+ v = v.detach()
+ new_data[k] = v
+ return new_data
+
+ # Tensor-like methods
+ def numpy(self):
+ """Apply same name function to all tensors in data_fields."""
+ new_data = self.new()
+ for k, v in self.items():
+ if isinstance(v, torch.Tensor):
+ v = v.detach().cpu().numpy()
+ new_data[k] = v
+ return new_data
+
+ def __nice__(self):
+ repr = '\n \n META INFORMATION \n'
+ for k, v in self.meta_info_items():
+ repr += f'{k}: {v} \n'
+ repr += '\n DATA FIELDS \n'
+ for k, v in self.items():
+ if isinstance(v, (torch.Tensor, np.ndarray)):
+ repr += f'shape of {k}: {v.shape} \n'
+ else:
+ repr += f'{k}: {v} \n'
+ return repr + '\n'
diff --git a/mmdet/core/data_structures/instance_data.py b/mmdet/core/data_structures/instance_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..eef2065c831541f1eea723a54c93bb551f9d7579
--- /dev/null
+++ b/mmdet/core/data_structures/instance_data.py
@@ -0,0 +1,188 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
+
+import numpy as np
+import torch
+
+from .general_data import GeneralData
+
+
+class InstanceData(GeneralData):
+ """Data structure for instance-level annnotations or predictions.
+
+ Subclass of :class:`GeneralData`. All value in `data_fields`
+ should have the same length. This design refer to
+ https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
+
+ Examples:
+ >>> from mmdet.core import InstanceData
+ >>> import numpy as np
+ >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
+ >>> results = InstanceData(img_meta)
+ >>> img_shape in results
+ True
+ >>> results.det_labels = torch.LongTensor([0, 1, 2, 3])
+ >>> results["det_scores"] = torch.Tensor([0.01, 0.7, 0.6, 0.3])
+ >>> results["det_masks"] = np.ndarray(4, 2, 2)
+ >>> len(results)
+ 4
+ >>> print(resutls)
+
+ >>> sorted_results = results[results.det_scores.sort().indices]
+ >>> sorted_results.det_scores
+ tensor([0.0100, 0.3000, 0.6000, 0.7000])
+ >>> sorted_results.det_labels
+ tensor([0, 3, 2, 1])
+ >>> print(results[results.scores > 0.5])
+
+ >>> results[results.det_scores > 0.5].det_labels
+ tensor([1, 2])
+ >>> results[results.det_scores > 0.5].det_scores
+ tensor([0.7000, 0.6000])
+ """
+
+ def __setattr__(self, name, value):
+
+ if name in ('_meta_info_fields', '_data_fields'):
+ if not hasattr(self, name):
+ super().__setattr__(name, value)
+ else:
+ raise AttributeError(
+ f'{name} has been used as a '
+ f'private attribute, which is immutable. ')
+
+ else:
+ assert isinstance(value, (torch.Tensor, np.ndarray, list)), \
+ f'Can set {type(value)}, only support' \
+ f' {(torch.Tensor, np.ndarray, list)}'
+
+ if self._data_fields:
+ assert len(value) == len(self), f'the length of ' \
+ f'values {len(value)} is ' \
+ f'not consistent with' \
+ f' the length ' \
+ f'of this :obj:`InstanceData` ' \
+ f'{len(self)} '
+ super().__setattr__(name, value)
+
+ def __getitem__(self, item):
+ """
+ Args:
+ item (str, obj:`slice`,
+ obj`torch.LongTensor`, obj:`torch.BoolTensor`):
+ get the corresponding values according to item.
+
+ Returns:
+ obj:`InstanceData`: Corresponding values.
+ """
+ assert len(self), ' This is a empty instance'
+
+ assert isinstance(
+ item, (str, slice, int, torch.LongTensor, torch.BoolTensor))
+
+ if isinstance(item, str):
+ return getattr(self, item)
+
+ if type(item) == int:
+ if item >= len(self) or item < -len(self):
+ raise IndexError(f'Index {item} out of range!')
+ else:
+ # keep the dimension
+ item = slice(item, None, len(self))
+
+ new_data = self.new()
+ if isinstance(item, (torch.Tensor)):
+ assert item.dim() == 1, 'Only support to get the' \
+ ' values along the first dimension.'
+ if isinstance(item, torch.BoolTensor):
+ assert len(item) == len(self), f'The shape of the' \
+ f' input(BoolTensor)) ' \
+ f'{len(item)} ' \
+ f' does not match the shape ' \
+ f'of the indexed tensor ' \
+ f'in results_filed ' \
+ f'{len(self)} at ' \
+ f'first dimension. '
+
+ for k, v in self.items():
+ if isinstance(v, torch.Tensor):
+ new_data[k] = v[item]
+ elif isinstance(v, np.ndarray):
+ new_data[k] = v[item.cpu().numpy()]
+ elif isinstance(v, list):
+ r_list = []
+ # convert to indexes from boolTensor
+ if isinstance(item, torch.BoolTensor):
+ indexes = torch.nonzero(item).view(-1)
+ else:
+ indexes = item
+ for index in indexes:
+ r_list.append(v[index])
+ new_data[k] = r_list
+ else:
+ # item is a slice
+ for k, v in self.items():
+ new_data[k] = v[item]
+ return new_data
+
+ @staticmethod
+ def cat(instances_list):
+ """Concat the predictions of all :obj:`InstanceData` in the list.
+
+ Args:
+ instances_list (list[:obj:`InstanceData`]): A list
+ of :obj:`InstanceData`.
+
+ Returns:
+ obj:`InstanceData`
+ """
+ assert all(
+ isinstance(results, InstanceData) for results in instances_list)
+ assert len(instances_list) > 0
+ if len(instances_list) == 1:
+ return instances_list[0]
+
+ new_data = instances_list[0].new()
+ for k in instances_list[0]._data_fields:
+ values = [results[k] for results in instances_list]
+ v0 = values[0]
+ if isinstance(v0, torch.Tensor):
+ values = torch.cat(values, dim=0)
+ elif isinstance(v0, np.ndarray):
+ values = np.concatenate(values, axis=0)
+ elif isinstance(v0, list):
+ values = list(itertools.chain(*values))
+ else:
+ raise ValueError(
+ f'Can not concat the {k} which is a {type(v0)}')
+ new_data[k] = values
+ return new_data
+
+ def __len__(self):
+ if len(self._data_fields):
+ for v in self.values():
+ return len(v)
+ else:
+ raise AssertionError('This is an empty `InstanceData`.')
diff --git a/mmdet/core/evaluation/__init__.py b/mmdet/core/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b488a71e6a10874e0df7f0de03c579e663eb830
--- /dev/null
+++ b/mmdet/core/evaluation/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .class_names import (cityscapes_classes, coco_classes, dataset_aliases,
+ get_classes, imagenet_det_classes,
+ imagenet_vid_classes, objects365v1_classes,
+ objects365v2_classes, oid_challenge_classes,
+ oid_v6_classes, voc_classes)
+from .eval_hooks import DistEvalHook, EvalHook
+from .mean_ap import average_precision, eval_map, print_map_summary
+from .panoptic_utils import INSTANCE_OFFSET
+from .recall import (eval_recalls, plot_iou_recall, plot_num_recall,
+ print_recall_summary)
+
+__all__ = [
+ 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
+ 'coco_classes', 'cityscapes_classes', 'dataset_aliases', 'get_classes',
+ 'DistEvalHook', 'EvalHook', 'average_precision', 'eval_map',
+ 'print_map_summary', 'eval_recalls', 'print_recall_summary',
+ 'plot_num_recall', 'plot_iou_recall', 'oid_v6_classes',
+ 'oid_challenge_classes', 'objects365v1_classes', 'objects365v2_classes',
+ 'INSTANCE_OFFSET'
+]
diff --git a/mmdet/core/evaluation/bbox_overlaps.py b/mmdet/core/evaluation/bbox_overlaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d6eb82fcfc8d5444dd2a13b7d95b978f8206a55
--- /dev/null
+++ b/mmdet/core/evaluation/bbox_overlaps.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+
+def bbox_overlaps(bboxes1,
+ bboxes2,
+ mode='iou',
+ eps=1e-6,
+ use_legacy_coordinate=False):
+ """Calculate the ious between each bbox of bboxes1 and bboxes2.
+
+ Args:
+ bboxes1 (ndarray): Shape (n, 4)
+ bboxes2 (ndarray): Shape (k, 4)
+ mode (str): IOU (intersection over union) or IOF (intersection
+ over foreground)
+ use_legacy_coordinate (bool): Whether to use coordinate system in
+ mmdet v1.x. which means width, height should be
+ calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
+ Note when function is used in `VOCDataset`, it should be
+ True to align with the official implementation
+ `http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar`
+ Default: False.
+
+ Returns:
+ ious (ndarray): Shape (n, k)
+ """
+
+ assert mode in ['iou', 'iof']
+ if not use_legacy_coordinate:
+ extra_length = 0.
+ else:
+ extra_length = 1.
+ bboxes1 = bboxes1.astype(np.float32)
+ bboxes2 = bboxes2.astype(np.float32)
+ rows = bboxes1.shape[0]
+ cols = bboxes2.shape[0]
+ ious = np.zeros((rows, cols), dtype=np.float32)
+ if rows * cols == 0:
+ return ious
+ exchange = False
+ if bboxes1.shape[0] > bboxes2.shape[0]:
+ bboxes1, bboxes2 = bboxes2, bboxes1
+ ious = np.zeros((cols, rows), dtype=np.float32)
+ exchange = True
+ area1 = (bboxes1[:, 2] - bboxes1[:, 0] + extra_length) * (
+ bboxes1[:, 3] - bboxes1[:, 1] + extra_length)
+ area2 = (bboxes2[:, 2] - bboxes2[:, 0] + extra_length) * (
+ bboxes2[:, 3] - bboxes2[:, 1] + extra_length)
+ for i in range(bboxes1.shape[0]):
+ x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
+ y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
+ x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
+ y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
+ overlap = np.maximum(x_end - x_start + extra_length, 0) * np.maximum(
+ y_end - y_start + extra_length, 0)
+ if mode == 'iou':
+ union = area1[i] + area2 - overlap
+ else:
+ union = area1[i] if not exchange else area2
+ union = np.maximum(union, eps)
+ ious[i, :] = overlap / union
+ if exchange:
+ ious = ious.T
+ return ious
diff --git a/mmdet/core/evaluation/class_names.py b/mmdet/core/evaluation/class_names.py
new file mode 100644
index 0000000000000000000000000000000000000000..c015c5d90f69abf8ae044438c3bbb3657792e809
--- /dev/null
+++ b/mmdet/core/evaluation/class_names.py
@@ -0,0 +1,476 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+
+
+def wider_face_classes():
+ return ['face']
+
+
+def voc_classes():
+ return [
+ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
+ 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
+ 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
+ ]
+
+
+def imagenet_det_classes():
+ return [
+ 'accordion', 'airplane', 'ant', 'antelope', 'apple', 'armadillo',
+ 'artichoke', 'axe', 'baby_bed', 'backpack', 'bagel', 'balance_beam',
+ 'banana', 'band_aid', 'banjo', 'baseball', 'basketball', 'bathing_cap',
+ 'beaker', 'bear', 'bee', 'bell_pepper', 'bench', 'bicycle', 'binder',
+ 'bird', 'bookshelf', 'bow_tie', 'bow', 'bowl', 'brassiere', 'burrito',
+ 'bus', 'butterfly', 'camel', 'can_opener', 'car', 'cart', 'cattle',
+ 'cello', 'centipede', 'chain_saw', 'chair', 'chime', 'cocktail_shaker',
+ 'coffee_maker', 'computer_keyboard', 'computer_mouse', 'corkscrew',
+ 'cream', 'croquet_ball', 'crutch', 'cucumber', 'cup_or_mug', 'diaper',
+ 'digital_clock', 'dishwasher', 'dog', 'domestic_cat', 'dragonfly',
+ 'drum', 'dumbbell', 'electric_fan', 'elephant', 'face_powder', 'fig',
+ 'filing_cabinet', 'flower_pot', 'flute', 'fox', 'french_horn', 'frog',
+ 'frying_pan', 'giant_panda', 'goldfish', 'golf_ball', 'golfcart',
+ 'guacamole', 'guitar', 'hair_dryer', 'hair_spray', 'hamburger',
+ 'hammer', 'hamster', 'harmonica', 'harp', 'hat_with_a_wide_brim',
+ 'head_cabbage', 'helmet', 'hippopotamus', 'horizontal_bar', 'horse',
+ 'hotdog', 'iPod', 'isopod', 'jellyfish', 'koala_bear', 'ladle',
+ 'ladybug', 'lamp', 'laptop', 'lemon', 'lion', 'lipstick', 'lizard',
+ 'lobster', 'maillot', 'maraca', 'microphone', 'microwave', 'milk_can',
+ 'miniskirt', 'monkey', 'motorcycle', 'mushroom', 'nail', 'neck_brace',
+ 'oboe', 'orange', 'otter', 'pencil_box', 'pencil_sharpener', 'perfume',
+ 'person', 'piano', 'pineapple', 'ping-pong_ball', 'pitcher', 'pizza',
+ 'plastic_bag', 'plate_rack', 'pomegranate', 'popsicle', 'porcupine',
+ 'power_drill', 'pretzel', 'printer', 'puck', 'punching_bag', 'purse',
+ 'rabbit', 'racket', 'ray', 'red_panda', 'refrigerator',
+ 'remote_control', 'rubber_eraser', 'rugby_ball', 'ruler',
+ 'salt_or_pepper_shaker', 'saxophone', 'scorpion', 'screwdriver',
+ 'seal', 'sheep', 'ski', 'skunk', 'snail', 'snake', 'snowmobile',
+ 'snowplow', 'soap_dispenser', 'soccer_ball', 'sofa', 'spatula',
+ 'squirrel', 'starfish', 'stethoscope', 'stove', 'strainer',
+ 'strawberry', 'stretcher', 'sunglasses', 'swimming_trunks', 'swine',
+ 'syringe', 'table', 'tape_player', 'tennis_ball', 'tick', 'tie',
+ 'tiger', 'toaster', 'traffic_light', 'train', 'trombone', 'trumpet',
+ 'turtle', 'tv_or_monitor', 'unicycle', 'vacuum', 'violin',
+ 'volleyball', 'waffle_iron', 'washer', 'water_bottle', 'watercraft',
+ 'whale', 'wine_bottle', 'zebra'
+ ]
+
+
+def imagenet_vid_classes():
+ return [
+ 'airplane', 'antelope', 'bear', 'bicycle', 'bird', 'bus', 'car',
+ 'cattle', 'dog', 'domestic_cat', 'elephant', 'fox', 'giant_panda',
+ 'hamster', 'horse', 'lion', 'lizard', 'monkey', 'motorcycle', 'rabbit',
+ 'red_panda', 'sheep', 'snake', 'squirrel', 'tiger', 'train', 'turtle',
+ 'watercraft', 'whale', 'zebra'
+ ]
+
+
+def coco_classes():
+ return [
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
+ 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
+ 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
+ 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
+ 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
+ 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
+ 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
+ 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
+ ]
+
+
+def cityscapes_classes():
+ return [
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle'
+ ]
+
+
+def oid_challenge_classes():
+ return [
+ 'Footwear', 'Jeans', 'House', 'Tree', 'Woman', 'Man', 'Land vehicle',
+ 'Person', 'Wheel', 'Bus', 'Human face', 'Bird', 'Dress', 'Girl',
+ 'Vehicle', 'Building', 'Cat', 'Car', 'Belt', 'Elephant', 'Dessert',
+ 'Butterfly', 'Train', 'Guitar', 'Poster', 'Book', 'Boy', 'Bee',
+ 'Flower', 'Window', 'Hat', 'Human head', 'Dog', 'Human arm', 'Drink',
+ 'Human mouth', 'Human hair', 'Human nose', 'Human hand', 'Table',
+ 'Marine invertebrates', 'Fish', 'Sculpture', 'Rose', 'Street light',
+ 'Glasses', 'Fountain', 'Skyscraper', 'Swimwear', 'Brassiere', 'Drum',
+ 'Duck', 'Countertop', 'Furniture', 'Ball', 'Human leg', 'Boat',
+ 'Balloon', 'Bicycle helmet', 'Goggles', 'Door', 'Human eye', 'Shirt',
+ 'Toy', 'Teddy bear', 'Pasta', 'Tomato', 'Human ear',
+ 'Vehicle registration plate', 'Microphone', 'Musical keyboard',
+ 'Tower', 'Houseplant', 'Flowerpot', 'Fruit', 'Vegetable',
+ 'Musical instrument', 'Suit', 'Motorcycle', 'Bagel', 'French fries',
+ 'Hamburger', 'Chair', 'Salt and pepper shakers', 'Snail', 'Airplane',
+ 'Horse', 'Laptop', 'Computer keyboard', 'Football helmet', 'Cocktail',
+ 'Juice', 'Tie', 'Computer monitor', 'Human beard', 'Bottle',
+ 'Saxophone', 'Lemon', 'Mouse', 'Sock', 'Cowboy hat', 'Sun hat',
+ 'Football', 'Porch', 'Sunglasses', 'Lobster', 'Crab', 'Picture frame',
+ 'Van', 'Crocodile', 'Surfboard', 'Shorts', 'Helicopter', 'Helmet',
+ 'Sports uniform', 'Taxi', 'Swan', 'Goose', 'Coat', 'Jacket', 'Handbag',
+ 'Flag', 'Skateboard', 'Television', 'Tire', 'Spoon', 'Palm tree',
+ 'Stairs', 'Salad', 'Castle', 'Oven', 'Microwave oven', 'Wine',
+ 'Ceiling fan', 'Mechanical fan', 'Cattle', 'Truck', 'Box', 'Ambulance',
+ 'Desk', 'Wine glass', 'Reptile', 'Tank', 'Traffic light', 'Billboard',
+ 'Tent', 'Insect', 'Spider', 'Treadmill', 'Cupboard', 'Shelf',
+ 'Seat belt', 'Human foot', 'Bicycle', 'Bicycle wheel', 'Couch',
+ 'Bookcase', 'Fedora', 'Backpack', 'Bench', 'Oyster',
+ 'Moths and butterflies', 'Lavender', 'Waffle', 'Fork', 'Animal',
+ 'Accordion', 'Mobile phone', 'Plate', 'Coffee cup', 'Saucer',
+ 'Platter', 'Dagger', 'Knife', 'Bull', 'Tortoise', 'Sea turtle', 'Deer',
+ 'Weapon', 'Apple', 'Ski', 'Taco', 'Traffic sign', 'Beer', 'Necklace',
+ 'Sunflower', 'Piano', 'Organ', 'Harpsichord', 'Bed', 'Cabinetry',
+ 'Nightstand', 'Curtain', 'Chest of drawers', 'Drawer', 'Parrot',
+ 'Sandal', 'High heels', 'Tableware', 'Cart', 'Mushroom', 'Kite',
+ 'Missile', 'Seafood', 'Camera', 'Paper towel', 'Toilet paper',
+ 'Sombrero', 'Radish', 'Lighthouse', 'Segway', 'Pig', 'Watercraft',
+ 'Golf cart', 'studio couch', 'Dolphin', 'Whale', 'Earrings', 'Otter',
+ 'Sea lion', 'Whiteboard', 'Monkey', 'Gondola', 'Zebra',
+ 'Baseball glove', 'Scarf', 'Adhesive tape', 'Trousers', 'Scoreboard',
+ 'Lily', 'Carnivore', 'Power plugs and sockets', 'Office building',
+ 'Sandwich', 'Swimming pool', 'Headphones', 'Tin can', 'Crown', 'Doll',
+ 'Cake', 'Frog', 'Beetle', 'Ant', 'Gas stove', 'Canoe', 'Falcon',
+ 'Blue jay', 'Egg', 'Fire hydrant', 'Raccoon', 'Muffin', 'Wall clock',
+ 'Coffee', 'Mug', 'Tea', 'Bear', 'Waste container', 'Home appliance',
+ 'Candle', 'Lion', 'Mirror', 'Starfish', 'Marine mammal', 'Wheelchair',
+ 'Umbrella', 'Alpaca', 'Violin', 'Cello', 'Brown bear', 'Canary', 'Bat',
+ 'Ruler', 'Plastic bag', 'Penguin', 'Watermelon', 'Harbor seal', 'Pen',
+ 'Pumpkin', 'Harp', 'Kitchen appliance', 'Roller skates', 'Bust',
+ 'Coffee table', 'Tennis ball', 'Tennis racket', 'Ladder', 'Boot',
+ 'Bowl', 'Stop sign', 'Volleyball', 'Eagle', 'Paddle', 'Chicken',
+ 'Skull', 'Lamp', 'Beehive', 'Maple', 'Sink', 'Goldfish', 'Tripod',
+ 'Coconut', 'Bidet', 'Tap', 'Bathroom cabinet', 'Toilet',
+ 'Filing cabinet', 'Pretzel', 'Table tennis racket', 'Bronze sculpture',
+ 'Rocket', 'Mouse', 'Hamster', 'Lizard', 'Lifejacket', 'Goat',
+ 'Washing machine', 'Trumpet', 'Horn', 'Trombone', 'Sheep',
+ 'Tablet computer', 'Pillow', 'Kitchen & dining room table',
+ 'Parachute', 'Raven', 'Glove', 'Loveseat', 'Christmas tree',
+ 'Shellfish', 'Rifle', 'Shotgun', 'Sushi', 'Sparrow', 'Bread',
+ 'Toaster', 'Watch', 'Asparagus', 'Artichoke', 'Suitcase', 'Antelope',
+ 'Broccoli', 'Ice cream', 'Racket', 'Banana', 'Cookie', 'Cucumber',
+ 'Dragonfly', 'Lynx', 'Caterpillar', 'Light bulb', 'Office supplies',
+ 'Miniskirt', 'Skirt', 'Fireplace', 'Potato', 'Light switch',
+ 'Croissant', 'Cabbage', 'Ladybug', 'Handgun', 'Luggage and bags',
+ 'Window blind', 'Snowboard', 'Baseball bat', 'Digital clock',
+ 'Serving tray', 'Infant bed', 'Sofa bed', 'Guacamole', 'Fox', 'Pizza',
+ 'Snowplow', 'Jet ski', 'Refrigerator', 'Lantern', 'Convenience store',
+ 'Sword', 'Rugby ball', 'Owl', 'Ostrich', 'Pancake', 'Strawberry',
+ 'Carrot', 'Tart', 'Dice', 'Turkey', 'Rabbit', 'Invertebrate', 'Vase',
+ 'Stool', 'Swim cap', 'Shower', 'Clock', 'Jellyfish', 'Aircraft',
+ 'Chopsticks', 'Orange', 'Snake', 'Sewing machine', 'Kangaroo', 'Mixer',
+ 'Food processor', 'Shrimp', 'Towel', 'Porcupine', 'Jaguar', 'Cannon',
+ 'Limousine', 'Mule', 'Squirrel', 'Kitchen knife', 'Tiara', 'Tiger',
+ 'Bow and arrow', 'Candy', 'Rhinoceros', 'Shark', 'Cricket ball',
+ 'Doughnut', 'Plumbing fixture', 'Camel', 'Polar bear', 'Coin',
+ 'Printer', 'Blender', 'Giraffe', 'Billiard table', 'Kettle',
+ 'Dinosaur', 'Pineapple', 'Zucchini', 'Jug', 'Barge', 'Teapot',
+ 'Golf ball', 'Binoculars', 'Scissors', 'Hot dog', 'Door handle',
+ 'Seahorse', 'Bathtub', 'Leopard', 'Centipede', 'Grapefruit', 'Snowman',
+ 'Cheetah', 'Alarm clock', 'Grape', 'Wrench', 'Wok', 'Bell pepper',
+ 'Cake stand', 'Barrel', 'Woodpecker', 'Flute', 'Corded phone',
+ 'Willow', 'Punching bag', 'Pomegranate', 'Telephone', 'Pear',
+ 'Common fig', 'Bench', 'Wood-burning stove', 'Burrito', 'Nail',
+ 'Turtle', 'Submarine sandwich', 'Drinking straw', 'Peach', 'Popcorn',
+ 'Frying pan', 'Picnic basket', 'Honeycomb', 'Envelope', 'Mango',
+ 'Cutting board', 'Pitcher', 'Stationary bicycle', 'Dumbbell',
+ 'Personal care', 'Dog bed', 'Snowmobile', 'Oboe', 'Briefcase',
+ 'Squash', 'Tick', 'Slow cooker', 'Coffeemaker', 'Measuring cup',
+ 'Crutch', 'Stretcher', 'Screwdriver', 'Flashlight', 'Spatula',
+ 'Pressure cooker', 'Ring binder', 'Beaker', 'Torch', 'Winter melon'
+ ]
+
+
+def oid_v6_classes():
+ return [
+ 'Tortoise', 'Container', 'Magpie', 'Sea turtle', 'Football',
+ 'Ambulance', 'Ladder', 'Toothbrush', 'Syringe', 'Sink', 'Toy',
+ 'Organ (Musical Instrument)', 'Cassette deck', 'Apple', 'Human eye',
+ 'Cosmetics', 'Paddle', 'Snowman', 'Beer', 'Chopsticks', 'Human beard',
+ 'Bird', 'Parking meter', 'Traffic light', 'Croissant', 'Cucumber',
+ 'Radish', 'Towel', 'Doll', 'Skull', 'Washing machine', 'Glove', 'Tick',
+ 'Belt', 'Sunglasses', 'Banjo', 'Cart', 'Ball', 'Backpack', 'Bicycle',
+ 'Home appliance', 'Centipede', 'Boat', 'Surfboard', 'Boot',
+ 'Headphones', 'Hot dog', 'Shorts', 'Fast food', 'Bus', 'Boy',
+ 'Screwdriver', 'Bicycle wheel', 'Barge', 'Laptop', 'Miniskirt',
+ 'Drill (Tool)', 'Dress', 'Bear', 'Waffle', 'Pancake', 'Brown bear',
+ 'Woodpecker', 'Blue jay', 'Pretzel', 'Bagel', 'Tower', 'Teapot',
+ 'Person', 'Bow and arrow', 'Swimwear', 'Beehive', 'Brassiere', 'Bee',
+ 'Bat (Animal)', 'Starfish', 'Popcorn', 'Burrito', 'Chainsaw',
+ 'Balloon', 'Wrench', 'Tent', 'Vehicle registration plate', 'Lantern',
+ 'Toaster', 'Flashlight', 'Billboard', 'Tiara', 'Limousine', 'Necklace',
+ 'Carnivore', 'Scissors', 'Stairs', 'Computer keyboard', 'Printer',
+ 'Traffic sign', 'Chair', 'Shirt', 'Poster', 'Cheese', 'Sock',
+ 'Fire hydrant', 'Land vehicle', 'Earrings', 'Tie', 'Watercraft',
+ 'Cabinetry', 'Suitcase', 'Muffin', 'Bidet', 'Snack', 'Snowmobile',
+ 'Clock', 'Medical equipment', 'Cattle', 'Cello', 'Jet ski', 'Camel',
+ 'Coat', 'Suit', 'Desk', 'Cat', 'Bronze sculpture', 'Juice', 'Gondola',
+ 'Beetle', 'Cannon', 'Computer mouse', 'Cookie', 'Office building',
+ 'Fountain', 'Coin', 'Calculator', 'Cocktail', 'Computer monitor',
+ 'Box', 'Stapler', 'Christmas tree', 'Cowboy hat', 'Hiking equipment',
+ 'Studio couch', 'Drum', 'Dessert', 'Wine rack', 'Drink', 'Zucchini',
+ 'Ladle', 'Human mouth', 'Dairy Product', 'Dice', 'Oven', 'Dinosaur',
+ 'Ratchet (Device)', 'Couch', 'Cricket ball', 'Winter melon', 'Spatula',
+ 'Whiteboard', 'Pencil sharpener', 'Door', 'Hat', 'Shower', 'Eraser',
+ 'Fedora', 'Guacamole', 'Dagger', 'Scarf', 'Dolphin', 'Sombrero',
+ 'Tin can', 'Mug', 'Tap', 'Harbor seal', 'Stretcher', 'Can opener',
+ 'Goggles', 'Human body', 'Roller skates', 'Coffee cup',
+ 'Cutting board', 'Blender', 'Plumbing fixture', 'Stop sign',
+ 'Office supplies', 'Volleyball (Ball)', 'Vase', 'Slow cooker',
+ 'Wardrobe', 'Coffee', 'Whisk', 'Paper towel', 'Personal care', 'Food',
+ 'Sun hat', 'Tree house', 'Flying disc', 'Skirt', 'Gas stove',
+ 'Salt and pepper shakers', 'Mechanical fan', 'Face powder', 'Fax',
+ 'Fruit', 'French fries', 'Nightstand', 'Barrel', 'Kite', 'Tart',
+ 'Treadmill', 'Fox', 'Flag', 'French horn', 'Window blind',
+ 'Human foot', 'Golf cart', 'Jacket', 'Egg (Food)', 'Street light',
+ 'Guitar', 'Pillow', 'Human leg', 'Isopod', 'Grape', 'Human ear',
+ 'Power plugs and sockets', 'Panda', 'Giraffe', 'Woman', 'Door handle',
+ 'Rhinoceros', 'Bathtub', 'Goldfish', 'Houseplant', 'Goat',
+ 'Baseball bat', 'Baseball glove', 'Mixing bowl',
+ 'Marine invertebrates', 'Kitchen utensil', 'Light switch', 'House',
+ 'Horse', 'Stationary bicycle', 'Hammer', 'Ceiling fan', 'Sofa bed',
+ 'Adhesive tape', 'Harp', 'Sandal', 'Bicycle helmet', 'Saucer',
+ 'Harpsichord', 'Human hair', 'Heater', 'Harmonica', 'Hamster',
+ 'Curtain', 'Bed', 'Kettle', 'Fireplace', 'Scale', 'Drinking straw',
+ 'Insect', 'Hair dryer', 'Kitchenware', 'Indoor rower', 'Invertebrate',
+ 'Food processor', 'Bookcase', 'Refrigerator', 'Wood-burning stove',
+ 'Punching bag', 'Common fig', 'Cocktail shaker', 'Jaguar (Animal)',
+ 'Golf ball', 'Fashion accessory', 'Alarm clock', 'Filing cabinet',
+ 'Artichoke', 'Table', 'Tableware', 'Kangaroo', 'Koala', 'Knife',
+ 'Bottle', 'Bottle opener', 'Lynx', 'Lavender (Plant)', 'Lighthouse',
+ 'Dumbbell', 'Human head', 'Bowl', 'Humidifier', 'Porch', 'Lizard',
+ 'Billiard table', 'Mammal', 'Mouse', 'Motorcycle',
+ 'Musical instrument', 'Swim cap', 'Frying pan', 'Snowplow',
+ 'Bathroom cabinet', 'Missile', 'Bust', 'Man', 'Waffle iron', 'Milk',
+ 'Ring binder', 'Plate', 'Mobile phone', 'Baked goods', 'Mushroom',
+ 'Crutch', 'Pitcher (Container)', 'Mirror', 'Personal flotation device',
+ 'Table tennis racket', 'Pencil case', 'Musical keyboard', 'Scoreboard',
+ 'Briefcase', 'Kitchen knife', 'Nail (Construction)', 'Tennis ball',
+ 'Plastic bag', 'Oboe', 'Chest of drawers', 'Ostrich', 'Piano', 'Girl',
+ 'Plant', 'Potato', 'Hair spray', 'Sports equipment', 'Pasta',
+ 'Penguin', 'Pumpkin', 'Pear', 'Infant bed', 'Polar bear', 'Mixer',
+ 'Cupboard', 'Jacuzzi', 'Pizza', 'Digital clock', 'Pig', 'Reptile',
+ 'Rifle', 'Lipstick', 'Skateboard', 'Raven', 'High heels', 'Red panda',
+ 'Rose', 'Rabbit', 'Sculpture', 'Saxophone', 'Shotgun', 'Seafood',
+ 'Submarine sandwich', 'Snowboard', 'Sword', 'Picture frame', 'Sushi',
+ 'Loveseat', 'Ski', 'Squirrel', 'Tripod', 'Stethoscope', 'Submarine',
+ 'Scorpion', 'Segway', 'Training bench', 'Snake', 'Coffee table',
+ 'Skyscraper', 'Sheep', 'Television', 'Trombone', 'Tea', 'Tank', 'Taco',
+ 'Telephone', 'Torch', 'Tiger', 'Strawberry', 'Trumpet', 'Tree',
+ 'Tomato', 'Train', 'Tool', 'Picnic basket', 'Cooking spray',
+ 'Trousers', 'Bowling equipment', 'Football helmet', 'Truck',
+ 'Measuring cup', 'Coffeemaker', 'Violin', 'Vehicle', 'Handbag',
+ 'Paper cutter', 'Wine', 'Weapon', 'Wheel', 'Worm', 'Wok', 'Whale',
+ 'Zebra', 'Auto part', 'Jug', 'Pizza cutter', 'Cream', 'Monkey', 'Lion',
+ 'Bread', 'Platter', 'Chicken', 'Eagle', 'Helicopter', 'Owl', 'Duck',
+ 'Turtle', 'Hippopotamus', 'Crocodile', 'Toilet', 'Toilet paper',
+ 'Squid', 'Clothing', 'Footwear', 'Lemon', 'Spider', 'Deer', 'Frog',
+ 'Banana', 'Rocket', 'Wine glass', 'Countertop', 'Tablet computer',
+ 'Waste container', 'Swimming pool', 'Dog', 'Book', 'Elephant', 'Shark',
+ 'Candle', 'Leopard', 'Axe', 'Hand dryer', 'Soap dispenser',
+ 'Porcupine', 'Flower', 'Canary', 'Cheetah', 'Palm tree', 'Hamburger',
+ 'Maple', 'Building', 'Fish', 'Lobster', 'Garden Asparagus',
+ 'Furniture', 'Hedgehog', 'Airplane', 'Spoon', 'Otter', 'Bull',
+ 'Oyster', 'Horizontal bar', 'Convenience store', 'Bomb', 'Bench',
+ 'Ice cream', 'Caterpillar', 'Butterfly', 'Parachute', 'Orange',
+ 'Antelope', 'Beaker', 'Moths and butterflies', 'Window', 'Closet',
+ 'Castle', 'Jellyfish', 'Goose', 'Mule', 'Swan', 'Peach', 'Coconut',
+ 'Seat belt', 'Raccoon', 'Chisel', 'Fork', 'Lamp', 'Camera',
+ 'Squash (Plant)', 'Racket', 'Human face', 'Human arm', 'Vegetable',
+ 'Diaper', 'Unicycle', 'Falcon', 'Chime', 'Snail', 'Shellfish',
+ 'Cabbage', 'Carrot', 'Mango', 'Jeans', 'Flowerpot', 'Pineapple',
+ 'Drawer', 'Stool', 'Envelope', 'Cake', 'Dragonfly', 'Common sunflower',
+ 'Microwave oven', 'Honeycomb', 'Marine mammal', 'Sea lion', 'Ladybug',
+ 'Shelf', 'Watch', 'Candy', 'Salad', 'Parrot', 'Handgun', 'Sparrow',
+ 'Van', 'Grinder', 'Spice rack', 'Light bulb', 'Corded phone',
+ 'Sports uniform', 'Tennis racket', 'Wall clock', 'Serving tray',
+ 'Kitchen & dining room table', 'Dog bed', 'Cake stand',
+ 'Cat furniture', 'Bathroom accessory', 'Facial tissue holder',
+ 'Pressure cooker', 'Kitchen appliance', 'Tire', 'Ruler',
+ 'Luggage and bags', 'Microphone', 'Broccoli', 'Umbrella', 'Pastry',
+ 'Grapefruit', 'Band-aid', 'Animal', 'Bell pepper', 'Turkey', 'Lily',
+ 'Pomegranate', 'Doughnut', 'Glasses', 'Human nose', 'Pen', 'Ant',
+ 'Car', 'Aircraft', 'Human hand', 'Skunk', 'Teddy bear', 'Watermelon',
+ 'Cantaloupe', 'Dishwasher', 'Flute', 'Balance beam', 'Sandwich',
+ 'Shrimp', 'Sewing machine', 'Binoculars', 'Rays and skates', 'Ipod',
+ 'Accordion', 'Willow', 'Crab', 'Crown', 'Seahorse', 'Perfume',
+ 'Alpaca', 'Taxi', 'Canoe', 'Remote control', 'Wheelchair',
+ 'Rugby ball', 'Armadillo', 'Maracas', 'Helmet'
+ ]
+
+
+def objects365v1_classes():
+ return [
+ 'person', 'sneakers', 'chair', 'hat', 'lamp', 'bottle',
+ 'cabinet/shelf', 'cup', 'car', 'glasses', 'picture/frame', 'desk',
+ 'handbag', 'street lights', 'book', 'plate', 'helmet', 'leather shoes',
+ 'pillow', 'glove', 'potted plant', 'bracelet', 'flower', 'tv',
+ 'storage box', 'vase', 'bench', 'wine glass', 'boots', 'bowl',
+ 'dining table', 'umbrella', 'boat', 'flag', 'speaker', 'trash bin/can',
+ 'stool', 'backpack', 'couch', 'belt', 'carpet', 'basket',
+ 'towel/napkin', 'slippers', 'barrel/bucket', 'coffee table', 'suv',
+ 'toy', 'tie', 'bed', 'traffic light', 'pen/pencil', 'microphone',
+ 'sandals', 'canned', 'necklace', 'mirror', 'faucet', 'bicycle',
+ 'bread', 'high heels', 'ring', 'van', 'watch', 'sink', 'horse', 'fish',
+ 'apple', 'camera', 'candle', 'teddy bear', 'cake', 'motorcycle',
+ 'wild bird', 'laptop', 'knife', 'traffic sign', 'cell phone', 'paddle',
+ 'truck', 'cow', 'power outlet', 'clock', 'drum', 'fork', 'bus',
+ 'hanger', 'nightstand', 'pot/pan', 'sheep', 'guitar', 'traffic cone',
+ 'tea pot', 'keyboard', 'tripod', 'hockey', 'fan', 'dog', 'spoon',
+ 'blackboard/whiteboard', 'balloon', 'air conditioner', 'cymbal',
+ 'mouse', 'telephone', 'pickup truck', 'orange', 'banana', 'airplane',
+ 'luggage', 'skis', 'soccer', 'trolley', 'oven', 'remote',
+ 'baseball glove', 'paper towel', 'refrigerator', 'train', 'tomato',
+ 'machinery vehicle', 'tent', 'shampoo/shower gel', 'head phone',
+ 'lantern', 'donut', 'cleaning products', 'sailboat', 'tangerine',
+ 'pizza', 'kite', 'computer box', 'elephant', 'toiletries', 'gas stove',
+ 'broccoli', 'toilet', 'stroller', 'shovel', 'baseball bat',
+ 'microwave', 'skateboard', 'surfboard', 'surveillance camera', 'gun',
+ 'life saver', 'cat', 'lemon', 'liquid soap', 'zebra', 'duck',
+ 'sports car', 'giraffe', 'pumpkin', 'piano', 'stop sign', 'radiator',
+ 'converter', 'tissue ', 'carrot', 'washing machine', 'vent', 'cookies',
+ 'cutting/chopping board', 'tennis racket', 'candy',
+ 'skating and skiing shoes', 'scissors', 'folder', 'baseball',
+ 'strawberry', 'bow tie', 'pigeon', 'pepper', 'coffee machine',
+ 'bathtub', 'snowboard', 'suitcase', 'grapes', 'ladder', 'pear',
+ 'american football', 'basketball', 'potato', 'paint brush', 'printer',
+ 'billiards', 'fire hydrant', 'goose', 'projector', 'sausage',
+ 'fire extinguisher', 'extension cord', 'facial mask', 'tennis ball',
+ 'chopsticks', 'electronic stove and gas stove', 'pie', 'frisbee',
+ 'kettle', 'hamburger', 'golf club', 'cucumber', 'clutch', 'blender',
+ 'tong', 'slide', 'hot dog', 'toothbrush', 'facial cleanser', 'mango',
+ 'deer', 'egg', 'violin', 'marker', 'ship', 'chicken', 'onion',
+ 'ice cream', 'tape', 'wheelchair', 'plum', 'bar soap', 'scale',
+ 'watermelon', 'cabbage', 'router/modem', 'golf ball', 'pine apple',
+ 'crane', 'fire truck', 'peach', 'cello', 'notepaper', 'tricycle',
+ 'toaster', 'helicopter', 'green beans', 'brush', 'carriage', 'cigar',
+ 'earphone', 'penguin', 'hurdle', 'swing', 'radio', 'CD',
+ 'parking meter', 'swan', 'garlic', 'french fries', 'horn', 'avocado',
+ 'saxophone', 'trumpet', 'sandwich', 'cue', 'kiwi fruit', 'bear',
+ 'fishing rod', 'cherry', 'tablet', 'green vegetables', 'nuts', 'corn',
+ 'key', 'screwdriver', 'globe', 'broom', 'pliers', 'volleyball',
+ 'hammer', 'eggplant', 'trophy', 'dates', 'board eraser', 'rice',
+ 'tape measure/ruler', 'dumbbell', 'hamimelon', 'stapler', 'camel',
+ 'lettuce', 'goldfish', 'meat balls', 'medal', 'toothpaste', 'antelope',
+ 'shrimp', 'rickshaw', 'trombone', 'pomegranate', 'coconut',
+ 'jellyfish', 'mushroom', 'calculator', 'treadmill', 'butterfly',
+ 'egg tart', 'cheese', 'pig', 'pomelo', 'race car', 'rice cooker',
+ 'tuba', 'crosswalk sign', 'papaya', 'hair drier', 'green onion',
+ 'chips', 'dolphin', 'sushi', 'urinal', 'donkey', 'electric drill',
+ 'spring rolls', 'tortoise/turtle', 'parrot', 'flute', 'measuring cup',
+ 'shark', 'steak', 'poker card', 'binoculars', 'llama', 'radish',
+ 'noodles', 'yak', 'mop', 'crab', 'microscope', 'barbell', 'bread/bun',
+ 'baozi', 'lion', 'red cabbage', 'polar bear', 'lighter', 'seal',
+ 'mangosteen', 'comb', 'eraser', 'pitaya', 'scallop', 'pencil case',
+ 'saw', 'table tennis paddle', 'okra', 'starfish', 'eagle', 'monkey',
+ 'durian', 'game board', 'rabbit', 'french horn', 'ambulance',
+ 'asparagus', 'hoverboard', 'pasta', 'target', 'hotair balloon',
+ 'chainsaw', 'lobster', 'iron', 'flashlight'
+ ]
+
+
+def objects365v2_classes():
+ return [
+ 'Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp',
+ 'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf',
+ 'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet',
+ 'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower',
+ 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots',
+ 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt',
+ 'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker',
+ 'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool',
+ 'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum',
+ 'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle', 'Guitar',
+ 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', 'Truck',
+ 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', 'Stuffed Toy',
+ 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed', 'Faucet', 'Tent',
+ 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple', 'Air Conditioner',
+ 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck', 'Fork',
+ 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon', 'Clock', 'Pot',
+ 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger',
+ 'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine',
+ 'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle',
+ 'Fan', 'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane',
+ 'Mouse', 'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage',
+ 'Nightstand', 'Tea pot', 'Telephone', 'Trolley', 'Head Phone',
+ 'Sports Car', 'Stop Sign', 'Dessert', 'Scooter', 'Stroller', 'Crane',
+ 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck', 'Baseball Bat',
+ 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli', 'Piano', 'Pizza',
+ 'Elephant', 'Skateboard', 'Surfboard', 'Gun',
+ 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot',
+ 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper',
+ 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks',
+ 'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board',
+ 'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder',
+ 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball',
+ 'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin',
+ 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck', 'Billards',
+ 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club', 'Briefcase',
+ 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear', 'Heavy Truck',
+ 'Hamburger', 'Extractor', 'Extention Cord', 'Tong', 'Tennis Racket',
+ 'Folder', 'American Football', 'earphone', 'Mask', 'Kettle', 'Tennis',
+ 'Ship', 'Swing', 'Coffee Machine', 'Slide', 'Carriage', 'Onion',
+ 'Green beans', 'Projector', 'Frisbee',
+ 'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon',
+ 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon',
+ 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog',
+ 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer',
+ 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple',
+ 'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle',
+ 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone',
+ 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion',
+ 'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom',
+ 'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit',
+ 'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese',
+ 'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue',
+ 'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap',
+ 'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut',
+ 'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak',
+ 'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate',
+ 'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker', 'Tuba',
+ 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal', 'Buttefly',
+ 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', 'Electric Drill',
+ 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill', 'Lighter',
+ 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi', 'Target',
+ 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case', 'Yak',
+ 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop',
+ 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle',
+ 'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster',
+ 'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling',
+ 'Table Tennis '
+ ]
+
+
+dataset_aliases = {
+ 'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'],
+ 'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'],
+ 'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'],
+ 'coco': ['coco', 'mscoco', 'ms_coco'],
+ 'wider_face': ['WIDERFaceDataset', 'wider_face', 'WIDERFace'],
+ 'cityscapes': ['cityscapes'],
+ 'oid_challenge': ['oid_challenge', 'openimages_challenge'],
+ 'oid_v6': ['oid_v6', 'openimages_v6'],
+ 'objects365v1': ['objects365v1', 'obj365v1'],
+ 'objects365v2': ['objects365v2', 'obj365v2']
+}
+
+
+def get_classes(dataset):
+ """Get class names of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_classes()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..98856c18ce65625fa1ac68beee3a1ea584ffec9d
--- /dev/null
+++ b/mmdet/core/evaluation/eval_hooks.py
@@ -0,0 +1,140 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import bisect
+import os.path as osp
+
+import mmcv
+import torch.distributed as dist
+from mmcv.runner import DistEvalHook as BaseDistEvalHook
+from mmcv.runner import EvalHook as BaseEvalHook
+from torch.nn.modules.batchnorm import _BatchNorm
+
+
+def _calc_dynamic_intervals(start_interval, dynamic_interval_list):
+ assert mmcv.is_list_of(dynamic_interval_list, tuple)
+
+ dynamic_milestones = [0]
+ dynamic_milestones.extend(
+ [dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
+ dynamic_intervals = [start_interval]
+ dynamic_intervals.extend(
+ [dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
+ return dynamic_milestones, dynamic_intervals
+
+
+class EvalHook(BaseEvalHook):
+
+ def __init__(self, *args, dynamic_intervals=None, **kwargs):
+ super(EvalHook, self).__init__(*args, **kwargs)
+ self.latest_results = None
+
+ self.use_dynamic_intervals = dynamic_intervals is not None
+ if self.use_dynamic_intervals:
+ self.dynamic_milestones, self.dynamic_intervals = \
+ _calc_dynamic_intervals(self.interval, dynamic_intervals)
+
+ def _decide_interval(self, runner):
+ if self.use_dynamic_intervals:
+ progress = runner.epoch if self.by_epoch else runner.iter
+ step = bisect.bisect(self.dynamic_milestones, (progress + 1))
+ # Dynamically modify the evaluation interval
+ self.interval = self.dynamic_intervals[step - 1]
+
+ def before_train_epoch(self, runner):
+ """Evaluate the model only at the start of training by epoch."""
+ self._decide_interval(runner)
+ super().before_train_epoch(runner)
+
+ def before_train_iter(self, runner):
+ self._decide_interval(runner)
+ super().before_train_iter(runner)
+
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ if not self._should_evaluate(runner):
+ return
+
+ from mmdet.apis import single_gpu_test
+
+ # Changed results to self.results so that MMDetWandbHook can access
+ # the evaluation results and log them to wandb.
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
+ self.latest_results = results
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+ # the key_score may be `None` so it needs to skip the action to save
+ # the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
+
+
+# Note: Considering that MMCV's EvalHook updated its interface in V1.3.16,
+# in order to avoid strong version dependency, we did not directly
+# inherit EvalHook but BaseDistEvalHook.
+class DistEvalHook(BaseDistEvalHook):
+
+ def __init__(self, *args, dynamic_intervals=None, **kwargs):
+ super(DistEvalHook, self).__init__(*args, **kwargs)
+ self.latest_results = None
+
+ self.use_dynamic_intervals = dynamic_intervals is not None
+ if self.use_dynamic_intervals:
+ self.dynamic_milestones, self.dynamic_intervals = \
+ _calc_dynamic_intervals(self.interval, dynamic_intervals)
+
+ def _decide_interval(self, runner):
+ if self.use_dynamic_intervals:
+ progress = runner.epoch if self.by_epoch else runner.iter
+ step = bisect.bisect(self.dynamic_milestones, (progress + 1))
+ # Dynamically modify the evaluation interval
+ self.interval = self.dynamic_intervals[step - 1]
+
+ def before_train_epoch(self, runner):
+ """Evaluate the model only at the start of training by epoch."""
+ self._decide_interval(runner)
+ super().before_train_epoch(runner)
+
+ def before_train_iter(self, runner):
+ self._decide_interval(runner)
+ super().before_train_iter(runner)
+
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ # Synchronization of BatchNorm's buffer (running_mean
+ # and running_var) is not supported in the DDP of pytorch,
+ # which may cause the inconsistent performance of models in
+ # different ranks, so we broadcast BatchNorm's buffers
+ # of rank 0 to other ranks to avoid this.
+ if self.broadcast_bn_buffer:
+ model = runner.model
+ for name, module in model.named_modules():
+ if isinstance(module,
+ _BatchNorm) and module.track_running_stats:
+ dist.broadcast(module.running_var, 0)
+ dist.broadcast(module.running_mean, 0)
+
+ if not self._should_evaluate(runner):
+ return
+
+ tmpdir = self.tmpdir
+ if tmpdir is None:
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
+
+ from mmdet.apis import multi_gpu_test
+
+ # Changed results to self.results so that MMDetWandbHook can access
+ # the evaluation results and log them to wandb.
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=tmpdir,
+ gpu_collect=self.gpu_collect)
+ self.latest_results = results
+ if runner.rank == 0:
+ print('\n')
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+
+ # the key_score may be `None` so it needs to skip
+ # the action to save the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
diff --git a/mmdet/core/evaluation/mean_ap.py b/mmdet/core/evaluation/mean_ap.py
new file mode 100644
index 0000000000000000000000000000000000000000..95689129230cab5ecc892d8c08a1804f82b77490
--- /dev/null
+++ b/mmdet/core/evaluation/mean_ap.py
@@ -0,0 +1,782 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from multiprocessing import Pool
+
+import mmcv
+import numpy as np
+from mmcv.utils import print_log
+from terminaltables import AsciiTable
+
+from .bbox_overlaps import bbox_overlaps
+from .class_names import get_classes
+
+
+def average_precision(recalls, precisions, mode='area'):
+ """Calculate average precision (for single or multiple scales).
+
+ Args:
+ recalls (ndarray): shape (num_scales, num_dets) or (num_dets, )
+ precisions (ndarray): shape (num_scales, num_dets) or (num_dets, )
+ mode (str): 'area' or '11points', 'area' means calculating the area
+ under precision-recall curve, '11points' means calculating
+ the average precision of recalls at [0, 0.1, ..., 1]
+
+ Returns:
+ float or ndarray: calculated average precision
+ """
+ no_scale = False
+ if recalls.ndim == 1:
+ no_scale = True
+ recalls = recalls[np.newaxis, :]
+ precisions = precisions[np.newaxis, :]
+ assert recalls.shape == precisions.shape and recalls.ndim == 2
+ num_scales = recalls.shape[0]
+ ap = np.zeros(num_scales, dtype=np.float32)
+ if mode == 'area':
+ zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
+ ones = np.ones((num_scales, 1), dtype=recalls.dtype)
+ mrec = np.hstack((zeros, recalls, ones))
+ mpre = np.hstack((zeros, precisions, zeros))
+ for i in range(mpre.shape[1] - 1, 0, -1):
+ mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
+ for i in range(num_scales):
+ ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0]
+ ap[i] = np.sum(
+ (mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1])
+ elif mode == '11points':
+ for i in range(num_scales):
+ for thr in np.arange(0, 1 + 1e-3, 0.1):
+ precs = precisions[i, recalls[i, :] >= thr]
+ prec = precs.max() if precs.size > 0 else 0
+ ap[i] += prec
+ ap /= 11
+ else:
+ raise ValueError(
+ 'Unrecognized mode, only "area" and "11points" are supported')
+ if no_scale:
+ ap = ap[0]
+ return ap
+
+
+def tpfp_imagenet(det_bboxes,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ default_iou_thr=0.5,
+ area_ranges=None,
+ use_legacy_coordinate=False,
+ **kwargs):
+ """Check if detected bboxes are true positive or false positive.
+
+ Args:
+ det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
+ gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
+ gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
+ of shape (k, 4). Default: None
+ default_iou_thr (float): IoU threshold to be considered as matched for
+ medium and large bboxes (small ones have special rules).
+ Default: 0.5.
+ area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
+ in the format [(min1, max1), (min2, max2), ...]. Default: None.
+ use_legacy_coordinate (bool): Whether to use coordinate system in
+ mmdet v1.x. which means width, height should be
+ calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
+ Default: False.
+
+ Returns:
+ tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
+ each array is (num_scales, m).
+ """
+
+ if not use_legacy_coordinate:
+ extra_length = 0.
+ else:
+ extra_length = 1.
+
+ # an indicator of ignored gts
+ gt_ignore_inds = np.concatenate(
+ (np.zeros(gt_bboxes.shape[0],
+ dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
+ # stack gt_bboxes and gt_bboxes_ignore for convenience
+ gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
+
+ num_dets = det_bboxes.shape[0]
+ num_gts = gt_bboxes.shape[0]
+ if area_ranges is None:
+ area_ranges = [(None, None)]
+ num_scales = len(area_ranges)
+ # tp and fp are of shape (num_scales, num_gts), each row is tp or fp
+ # of a certain scale.
+ tp = np.zeros((num_scales, num_dets), dtype=np.float32)
+ fp = np.zeros((num_scales, num_dets), dtype=np.float32)
+ if gt_bboxes.shape[0] == 0:
+ if area_ranges == [(None, None)]:
+ fp[...] = 1
+ else:
+ det_areas = (
+ det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
+ det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
+ for i, (min_area, max_area) in enumerate(area_ranges):
+ fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
+ return tp, fp
+ ious = bbox_overlaps(
+ det_bboxes, gt_bboxes - 1, use_legacy_coordinate=use_legacy_coordinate)
+ gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length
+ gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length
+ iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)),
+ default_iou_thr)
+ # sort all detections by scores in descending order
+ sort_inds = np.argsort(-det_bboxes[:, -1])
+ for k, (min_area, max_area) in enumerate(area_ranges):
+ gt_covered = np.zeros(num_gts, dtype=bool)
+ # if no area range is specified, gt_area_ignore is all False
+ if min_area is None:
+ gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
+ else:
+ gt_areas = gt_w * gt_h
+ gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
+ for i in sort_inds:
+ max_iou = -1
+ matched_gt = -1
+ # find best overlapped available gt
+ for j in range(num_gts):
+ # different from PASCAL VOC: allow finding other gts if the
+ # best overlapped ones are already matched by other det bboxes
+ if gt_covered[j]:
+ continue
+ elif ious[i, j] >= iou_thrs[j] and ious[i, j] > max_iou:
+ max_iou = ious[i, j]
+ matched_gt = j
+ # there are 4 cases for a det bbox:
+ # 1. it matches a gt, tp = 1, fp = 0
+ # 2. it matches an ignored gt, tp = 0, fp = 0
+ # 3. it matches no gt and within area range, tp = 0, fp = 1
+ # 4. it matches no gt but is beyond area range, tp = 0, fp = 0
+ if matched_gt >= 0:
+ gt_covered[matched_gt] = 1
+ if not (gt_ignore_inds[matched_gt]
+ or gt_area_ignore[matched_gt]):
+ tp[k, i] = 1
+ elif min_area is None:
+ fp[k, i] = 1
+ else:
+ bbox = det_bboxes[i, :4]
+ area = (bbox[2] - bbox[0] + extra_length) * (
+ bbox[3] - bbox[1] + extra_length)
+ if area >= min_area and area < max_area:
+ fp[k, i] = 1
+ return tp, fp
+
+
+def tpfp_default(det_bboxes,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ iou_thr=0.5,
+ area_ranges=None,
+ use_legacy_coordinate=False,
+ **kwargs):
+ """Check if detected bboxes are true positive or false positive.
+
+ Args:
+ det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
+ gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
+ gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
+ of shape (k, 4). Default: None
+ iou_thr (float): IoU threshold to be considered as matched.
+ Default: 0.5.
+ area_ranges (list[tuple] | None): Range of bbox areas to be
+ evaluated, in the format [(min1, max1), (min2, max2), ...].
+ Default: None.
+ use_legacy_coordinate (bool): Whether to use coordinate system in
+ mmdet v1.x. which means width, height should be
+ calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
+ Default: False.
+
+ Returns:
+ tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
+ each array is (num_scales, m).
+ """
+
+ if not use_legacy_coordinate:
+ extra_length = 0.
+ else:
+ extra_length = 1.
+
+ # an indicator of ignored gts
+ gt_ignore_inds = np.concatenate(
+ (np.zeros(gt_bboxes.shape[0],
+ dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
+ # stack gt_bboxes and gt_bboxes_ignore for convenience
+ gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
+
+ num_dets = det_bboxes.shape[0]
+ num_gts = gt_bboxes.shape[0]
+ if area_ranges is None:
+ area_ranges = [(None, None)]
+ num_scales = len(area_ranges)
+ # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
+ # a certain scale
+ tp = np.zeros((num_scales, num_dets), dtype=np.float32)
+ fp = np.zeros((num_scales, num_dets), dtype=np.float32)
+
+ # if there is no gt bboxes in this image, then all det bboxes
+ # within area range are false positives
+ if gt_bboxes.shape[0] == 0:
+ if area_ranges == [(None, None)]:
+ fp[...] = 1
+ else:
+ det_areas = (
+ det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
+ det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
+ for i, (min_area, max_area) in enumerate(area_ranges):
+ fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
+ return tp, fp
+
+ ious = bbox_overlaps(
+ det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate)
+ # for each det, the max iou with all gts
+ ious_max = ious.max(axis=1)
+ # for each det, which gt overlaps most with it
+ ious_argmax = ious.argmax(axis=1)
+ # sort all dets in descending order by scores
+ sort_inds = np.argsort(-det_bboxes[:, -1])
+ for k, (min_area, max_area) in enumerate(area_ranges):
+ gt_covered = np.zeros(num_gts, dtype=bool)
+ # if no area range is specified, gt_area_ignore is all False
+ if min_area is None:
+ gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
+ else:
+ gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * (
+ gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length)
+ gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
+ for i in sort_inds:
+ if ious_max[i] >= iou_thr:
+ matched_gt = ious_argmax[i]
+ if not (gt_ignore_inds[matched_gt]
+ or gt_area_ignore[matched_gt]):
+ if not gt_covered[matched_gt]:
+ gt_covered[matched_gt] = True
+ tp[k, i] = 1
+ else:
+ fp[k, i] = 1
+ # otherwise ignore this detected bbox, tp = 0, fp = 0
+ elif min_area is None:
+ fp[k, i] = 1
+ else:
+ bbox = det_bboxes[i, :4]
+ area = (bbox[2] - bbox[0] + extra_length) * (
+ bbox[3] - bbox[1] + extra_length)
+ if area >= min_area and area < max_area:
+ fp[k, i] = 1
+ return tp, fp
+
+
+def tpfp_openimages(det_bboxes,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ iou_thr=0.5,
+ area_ranges=None,
+ use_legacy_coordinate=False,
+ gt_bboxes_group_of=None,
+ use_group_of=True,
+ ioa_thr=0.5,
+ **kwargs):
+ """Check if detected bboxes are true positive or false positive.
+
+ Args:
+ det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
+ gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
+ gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
+ of shape (k, 4). Default: None
+ iou_thr (float): IoU threshold to be considered as matched.
+ Default: 0.5.
+ area_ranges (list[tuple] | None): Range of bbox areas to be
+ evaluated, in the format [(min1, max1), (min2, max2), ...].
+ Default: None.
+ use_legacy_coordinate (bool): Whether to use coordinate system in
+ mmdet v1.x. which means width, height should be
+ calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
+ Default: False.
+ gt_bboxes_group_of (ndarray): GT group_of of this image, of shape
+ (k, 1). Default: None
+ use_group_of (bool): Whether to use group of when calculate TP and FP,
+ which only used in OpenImages evaluation. Default: True.
+ ioa_thr (float | None): IoA threshold to be considered as matched,
+ which only used in OpenImages evaluation. Default: 0.5.
+
+ Returns:
+ tuple[np.ndarray]: Returns a tuple (tp, fp, det_bboxes), where
+ (tp, fp) whose elements are 0 and 1. The shape of each array is
+ (num_scales, m). (det_bboxes) whose will filter those are not
+ matched by group of gts when processing Open Images evaluation.
+ The shape is (num_scales, m).
+ """
+
+ if not use_legacy_coordinate:
+ extra_length = 0.
+ else:
+ extra_length = 1.
+
+ # an indicator of ignored gts
+ gt_ignore_inds = np.concatenate(
+ (np.zeros(gt_bboxes.shape[0],
+ dtype=bool), np.ones(gt_bboxes_ignore.shape[0], dtype=bool)))
+ # stack gt_bboxes and gt_bboxes_ignore for convenience
+ gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
+
+ num_dets = det_bboxes.shape[0]
+ num_gts = gt_bboxes.shape[0]
+ if area_ranges is None:
+ area_ranges = [(None, None)]
+ num_scales = len(area_ranges)
+ # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
+ # a certain scale
+ tp = np.zeros((num_scales, num_dets), dtype=np.float32)
+ fp = np.zeros((num_scales, num_dets), dtype=np.float32)
+
+ # if there is no gt bboxes in this image, then all det bboxes
+ # within area range are false positives
+ if gt_bboxes.shape[0] == 0:
+ if area_ranges == [(None, None)]:
+ fp[...] = 1
+ else:
+ det_areas = (
+ det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
+ det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
+ for i, (min_area, max_area) in enumerate(area_ranges):
+ fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
+ return tp, fp, det_bboxes
+
+ if gt_bboxes_group_of is not None and use_group_of:
+ # if handle group-of boxes, divided gt boxes into two parts:
+ # non-group-of and group-of.Then calculate ious and ioas through
+ # non-group-of group-of gts respectively. This only used in
+ # OpenImages evaluation.
+ assert gt_bboxes_group_of.shape[0] == gt_bboxes.shape[0]
+ non_group_gt_bboxes = gt_bboxes[~gt_bboxes_group_of]
+ group_gt_bboxes = gt_bboxes[gt_bboxes_group_of]
+ num_gts_group = group_gt_bboxes.shape[0]
+ ious = bbox_overlaps(det_bboxes, non_group_gt_bboxes)
+ ioas = bbox_overlaps(det_bboxes, group_gt_bboxes, mode='iof')
+ else:
+ # if not consider group-of boxes, only calculate ious through gt boxes
+ ious = bbox_overlaps(
+ det_bboxes, gt_bboxes, use_legacy_coordinate=use_legacy_coordinate)
+ ioas = None
+
+ if ious.shape[1] > 0:
+ # for each det, the max iou with all gts
+ ious_max = ious.max(axis=1)
+ # for each det, which gt overlaps most with it
+ ious_argmax = ious.argmax(axis=1)
+ # sort all dets in descending order by scores
+ sort_inds = np.argsort(-det_bboxes[:, -1])
+ for k, (min_area, max_area) in enumerate(area_ranges):
+ gt_covered = np.zeros(num_gts, dtype=bool)
+ # if no area range is specified, gt_area_ignore is all False
+ if min_area is None:
+ gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
+ else:
+ gt_areas = (
+ gt_bboxes[:, 2] - gt_bboxes[:, 0] + extra_length) * (
+ gt_bboxes[:, 3] - gt_bboxes[:, 1] + extra_length)
+ gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
+ for i in sort_inds:
+ if ious_max[i] >= iou_thr:
+ matched_gt = ious_argmax[i]
+ if not (gt_ignore_inds[matched_gt]
+ or gt_area_ignore[matched_gt]):
+ if not gt_covered[matched_gt]:
+ gt_covered[matched_gt] = True
+ tp[k, i] = 1
+ else:
+ fp[k, i] = 1
+ # otherwise ignore this detected bbox, tp = 0, fp = 0
+ elif min_area is None:
+ fp[k, i] = 1
+ else:
+ bbox = det_bboxes[i, :4]
+ area = (bbox[2] - bbox[0] + extra_length) * (
+ bbox[3] - bbox[1] + extra_length)
+ if area >= min_area and area < max_area:
+ fp[k, i] = 1
+ else:
+ # if there is no no-group-of gt bboxes in this image,
+ # then all det bboxes within area range are false positives.
+ # Only used in OpenImages evaluation.
+ if area_ranges == [(None, None)]:
+ fp[...] = 1
+ else:
+ det_areas = (
+ det_bboxes[:, 2] - det_bboxes[:, 0] + extra_length) * (
+ det_bboxes[:, 3] - det_bboxes[:, 1] + extra_length)
+ for i, (min_area, max_area) in enumerate(area_ranges):
+ fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
+
+ if ioas is None or ioas.shape[1] <= 0:
+ return tp, fp, det_bboxes
+ else:
+ # The evaluation of group-of TP and FP are done in two stages:
+ # 1. All detections are first matched to non group-of boxes; true
+ # positives are determined.
+ # 2. Detections that are determined as false positives are matched
+ # against group-of boxes and calculated group-of TP and FP.
+ # Only used in OpenImages evaluation.
+ det_bboxes_group = np.zeros(
+ (num_scales, ioas.shape[1], det_bboxes.shape[1]), dtype=float)
+ match_group_of = np.zeros((num_scales, num_dets), dtype=bool)
+ tp_group = np.zeros((num_scales, num_gts_group), dtype=np.float32)
+ ioas_max = ioas.max(axis=1)
+ # for each det, which gt overlaps most with it
+ ioas_argmax = ioas.argmax(axis=1)
+ # sort all dets in descending order by scores
+ sort_inds = np.argsort(-det_bboxes[:, -1])
+ for k, (min_area, max_area) in enumerate(area_ranges):
+ box_is_covered = tp[k]
+ # if no area range is specified, gt_area_ignore is all False
+ if min_area is None:
+ gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
+ else:
+ gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
+ gt_bboxes[:, 3] - gt_bboxes[:, 1])
+ gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
+ for i in sort_inds:
+ matched_gt = ioas_argmax[i]
+ if not box_is_covered[i]:
+ if ioas_max[i] >= ioa_thr:
+ if not (gt_ignore_inds[matched_gt]
+ or gt_area_ignore[matched_gt]):
+ if not tp_group[k, matched_gt]:
+ tp_group[k, matched_gt] = 1
+ match_group_of[k, i] = True
+ else:
+ match_group_of[k, i] = True
+
+ if det_bboxes_group[k, matched_gt, -1] < \
+ det_bboxes[i, -1]:
+ det_bboxes_group[k, matched_gt] = \
+ det_bboxes[i]
+
+ fp_group = (tp_group <= 0).astype(float)
+ tps = []
+ fps = []
+ # concatenate tp, fp, and det-boxes which not matched group of
+ # gt boxes and tp_group, fp_group, and det_bboxes_group which
+ # matched group of boxes respectively.
+ for i in range(num_scales):
+ tps.append(
+ np.concatenate((tp[i][~match_group_of[i]], tp_group[i])))
+ fps.append(
+ np.concatenate((fp[i][~match_group_of[i]], fp_group[i])))
+ det_bboxes = np.concatenate(
+ (det_bboxes[~match_group_of[i]], det_bboxes_group[i]))
+
+ tp = np.vstack(tps)
+ fp = np.vstack(fps)
+ return tp, fp, det_bboxes
+
+
+def get_cls_results(det_results, annotations, class_id):
+ """Get det results and gt information of a certain class.
+
+ Args:
+ det_results (list[list]): Same as `eval_map()`.
+ annotations (list[dict]): Same as `eval_map()`.
+ class_id (int): ID of a specific class.
+
+ Returns:
+ tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes
+ """
+ cls_dets = [img_res[class_id] for img_res in det_results]
+ cls_gts = []
+ cls_gts_ignore = []
+ for ann in annotations:
+ gt_inds = ann['labels'] == class_id
+ cls_gts.append(ann['bboxes'][gt_inds, :])
+
+ if ann.get('labels_ignore', None) is not None:
+ ignore_inds = ann['labels_ignore'] == class_id
+ cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :])
+ else:
+ cls_gts_ignore.append(np.empty((0, 4), dtype=np.float32))
+
+ return cls_dets, cls_gts, cls_gts_ignore
+
+
+def get_cls_group_ofs(annotations, class_id):
+ """Get `gt_group_of` of a certain class, which is used in Open Images.
+
+ Args:
+ annotations (list[dict]): Same as `eval_map()`.
+ class_id (int): ID of a specific class.
+
+ Returns:
+ list[np.ndarray]: `gt_group_of` of a certain class.
+ """
+ gt_group_ofs = []
+ for ann in annotations:
+ gt_inds = ann['labels'] == class_id
+ if ann.get('gt_is_group_ofs', None) is not None:
+ gt_group_ofs.append(ann['gt_is_group_ofs'][gt_inds])
+ else:
+ gt_group_ofs.append(np.empty((0, 1), dtype=bool))
+
+ return gt_group_ofs
+
+
+def eval_map(det_results,
+ annotations,
+ scale_ranges=None,
+ iou_thr=0.5,
+ ioa_thr=None,
+ dataset=None,
+ logger=None,
+ tpfp_fn=None,
+ nproc=4,
+ use_legacy_coordinate=False,
+ use_group_of=False):
+ """Evaluate mAP of a dataset.
+
+ Args:
+ det_results (list[list]): [[cls1_det, cls2_det, ...], ...].
+ The outer list indicates images, and the inner list indicates
+ per-class detected bboxes.
+ annotations (list[dict]): Ground truth annotations where each item of
+ the list indicates an image. Keys of annotations are:
+
+ - `bboxes`: numpy array of shape (n, 4)
+ - `labels`: numpy array of shape (n, )
+ - `bboxes_ignore` (optional): numpy array of shape (k, 4)
+ - `labels_ignore` (optional): numpy array of shape (k, )
+ scale_ranges (list[tuple] | None): Range of scales to be evaluated,
+ in the format [(min1, max1), (min2, max2), ...]. A range of
+ (32, 64) means the area range between (32**2, 64**2).
+ Default: None.
+ iou_thr (float): IoU threshold to be considered as matched.
+ Default: 0.5.
+ ioa_thr (float | None): IoA threshold to be considered as matched,
+ which only used in OpenImages evaluation. Default: None.
+ dataset (list[str] | str | None): Dataset name or dataset classes,
+ there are minor differences in metrics for different datasets, e.g.
+ "voc07", "imagenet_det", etc. Default: None.
+ logger (logging.Logger | str | None): The way to print the mAP
+ summary. See `mmcv.utils.print_log()` for details. Default: None.
+ tpfp_fn (callable | None): The function used to determine true/
+ false positives. If None, :func:`tpfp_default` is used as default
+ unless dataset is 'det' or 'vid' (:func:`tpfp_imagenet` in this
+ case). If it is given as a function, then this function is used
+ to evaluate tp & fp. Default None.
+ nproc (int): Processes used for computing TP and FP.
+ Default: 4.
+ use_legacy_coordinate (bool): Whether to use coordinate system in
+ mmdet v1.x. which means width, height should be
+ calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
+ Default: False.
+ use_group_of (bool): Whether to use group of when calculate TP and FP,
+ which only used in OpenImages evaluation. Default: False.
+
+ Returns:
+ tuple: (mAP, [dict, dict, ...])
+ """
+ assert len(det_results) == len(annotations)
+ if not use_legacy_coordinate:
+ extra_length = 0.
+ else:
+ extra_length = 1.
+
+ num_imgs = len(det_results)
+ num_scales = len(scale_ranges) if scale_ranges is not None else 1
+ num_classes = len(det_results[0]) # positive class num
+ area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges]
+ if scale_ranges is not None else None)
+
+ # There is no need to use multi processes to process
+ # when num_imgs = 1 .
+ if num_imgs > 1:
+ assert nproc > 0, 'nproc must be at least one.'
+ nproc = min(nproc, num_imgs)
+ pool = Pool(nproc)
+
+ eval_results = []
+ for i in range(num_classes):
+ # get gt and det bboxes of this class
+ cls_dets, cls_gts, cls_gts_ignore = get_cls_results(
+ det_results, annotations, i)
+ # choose proper function according to datasets to compute tp and fp
+ if tpfp_fn is None:
+ if dataset in ['det', 'vid']:
+ tpfp_fn = tpfp_imagenet
+ elif dataset in ['oid_challenge', 'oid_v6'] \
+ or use_group_of is True:
+ tpfp_fn = tpfp_openimages
+ else:
+ tpfp_fn = tpfp_default
+ if not callable(tpfp_fn):
+ raise ValueError(
+ f'tpfp_fn has to be a function or None, but got {tpfp_fn}')
+
+ if num_imgs > 1:
+ # compute tp and fp for each image with multiple processes
+ args = []
+ if use_group_of:
+ # used in Open Images Dataset evaluation
+ gt_group_ofs = get_cls_group_ofs(annotations, i)
+ args.append(gt_group_ofs)
+ args.append([use_group_of for _ in range(num_imgs)])
+ if ioa_thr is not None:
+ args.append([ioa_thr for _ in range(num_imgs)])
+
+ tpfp = pool.starmap(
+ tpfp_fn,
+ zip(cls_dets, cls_gts, cls_gts_ignore,
+ [iou_thr for _ in range(num_imgs)],
+ [area_ranges for _ in range(num_imgs)],
+ [use_legacy_coordinate for _ in range(num_imgs)], *args))
+ else:
+ tpfp = tpfp_fn(
+ cls_dets[0],
+ cls_gts[0],
+ cls_gts_ignore[0],
+ iou_thr,
+ area_ranges,
+ use_legacy_coordinate,
+ gt_bboxes_group_of=(get_cls_group_ofs(annotations, i)[0]
+ if use_group_of else None),
+ use_group_of=use_group_of,
+ ioa_thr=ioa_thr)
+ tpfp = [tpfp]
+
+ if use_group_of:
+ tp, fp, cls_dets = tuple(zip(*tpfp))
+ else:
+ tp, fp = tuple(zip(*tpfp))
+ # calculate gt number of each scale
+ # ignored gts or gts beyond the specific scale are not counted
+ num_gts = np.zeros(num_scales, dtype=int)
+ for j, bbox in enumerate(cls_gts):
+ if area_ranges is None:
+ num_gts[0] += bbox.shape[0]
+ else:
+ gt_areas = (bbox[:, 2] - bbox[:, 0] + extra_length) * (
+ bbox[:, 3] - bbox[:, 1] + extra_length)
+ for k, (min_area, max_area) in enumerate(area_ranges):
+ num_gts[k] += np.sum((gt_areas >= min_area)
+ & (gt_areas < max_area))
+ # sort all det bboxes by score, also sort tp and fp
+ cls_dets = np.vstack(cls_dets)
+ num_dets = cls_dets.shape[0]
+ sort_inds = np.argsort(-cls_dets[:, -1])
+ tp = np.hstack(tp)[:, sort_inds]
+ fp = np.hstack(fp)[:, sort_inds]
+ # calculate recall and precision with tp and fp
+ tp = np.cumsum(tp, axis=1)
+ fp = np.cumsum(fp, axis=1)
+ eps = np.finfo(np.float32).eps
+ recalls = tp / np.maximum(num_gts[:, np.newaxis], eps)
+ precisions = tp / np.maximum((tp + fp), eps)
+ # calculate AP
+ if scale_ranges is None:
+ recalls = recalls[0, :]
+ precisions = precisions[0, :]
+ num_gts = num_gts.item()
+ mode = 'area' if dataset != 'voc07' else '11points'
+ ap = average_precision(recalls, precisions, mode)
+ eval_results.append({
+ 'num_gts': num_gts,
+ 'num_dets': num_dets,
+ 'recall': recalls,
+ 'precision': precisions,
+ 'ap': ap
+ })
+
+ if num_imgs > 1:
+ pool.close()
+
+ if scale_ranges is not None:
+ # shape (num_classes, num_scales)
+ all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results])
+ all_num_gts = np.vstack(
+ [cls_result['num_gts'] for cls_result in eval_results])
+ mean_ap = []
+ for i in range(num_scales):
+ if np.any(all_num_gts[:, i] > 0):
+ mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean())
+ else:
+ mean_ap.append(0.0)
+ else:
+ aps = []
+ for cls_result in eval_results:
+ if cls_result['num_gts'] > 0:
+ aps.append(cls_result['ap'])
+ mean_ap = np.array(aps).mean().item() if aps else 0.0
+
+ print_map_summary(
+ mean_ap, eval_results, dataset, area_ranges, logger=logger)
+
+ return mean_ap, eval_results
+
+
+def print_map_summary(mean_ap,
+ results,
+ dataset=None,
+ scale_ranges=None,
+ logger=None):
+ """Print mAP and results of each class.
+
+ A table will be printed to show the gts/dets/recall/AP of each class and
+ the mAP.
+
+ Args:
+ mean_ap (float): Calculated from `eval_map()`.
+ results (list[dict]): Calculated from `eval_map()`.
+ dataset (list[str] | str | None): Dataset name or dataset classes.
+ scale_ranges (list[tuple] | None): Range of scales to be evaluated.
+ logger (logging.Logger | str | None): The way to print the mAP
+ summary. See `mmcv.utils.print_log()` for details. Default: None.
+ """
+
+ if logger == 'silent':
+ return
+
+ if isinstance(results[0]['ap'], np.ndarray):
+ num_scales = len(results[0]['ap'])
+ else:
+ num_scales = 1
+
+ if scale_ranges is not None:
+ assert len(scale_ranges) == num_scales
+
+ num_classes = len(results)
+
+ recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
+ aps = np.zeros((num_scales, num_classes), dtype=np.float32)
+ num_gts = np.zeros((num_scales, num_classes), dtype=int)
+ for i, cls_result in enumerate(results):
+ if cls_result['recall'].size > 0:
+ recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
+ aps[:, i] = cls_result['ap']
+ num_gts[:, i] = cls_result['num_gts']
+
+ if dataset is None:
+ label_names = [str(i) for i in range(num_classes)]
+ elif mmcv.is_str(dataset):
+ label_names = get_classes(dataset)
+ else:
+ label_names = dataset
+
+ if not isinstance(mean_ap, list):
+ mean_ap = [mean_ap]
+
+ header = ['class', 'gts', 'dets', 'recall', 'ap']
+ for i in range(num_scales):
+ if scale_ranges is not None:
+ print_log(f'Scale range {scale_ranges[i]}', logger=logger)
+ table_data = [header]
+ for j in range(num_classes):
+ row_data = [
+ label_names[j], num_gts[i, j], results[j]['num_dets'],
+ f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}'
+ ]
+ table_data.append(row_data)
+ table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}'])
+ table = AsciiTable(table_data)
+ table.inner_footing_row_border = True
+ print_log('\n' + table.table, logger=logger)
diff --git a/mmdet/core/evaluation/panoptic_utils.py b/mmdet/core/evaluation/panoptic_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..10c9ad934e0c9047ccdcfbf0d429ab13b8527d88
--- /dev/null
+++ b/mmdet/core/evaluation/panoptic_utils.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# A custom value to distinguish instance ID and category ID; need to
+# be greater than the number of categories.
+# For a pixel in the panoptic result map:
+# pan_id = ins_id * INSTANCE_OFFSET + cat_id
+INSTANCE_OFFSET = 1000
diff --git a/mmdet/core/evaluation/recall.py b/mmdet/core/evaluation/recall.py
new file mode 100644
index 0000000000000000000000000000000000000000..82b3c909b82fad29d6d5147c562a674e5db7c14c
--- /dev/null
+++ b/mmdet/core/evaluation/recall.py
@@ -0,0 +1,197 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections.abc import Sequence
+
+import numpy as np
+from mmcv.utils import print_log
+from terminaltables import AsciiTable
+
+from .bbox_overlaps import bbox_overlaps
+
+
+def _recalls(all_ious, proposal_nums, thrs):
+
+ img_num = all_ious.shape[0]
+ total_gt_num = sum([ious.shape[0] for ious in all_ious])
+
+ _ious = np.zeros((proposal_nums.size, total_gt_num), dtype=np.float32)
+ for k, proposal_num in enumerate(proposal_nums):
+ tmp_ious = np.zeros(0)
+ for i in range(img_num):
+ ious = all_ious[i][:, :proposal_num].copy()
+ gt_ious = np.zeros((ious.shape[0]))
+ if ious.size == 0:
+ tmp_ious = np.hstack((tmp_ious, gt_ious))
+ continue
+ for j in range(ious.shape[0]):
+ gt_max_overlaps = ious.argmax(axis=1)
+ max_ious = ious[np.arange(0, ious.shape[0]), gt_max_overlaps]
+ gt_idx = max_ious.argmax()
+ gt_ious[j] = max_ious[gt_idx]
+ box_idx = gt_max_overlaps[gt_idx]
+ ious[gt_idx, :] = -1
+ ious[:, box_idx] = -1
+ tmp_ious = np.hstack((tmp_ious, gt_ious))
+ _ious[k, :] = tmp_ious
+
+ _ious = np.fliplr(np.sort(_ious, axis=1))
+ recalls = np.zeros((proposal_nums.size, thrs.size))
+ for i, thr in enumerate(thrs):
+ recalls[:, i] = (_ious >= thr).sum(axis=1) / float(total_gt_num)
+
+ return recalls
+
+
+def set_recall_param(proposal_nums, iou_thrs):
+ """Check proposal_nums and iou_thrs and set correct format."""
+ if isinstance(proposal_nums, Sequence):
+ _proposal_nums = np.array(proposal_nums)
+ elif isinstance(proposal_nums, int):
+ _proposal_nums = np.array([proposal_nums])
+ else:
+ _proposal_nums = proposal_nums
+
+ if iou_thrs is None:
+ _iou_thrs = np.array([0.5])
+ elif isinstance(iou_thrs, Sequence):
+ _iou_thrs = np.array(iou_thrs)
+ elif isinstance(iou_thrs, float):
+ _iou_thrs = np.array([iou_thrs])
+ else:
+ _iou_thrs = iou_thrs
+
+ return _proposal_nums, _iou_thrs
+
+
+def eval_recalls(gts,
+ proposals,
+ proposal_nums=None,
+ iou_thrs=0.5,
+ logger=None,
+ use_legacy_coordinate=False):
+ """Calculate recalls.
+
+ Args:
+ gts (list[ndarray]): a list of arrays of shape (n, 4)
+ proposals (list[ndarray]): a list of arrays of shape (k, 4) or (k, 5)
+ proposal_nums (int | Sequence[int]): Top N proposals to be evaluated.
+ iou_thrs (float | Sequence[float]): IoU thresholds. Default: 0.5.
+ logger (logging.Logger | str | None): The way to print the recall
+ summary. See `mmcv.utils.print_log()` for details. Default: None.
+ use_legacy_coordinate (bool): Whether use coordinate system
+ in mmdet v1.x. "1" was added to both height and width
+ which means w, h should be
+ computed as 'x2 - x1 + 1` and 'y2 - y1 + 1'. Default: False.
+
+
+ Returns:
+ ndarray: recalls of different ious and proposal nums
+ """
+
+ img_num = len(gts)
+ assert img_num == len(proposals)
+ proposal_nums, iou_thrs = set_recall_param(proposal_nums, iou_thrs)
+ all_ious = []
+ for i in range(img_num):
+ if proposals[i].ndim == 2 and proposals[i].shape[1] == 5:
+ scores = proposals[i][:, 4]
+ sort_idx = np.argsort(scores)[::-1]
+ img_proposal = proposals[i][sort_idx, :]
+ else:
+ img_proposal = proposals[i]
+ prop_num = min(img_proposal.shape[0], proposal_nums[-1])
+ if gts[i] is None or gts[i].shape[0] == 0:
+ ious = np.zeros((0, img_proposal.shape[0]), dtype=np.float32)
+ else:
+ ious = bbox_overlaps(
+ gts[i],
+ img_proposal[:prop_num, :4],
+ use_legacy_coordinate=use_legacy_coordinate)
+ all_ious.append(ious)
+ all_ious = np.array(all_ious)
+ recalls = _recalls(all_ious, proposal_nums, iou_thrs)
+
+ print_recall_summary(recalls, proposal_nums, iou_thrs, logger=logger)
+ return recalls
+
+
+def print_recall_summary(recalls,
+ proposal_nums,
+ iou_thrs,
+ row_idxs=None,
+ col_idxs=None,
+ logger=None):
+ """Print recalls in a table.
+
+ Args:
+ recalls (ndarray): calculated from `bbox_recalls`
+ proposal_nums (ndarray or list): top N proposals
+ iou_thrs (ndarray or list): iou thresholds
+ row_idxs (ndarray): which rows(proposal nums) to print
+ col_idxs (ndarray): which cols(iou thresholds) to print
+ logger (logging.Logger | str | None): The way to print the recall
+ summary. See `mmcv.utils.print_log()` for details. Default: None.
+ """
+ proposal_nums = np.array(proposal_nums, dtype=np.int32)
+ iou_thrs = np.array(iou_thrs)
+ if row_idxs is None:
+ row_idxs = np.arange(proposal_nums.size)
+ if col_idxs is None:
+ col_idxs = np.arange(iou_thrs.size)
+ row_header = [''] + iou_thrs[col_idxs].tolist()
+ table_data = [row_header]
+ for i, num in enumerate(proposal_nums[row_idxs]):
+ row = [f'{val:.3f}' for val in recalls[row_idxs[i], col_idxs].tolist()]
+ row.insert(0, num)
+ table_data.append(row)
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table, logger=logger)
+
+
+def plot_num_recall(recalls, proposal_nums):
+ """Plot Proposal_num-Recalls curve.
+
+ Args:
+ recalls(ndarray or list): shape (k,)
+ proposal_nums(ndarray or list): same shape as `recalls`
+ """
+ if isinstance(proposal_nums, np.ndarray):
+ _proposal_nums = proposal_nums.tolist()
+ else:
+ _proposal_nums = proposal_nums
+ if isinstance(recalls, np.ndarray):
+ _recalls = recalls.tolist()
+ else:
+ _recalls = recalls
+
+ import matplotlib.pyplot as plt
+ f = plt.figure()
+ plt.plot([0] + _proposal_nums, [0] + _recalls)
+ plt.xlabel('Proposal num')
+ plt.ylabel('Recall')
+ plt.axis([0, proposal_nums.max(), 0, 1])
+ f.show()
+
+
+def plot_iou_recall(recalls, iou_thrs):
+ """Plot IoU-Recalls curve.
+
+ Args:
+ recalls(ndarray or list): shape (k,)
+ iou_thrs(ndarray or list): same shape as `recalls`
+ """
+ if isinstance(iou_thrs, np.ndarray):
+ _iou_thrs = iou_thrs.tolist()
+ else:
+ _iou_thrs = iou_thrs
+ if isinstance(recalls, np.ndarray):
+ _recalls = recalls.tolist()
+ else:
+ _recalls = recalls
+
+ import matplotlib.pyplot as plt
+ f = plt.figure()
+ plt.plot(_iou_thrs + [1.0], _recalls + [0.])
+ plt.xlabel('IoU')
+ plt.ylabel('Recall')
+ plt.axis([iou_thrs.min(), 1, 0, 1])
+ f.show()
diff --git a/mmdet/core/export/__init__.py b/mmdet/core/export/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8179c93642dcfaa780c5beccd3f1f104f32d4ae
--- /dev/null
+++ b/mmdet/core/export/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .onnx_helper import (add_dummy_nms_for_onnx, dynamic_clip_for_onnx,
+ get_k_for_topk)
+from .pytorch2onnx import (build_model_from_cfg,
+ generate_inputs_and_wrap_model,
+ preprocess_example_input)
+
+__all__ = [
+ 'build_model_from_cfg', 'generate_inputs_and_wrap_model',
+ 'preprocess_example_input', 'get_k_for_topk', 'add_dummy_nms_for_onnx',
+ 'dynamic_clip_for_onnx'
+]
diff --git a/mmdet/core/export/model_wrappers.py b/mmdet/core/export/model_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7be2df724a08acdb0c8de1ce19108dc75d5c2e3
--- /dev/null
+++ b/mmdet/core/export/model_wrappers.py
@@ -0,0 +1,183 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+
+import numpy as np
+import torch
+
+from mmdet.core import bbox2result
+from mmdet.models import BaseDetector
+
+
+class DeployBaseDetector(BaseDetector):
+ """DeployBaseDetector."""
+
+ def __init__(self, class_names, device_id):
+ super(DeployBaseDetector, self).__init__()
+ self.CLASSES = class_names
+ self.device_id = device_id
+
+ def simple_test(self, img, img_metas, **kwargs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def aug_test(self, imgs, img_metas, **kwargs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def extract_feat(self, imgs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def forward_train(self, imgs, img_metas, **kwargs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def val_step(self, data, optimizer):
+ raise NotImplementedError('This method is not implemented.')
+
+ def train_step(self, data, optimizer):
+ raise NotImplementedError('This method is not implemented.')
+
+ def forward_test(self, *, img, img_metas, **kwargs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def async_simple_test(self, img, img_metas, **kwargs):
+ raise NotImplementedError('This method is not implemented.')
+
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ outputs = self.forward_test(img, img_metas, **kwargs)
+ batch_dets, batch_labels = outputs[:2]
+ batch_masks = outputs[2] if len(outputs) == 3 else None
+ batch_size = img[0].shape[0]
+ img_metas = img_metas[0]
+ results = []
+ rescale = kwargs.get('rescale', True)
+ for i in range(batch_size):
+ dets, labels = batch_dets[i], batch_labels[i]
+ if rescale:
+ scale_factor = img_metas[i]['scale_factor']
+
+ if isinstance(scale_factor, (list, tuple, np.ndarray)):
+ assert len(scale_factor) == 4
+ scale_factor = np.array(scale_factor)[None, :] # [1,4]
+ dets[:, :4] /= scale_factor
+
+ if 'border' in img_metas[i]:
+ # offset pixel of the top-left corners between original image
+ # and padded/enlarged image, 'border' is used when exporting
+ # CornerNet and CentripetalNet to onnx
+ x_off = img_metas[i]['border'][2]
+ y_off = img_metas[i]['border'][0]
+ dets[:, [0, 2]] -= x_off
+ dets[:, [1, 3]] -= y_off
+ dets[:, :4] *= (dets[:, :4] > 0).astype(dets.dtype)
+
+ dets_results = bbox2result(dets, labels, len(self.CLASSES))
+
+ if batch_masks is not None:
+ masks = batch_masks[i]
+ img_h, img_w = img_metas[i]['img_shape'][:2]
+ ori_h, ori_w = img_metas[i]['ori_shape'][:2]
+ masks = masks[:, :img_h, :img_w]
+ if rescale:
+ masks = masks.astype(np.float32)
+ masks = torch.from_numpy(masks)
+ masks = torch.nn.functional.interpolate(
+ masks.unsqueeze(0), size=(ori_h, ori_w))
+ masks = masks.squeeze(0).detach().numpy()
+ if masks.dtype != bool:
+ masks = masks >= 0.5
+ segms_results = [[] for _ in range(len(self.CLASSES))]
+ for j in range(len(dets)):
+ segms_results[labels[j]].append(masks[j])
+ results.append((dets_results, segms_results))
+ else:
+ results.append(dets_results)
+ return results
+
+
+class ONNXRuntimeDetector(DeployBaseDetector):
+ """Wrapper for detector's inference with ONNXRuntime."""
+
+ def __init__(self, onnx_file, class_names, device_id):
+ super(ONNXRuntimeDetector, self).__init__(class_names, device_id)
+ import onnxruntime as ort
+
+ # get the custom op path
+ ort_custom_op_path = ''
+ try:
+ from mmcv.ops import get_onnxruntime_op_path
+ ort_custom_op_path = get_onnxruntime_op_path()
+ except (ImportError, ModuleNotFoundError):
+ warnings.warn('If input model has custom op from mmcv, \
+ you may have to build mmcv with ONNXRuntime from source.')
+ session_options = ort.SessionOptions()
+ # register custom op for onnxruntime
+ if osp.exists(ort_custom_op_path):
+ session_options.register_custom_ops_library(ort_custom_op_path)
+ sess = ort.InferenceSession(onnx_file, session_options)
+ providers = ['CPUExecutionProvider']
+ options = [{}]
+ is_cuda_available = ort.get_device() == 'GPU'
+ if is_cuda_available:
+ providers.insert(0, 'CUDAExecutionProvider')
+ options.insert(0, {'device_id': device_id})
+
+ sess.set_providers(providers, options)
+
+ self.sess = sess
+ self.io_binding = sess.io_binding()
+ self.output_names = [_.name for _ in sess.get_outputs()]
+ self.is_cuda_available = is_cuda_available
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ input_data = imgs[0]
+ # set io binding for inputs/outputs
+ device_type = 'cuda' if self.is_cuda_available else 'cpu'
+ if not self.is_cuda_available:
+ input_data = input_data.cpu()
+ self.io_binding.bind_input(
+ name='input',
+ device_type=device_type,
+ device_id=self.device_id,
+ element_type=np.float32,
+ shape=input_data.shape,
+ buffer_ptr=input_data.data_ptr())
+
+ for name in self.output_names:
+ self.io_binding.bind_output(name)
+ # run session to get outputs
+ self.sess.run_with_iobinding(self.io_binding)
+ ort_outputs = self.io_binding.copy_outputs_to_cpu()
+ return ort_outputs
+
+
+class TensorRTDetector(DeployBaseDetector):
+ """Wrapper for detector's inference with TensorRT."""
+
+ def __init__(self, engine_file, class_names, device_id, output_names=None):
+ super(TensorRTDetector, self).__init__(class_names, device_id)
+ warnings.warn('`output_names` is deprecated and will be removed in '
+ 'future releases.')
+ from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
+ try:
+ load_tensorrt_plugin()
+ except (ImportError, ModuleNotFoundError):
+ warnings.warn('If input model has custom op from mmcv, \
+ you may have to build mmcv with TensorRT from source.')
+
+ output_names = ['dets', 'labels']
+ model = TRTWraper(engine_file, ['input'], output_names)
+ with_masks = False
+ # if TensorRT has totally 4 inputs/outputs, then
+ # the detector should have `mask` output.
+ if len(model.engine) == 4:
+ model.output_names = output_names + ['masks']
+ with_masks = True
+ self.model = model
+ self.with_masks = with_masks
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ input_data = imgs[0].contiguous()
+ with torch.cuda.device(self.device_id), torch.no_grad():
+ outputs = self.model({'input': input_data})
+ outputs = [outputs[name] for name in self.model.output_names]
+ outputs = [out.detach().cpu().numpy() for out in outputs]
+ return outputs
diff --git a/mmdet/core/export/onnx_helper.py b/mmdet/core/export/onnx_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f6b9a012a621be616fe9c086740fb9367ec2311
--- /dev/null
+++ b/mmdet/core/export/onnx_helper.py
@@ -0,0 +1,223 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+
+import torch
+
+
+def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape):
+ """Clip boxes dynamically for onnx.
+
+ Since torch.clamp cannot have dynamic `min` and `max`, we scale the
+ boxes by 1/max_shape and clamp in the range [0, 1].
+
+ Args:
+ x1 (Tensor): The x1 for bounding boxes.
+ y1 (Tensor): The y1 for bounding boxes.
+ x2 (Tensor): The x2 for bounding boxes.
+ y2 (Tensor): The y2 for bounding boxes.
+ max_shape (Tensor or torch.Size): The (H,W) of original image.
+ Returns:
+ tuple(Tensor): The clipped x1, y1, x2, y2.
+ """
+ assert isinstance(
+ max_shape,
+ torch.Tensor), '`max_shape` should be tensor of (h,w) for onnx'
+
+ # scale by 1/max_shape
+ x1 = x1 / max_shape[1]
+ y1 = y1 / max_shape[0]
+ x2 = x2 / max_shape[1]
+ y2 = y2 / max_shape[0]
+
+ # clamp [0, 1]
+ x1 = torch.clamp(x1, 0, 1)
+ y1 = torch.clamp(y1, 0, 1)
+ x2 = torch.clamp(x2, 0, 1)
+ y2 = torch.clamp(y2, 0, 1)
+
+ # scale back
+ x1 = x1 * max_shape[1]
+ y1 = y1 * max_shape[0]
+ x2 = x2 * max_shape[1]
+ y2 = y2 * max_shape[0]
+ return x1, y1, x2, y2
+
+
+def get_k_for_topk(k, size):
+ """Get k of TopK for onnx exporting.
+
+ The K of TopK in TensorRT should not be a Tensor, while in ONNX Runtime
+ it could be a Tensor.Due to dynamic shape feature, we have to decide
+ whether to do TopK and what K it should be while exporting to ONNX.
+ If returned K is less than zero, it means we do not have to do
+ TopK operation.
+
+ Args:
+ k (int or Tensor): The set k value for nms from config file.
+ size (Tensor or torch.Size): The number of elements of \
+ TopK's input tensor
+ Returns:
+ tuple: (int or Tensor): The final K for TopK.
+ """
+ ret_k = -1
+ if k <= 0 or size <= 0:
+ return ret_k
+ if torch.onnx.is_in_onnx_export():
+ is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
+ if is_trt_backend:
+ # TensorRT does not support dynamic K with TopK op
+ if 0 < k < size:
+ ret_k = k
+ else:
+ # Always keep topk op for dynamic input in onnx for ONNX Runtime
+ ret_k = torch.where(k < size, k, size)
+ elif k < size:
+ ret_k = k
+ else:
+ # ret_k is -1
+ pass
+ return ret_k
+
+
+def add_dummy_nms_for_onnx(boxes,
+ scores,
+ max_output_boxes_per_class=1000,
+ iou_threshold=0.5,
+ score_threshold=0.05,
+ pre_top_k=-1,
+ after_top_k=-1,
+ labels=None):
+ """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
+
+ This function helps exporting to onnx with batch and multiclass NMS op.
+ It only supports class-agnostic detection results. That is, the scores
+ is of shape (N, num_bboxes, num_classes) and the boxes is of shape
+ (N, num_boxes, 4).
+
+ Args:
+ boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]
+ scores (Tensor): The detection scores of shape
+ [N, num_boxes, num_classes]
+ max_output_boxes_per_class (int): Maximum number of output
+ boxes per class of nms. Defaults to 1000.
+ iou_threshold (float): IOU threshold of nms. Defaults to 0.5
+ score_threshold (float): score threshold of nms.
+ Defaults to 0.05.
+ pre_top_k (bool): Number of top K boxes to keep before nms.
+ Defaults to -1.
+ after_top_k (int): Number of top K boxes to keep after nms.
+ Defaults to -1.
+ labels (Tensor, optional): It not None, explicit labels would be used.
+ Otherwise, labels would be automatically generated using
+ num_classed. Defaults to None.
+
+ Returns:
+ tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
+ and class labels of shape [N, num_det].
+ """
+ max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
+ iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
+ score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
+ batch_size = scores.shape[0]
+ num_class = scores.shape[2]
+
+ nms_pre = torch.tensor(pre_top_k, device=scores.device, dtype=torch.long)
+ nms_pre = get_k_for_topk(nms_pre, boxes.shape[1])
+
+ if nms_pre > 0:
+ max_scores, _ = scores.max(-1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds).long()
+ # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
+ transformed_inds = boxes.shape[1] * batch_inds + topk_inds
+ boxes = boxes.reshape(-1, 4)[transformed_inds, :].reshape(
+ batch_size, -1, 4)
+ scores = scores.reshape(-1, num_class)[transformed_inds, :].reshape(
+ batch_size, -1, num_class)
+ if labels is not None:
+ labels = labels.reshape(-1, 1)[transformed_inds].reshape(
+ batch_size, -1)
+
+ scores = scores.permute(0, 2, 1)
+ num_box = boxes.shape[1]
+ # turn off tracing to create a dummy output of nms
+ state = torch._C._get_tracing_state()
+ # dummy indices of nms's output
+ num_fake_det = 2
+ batch_inds = torch.randint(batch_size, (num_fake_det, 1))
+ cls_inds = torch.randint(num_class, (num_fake_det, 1))
+ box_inds = torch.randint(num_box, (num_fake_det, 1))
+ indices = torch.cat([batch_inds, cls_inds, box_inds], dim=1)
+ output = indices
+ setattr(DummyONNXNMSop, 'output', output)
+
+ # open tracing
+ torch._C._set_tracing_state(state)
+ selected_indices = DummyONNXNMSop.apply(boxes, scores,
+ max_output_boxes_per_class,
+ iou_threshold, score_threshold)
+
+ batch_inds, cls_inds = selected_indices[:, 0], selected_indices[:, 1]
+ box_inds = selected_indices[:, 2]
+ if labels is None:
+ labels = torch.arange(num_class, dtype=torch.long).to(scores.device)
+ labels = labels.view(1, num_class, 1).expand_as(scores)
+ scores = scores.reshape(-1, 1)
+ boxes = boxes.reshape(batch_size, -1).repeat(1, num_class).reshape(-1, 4)
+ pos_inds = (num_class * batch_inds + cls_inds) * num_box + box_inds
+ mask = scores.new_zeros(scores.shape)
+ # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
+ # PyTorch style code: mask[batch_inds, box_inds] += 1
+ mask[pos_inds, :] += 1
+ scores = scores * mask
+ boxes = boxes * mask
+
+ scores = scores.reshape(batch_size, -1)
+ boxes = boxes.reshape(batch_size, -1, 4)
+ labels = labels.reshape(batch_size, -1)
+
+ nms_after = torch.tensor(
+ after_top_k, device=scores.device, dtype=torch.long)
+ nms_after = get_k_for_topk(nms_after, num_box * num_class)
+
+ if nms_after > 0:
+ _, topk_inds = scores.topk(nms_after)
+ batch_inds = torch.arange(batch_size).view(-1, 1).expand_as(topk_inds)
+ # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
+ transformed_inds = scores.shape[1] * batch_inds + topk_inds
+ scores = scores.reshape(-1, 1)[transformed_inds, :].reshape(
+ batch_size, -1)
+ boxes = boxes.reshape(-1, 4)[transformed_inds, :].reshape(
+ batch_size, -1, 4)
+ labels = labels.reshape(-1, 1)[transformed_inds, :].reshape(
+ batch_size, -1)
+
+ scores = scores.unsqueeze(2)
+ dets = torch.cat([boxes, scores], dim=2)
+ return dets, labels
+
+
+class DummyONNXNMSop(torch.autograd.Function):
+ """DummyONNXNMSop.
+
+ This class is only for creating onnx::NonMaxSuppression.
+ """
+
+ @staticmethod
+ def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold,
+ score_threshold):
+
+ return DummyONNXNMSop.output
+
+ @staticmethod
+ def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold,
+ score_threshold):
+ return g.op(
+ 'NonMaxSuppression',
+ boxes,
+ scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ outputs=1)
diff --git a/mmdet/core/export/pytorch2onnx.py b/mmdet/core/export/pytorch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8261eed9e81c99db7f49ea929fce3d3ac1c0ca0
--- /dev/null
+++ b/mmdet/core/export/pytorch2onnx.py
@@ -0,0 +1,159 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+
+import mmcv
+import numpy as np
+import torch
+from mmcv.runner import load_checkpoint
+
+
+def generate_inputs_and_wrap_model(config_path,
+ checkpoint_path,
+ input_config,
+ cfg_options=None):
+ """Prepare sample input and wrap model for ONNX export.
+
+ The ONNX export API only accept args, and all inputs should be
+ torch.Tensor or corresponding types (such as tuple of tensor).
+ So we should call this function before exporting. This function will:
+
+ 1. generate corresponding inputs which are used to execute the model.
+ 2. Wrap the model's forward function.
+
+ For example, the MMDet models' forward function has a parameter
+ ``return_loss:bool``. As we want to set it as False while export API
+ supports neither bool type or kwargs. So we have to replace the forward
+ method like ``model.forward = partial(model.forward, return_loss=False)``.
+
+ Args:
+ config_path (str): the OpenMMLab config for the model we want to
+ export to ONNX
+ checkpoint_path (str): Path to the corresponding checkpoint
+ input_config (dict): the exactly data in this dict depends on the
+ framework. For MMSeg, we can just declare the input shape,
+ and generate the dummy data accordingly. However, for MMDet,
+ we may pass the real img path, or the NMS will return None
+ as there is no legal bbox.
+
+ Returns:
+ tuple: (model, tensor_data) wrapped model which can be called by
+ ``model(*tensor_data)`` and a list of inputs which are used to
+ execute the model while exporting.
+ """
+
+ model = build_model_from_cfg(
+ config_path, checkpoint_path, cfg_options=cfg_options)
+ one_img, one_meta = preprocess_example_input(input_config)
+ tensor_data = [one_img]
+ model.forward = partial(
+ model.forward, img_metas=[[one_meta]], return_loss=False)
+
+ # pytorch has some bug in pytorch1.3, we have to fix it
+ # by replacing these existing op
+ opset_version = 11
+ # put the import within the function thus it will not cause import error
+ # when not using this function
+ try:
+ from mmcv.onnx.symbolic import register_extra_symbolics
+ except ModuleNotFoundError:
+ raise NotImplementedError('please update mmcv to version>=v1.0.4')
+ register_extra_symbolics(opset_version)
+
+ return model, tensor_data
+
+
+def build_model_from_cfg(config_path, checkpoint_path, cfg_options=None):
+ """Build a model from config and load the given checkpoint.
+
+ Args:
+ config_path (str): the OpenMMLab config for the model we want to
+ export to ONNX
+ checkpoint_path (str): Path to the corresponding checkpoint
+
+ Returns:
+ torch.nn.Module: the built model
+ """
+ from mmdet.models import build_detector
+
+ cfg = mmcv.Config.fromfile(config_path)
+ if cfg_options is not None:
+ cfg.merge_from_dict(cfg_options)
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+ cfg.model.pretrained = None
+ cfg.data.test.test_mode = True
+
+ # build the model
+ cfg.model.train_cfg = None
+ model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
+ checkpoint = load_checkpoint(model, checkpoint_path, map_location='cpu')
+ if 'CLASSES' in checkpoint.get('meta', {}):
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ else:
+ from mmdet.datasets import DATASETS
+ dataset = DATASETS.get(cfg.data.test['type'])
+ assert (dataset is not None)
+ model.CLASSES = dataset.CLASSES
+ model.cpu().eval()
+ return model
+
+
+def preprocess_example_input(input_config):
+ """Prepare an example input image for ``generate_inputs_and_wrap_model``.
+
+ Args:
+ input_config (dict): customized config describing the example input.
+
+ Returns:
+ tuple: (one_img, one_meta), tensor of the example input image and \
+ meta information for the example input image.
+
+ Examples:
+ >>> from mmdet.core.export import preprocess_example_input
+ >>> input_config = {
+ >>> 'input_shape': (1,3,224,224),
+ >>> 'input_path': 'demo/demo.jpg',
+ >>> 'normalize_cfg': {
+ >>> 'mean': (123.675, 116.28, 103.53),
+ >>> 'std': (58.395, 57.12, 57.375)
+ >>> }
+ >>> }
+ >>> one_img, one_meta = preprocess_example_input(input_config)
+ >>> print(one_img.shape)
+ torch.Size([1, 3, 224, 224])
+ >>> print(one_meta)
+ {'img_shape': (224, 224, 3),
+ 'ori_shape': (224, 224, 3),
+ 'pad_shape': (224, 224, 3),
+ 'filename': '.png',
+ 'scale_factor': 1.0,
+ 'flip': False}
+ """
+ input_path = input_config['input_path']
+ input_shape = input_config['input_shape']
+ one_img = mmcv.imread(input_path)
+ one_img = mmcv.imresize(one_img, input_shape[2:][::-1])
+ show_img = one_img.copy()
+ if 'normalize_cfg' in input_config.keys():
+ normalize_cfg = input_config['normalize_cfg']
+ mean = np.array(normalize_cfg['mean'], dtype=np.float32)
+ std = np.array(normalize_cfg['std'], dtype=np.float32)
+ to_rgb = normalize_cfg.get('to_rgb', True)
+ one_img = mmcv.imnormalize(one_img, mean, std, to_rgb=to_rgb)
+ one_img = one_img.transpose(2, 0, 1)
+ one_img = torch.from_numpy(one_img).unsqueeze(0).float().requires_grad_(
+ True)
+ (_, C, H, W) = input_shape
+ one_meta = {
+ 'img_shape': (H, W, C),
+ 'ori_shape': (H, W, C),
+ 'pad_shape': (H, W, C),
+ 'filename': '.png',
+ 'scale_factor': np.ones(4, dtype=np.float32),
+ 'flip': False,
+ 'show_img': show_img,
+ 'flip_direction': None
+ }
+
+ return one_img, one_meta
diff --git a/mmdet/core/hook/__init__.py b/mmdet/core/hook/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b9ac9ff3efcff73c44d34dd9ce699da5c009534
--- /dev/null
+++ b/mmdet/core/hook/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .checkloss_hook import CheckInvalidLossHook
+from .ema import ExpMomentumEMAHook, LinearMomentumEMAHook
+from .memory_profiler_hook import MemoryProfilerHook
+from .set_epoch_info_hook import SetEpochInfoHook
+from .sync_norm_hook import SyncNormHook
+from .sync_random_size_hook import SyncRandomSizeHook
+from .wandblogger_hook import MMDetWandbHook
+from .yolox_lrupdater_hook import YOLOXLrUpdaterHook
+from .yolox_mode_switch_hook import YOLOXModeSwitchHook
+
+__all__ = [
+ 'SyncRandomSizeHook', 'YOLOXModeSwitchHook', 'SyncNormHook',
+ 'ExpMomentumEMAHook', 'LinearMomentumEMAHook', 'YOLOXLrUpdaterHook',
+ 'CheckInvalidLossHook', 'SetEpochInfoHook', 'MemoryProfilerHook',
+ 'MMDetWandbHook'
+]
diff --git a/mmdet/core/hook/checkloss_hook.py b/mmdet/core/hook/checkloss_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..754e61bef87dd074f4b7a06943b7db7060d5f1e6
--- /dev/null
+++ b/mmdet/core/hook/checkloss_hook.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.runner.hooks import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class CheckInvalidLossHook(Hook):
+ """Check invalid loss hook.
+
+ This hook will regularly check whether the loss is valid
+ during training.
+
+ Args:
+ interval (int): Checking interval (every k iterations).
+ Default: 50.
+ """
+
+ def __init__(self, interval=50):
+ self.interval = interval
+
+ def after_train_iter(self, runner):
+ if self.every_n_iters(runner, self.interval):
+ assert torch.isfinite(runner.outputs['loss']), \
+ runner.logger.info('loss become infinite or NaN!')
diff --git a/mmdet/core/hook/ema.py b/mmdet/core/hook/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff7bfbabe0284db6f7396dbaa66656f3b7bfc9ba
--- /dev/null
+++ b/mmdet/core/hook/ema.py
@@ -0,0 +1,130 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+from mmcv.parallel import is_module_wrapper
+from mmcv.runner.hooks import HOOKS, Hook
+
+
+class BaseEMAHook(Hook):
+ """Exponential Moving Average Hook.
+
+ Use Exponential Moving Average on all parameters of model in training
+ process. All parameters have a ema backup, which update by the formula
+ as below. EMAHook takes priority over EvalHook and CheckpointHook. Note,
+ the original model parameters are actually saved in ema field after train.
+
+ Args:
+ momentum (float): The momentum used for updating ema parameter.
+ Ema's parameter are updated with the formula:
+ `ema_param = (1-momentum) * ema_param + momentum * cur_param`.
+ Defaults to 0.0002.
+ skip_buffers (bool): Whether to skip the model buffers, such as
+ batchnorm running stats (running_mean, running_var), it does not
+ perform the ema operation. Default to False.
+ interval (int): Update ema parameter every interval iteration.
+ Defaults to 1.
+ resume_from (str, optional): The checkpoint path. Defaults to None.
+ momentum_fun (func, optional): The function to change momentum
+ during early iteration (also warmup) to help early training.
+ It uses `momentum` as a constant. Defaults to None.
+ """
+
+ def __init__(self,
+ momentum=0.0002,
+ interval=1,
+ skip_buffers=False,
+ resume_from=None,
+ momentum_fun=None):
+ assert 0 < momentum < 1
+ self.momentum = momentum
+ self.skip_buffers = skip_buffers
+ self.interval = interval
+ self.checkpoint = resume_from
+ self.momentum_fun = momentum_fun
+
+ def before_run(self, runner):
+ """To resume model with it's ema parameters more friendly.
+
+ Register ema parameter as ``named_buffer`` to model.
+ """
+ model = runner.model
+ if is_module_wrapper(model):
+ model = model.module
+ self.param_ema_buffer = {}
+ if self.skip_buffers:
+ self.model_parameters = dict(model.named_parameters())
+ else:
+ self.model_parameters = model.state_dict()
+ for name, value in self.model_parameters.items():
+ # "." is not allowed in module's buffer name
+ buffer_name = f"ema_{name.replace('.', '_')}"
+ self.param_ema_buffer[name] = buffer_name
+ model.register_buffer(buffer_name, value.data.clone())
+ self.model_buffers = dict(model.named_buffers())
+ if self.checkpoint is not None:
+ runner.resume(self.checkpoint)
+
+ def get_momentum(self, runner):
+ return self.momentum_fun(runner.iter) if self.momentum_fun else \
+ self.momentum
+
+ def after_train_iter(self, runner):
+ """Update ema parameter every self.interval iterations."""
+ if (runner.iter + 1) % self.interval != 0:
+ return
+ momentum = self.get_momentum(runner)
+ for name, parameter in self.model_parameters.items():
+ # exclude num_tracking
+ if parameter.dtype.is_floating_point:
+ buffer_name = self.param_ema_buffer[name]
+ buffer_parameter = self.model_buffers[buffer_name]
+ buffer_parameter.mul_(1 - momentum).add_(
+ parameter.data, alpha=momentum)
+
+ def after_train_epoch(self, runner):
+ """We load parameter values from ema backup to model before the
+ EvalHook."""
+ self._swap_ema_parameters()
+
+ def before_train_epoch(self, runner):
+ """We recover model's parameter from ema backup after last epoch's
+ EvalHook."""
+ self._swap_ema_parameters()
+
+ def _swap_ema_parameters(self):
+ """Swap the parameter of model with parameter in ema_buffer."""
+ for name, value in self.model_parameters.items():
+ temp = value.data.clone()
+ ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
+ value.data.copy_(ema_buffer.data)
+ ema_buffer.data.copy_(temp)
+
+
+@HOOKS.register_module()
+class ExpMomentumEMAHook(BaseEMAHook):
+ """EMAHook using exponential momentum strategy.
+
+ Args:
+ total_iter (int): The total number of iterations of EMA momentum.
+ Defaults to 2000.
+ """
+
+ def __init__(self, total_iter=2000, **kwargs):
+ super(ExpMomentumEMAHook, self).__init__(**kwargs)
+ self.momentum_fun = lambda x: (1 - self.momentum) * math.exp(-(
+ 1 + x) / total_iter) + self.momentum
+
+
+@HOOKS.register_module()
+class LinearMomentumEMAHook(BaseEMAHook):
+ """EMAHook using linear momentum strategy.
+
+ Args:
+ warm_up (int): During first warm_up steps, we may use smaller decay
+ to update ema parameters more slowly. Defaults to 100.
+ """
+
+ def __init__(self, warm_up=100, **kwargs):
+ super(LinearMomentumEMAHook, self).__init__(**kwargs)
+ self.momentum_fun = lambda x: min(self.momentum**self.interval,
+ (1 + x) / (warm_up + x))
diff --git a/mmdet/core/hook/memory_profiler_hook.py b/mmdet/core/hook/memory_profiler_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..a473061b566f92f4bee6280ec33875e2c50a51dd
--- /dev/null
+++ b/mmdet/core/hook/memory_profiler_hook.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.runner.hooks import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class MemoryProfilerHook(Hook):
+ """Memory profiler hook recording memory information including virtual
+ memory, swap memory, and the memory of the current process.
+
+ Args:
+ interval (int): Checking interval (every k iterations).
+ Default: 50.
+ """
+
+ def __init__(self, interval=50):
+ try:
+ from psutil import swap_memory, virtual_memory
+ self._swap_memory = swap_memory
+ self._virtual_memory = virtual_memory
+ except ImportError:
+ raise ImportError('psutil is not installed, please install it by: '
+ 'pip install psutil')
+
+ try:
+ from memory_profiler import memory_usage
+ self._memory_usage = memory_usage
+ except ImportError:
+ raise ImportError(
+ 'memory_profiler is not installed, please install it by: '
+ 'pip install memory_profiler')
+
+ self.interval = interval
+
+ def after_iter(self, runner):
+ if self.every_n_iters(runner, self.interval):
+ # in Byte
+ virtual_memory = self._virtual_memory()
+ swap_memory = self._swap_memory()
+ # in MB
+ process_memory = self._memory_usage()[0]
+ factor = 1024 * 1024
+ runner.logger.info(
+ 'Memory information '
+ 'available_memory: '
+ f'{round(virtual_memory.available / factor)} MB, '
+ 'used_memory: '
+ f'{round(virtual_memory.used / factor)} MB, '
+ f'memory_utilization: {virtual_memory.percent} %, '
+ 'available_swap_memory: '
+ f'{round((swap_memory.total - swap_memory.used) / factor)}'
+ ' MB, '
+ f'used_swap_memory: {round(swap_memory.used / factor)} MB, '
+ f'swap_memory_utilization: {swap_memory.percent} %, '
+ 'current_process_memory: '
+ f'{round(process_memory)} MB')
diff --git a/mmdet/core/hook/set_epoch_info_hook.py b/mmdet/core/hook/set_epoch_info_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2b134ceb69856338097cf283f67d7e2c580739f
--- /dev/null
+++ b/mmdet/core/hook/set_epoch_info_hook.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.parallel import is_module_wrapper
+from mmcv.runner import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class SetEpochInfoHook(Hook):
+ """Set runner's epoch information to the model."""
+
+ def before_train_epoch(self, runner):
+ epoch = runner.epoch
+ model = runner.model
+ if is_module_wrapper(model):
+ model = model.module
+ model.set_epoch(epoch)
diff --git a/mmdet/core/hook/sync_norm_hook.py b/mmdet/core/hook/sync_norm_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..82931cef3bcaba0521a0d9c56cff1e5f50fe8db7
--- /dev/null
+++ b/mmdet/core/hook/sync_norm_hook.py
@@ -0,0 +1,52 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+from mmcv.runner import get_dist_info
+from mmcv.runner.hooks import HOOKS, Hook
+from torch import nn
+
+from ..utils.dist_utils import all_reduce_dict
+
+
+def get_norm_states(module):
+ async_norm_states = OrderedDict()
+ for name, child in module.named_modules():
+ if isinstance(child, nn.modules.batchnorm._NormBase):
+ for k, v in child.state_dict().items():
+ async_norm_states['.'.join([name, k])] = v
+ return async_norm_states
+
+
+@HOOKS.register_module()
+class SyncNormHook(Hook):
+ """Synchronize Norm states after training epoch, currently used in YOLOX.
+
+ Args:
+ num_last_epochs (int): The number of latter epochs in the end of the
+ training to switch to synchronizing norm interval. Default: 15.
+ interval (int): Synchronizing norm interval. Default: 1.
+ """
+
+ def __init__(self, num_last_epochs=15, interval=1):
+ self.interval = interval
+ self.num_last_epochs = num_last_epochs
+
+ def before_train_epoch(self, runner):
+ epoch = runner.epoch
+ if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
+ # Synchronize norm every epoch.
+ self.interval = 1
+
+ def after_train_epoch(self, runner):
+ """Synchronizing norm."""
+ epoch = runner.epoch
+ module = runner.model
+ if (epoch + 1) % self.interval == 0:
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return
+ norm_states = get_norm_states(module)
+ if len(norm_states) == 0:
+ return
+ norm_states = all_reduce_dict(norm_states, op='mean')
+ module.load_state_dict(norm_states, strict=False)
diff --git a/mmdet/core/hook/sync_random_size_hook.py b/mmdet/core/hook/sync_random_size_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d7e96c6aaf5207faef9bd835806bdded475bd72
--- /dev/null
+++ b/mmdet/core/hook/sync_random_size_hook.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import random
+import warnings
+
+import torch
+from mmcv.runner import get_dist_info
+from mmcv.runner.hooks import HOOKS, Hook
+from torch import distributed as dist
+
+
+@HOOKS.register_module()
+class SyncRandomSizeHook(Hook):
+ """Change and synchronize the random image size across ranks.
+ SyncRandomSizeHook is deprecated, please use Resize pipeline to achieve
+ similar functions. Such as `dict(type='Resize', img_scale=[(448, 448),
+ (832, 832)], multiscale_mode='range', keep_ratio=True)`.
+
+ Note: Due to the multi-process dataloader, its behavior is different
+ from YOLOX's official implementation, the official is to change the
+ size every fixed iteration interval and what we achieved is a fixed
+ epoch interval.
+
+ Args:
+ ratio_range (tuple[int]): Random ratio range. It will be multiplied
+ by 32, and then change the dataset output image size.
+ Default: (14, 26).
+ img_scale (tuple[int]): Size of input image. Default: (640, 640).
+ interval (int): The epoch interval of change image size. Default: 1.
+ device (torch.device | str): device for returned tensors.
+ Default: 'cuda'.
+ """
+
+ def __init__(self,
+ ratio_range=(14, 26),
+ img_scale=(640, 640),
+ interval=1,
+ device='cuda'):
+ warnings.warn('DeprecationWarning: SyncRandomSizeHook is deprecated. '
+ 'Please use Resize pipeline to achieve similar '
+ 'functions. Due to the multi-process dataloader, '
+ 'its behavior is different from YOLOX\'s official '
+ 'implementation, the official is to change the size '
+ 'every fixed iteration interval and what we achieved '
+ 'is a fixed epoch interval.')
+ self.rank, world_size = get_dist_info()
+ self.is_distributed = world_size > 1
+ self.ratio_range = ratio_range
+ self.img_scale = img_scale
+ self.interval = interval
+ self.device = device
+
+ def after_train_epoch(self, runner):
+ """Change the dataset output image size."""
+ if self.ratio_range is not None and (runner.epoch +
+ 1) % self.interval == 0:
+ # Due to DDP and DP get the device behavior inconsistent,
+ # so we did not get the device from runner.model.
+ tensor = torch.LongTensor(2).to(self.device)
+
+ if self.rank == 0:
+ size_factor = self.img_scale[1] * 1. / self.img_scale[0]
+ size = random.randint(*self.ratio_range)
+ size = (int(32 * size), 32 * int(size * size_factor))
+ tensor[0] = size[0]
+ tensor[1] = size[1]
+
+ if self.is_distributed:
+ dist.barrier()
+ dist.broadcast(tensor, 0)
+
+ runner.data_loader.dataset.update_dynamic_scale(
+ (tensor[0].item(), tensor[1].item()))
diff --git a/mmdet/core/hook/wandblogger_hook.py b/mmdet/core/hook/wandblogger_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bf252f0113cf2b59dba843f1c23db5b62faf982
--- /dev/null
+++ b/mmdet/core/hook/wandblogger_hook.py
@@ -0,0 +1,593 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import importlib
+import os.path as osp
+import sys
+import warnings
+
+import mmcv
+import numpy as np
+import pycocotools.mask as mask_util
+from mmcv.runner import HOOKS
+from mmcv.runner.dist_utils import master_only
+from mmcv.runner.hooks.checkpoint import CheckpointHook
+from mmcv.runner.hooks.logger.wandb import WandbLoggerHook
+from mmcv.utils import digit_version
+
+from mmdet.core import DistEvalHook, EvalHook
+from mmdet.core.mask.structures import polygon_to_bitmap
+
+
+@HOOKS.register_module()
+class MMDetWandbHook(WandbLoggerHook):
+ """Enhanced Wandb logger hook for MMDetection.
+
+ Comparing with the :cls:`mmcv.runner.WandbLoggerHook`, this hook can not
+ only automatically log all the metrics but also log the following extra
+ information - saves model checkpoints as W&B Artifact, and
+ logs model prediction as interactive W&B Tables.
+
+ - Metrics: The MMDetWandbHook will automatically log training
+ and validation metrics along with system metrics (CPU/GPU).
+
+ - Checkpointing: If `log_checkpoint` is True, the checkpoint saved at
+ every checkpoint interval will be saved as W&B Artifacts.
+ This depends on the : class:`mmcv.runner.CheckpointHook` whose priority
+ is higher than this hook. Please refer to
+ https://docs.wandb.ai/guides/artifacts/model-versioning
+ to learn more about model versioning with W&B Artifacts.
+
+ - Checkpoint Metadata: If evaluation results are available for a given
+ checkpoint artifact, it will have a metadata associated with it.
+ The metadata contains the evaluation metrics computed on validation
+ data with that checkpoint along with the current epoch. It depends
+ on `EvalHook` whose priority is more than MMDetWandbHook.
+
+ - Evaluation: At every evaluation interval, the `MMDetWandbHook` logs the
+ model prediction as interactive W&B Tables. The number of samples
+ logged is given by `num_eval_images`. Currently, the `MMDetWandbHook`
+ logs the predicted bounding boxes along with the ground truth at every
+ evaluation interval. This depends on the `EvalHook` whose priority is
+ more than `MMDetWandbHook`. Also note that the data is just logged once
+ and subsequent evaluation tables uses reference to the logged data
+ to save memory usage. Please refer to
+ https://docs.wandb.ai/guides/data-vis to learn more about W&B Tables.
+
+ For more details check out W&B's MMDetection docs:
+ https://docs.wandb.ai/guides/integrations/mmdetection
+
+ ```
+ Example:
+ log_config = dict(
+ ...
+ hooks=[
+ ...,
+ dict(type='MMDetWandbHook',
+ init_kwargs={
+ 'entity': "YOUR_ENTITY",
+ 'project': "YOUR_PROJECT_NAME"
+ },
+ interval=50,
+ log_checkpoint=True,
+ log_checkpoint_metadata=True,
+ num_eval_images=100,
+ bbox_score_thr=0.3)
+ ])
+ ```
+
+ Args:
+ init_kwargs (dict): A dict passed to wandb.init to initialize
+ a W&B run. Please refer to https://docs.wandb.ai/ref/python/init
+ for possible key-value pairs.
+ interval (int): Logging interval (every k iterations). Defaults to 50.
+ log_checkpoint (bool): Save the checkpoint at every checkpoint interval
+ as W&B Artifacts. Use this for model versioning where each version
+ is a checkpoint. Defaults to False.
+ log_checkpoint_metadata (bool): Log the evaluation metrics computed
+ on the validation data with the checkpoint, along with current
+ epoch as a metadata to that checkpoint.
+ Defaults to True.
+ num_eval_images (int): The number of validation images to be logged.
+ If zero, the evaluation won't be logged. Defaults to 100.
+ bbox_score_thr (float): Threshold for bounding box scores.
+ Defaults to 0.3.
+ """
+
+ def __init__(self,
+ init_kwargs=None,
+ interval=50,
+ log_checkpoint=False,
+ log_checkpoint_metadata=False,
+ num_eval_images=100,
+ bbox_score_thr=0.3,
+ **kwargs):
+ super(MMDetWandbHook, self).__init__(init_kwargs, interval, **kwargs)
+
+ self.log_checkpoint = log_checkpoint
+ self.log_checkpoint_metadata = (
+ log_checkpoint and log_checkpoint_metadata)
+ self.num_eval_images = num_eval_images
+ self.bbox_score_thr = bbox_score_thr
+ self.log_evaluation = (num_eval_images > 0)
+ self.ckpt_hook: CheckpointHook = None
+ self.eval_hook: EvalHook = None
+
+ def import_wandb(self):
+ try:
+ import wandb
+ from wandb import init # noqa
+
+ # Fix ResourceWarning when calling wandb.log in wandb v0.12.10.
+ # https://github.com/wandb/client/issues/2837
+ if digit_version(wandb.__version__) < digit_version('0.12.10'):
+ warnings.warn(
+ f'The current wandb {wandb.__version__} is '
+ f'lower than v0.12.10 will cause ResourceWarning '
+ f'when calling wandb.log, Please run '
+ f'"pip install --upgrade wandb"')
+
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install "wandb>=0.12.10"" to install wandb')
+ self.wandb = wandb
+
+ @master_only
+ def before_run(self, runner):
+ super(MMDetWandbHook, self).before_run(runner)
+
+ # Save and Log config.
+ if runner.meta is not None and runner.meta.get('exp_name',
+ None) is not None:
+ src_cfg_path = osp.join(runner.work_dir,
+ runner.meta.get('exp_name', None))
+ if osp.exists(src_cfg_path):
+ self.wandb.save(src_cfg_path, base_path=runner.work_dir)
+ self._update_wandb_config(runner)
+ else:
+ runner.logger.warning('No meta information found in the runner. ')
+
+ # Inspect CheckpointHook and EvalHook
+ for hook in runner.hooks:
+ if isinstance(hook, CheckpointHook):
+ self.ckpt_hook = hook
+ if isinstance(hook, (EvalHook, DistEvalHook)):
+ self.eval_hook = hook
+
+ # Check conditions to log checkpoint
+ if self.log_checkpoint:
+ if self.ckpt_hook is None:
+ self.log_checkpoint = False
+ self.log_checkpoint_metadata = False
+ runner.logger.warning(
+ 'To log checkpoint in MMDetWandbHook, `CheckpointHook` is'
+ 'required, please check hooks in the runner.')
+ else:
+ self.ckpt_interval = self.ckpt_hook.interval
+
+ # Check conditions to log evaluation
+ if self.log_evaluation or self.log_checkpoint_metadata:
+ if self.eval_hook is None:
+ self.log_evaluation = False
+ self.log_checkpoint_metadata = False
+ runner.logger.warning(
+ 'To log evaluation or checkpoint metadata in '
+ 'MMDetWandbHook, `EvalHook` or `DistEvalHook` in mmdet '
+ 'is required, please check whether the validation '
+ 'is enabled.')
+ else:
+ self.eval_interval = self.eval_hook.interval
+ self.val_dataset = self.eval_hook.dataloader.dataset
+ # Determine the number of samples to be logged.
+ if self.num_eval_images > len(self.val_dataset):
+ self.num_eval_images = len(self.val_dataset)
+ runner.logger.warning(
+ f'The num_eval_images ({self.num_eval_images}) is '
+ 'greater than the total number of validation samples '
+ f'({len(self.val_dataset)}). The complete validation '
+ 'dataset will be logged.')
+
+ # Check conditions to log checkpoint metadata
+ if self.log_checkpoint_metadata:
+ assert self.ckpt_interval % self.eval_interval == 0, \
+ 'To log checkpoint metadata in MMDetWandbHook, the interval ' \
+ f'of checkpoint saving ({self.ckpt_interval}) should be ' \
+ 'divisible by the interval of evaluation ' \
+ f'({self.eval_interval}).'
+
+ # Initialize evaluation table
+ if self.log_evaluation:
+ # Initialize data table
+ self._init_data_table()
+ # Add data to the data table
+ self._add_ground_truth(runner)
+ # Log ground truth data
+ self._log_data_table()
+
+ @master_only
+ def after_train_epoch(self, runner):
+ super(MMDetWandbHook, self).after_train_epoch(runner)
+
+ if not self.by_epoch:
+ return
+
+ # Log checkpoint and metadata.
+ if (self.log_checkpoint
+ and self.every_n_epochs(runner, self.ckpt_interval)
+ or (self.ckpt_hook.save_last and self.is_last_epoch(runner))):
+ if self.log_checkpoint_metadata and self.eval_hook:
+ metadata = {
+ 'epoch': runner.epoch + 1,
+ **self._get_eval_results()
+ }
+ else:
+ metadata = None
+ aliases = [f'epoch_{runner.epoch + 1}', 'latest']
+ model_path = osp.join(self.ckpt_hook.out_dir,
+ f'epoch_{runner.epoch + 1}.pth')
+ self._log_ckpt_as_artifact(model_path, aliases, metadata)
+
+ # Save prediction table
+ if self.log_evaluation and self.eval_hook._should_evaluate(runner):
+ results = self.eval_hook.latest_results
+ # Initialize evaluation table
+ self._init_pred_table()
+ # Log predictions
+ self._log_predictions(results)
+ # Log the table
+ self._log_eval_table(runner.epoch + 1)
+
+ # for the reason of this double-layered structure, refer to
+ # https://github.com/open-mmlab/mmdetection/issues/8145#issuecomment-1345343076
+ def after_train_iter(self, runner):
+ if self.get_mode(runner) == 'train':
+ # An ugly patch. The iter-based eval hook will call the
+ # `after_train_iter` method of all logger hooks before evaluation.
+ # Use this trick to skip that call.
+ # Don't call super method at first, it will clear the log_buffer
+ return super(MMDetWandbHook, self).after_train_iter(runner)
+ else:
+ super(MMDetWandbHook, self).after_train_iter(runner)
+ self._after_train_iter(runner)
+
+ @master_only
+ def _after_train_iter(self, runner):
+ if self.by_epoch:
+ return
+
+ # Save checkpoint and metadata
+ if (self.log_checkpoint
+ and self.every_n_iters(runner, self.ckpt_interval)
+ or (self.ckpt_hook.save_last and self.is_last_iter(runner))):
+ if self.log_checkpoint_metadata and self.eval_hook:
+ metadata = {
+ 'iter': runner.iter + 1,
+ **self._get_eval_results()
+ }
+ else:
+ metadata = None
+ aliases = [f'iter_{runner.iter + 1}', 'latest']
+ model_path = osp.join(self.ckpt_hook.out_dir,
+ f'iter_{runner.iter + 1}.pth')
+ self._log_ckpt_as_artifact(model_path, aliases, metadata)
+
+ # Save prediction table
+ if self.log_evaluation and self.eval_hook._should_evaluate(runner):
+ results = self.eval_hook.latest_results
+ # Initialize evaluation table
+ self._init_pred_table()
+ # Log predictions
+ self._log_predictions(results)
+ # Log the table
+ self._log_eval_table(runner.iter + 1)
+
+ @master_only
+ def after_run(self, runner):
+ self.wandb.finish()
+
+ def _update_wandb_config(self, runner):
+ """Update wandb config."""
+ # Import the config file.
+ sys.path.append(runner.work_dir)
+ config_filename = runner.meta['exp_name'][:-3]
+ configs = importlib.import_module(config_filename)
+ # Prepare a nested dict of config variables.
+ config_keys = [key for key in dir(configs) if not key.startswith('__')]
+ config_dict = {key: getattr(configs, key) for key in config_keys}
+ # Update the W&B config.
+ self.wandb.config.update(config_dict)
+
+ def _log_ckpt_as_artifact(self, model_path, aliases, metadata=None):
+ """Log model checkpoint as W&B Artifact.
+
+ Args:
+ model_path (str): Path of the checkpoint to log.
+ aliases (list): List of the aliases associated with this artifact.
+ metadata (dict, optional): Metadata associated with this artifact.
+ """
+ model_artifact = self.wandb.Artifact(
+ f'run_{self.wandb.run.id}_model', type='model', metadata=metadata)
+ model_artifact.add_file(model_path)
+ self.wandb.log_artifact(model_artifact, aliases=aliases)
+
+ def _get_eval_results(self):
+ """Get model evaluation results."""
+ results = self.eval_hook.latest_results
+ eval_results = self.val_dataset.evaluate(
+ results, logger='silent', **self.eval_hook.eval_kwargs)
+ return eval_results
+
+ def _init_data_table(self):
+ """Initialize the W&B Tables for validation data."""
+ columns = ['image_name', 'image']
+ self.data_table = self.wandb.Table(columns=columns)
+
+ def _init_pred_table(self):
+ """Initialize the W&B Tables for model evaluation."""
+ columns = ['image_name', 'ground_truth', 'prediction']
+ self.eval_table = self.wandb.Table(columns=columns)
+
+ def _add_ground_truth(self, runner):
+ # Get image loading pipeline
+ from mmdet.datasets.pipelines import LoadImageFromFile
+ img_loader = None
+ for t in self.val_dataset.pipeline.transforms:
+ if isinstance(t, LoadImageFromFile):
+ img_loader = t
+
+ if img_loader is None:
+ self.log_evaluation = False
+ runner.logger.warning(
+ 'LoadImageFromFile is required to add images '
+ 'to W&B Tables.')
+ return
+
+ # Select the images to be logged.
+ self.eval_image_indexs = np.arange(len(self.val_dataset))
+ # Set seed so that same validation set is logged each time.
+ np.random.seed(42)
+ np.random.shuffle(self.eval_image_indexs)
+ self.eval_image_indexs = self.eval_image_indexs[:self.num_eval_images]
+
+ CLASSES = self.val_dataset.CLASSES
+ self.class_id_to_label = {
+ id + 1: name
+ for id, name in enumerate(CLASSES)
+ }
+ self.class_set = self.wandb.Classes([{
+ 'id': id,
+ 'name': name
+ } for id, name in self.class_id_to_label.items()])
+
+ img_prefix = self.val_dataset.img_prefix
+
+ for idx in self.eval_image_indexs:
+ img_info = self.val_dataset.data_infos[idx]
+ image_name = img_info.get('filename', f'img_{idx}')
+ img_height, img_width = img_info['height'], img_info['width']
+
+ img_meta = img_loader(
+ dict(img_info=img_info, img_prefix=img_prefix))
+
+ # Get image and convert from BGR to RGB
+ image = mmcv.bgr2rgb(img_meta['img'])
+
+ data_ann = self.val_dataset.get_ann_info(idx)
+ bboxes = data_ann['bboxes']
+ labels = data_ann['labels']
+ masks = data_ann.get('masks', None)
+
+ # Get dict of bounding boxes to be logged.
+ assert len(bboxes) == len(labels)
+ wandb_boxes = self._get_wandb_bboxes(bboxes, labels)
+
+ # Get dict of masks to be logged.
+ if masks is not None:
+ wandb_masks = self._get_wandb_masks(
+ masks,
+ labels,
+ is_poly_mask=True,
+ height=img_height,
+ width=img_width)
+ else:
+ wandb_masks = None
+ # TODO: Panoramic segmentation visualization.
+
+ # Log a row to the data table.
+ self.data_table.add_data(
+ image_name,
+ self.wandb.Image(
+ image,
+ boxes=wandb_boxes,
+ masks=wandb_masks,
+ classes=self.class_set))
+
+ def _log_predictions(self, results):
+ table_idxs = self.data_table_ref.get_index()
+ assert len(table_idxs) == len(self.eval_image_indexs)
+
+ for ndx, eval_image_index in enumerate(self.eval_image_indexs):
+ # Get the result
+ result = results[eval_image_index]
+ if isinstance(result, tuple):
+ bbox_result, segm_result = result
+ if isinstance(segm_result, tuple):
+ segm_result = segm_result[0] # ms rcnn
+ else:
+ bbox_result, segm_result = result, None
+ assert len(bbox_result) == len(self.class_id_to_label)
+
+ # Get labels
+ bboxes = np.vstack(bbox_result)
+ labels = [
+ np.full(bbox.shape[0], i, dtype=np.int32)
+ for i, bbox in enumerate(bbox_result)
+ ]
+ labels = np.concatenate(labels)
+
+ # Get segmentation mask if available.
+ segms = None
+ if segm_result is not None and len(labels) > 0:
+ segms = mmcv.concat_list(segm_result)
+ segms = mask_util.decode(segms)
+ segms = segms.transpose(2, 0, 1)
+ assert len(segms) == len(labels)
+ # TODO: Panoramic segmentation visualization.
+
+ # Remove bounding boxes and masks with score lower than threshold.
+ if self.bbox_score_thr > 0:
+ assert bboxes is not None and bboxes.shape[1] == 5
+ scores = bboxes[:, -1]
+ inds = scores > self.bbox_score_thr
+ bboxes = bboxes[inds, :]
+ labels = labels[inds]
+ if segms is not None:
+ segms = segms[inds, ...]
+
+ # Get dict of bounding boxes to be logged.
+ wandb_boxes = self._get_wandb_bboxes(bboxes, labels, log_gt=False)
+ # Get dict of masks to be logged.
+ if segms is not None:
+ wandb_masks = self._get_wandb_masks(segms, labels)
+ else:
+ wandb_masks = None
+
+ # Log a row to the eval table.
+ self.eval_table.add_data(
+ self.data_table_ref.data[ndx][0],
+ self.data_table_ref.data[ndx][1],
+ self.wandb.Image(
+ self.data_table_ref.data[ndx][1],
+ boxes=wandb_boxes,
+ masks=wandb_masks,
+ classes=self.class_set))
+
+ def _get_wandb_bboxes(self, bboxes, labels, log_gt=True):
+ """Get list of structured dict for logging bounding boxes to W&B.
+
+ Args:
+ bboxes (list): List of bounding box coordinates in
+ (minX, minY, maxX, maxY) format.
+ labels (int): List of label ids.
+ log_gt (bool): Whether to log ground truth or prediction boxes.
+
+ Returns:
+ Dictionary of bounding boxes to be logged.
+ """
+ wandb_boxes = {}
+
+ box_data = []
+ for bbox, label in zip(bboxes, labels):
+ if not isinstance(label, int):
+ label = int(label)
+ label = label + 1
+
+ if len(bbox) == 5:
+ confidence = float(bbox[4])
+ class_name = self.class_id_to_label[label]
+ box_caption = f'{class_name} {confidence:.2f}'
+ else:
+ box_caption = str(self.class_id_to_label[label])
+
+ position = dict(
+ minX=int(bbox[0]),
+ minY=int(bbox[1]),
+ maxX=int(bbox[2]),
+ maxY=int(bbox[3]))
+
+ box_data.append({
+ 'position': position,
+ 'class_id': label,
+ 'box_caption': box_caption,
+ 'domain': 'pixel'
+ })
+
+ wandb_bbox_dict = {
+ 'box_data': box_data,
+ 'class_labels': self.class_id_to_label
+ }
+
+ if log_gt:
+ wandb_boxes['ground_truth'] = wandb_bbox_dict
+ else:
+ wandb_boxes['predictions'] = wandb_bbox_dict
+
+ return wandb_boxes
+
+ def _get_wandb_masks(self,
+ masks,
+ labels,
+ is_poly_mask=False,
+ height=None,
+ width=None):
+ """Get list of structured dict for logging masks to W&B.
+
+ Args:
+ masks (list): List of masks.
+ labels (int): List of label ids.
+ is_poly_mask (bool): Whether the mask is polygonal or not.
+ This is true for CocoDataset.
+ height (int): Height of the image.
+ width (int): Width of the image.
+
+ Returns:
+ Dictionary of masks to be logged.
+ """
+ mask_label_dict = dict()
+ for mask, label in zip(masks, labels):
+ label = label + 1
+ # Get bitmap mask from polygon.
+ if is_poly_mask:
+ if height is not None and width is not None:
+ mask = polygon_to_bitmap(mask, height, width)
+ # Create composite masks for each class.
+ if label not in mask_label_dict.keys():
+ mask_label_dict[label] = mask
+ else:
+ mask_label_dict[label] = np.logical_or(mask_label_dict[label],
+ mask)
+
+ wandb_masks = dict()
+ for key, value in mask_label_dict.items():
+ # Create mask for that class.
+ value = value.astype(np.uint8)
+ value[value > 0] = key
+
+ # Create dict of masks for logging.
+ class_name = self.class_id_to_label[key]
+ wandb_masks[class_name] = {
+ 'mask_data': value,
+ 'class_labels': self.class_id_to_label
+ }
+
+ return wandb_masks
+
+ def _log_data_table(self):
+ """Log the W&B Tables for validation data as artifact and calls
+ `use_artifact` on it so that the evaluation table can use the reference
+ of already uploaded images.
+
+ This allows the data to be uploaded just once.
+ """
+ data_artifact = self.wandb.Artifact('val', type='dataset')
+ data_artifact.add(self.data_table, 'val_data')
+
+ if not self.wandb.run.offline:
+ self.wandb.run.use_artifact(data_artifact)
+ data_artifact.wait()
+ self.data_table_ref = data_artifact.get('val_data')
+ else:
+ self.data_table_ref = self.data_table
+
+ def _log_eval_table(self, idx):
+ """Log the W&B Tables for model evaluation.
+
+ The table will be logged multiple times creating new version. Use this
+ to compare models at different intervals interactively.
+ """
+ pred_artifact = self.wandb.Artifact(
+ f'run_{self.wandb.run.id}_pred', type='evaluation')
+ pred_artifact.add(self.eval_table, 'eval_data')
+ if self.by_epoch:
+ aliases = ['latest', f'epoch_{idx}']
+ else:
+ aliases = ['latest', f'iter_{idx}']
+ self.wandb.run.log_artifact(pred_artifact, aliases=aliases)
diff --git a/mmdet/core/hook/yolox_lrupdater_hook.py b/mmdet/core/hook/yolox_lrupdater_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecb028ed252047dd07086eef18d3b0e5abc778c0
--- /dev/null
+++ b/mmdet/core/hook/yolox_lrupdater_hook.py
@@ -0,0 +1,67 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.runner.hooks import HOOKS
+from mmcv.runner.hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
+ annealing_cos)
+
+
+@HOOKS.register_module()
+class YOLOXLrUpdaterHook(CosineAnnealingLrUpdaterHook):
+ """YOLOX learning rate scheme.
+
+ There are two main differences between YOLOXLrUpdaterHook
+ and CosineAnnealingLrUpdaterHook.
+
+ 1. When the current running epoch is greater than
+ `max_epoch-last_epoch`, a fixed learning rate will be used
+ 2. The exp warmup scheme is different with LrUpdaterHook in MMCV
+
+ Args:
+ num_last_epochs (int): The number of epochs with a fixed learning rate
+ before the end of the training.
+ """
+
+ def __init__(self, num_last_epochs, **kwargs):
+ self.num_last_epochs = num_last_epochs
+ super(YOLOXLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_warmup_lr(self, cur_iters):
+
+ def _get_warmup_lr(cur_iters, regular_lr):
+ # exp warmup scheme
+ k = self.warmup_ratio * pow(
+ (cur_iters + 1) / float(self.warmup_iters), 2)
+ warmup_lr = [_lr * k for _lr in regular_lr]
+ return warmup_lr
+
+ if isinstance(self.base_lr, dict):
+ lr_groups = {}
+ for key, base_lr in self.base_lr.items():
+ lr_groups[key] = _get_warmup_lr(cur_iters, base_lr)
+ return lr_groups
+ else:
+ return _get_warmup_lr(cur_iters, self.base_lr)
+
+ def get_lr(self, runner, base_lr):
+ last_iter = len(runner.data_loader) * self.num_last_epochs
+
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+
+ progress += 1
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+
+ if progress >= max_progress - last_iter:
+ # fixed learning rate
+ return target_lr
+ else:
+ return annealing_cos(
+ base_lr, target_lr, (progress - self.warmup_iters) /
+ (max_progress - self.warmup_iters - last_iter))
diff --git a/mmdet/core/hook/yolox_mode_switch_hook.py b/mmdet/core/hook/yolox_mode_switch_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..10834e686af5c7f70c1f01ce1bef0c707740aea5
--- /dev/null
+++ b/mmdet/core/hook/yolox_mode_switch_hook.py
@@ -0,0 +1,52 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.parallel import is_module_wrapper
+from mmcv.runner.hooks import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class YOLOXModeSwitchHook(Hook):
+ """Switch the mode of YOLOX during training.
+
+ This hook turns off the mosaic and mixup data augmentation and switches
+ to use L1 loss in bbox_head.
+
+ Args:
+ num_last_epochs (int): The number of latter epochs in the end of the
+ training to close the data augmentation and switch to L1 loss.
+ Default: 15.
+ skip_type_keys (list[str], optional): Sequence of type string to be
+ skip pipeline. Default: ('Mosaic', 'RandomAffine', 'MixUp')
+ """
+
+ def __init__(self,
+ num_last_epochs=15,
+ skip_type_keys=('Mosaic', 'RandomAffine', 'MixUp')):
+ self.num_last_epochs = num_last_epochs
+ self.skip_type_keys = skip_type_keys
+ self._restart_dataloader = False
+
+ def before_train_epoch(self, runner):
+ """Close mosaic and mixup augmentation and switches to use L1 loss."""
+ epoch = runner.epoch
+ train_loader = runner.data_loader
+ model = runner.model
+ if is_module_wrapper(model):
+ model = model.module
+ if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
+ runner.logger.info('No mosaic and mixup aug now!')
+ # The dataset pipeline cannot be updated when persistent_workers
+ # is True, so we need to force the dataloader's multi-process
+ # restart. This is a very hacky approach.
+ train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
+ if hasattr(train_loader, 'persistent_workers'
+ ) and train_loader.persistent_workers is True:
+ train_loader._DataLoader__initialized = False
+ train_loader._iterator = None
+ self._restart_dataloader = True
+ runner.logger.info('Add additional L1 loss now!')
+ model.bbox_head.use_l1 = True
+ else:
+ # Once the restart is complete, we need to restore
+ # the initialization flag.
+ if self._restart_dataloader:
+ train_loader._DataLoader__initialized = True
diff --git a/mmdet/core/mask/__init__.py b/mmdet/core/mask/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..644a9b1d9b4c2a557561da6c048f9056a1090526
--- /dev/null
+++ b/mmdet/core/mask/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .mask_target import mask_target
+from .structures import BaseInstanceMasks, BitmapMasks, PolygonMasks
+from .utils import encode_mask_results, mask2bbox, split_combined_polys
+
+__all__ = [
+ 'split_combined_polys', 'mask_target', 'BaseInstanceMasks', 'BitmapMasks',
+ 'PolygonMasks', 'encode_mask_results', 'mask2bbox'
+]
diff --git a/mmdet/core/mask/mask_target.py b/mmdet/core/mask/mask_target.py
new file mode 100644
index 0000000000000000000000000000000000000000..273e7678fc14cec9f34a88edf6d6cac6c04e30fb
--- /dev/null
+++ b/mmdet/core/mask/mask_target.py
@@ -0,0 +1,127 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+from torch.nn.modules.utils import _pair
+
+
+def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list,
+ cfg):
+ """Compute mask target for positive proposals in multiple images.
+
+ Args:
+ pos_proposals_list (list[Tensor]): Positive proposals in multiple
+ images.
+ pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each
+ positive proposals.
+ gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of
+ each image.
+ cfg (dict): Config dict that specifies the mask size.
+
+ Returns:
+ list[Tensor]: Mask target of each image.
+
+ Example:
+ >>> import mmcv
+ >>> import mmdet
+ >>> from mmdet.core.mask import BitmapMasks
+ >>> from mmdet.core.mask.mask_target import *
+ >>> H, W = 17, 18
+ >>> cfg = mmcv.Config({'mask_size': (13, 14)})
+ >>> rng = np.random.RandomState(0)
+ >>> # Positive proposals (tl_x, tl_y, br_x, br_y) for each image
+ >>> pos_proposals_list = [
+ >>> torch.Tensor([
+ >>> [ 7.2425, 5.5929, 13.9414, 14.9541],
+ >>> [ 7.3241, 3.6170, 16.3850, 15.3102],
+ >>> ]),
+ >>> torch.Tensor([
+ >>> [ 4.8448, 6.4010, 7.0314, 9.7681],
+ >>> [ 5.9790, 2.6989, 7.4416, 4.8580],
+ >>> [ 0.0000, 0.0000, 0.1398, 9.8232],
+ >>> ]),
+ >>> ]
+ >>> # Corresponding class index for each proposal for each image
+ >>> pos_assigned_gt_inds_list = [
+ >>> torch.LongTensor([7, 0]),
+ >>> torch.LongTensor([5, 4, 1]),
+ >>> ]
+ >>> # Ground truth mask for each true object for each image
+ >>> gt_masks_list = [
+ >>> BitmapMasks(rng.rand(8, H, W), height=H, width=W),
+ >>> BitmapMasks(rng.rand(6, H, W), height=H, width=W),
+ >>> ]
+ >>> mask_targets = mask_target(
+ >>> pos_proposals_list, pos_assigned_gt_inds_list,
+ >>> gt_masks_list, cfg)
+ >>> assert mask_targets.shape == (5,) + cfg['mask_size']
+ """
+ cfg_list = [cfg for _ in range(len(pos_proposals_list))]
+ mask_targets = map(mask_target_single, pos_proposals_list,
+ pos_assigned_gt_inds_list, gt_masks_list, cfg_list)
+ mask_targets = list(mask_targets)
+ if len(mask_targets) > 0:
+ mask_targets = torch.cat(mask_targets)
+ return mask_targets
+
+
+def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg):
+ """Compute mask target for each positive proposal in the image.
+
+ Args:
+ pos_proposals (Tensor): Positive proposals.
+ pos_assigned_gt_inds (Tensor): Assigned GT inds of positive proposals.
+ gt_masks (:obj:`BaseInstanceMasks`): GT masks in the format of Bitmap
+ or Polygon.
+ cfg (dict): Config dict that indicate the mask size.
+
+ Returns:
+ Tensor: Mask target of each positive proposals in the image.
+
+ Example:
+ >>> import mmcv
+ >>> import mmdet
+ >>> from mmdet.core.mask import BitmapMasks
+ >>> from mmdet.core.mask.mask_target import * # NOQA
+ >>> H, W = 32, 32
+ >>> cfg = mmcv.Config({'mask_size': (7, 11)})
+ >>> rng = np.random.RandomState(0)
+ >>> # Masks for each ground truth box (relative to the image)
+ >>> gt_masks_data = rng.rand(3, H, W)
+ >>> gt_masks = BitmapMasks(gt_masks_data, height=H, width=W)
+ >>> # Predicted positive boxes in one image
+ >>> pos_proposals = torch.FloatTensor([
+ >>> [ 16.2, 5.5, 19.9, 20.9],
+ >>> [ 17.3, 13.6, 19.3, 19.3],
+ >>> [ 14.8, 16.4, 17.0, 23.7],
+ >>> [ 0.0, 0.0, 16.0, 16.0],
+ >>> [ 4.0, 0.0, 20.0, 16.0],
+ >>> ])
+ >>> # For each predicted proposal, its assignment to a gt mask
+ >>> pos_assigned_gt_inds = torch.LongTensor([0, 1, 2, 1, 1])
+ >>> mask_targets = mask_target_single(
+ >>> pos_proposals, pos_assigned_gt_inds, gt_masks, cfg)
+ >>> assert mask_targets.shape == (5,) + cfg['mask_size']
+ """
+ device = pos_proposals.device
+ mask_size = _pair(cfg.mask_size)
+ binarize = not cfg.get('soft_mask_target', False)
+ num_pos = pos_proposals.size(0)
+ if num_pos > 0:
+ proposals_np = pos_proposals.cpu().numpy()
+ maxh, maxw = gt_masks.height, gt_masks.width
+ proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw)
+ proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh)
+ pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
+
+ mask_targets = gt_masks.crop_and_resize(
+ proposals_np,
+ mask_size,
+ device=device,
+ inds=pos_assigned_gt_inds,
+ binarize=binarize).to_ndarray()
+
+ mask_targets = torch.from_numpy(mask_targets).float().to(device)
+ else:
+ mask_targets = pos_proposals.new_zeros((0, ) + mask_size)
+
+ return mask_targets
diff --git a/mmdet/core/mask/structures.py b/mmdet/core/mask/structures.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e730dc52a38f1858428bff603734ebffd7ba6d1
--- /dev/null
+++ b/mmdet/core/mask/structures.py
@@ -0,0 +1,1102 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+import cv2
+import mmcv
+import numpy as np
+import pycocotools.mask as maskUtils
+import torch
+from mmcv.ops.roi_align import roi_align
+
+
+class BaseInstanceMasks(metaclass=ABCMeta):
+ """Base class for instance masks."""
+
+ @abstractmethod
+ def rescale(self, scale, interpolation='nearest'):
+ """Rescale masks as large as possible while keeping the aspect ratio.
+ For details can refer to `mmcv.imrescale`.
+
+ Args:
+ scale (tuple[int]): The maximum size (h, w) of rescaled mask.
+ interpolation (str): Same as :func:`mmcv.imrescale`.
+
+ Returns:
+ BaseInstanceMasks: The rescaled masks.
+ """
+
+ @abstractmethod
+ def resize(self, out_shape, interpolation='nearest'):
+ """Resize masks to the given out_shape.
+
+ Args:
+ out_shape: Target (h, w) of resized mask.
+ interpolation (str): See :func:`mmcv.imresize`.
+
+ Returns:
+ BaseInstanceMasks: The resized masks.
+ """
+
+ @abstractmethod
+ def flip(self, flip_direction='horizontal'):
+ """Flip masks alone the given direction.
+
+ Args:
+ flip_direction (str): Either 'horizontal' or 'vertical'.
+
+ Returns:
+ BaseInstanceMasks: The flipped masks.
+ """
+
+ @abstractmethod
+ def pad(self, out_shape, pad_val):
+ """Pad masks to the given size of (h, w).
+
+ Args:
+ out_shape (tuple[int]): Target (h, w) of padded mask.
+ pad_val (int): The padded value.
+
+ Returns:
+ BaseInstanceMasks: The padded masks.
+ """
+
+ @abstractmethod
+ def crop(self, bbox):
+ """Crop each mask by the given bbox.
+
+ Args:
+ bbox (ndarray): Bbox in format [x1, y1, x2, y2], shape (4, ).
+
+ Return:
+ BaseInstanceMasks: The cropped masks.
+ """
+
+ @abstractmethod
+ def crop_and_resize(self,
+ bboxes,
+ out_shape,
+ inds,
+ device,
+ interpolation='bilinear',
+ binarize=True):
+ """Crop and resize masks by the given bboxes.
+
+ This function is mainly used in mask targets computation.
+ It firstly align mask to bboxes by assigned_inds, then crop mask by the
+ assigned bbox and resize to the size of (mask_h, mask_w)
+
+ Args:
+ bboxes (Tensor): Bboxes in format [x1, y1, x2, y2], shape (N, 4)
+ out_shape (tuple[int]): Target (h, w) of resized mask
+ inds (ndarray): Indexes to assign masks to each bbox,
+ shape (N,) and values should be between [0, num_masks - 1].
+ device (str): Device of bboxes
+ interpolation (str): See `mmcv.imresize`
+ binarize (bool): if True fractional values are rounded to 0 or 1
+ after the resize operation. if False and unsupported an error
+ will be raised. Defaults to True.
+
+ Return:
+ BaseInstanceMasks: the cropped and resized masks.
+ """
+
+ @abstractmethod
+ def expand(self, expanded_h, expanded_w, top, left):
+ """see :class:`Expand`."""
+
+ @property
+ @abstractmethod
+ def areas(self):
+ """ndarray: areas of each instance."""
+
+ @abstractmethod
+ def to_ndarray(self):
+ """Convert masks to the format of ndarray.
+
+ Return:
+ ndarray: Converted masks in the format of ndarray.
+ """
+
+ @abstractmethod
+ def to_tensor(self, dtype, device):
+ """Convert masks to the format of Tensor.
+
+ Args:
+ dtype (str): Dtype of converted mask.
+ device (torch.device): Device of converted masks.
+
+ Returns:
+ Tensor: Converted masks in the format of Tensor.
+ """
+
+ @abstractmethod
+ def translate(self,
+ out_shape,
+ offset,
+ direction='horizontal',
+ fill_val=0,
+ interpolation='bilinear'):
+ """Translate the masks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ offset (int | float): The offset for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ fill_val (int | float): Border value. Default 0.
+ interpolation (str): Same as :func:`mmcv.imtranslate`.
+
+ Returns:
+ Translated masks.
+ """
+
+ def shear(self,
+ out_shape,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear the masks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The shear direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border. Default 0.
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+
+ Returns:
+ ndarray: Sheared masks.
+ """
+
+ @abstractmethod
+ def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
+ """Rotate the masks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ angle (int | float): Rotation angle in degrees. Positive values
+ mean counter-clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the
+ rotation in source image. If not specified, the center of
+ the image will be used.
+ scale (int | float): Isotropic scale factor.
+ fill_val (int | float): Border value. Default 0 for masks.
+
+ Returns:
+ Rotated masks.
+ """
+
+
+class BitmapMasks(BaseInstanceMasks):
+ """This class represents masks in the form of bitmaps.
+
+ Args:
+ masks (ndarray): ndarray of masks in shape (N, H, W), where N is
+ the number of objects.
+ height (int): height of masks
+ width (int): width of masks
+
+ Example:
+ >>> from mmdet.core.mask.structures import * # NOQA
+ >>> num_masks, H, W = 3, 32, 32
+ >>> rng = np.random.RandomState(0)
+ >>> masks = (rng.rand(num_masks, H, W) > 0.1).astype(np.int)
+ >>> self = BitmapMasks(masks, height=H, width=W)
+
+ >>> # demo crop_and_resize
+ >>> num_boxes = 5
+ >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes)
+ >>> out_shape = (14, 14)
+ >>> inds = torch.randint(0, len(self), size=(num_boxes,))
+ >>> device = 'cpu'
+ >>> interpolation = 'bilinear'
+ >>> new = self.crop_and_resize(
+ ... bboxes, out_shape, inds, device, interpolation)
+ >>> assert len(new) == num_boxes
+ >>> assert new.height, new.width == out_shape
+ """
+
+ def __init__(self, masks, height, width):
+ self.height = height
+ self.width = width
+ if len(masks) == 0:
+ self.masks = np.empty((0, self.height, self.width), dtype=np.uint8)
+ else:
+ assert isinstance(masks, (list, np.ndarray))
+ if isinstance(masks, list):
+ assert isinstance(masks[0], np.ndarray)
+ assert masks[0].ndim == 2 # (H, W)
+ else:
+ assert masks.ndim == 3 # (N, H, W)
+
+ self.masks = np.stack(masks).reshape(-1, height, width)
+ assert self.masks.shape[1] == self.height
+ assert self.masks.shape[2] == self.width
+
+ def __getitem__(self, index):
+ """Index the BitmapMask.
+
+ Args:
+ index (int | ndarray): Indices in the format of integer or ndarray.
+
+ Returns:
+ :obj:`BitmapMasks`: Indexed bitmap masks.
+ """
+ masks = self.masks[index].reshape(-1, self.height, self.width)
+ return BitmapMasks(masks, self.height, self.width)
+
+ def __iter__(self):
+ return iter(self.masks)
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += f'num_masks={len(self.masks)}, '
+ s += f'height={self.height}, '
+ s += f'width={self.width})'
+ return s
+
+ def __len__(self):
+ """Number of masks."""
+ return len(self.masks)
+
+ def rescale(self, scale, interpolation='nearest'):
+ """See :func:`BaseInstanceMasks.rescale`."""
+ if len(self.masks) == 0:
+ new_w, new_h = mmcv.rescale_size((self.width, self.height), scale)
+ rescaled_masks = np.empty((0, new_h, new_w), dtype=np.uint8)
+ else:
+ rescaled_masks = np.stack([
+ mmcv.imrescale(mask, scale, interpolation=interpolation)
+ for mask in self.masks
+ ])
+ height, width = rescaled_masks.shape[1:]
+ return BitmapMasks(rescaled_masks, height, width)
+
+ def resize(self, out_shape, interpolation='nearest'):
+ """See :func:`BaseInstanceMasks.resize`."""
+ if len(self.masks) == 0:
+ resized_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ else:
+ resized_masks = np.stack([
+ mmcv.imresize(
+ mask, out_shape[::-1], interpolation=interpolation)
+ for mask in self.masks
+ ])
+ return BitmapMasks(resized_masks, *out_shape)
+
+ def flip(self, flip_direction='horizontal'):
+ """See :func:`BaseInstanceMasks.flip`."""
+ assert flip_direction in ('horizontal', 'vertical', 'diagonal')
+
+ if len(self.masks) == 0:
+ flipped_masks = self.masks
+ else:
+ flipped_masks = np.stack([
+ mmcv.imflip(mask, direction=flip_direction)
+ for mask in self.masks
+ ])
+ return BitmapMasks(flipped_masks, self.height, self.width)
+
+ def pad(self, out_shape, pad_val=0):
+ """See :func:`BaseInstanceMasks.pad`."""
+ if len(self.masks) == 0:
+ padded_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ else:
+ padded_masks = np.stack([
+ mmcv.impad(mask, shape=out_shape, pad_val=pad_val)
+ for mask in self.masks
+ ])
+ return BitmapMasks(padded_masks, *out_shape)
+
+ def crop(self, bbox):
+ """See :func:`BaseInstanceMasks.crop`."""
+ assert isinstance(bbox, np.ndarray)
+ assert bbox.ndim == 1
+
+ # clip the boundary
+ bbox = bbox.copy()
+ bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
+ bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
+ x1, y1, x2, y2 = bbox
+ w = np.maximum(x2 - x1, 1)
+ h = np.maximum(y2 - y1, 1)
+
+ if len(self.masks) == 0:
+ cropped_masks = np.empty((0, h, w), dtype=np.uint8)
+ else:
+ cropped_masks = self.masks[:, y1:y1 + h, x1:x1 + w]
+ return BitmapMasks(cropped_masks, h, w)
+
+ def crop_and_resize(self,
+ bboxes,
+ out_shape,
+ inds,
+ device='cpu',
+ interpolation='bilinear',
+ binarize=True):
+ """See :func:`BaseInstanceMasks.crop_and_resize`."""
+ if len(self.masks) == 0:
+ empty_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ return BitmapMasks(empty_masks, *out_shape)
+
+ # convert bboxes to tensor
+ if isinstance(bboxes, np.ndarray):
+ bboxes = torch.from_numpy(bboxes).to(device=device)
+ if isinstance(inds, np.ndarray):
+ inds = torch.from_numpy(inds).to(device=device)
+
+ num_bbox = bboxes.shape[0]
+ fake_inds = torch.arange(
+ num_bbox, device=device).to(dtype=bboxes.dtype)[:, None]
+ rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5
+ rois = rois.to(device=device)
+ if num_bbox > 0:
+ gt_masks_th = torch.from_numpy(self.masks).to(device).index_select(
+ 0, inds).to(dtype=rois.dtype)
+ targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape,
+ 1.0, 0, 'avg', True).squeeze(1)
+ if binarize:
+ resized_masks = (targets >= 0.5).cpu().numpy()
+ else:
+ resized_masks = targets.cpu().numpy()
+ else:
+ resized_masks = []
+ return BitmapMasks(resized_masks, *out_shape)
+
+ def expand(self, expanded_h, expanded_w, top, left):
+ """See :func:`BaseInstanceMasks.expand`."""
+ if len(self.masks) == 0:
+ expanded_mask = np.empty((0, expanded_h, expanded_w),
+ dtype=np.uint8)
+ else:
+ expanded_mask = np.zeros((len(self), expanded_h, expanded_w),
+ dtype=np.uint8)
+ expanded_mask[:, top:top + self.height,
+ left:left + self.width] = self.masks
+ return BitmapMasks(expanded_mask, expanded_h, expanded_w)
+
+ def translate(self,
+ out_shape,
+ offset,
+ direction='horizontal',
+ fill_val=0,
+ interpolation='bilinear'):
+ """Translate the BitmapMasks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ offset (int | float): The offset for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ fill_val (int | float): Border value. Default 0 for masks.
+ interpolation (str): Same as :func:`mmcv.imtranslate`.
+
+ Returns:
+ BitmapMasks: Translated BitmapMasks.
+
+ Example:
+ >>> from mmdet.core.mask.structures import BitmapMasks
+ >>> self = BitmapMasks.random(dtype=np.uint8)
+ >>> out_shape = (32, 32)
+ >>> offset = 4
+ >>> direction = 'horizontal'
+ >>> fill_val = 0
+ >>> interpolation = 'bilinear'
+ >>> # Note, There seem to be issues when:
+ >>> # * out_shape is different than self's shape
+ >>> # * the mask dtype is not supported by cv2.AffineWarp
+ >>> new = self.translate(out_shape, offset, direction, fill_val,
+ >>> interpolation)
+ >>> assert len(new) == len(self)
+ >>> assert new.height, new.width == out_shape
+ """
+ if len(self.masks) == 0:
+ translated_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ else:
+ translated_masks = mmcv.imtranslate(
+ self.masks.transpose((1, 2, 0)),
+ offset,
+ direction,
+ border_value=fill_val,
+ interpolation=interpolation)
+ if translated_masks.ndim == 2:
+ translated_masks = translated_masks[:, :, None]
+ translated_masks = translated_masks.transpose(
+ (2, 0, 1)).astype(self.masks.dtype)
+ return BitmapMasks(translated_masks, *out_shape)
+
+ def shear(self,
+ out_shape,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear the BitmapMasks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The shear direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+
+ Returns:
+ BitmapMasks: The sheared masks.
+ """
+ if len(self.masks) == 0:
+ sheared_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ else:
+ sheared_masks = mmcv.imshear(
+ self.masks.transpose((1, 2, 0)),
+ magnitude,
+ direction,
+ border_value=border_value,
+ interpolation=interpolation)
+ if sheared_masks.ndim == 2:
+ sheared_masks = sheared_masks[:, :, None]
+ sheared_masks = sheared_masks.transpose(
+ (2, 0, 1)).astype(self.masks.dtype)
+ return BitmapMasks(sheared_masks, *out_shape)
+
+ def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
+ """Rotate the BitmapMasks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ angle (int | float): Rotation angle in degrees. Positive values
+ mean counter-clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the
+ rotation in source image. If not specified, the center of
+ the image will be used.
+ scale (int | float): Isotropic scale factor.
+ fill_val (int | float): Border value. Default 0 for masks.
+
+ Returns:
+ BitmapMasks: Rotated BitmapMasks.
+ """
+ if len(self.masks) == 0:
+ rotated_masks = np.empty((0, *out_shape), dtype=self.masks.dtype)
+ else:
+ rotated_masks = mmcv.imrotate(
+ self.masks.transpose((1, 2, 0)),
+ angle,
+ center=center,
+ scale=scale,
+ border_value=fill_val)
+ if rotated_masks.ndim == 2:
+ # case when only one mask, (h, w)
+ rotated_masks = rotated_masks[:, :, None] # (h, w, 1)
+ rotated_masks = rotated_masks.transpose(
+ (2, 0, 1)).astype(self.masks.dtype)
+ return BitmapMasks(rotated_masks, *out_shape)
+
+ @property
+ def areas(self):
+ """See :py:attr:`BaseInstanceMasks.areas`."""
+ return self.masks.sum((1, 2))
+
+ def to_ndarray(self):
+ """See :func:`BaseInstanceMasks.to_ndarray`."""
+ return self.masks
+
+ def to_tensor(self, dtype, device):
+ """See :func:`BaseInstanceMasks.to_tensor`."""
+ return torch.tensor(self.masks, dtype=dtype, device=device)
+
+ @classmethod
+ def random(cls,
+ num_masks=3,
+ height=32,
+ width=32,
+ dtype=np.uint8,
+ rng=None):
+ """Generate random bitmap masks for demo / testing purposes.
+
+ Example:
+ >>> from mmdet.core.mask.structures import BitmapMasks
+ >>> self = BitmapMasks.random()
+ >>> print('self = {}'.format(self))
+ self = BitmapMasks(num_masks=3, height=32, width=32)
+ """
+ from mmdet.utils.util_random import ensure_rng
+ rng = ensure_rng(rng)
+ masks = (rng.rand(num_masks, height, width) > 0.1).astype(dtype)
+ self = cls(masks, height=height, width=width)
+ return self
+
+ def get_bboxes(self):
+ num_masks = len(self)
+ boxes = np.zeros((num_masks, 4), dtype=np.float32)
+ x_any = self.masks.any(axis=1)
+ y_any = self.masks.any(axis=2)
+ for idx in range(num_masks):
+ x = np.where(x_any[idx, :])[0]
+ y = np.where(y_any[idx, :])[0]
+ if len(x) > 0 and len(y) > 0:
+ # use +1 for x_max and y_max so that the right and bottom
+ # boundary of instance masks are fully included by the box
+ boxes[idx, :] = np.array([x[0], y[0], x[-1] + 1, y[-1] + 1],
+ dtype=np.float32)
+ return boxes
+
+
+class PolygonMasks(BaseInstanceMasks):
+ """This class represents masks in the form of polygons.
+
+ Polygons is a list of three levels. The first level of the list
+ corresponds to objects, the second level to the polys that compose the
+ object, the third level to the poly coordinates
+
+ Args:
+ masks (list[list[ndarray]]): The first level of the list
+ corresponds to objects, the second level to the polys that
+ compose the object, the third level to the poly coordinates
+ height (int): height of masks
+ width (int): width of masks
+
+ Example:
+ >>> from mmdet.core.mask.structures import * # NOQA
+ >>> masks = [
+ >>> [ np.array([0, 0, 10, 0, 10, 10., 0, 10, 0, 0]) ]
+ >>> ]
+ >>> height, width = 16, 16
+ >>> self = PolygonMasks(masks, height, width)
+
+ >>> # demo translate
+ >>> new = self.translate((16, 16), 4., direction='horizontal')
+ >>> assert np.all(new.masks[0][0][1::2] == masks[0][0][1::2])
+ >>> assert np.all(new.masks[0][0][0::2] == masks[0][0][0::2] + 4)
+
+ >>> # demo crop_and_resize
+ >>> num_boxes = 3
+ >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes)
+ >>> out_shape = (16, 16)
+ >>> inds = torch.randint(0, len(self), size=(num_boxes,))
+ >>> device = 'cpu'
+ >>> interpolation = 'bilinear'
+ >>> new = self.crop_and_resize(
+ ... bboxes, out_shape, inds, device, interpolation)
+ >>> assert len(new) == num_boxes
+ >>> assert new.height, new.width == out_shape
+ """
+
+ def __init__(self, masks, height, width):
+ assert isinstance(masks, list)
+ if len(masks) > 0:
+ assert isinstance(masks[0], list)
+ assert isinstance(masks[0][0], np.ndarray)
+
+ self.height = height
+ self.width = width
+ self.masks = masks
+
+ def __getitem__(self, index):
+ """Index the polygon masks.
+
+ Args:
+ index (ndarray | List): The indices.
+
+ Returns:
+ :obj:`PolygonMasks`: The indexed polygon masks.
+ """
+ if isinstance(index, np.ndarray):
+ index = index.tolist()
+ if isinstance(index, list):
+ masks = [self.masks[i] for i in index]
+ else:
+ try:
+ masks = self.masks[index]
+ except Exception:
+ raise ValueError(
+ f'Unsupported input of type {type(index)} for indexing!')
+ if len(masks) and isinstance(masks[0], np.ndarray):
+ masks = [masks] # ensure a list of three levels
+ return PolygonMasks(masks, self.height, self.width)
+
+ def __iter__(self):
+ return iter(self.masks)
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += f'num_masks={len(self.masks)}, '
+ s += f'height={self.height}, '
+ s += f'width={self.width})'
+ return s
+
+ def __len__(self):
+ """Number of masks."""
+ return len(self.masks)
+
+ def rescale(self, scale, interpolation=None):
+ """see :func:`BaseInstanceMasks.rescale`"""
+ new_w, new_h = mmcv.rescale_size((self.width, self.height), scale)
+ if len(self.masks) == 0:
+ rescaled_masks = PolygonMasks([], new_h, new_w)
+ else:
+ rescaled_masks = self.resize((new_h, new_w))
+ return rescaled_masks
+
+ def resize(self, out_shape, interpolation=None):
+ """see :func:`BaseInstanceMasks.resize`"""
+ if len(self.masks) == 0:
+ resized_masks = PolygonMasks([], *out_shape)
+ else:
+ h_scale = out_shape[0] / self.height
+ w_scale = out_shape[1] / self.width
+ resized_masks = []
+ for poly_per_obj in self.masks:
+ resized_poly = []
+ for p in poly_per_obj:
+ p = p.copy()
+ p[0::2] = p[0::2] * w_scale
+ p[1::2] = p[1::2] * h_scale
+ resized_poly.append(p)
+ resized_masks.append(resized_poly)
+ resized_masks = PolygonMasks(resized_masks, *out_shape)
+ return resized_masks
+
+ def flip(self, flip_direction='horizontal'):
+ """see :func:`BaseInstanceMasks.flip`"""
+ assert flip_direction in ('horizontal', 'vertical', 'diagonal')
+ if len(self.masks) == 0:
+ flipped_masks = PolygonMasks([], self.height, self.width)
+ else:
+ flipped_masks = []
+ for poly_per_obj in self.masks:
+ flipped_poly_per_obj = []
+ for p in poly_per_obj:
+ p = p.copy()
+ if flip_direction == 'horizontal':
+ p[0::2] = self.width - p[0::2]
+ elif flip_direction == 'vertical':
+ p[1::2] = self.height - p[1::2]
+ else:
+ p[0::2] = self.width - p[0::2]
+ p[1::2] = self.height - p[1::2]
+ flipped_poly_per_obj.append(p)
+ flipped_masks.append(flipped_poly_per_obj)
+ flipped_masks = PolygonMasks(flipped_masks, self.height,
+ self.width)
+ return flipped_masks
+
+ def crop(self, bbox):
+ """see :func:`BaseInstanceMasks.crop`"""
+ assert isinstance(bbox, np.ndarray)
+ assert bbox.ndim == 1
+
+ # clip the boundary
+ bbox = bbox.copy()
+ bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
+ bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
+ x1, y1, x2, y2 = bbox
+ w = np.maximum(x2 - x1, 1)
+ h = np.maximum(y2 - y1, 1)
+
+ if len(self.masks) == 0:
+ cropped_masks = PolygonMasks([], h, w)
+ else:
+ cropped_masks = []
+ for poly_per_obj in self.masks:
+ cropped_poly_per_obj = []
+ for p in poly_per_obj:
+ # pycocotools will clip the boundary
+ p = p.copy()
+ p[0::2] = p[0::2] - bbox[0]
+ p[1::2] = p[1::2] - bbox[1]
+ cropped_poly_per_obj.append(p)
+ cropped_masks.append(cropped_poly_per_obj)
+ cropped_masks = PolygonMasks(cropped_masks, h, w)
+ return cropped_masks
+
+ def pad(self, out_shape, pad_val=0):
+ """padding has no effect on polygons`"""
+ return PolygonMasks(self.masks, *out_shape)
+
+ def expand(self, *args, **kwargs):
+ """TODO: Add expand for polygon"""
+ raise NotImplementedError
+
+ def crop_and_resize(self,
+ bboxes,
+ out_shape,
+ inds,
+ device='cpu',
+ interpolation='bilinear',
+ binarize=True):
+ """see :func:`BaseInstanceMasks.crop_and_resize`"""
+ out_h, out_w = out_shape
+ if len(self.masks) == 0:
+ return PolygonMasks([], out_h, out_w)
+
+ if not binarize:
+ raise ValueError('Polygons are always binary, '
+ 'setting binarize=False is unsupported')
+
+ resized_masks = []
+ for i in range(len(bboxes)):
+ mask = self.masks[inds[i]]
+ bbox = bboxes[i, :]
+ x1, y1, x2, y2 = bbox
+ w = np.maximum(x2 - x1, 1)
+ h = np.maximum(y2 - y1, 1)
+ h_scale = out_h / max(h, 0.1) # avoid too large scale
+ w_scale = out_w / max(w, 0.1)
+
+ resized_mask = []
+ for p in mask:
+ p = p.copy()
+ # crop
+ # pycocotools will clip the boundary
+ p[0::2] = p[0::2] - bbox[0]
+ p[1::2] = p[1::2] - bbox[1]
+
+ # resize
+ p[0::2] = p[0::2] * w_scale
+ p[1::2] = p[1::2] * h_scale
+ resized_mask.append(p)
+ resized_masks.append(resized_mask)
+ return PolygonMasks(resized_masks, *out_shape)
+
+ def translate(self,
+ out_shape,
+ offset,
+ direction='horizontal',
+ fill_val=None,
+ interpolation=None):
+ """Translate the PolygonMasks.
+
+ Example:
+ >>> self = PolygonMasks.random(dtype=np.int)
+ >>> out_shape = (self.height, self.width)
+ >>> new = self.translate(out_shape, 4., direction='horizontal')
+ >>> assert np.all(new.masks[0][0][1::2] == self.masks[0][0][1::2])
+ >>> assert np.all(new.masks[0][0][0::2] == self.masks[0][0][0::2] + 4) # noqa: E501
+ """
+ assert fill_val is None or fill_val == 0, 'Here fill_val is not '\
+ f'used, and defaultly should be None or 0. got {fill_val}.'
+ if len(self.masks) == 0:
+ translated_masks = PolygonMasks([], *out_shape)
+ else:
+ translated_masks = []
+ for poly_per_obj in self.masks:
+ translated_poly_per_obj = []
+ for p in poly_per_obj:
+ p = p.copy()
+ if direction == 'horizontal':
+ p[0::2] = np.clip(p[0::2] + offset, 0, out_shape[1])
+ elif direction == 'vertical':
+ p[1::2] = np.clip(p[1::2] + offset, 0, out_shape[0])
+ translated_poly_per_obj.append(p)
+ translated_masks.append(translated_poly_per_obj)
+ translated_masks = PolygonMasks(translated_masks, *out_shape)
+ return translated_masks
+
+ def shear(self,
+ out_shape,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """See :func:`BaseInstanceMasks.shear`."""
+ if len(self.masks) == 0:
+ sheared_masks = PolygonMasks([], *out_shape)
+ else:
+ sheared_masks = []
+ if direction == 'horizontal':
+ shear_matrix = np.stack([[1, magnitude],
+ [0, 1]]).astype(np.float32)
+ elif direction == 'vertical':
+ shear_matrix = np.stack([[1, 0], [magnitude,
+ 1]]).astype(np.float32)
+ for poly_per_obj in self.masks:
+ sheared_poly = []
+ for p in poly_per_obj:
+ p = np.stack([p[0::2], p[1::2]], axis=0) # [2, n]
+ new_coords = np.matmul(shear_matrix, p) # [2, n]
+ new_coords[0, :] = np.clip(new_coords[0, :], 0,
+ out_shape[1])
+ new_coords[1, :] = np.clip(new_coords[1, :], 0,
+ out_shape[0])
+ sheared_poly.append(
+ new_coords.transpose((1, 0)).reshape(-1))
+ sheared_masks.append(sheared_poly)
+ sheared_masks = PolygonMasks(sheared_masks, *out_shape)
+ return sheared_masks
+
+ def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
+ """See :func:`BaseInstanceMasks.rotate`."""
+ if len(self.masks) == 0:
+ rotated_masks = PolygonMasks([], *out_shape)
+ else:
+ rotated_masks = []
+ rotate_matrix = cv2.getRotationMatrix2D(center, -angle, scale)
+ for poly_per_obj in self.masks:
+ rotated_poly = []
+ for p in poly_per_obj:
+ p = p.copy()
+ coords = np.stack([p[0::2], p[1::2]], axis=1) # [n, 2]
+ # pad 1 to convert from format [x, y] to homogeneous
+ # coordinates format [x, y, 1]
+ coords = np.concatenate(
+ (coords, np.ones((coords.shape[0], 1), coords.dtype)),
+ axis=1) # [n, 3]
+ rotated_coords = np.matmul(
+ rotate_matrix[None, :, :],
+ coords[:, :, None])[..., 0] # [n, 2, 1] -> [n, 2]
+ rotated_coords[:, 0] = np.clip(rotated_coords[:, 0], 0,
+ out_shape[1])
+ rotated_coords[:, 1] = np.clip(rotated_coords[:, 1], 0,
+ out_shape[0])
+ rotated_poly.append(rotated_coords.reshape(-1))
+ rotated_masks.append(rotated_poly)
+ rotated_masks = PolygonMasks(rotated_masks, *out_shape)
+ return rotated_masks
+
+ def to_bitmap(self):
+ """convert polygon masks to bitmap masks."""
+ bitmap_masks = self.to_ndarray()
+ return BitmapMasks(bitmap_masks, self.height, self.width)
+
+ @property
+ def areas(self):
+ """Compute areas of masks.
+
+ This func is modified from `detectron2
+ `_.
+ The function only works with Polygons using the shoelace formula.
+
+ Return:
+ ndarray: areas of each instance
+ """ # noqa: W501
+ area = []
+ for polygons_per_obj in self.masks:
+ area_per_obj = 0
+ for p in polygons_per_obj:
+ area_per_obj += self._polygon_area(p[0::2], p[1::2])
+ area.append(area_per_obj)
+ return np.asarray(area)
+
+ def _polygon_area(self, x, y):
+ """Compute the area of a component of a polygon.
+
+ Using the shoelace formula:
+ https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
+
+ Args:
+ x (ndarray): x coordinates of the component
+ y (ndarray): y coordinates of the component
+
+ Return:
+ float: the are of the component
+ """ # noqa: 501
+ return 0.5 * np.abs(
+ np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
+
+ def to_ndarray(self):
+ """Convert masks to the format of ndarray."""
+ if len(self.masks) == 0:
+ return np.empty((0, self.height, self.width), dtype=np.uint8)
+ bitmap_masks = []
+ for poly_per_obj in self.masks:
+ bitmap_masks.append(
+ polygon_to_bitmap(poly_per_obj, self.height, self.width))
+ return np.stack(bitmap_masks)
+
+ def to_tensor(self, dtype, device):
+ """See :func:`BaseInstanceMasks.to_tensor`."""
+ if len(self.masks) == 0:
+ return torch.empty((0, self.height, self.width),
+ dtype=dtype,
+ device=device)
+ ndarray_masks = self.to_ndarray()
+ return torch.tensor(ndarray_masks, dtype=dtype, device=device)
+
+ @classmethod
+ def random(cls,
+ num_masks=3,
+ height=32,
+ width=32,
+ n_verts=5,
+ dtype=np.float32,
+ rng=None):
+ """Generate random polygon masks for demo / testing purposes.
+
+ Adapted from [1]_
+
+ References:
+ .. [1] https://gitlab.kitware.com/computer-vision/kwimage/-/blob/928cae35ca8/kwimage/structs/polygon.py#L379 # noqa: E501
+
+ Example:
+ >>> from mmdet.core.mask.structures import PolygonMasks
+ >>> self = PolygonMasks.random()
+ >>> print('self = {}'.format(self))
+ """
+ from mmdet.utils.util_random import ensure_rng
+ rng = ensure_rng(rng)
+
+ def _gen_polygon(n, irregularity, spikeyness):
+ """Creates the polygon by sampling points on a circle around the
+ centre. Random noise is added by varying the angular spacing
+ between sequential points, and by varying the radial distance of
+ each point from the centre.
+
+ Based on original code by Mike Ounsworth
+
+ Args:
+ n (int): number of vertices
+ irregularity (float): [0,1] indicating how much variance there
+ is in the angular spacing of vertices. [0,1] will map to
+ [0, 2pi/numberOfVerts]
+ spikeyness (float): [0,1] indicating how much variance there is
+ in each vertex from the circle of radius aveRadius. [0,1]
+ will map to [0, aveRadius]
+
+ Returns:
+ a list of vertices, in CCW order.
+ """
+ from scipy.stats import truncnorm
+
+ # Generate around the unit circle
+ cx, cy = (0.0, 0.0)
+ radius = 1
+
+ tau = np.pi * 2
+
+ irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / n
+ spikeyness = np.clip(spikeyness, 1e-9, 1)
+
+ # generate n angle steps
+ lower = (tau / n) - irregularity
+ upper = (tau / n) + irregularity
+ angle_steps = rng.uniform(lower, upper, n)
+
+ # normalize the steps so that point 0 and point n+1 are the same
+ k = angle_steps.sum() / (2 * np.pi)
+ angles = (angle_steps / k).cumsum() + rng.uniform(0, tau)
+
+ # Convert high and low values to be wrt the standard normal range
+ # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html
+ low = 0
+ high = 2 * radius
+ mean = radius
+ std = spikeyness
+ a = (low - mean) / std
+ b = (high - mean) / std
+ tnorm = truncnorm(a=a, b=b, loc=mean, scale=std)
+
+ # now generate the points
+ radii = tnorm.rvs(n, random_state=rng)
+ x_pts = cx + radii * np.cos(angles)
+ y_pts = cy + radii * np.sin(angles)
+
+ points = np.hstack([x_pts[:, None], y_pts[:, None]])
+
+ # Scale to 0-1 space
+ points = points - points.min(axis=0)
+ points = points / points.max(axis=0)
+
+ # Randomly place within 0-1 space
+ points = points * (rng.rand() * .8 + .2)
+ min_pt = points.min(axis=0)
+ max_pt = points.max(axis=0)
+
+ high = (1 - max_pt)
+ low = (0 - min_pt)
+ offset = (rng.rand(2) * (high - low)) + low
+ points = points + offset
+ return points
+
+ def _order_vertices(verts):
+ """
+ References:
+ https://stackoverflow.com/questions/1709283/how-can-i-sort-a-coordinate-list-for-a-rectangle-counterclockwise
+ """
+ mlat = verts.T[0].sum() / len(verts)
+ mlng = verts.T[1].sum() / len(verts)
+
+ tau = np.pi * 2
+ angle = (np.arctan2(mlat - verts.T[0], verts.T[1] - mlng) +
+ tau) % tau
+ sortx = angle.argsort()
+ verts = verts.take(sortx, axis=0)
+ return verts
+
+ # Generate a random exterior for each requested mask
+ masks = []
+ for _ in range(num_masks):
+ exterior = _order_vertices(_gen_polygon(n_verts, 0.9, 0.9))
+ exterior = (exterior * [(width, height)]).astype(dtype)
+ masks.append([exterior.ravel()])
+
+ self = cls(masks, height, width)
+ return self
+
+ def get_bboxes(self):
+ num_masks = len(self)
+ boxes = np.zeros((num_masks, 4), dtype=np.float32)
+ for idx, poly_per_obj in enumerate(self.masks):
+ # simply use a number that is big enough for comparison with
+ # coordinates
+ xy_min = np.array([self.width * 2, self.height * 2],
+ dtype=np.float32)
+ xy_max = np.zeros(2, dtype=np.float32)
+ for p in poly_per_obj:
+ xy = np.array(p).reshape(-1, 2).astype(np.float32)
+ xy_min = np.minimum(xy_min, np.min(xy, axis=0))
+ xy_max = np.maximum(xy_max, np.max(xy, axis=0))
+ boxes[idx, :2] = xy_min
+ boxes[idx, 2:] = xy_max
+
+ return boxes
+
+
+def polygon_to_bitmap(polygons, height, width):
+ """Convert masks from the form of polygons to bitmaps.
+
+ Args:
+ polygons (list[ndarray]): masks in polygon representation
+ height (int): mask height
+ width (int): mask width
+
+ Return:
+ ndarray: the converted masks in bitmap representation
+ """
+ rles = maskUtils.frPyObjects(polygons, height, width)
+ rle = maskUtils.merge(rles)
+ bitmap_mask = maskUtils.decode(rle).astype(bool)
+ return bitmap_mask
+
+
+def bitmap_to_polygon(bitmap):
+ """Convert masks from the form of bitmaps to polygons.
+
+ Args:
+ bitmap (ndarray): masks in bitmap representation.
+
+ Return:
+ list[ndarray]: the converted mask in polygon representation.
+ bool: whether the mask has holes.
+ """
+ bitmap = np.ascontiguousarray(bitmap).astype(np.uint8)
+ # cv2.RETR_CCOMP: retrieves all of the contours and organizes them
+ # into a two-level hierarchy. At the top level, there are external
+ # boundaries of the components. At the second level, there are
+ # boundaries of the holes. If there is another contour inside a hole
+ # of a connected component, it is still put at the top level.
+ # cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points.
+ outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
+ contours = outs[-2]
+ hierarchy = outs[-1]
+ if hierarchy is None:
+ return [], False
+ # hierarchy[i]: 4 elements, for the indexes of next, previous,
+ # parent, or nested contours. If there is no corresponding contour,
+ # it will be -1.
+ with_hole = (hierarchy.reshape(-1, 4)[:, 3] >= 0).any()
+ contours = [c.reshape(-1, 2) for c in contours]
+ return contours, with_hole
diff --git a/mmdet/core/mask/utils.py b/mmdet/core/mask/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..90544b34f49aa60ac2a1abae10f1a89cc9fe43f0
--- /dev/null
+++ b/mmdet/core/mask/utils.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import numpy as np
+import pycocotools.mask as mask_util
+import torch
+
+
+def split_combined_polys(polys, poly_lens, polys_per_mask):
+ """Split the combined 1-D polys into masks.
+
+ A mask is represented as a list of polys, and a poly is represented as
+ a 1-D array. In dataset, all masks are concatenated into a single 1-D
+ tensor. Here we need to split the tensor into original representations.
+
+ Args:
+ polys (list): a list (length = image num) of 1-D tensors
+ poly_lens (list): a list (length = image num) of poly length
+ polys_per_mask (list): a list (length = image num) of poly number
+ of each mask
+
+ Returns:
+ list: a list (length = image num) of list (length = mask num) of \
+ list (length = poly num) of numpy array.
+ """
+ mask_polys_list = []
+ for img_id in range(len(polys)):
+ polys_single = polys[img_id]
+ polys_lens_single = poly_lens[img_id].tolist()
+ polys_per_mask_single = polys_per_mask[img_id].tolist()
+
+ split_polys = mmcv.slice_list(polys_single, polys_lens_single)
+ mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single)
+ mask_polys_list.append(mask_polys)
+ return mask_polys_list
+
+
+# TODO: move this function to more proper place
+def encode_mask_results(mask_results):
+ """Encode bitmap mask to RLE code.
+
+ Args:
+ mask_results (list | tuple[list]): bitmap mask results.
+ In mask scoring rcnn, mask_results is a tuple of (segm_results,
+ segm_cls_score).
+
+ Returns:
+ list | tuple: RLE encoded mask.
+ """
+ if isinstance(mask_results, tuple): # mask scoring
+ cls_segms, cls_mask_scores = mask_results
+ else:
+ cls_segms = mask_results
+ num_classes = len(cls_segms)
+ encoded_mask_results = [[] for _ in range(num_classes)]
+ for i in range(len(cls_segms)):
+ for cls_segm in cls_segms[i]:
+ encoded_mask_results[i].append(
+ mask_util.encode(
+ np.array(
+ cls_segm[:, :, np.newaxis], order='F',
+ dtype='uint8'))[0]) # encoded with RLE
+ if isinstance(mask_results, tuple):
+ return encoded_mask_results, cls_mask_scores
+ else:
+ return encoded_mask_results
+
+
+def mask2bbox(masks):
+ """Obtain tight bounding boxes of binary masks.
+
+ Args:
+ masks (Tensor): Binary mask of shape (n, h, w).
+
+ Returns:
+ Tensor: Bboxe with shape (n, 4) of \
+ positive region in binary mask.
+ """
+ N = masks.shape[0]
+ bboxes = masks.new_zeros((N, 4), dtype=torch.float32)
+ x_any = torch.any(masks, dim=1)
+ y_any = torch.any(masks, dim=2)
+ for i in range(N):
+ x = torch.where(x_any[i, :])[0]
+ y = torch.where(y_any[i, :])[0]
+ if len(x) > 0 and len(y) > 0:
+ bboxes[i, :] = bboxes.new_tensor(
+ [x[0], y[0], x[-1] + 1, y[-1] + 1])
+
+ return bboxes
diff --git a/mmdet/core/optimizers/__init__.py b/mmdet/core/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e867d0761cb54a6f228a0fb3e0560dea67b67881
--- /dev/null
+++ b/mmdet/core/optimizers/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import OPTIMIZER_BUILDERS, build_optimizer
+from .layer_decay_optimizer_constructor import \
+ LearningRateDecayOptimizerConstructor
+
+__all__ = [
+ 'LearningRateDecayOptimizerConstructor', 'OPTIMIZER_BUILDERS',
+ 'build_optimizer'
+]
diff --git a/mmdet/core/optimizers/builder.py b/mmdet/core/optimizers/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..406dd9b4b7027e9c2254b0d18cf0c80a7161912b
--- /dev/null
+++ b/mmdet/core/optimizers/builder.py
@@ -0,0 +1,33 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
+from mmcv.utils import Registry, build_from_cfg
+
+OPTIMIZER_BUILDERS = Registry(
+ 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
+
+
+def build_optimizer_constructor(cfg):
+ constructor_type = cfg.get('type')
+ if constructor_type in OPTIMIZER_BUILDERS:
+ return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
+ elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
+ return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
+ else:
+ raise KeyError(f'{constructor_type} is not registered '
+ 'in the optimizer builder registry.')
+
+
+def build_optimizer(model, cfg):
+ optimizer_cfg = copy.deepcopy(cfg)
+ constructor_type = optimizer_cfg.pop('constructor',
+ 'DefaultOptimizerConstructor')
+ paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
+ optim_constructor = build_optimizer_constructor(
+ dict(
+ type=constructor_type,
+ optimizer_cfg=optimizer_cfg,
+ paramwise_cfg=paramwise_cfg))
+ optimizer = optim_constructor(model)
+ return optimizer
diff --git a/mmdet/core/optimizers/layer_decay_optimizer_constructor.py b/mmdet/core/optimizers/layer_decay_optimizer_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bc3469e8884a7a1f0a154ab859b8079575b56ff
--- /dev/null
+++ b/mmdet/core/optimizers/layer_decay_optimizer_constructor.py
@@ -0,0 +1,154 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+
+from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
+
+from mmdet.utils import get_root_logger
+from .builder import OPTIMIZER_BUILDERS
+
+
+def get_layer_id_for_convnext(var_name, max_layer_id):
+ """Get the layer id to set the different learning rates in ``layer_wise``
+ decay_type.
+
+ Args:
+ var_name (str): The key of the model.
+ max_layer_id (int): Maximum layer id.
+
+ Returns:
+ int: The id number corresponding to different learning rate in
+ ``LearningRateDecayOptimizerConstructor``.
+ """
+
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
+ 'backbone.pos_embed'):
+ return 0
+ elif var_name.startswith('backbone.downsample_layers'):
+ stage_id = int(var_name.split('.')[2])
+ if stage_id == 0:
+ layer_id = 0
+ elif stage_id == 1:
+ layer_id = 2
+ elif stage_id == 2:
+ layer_id = 3
+ elif stage_id == 3:
+ layer_id = max_layer_id
+ return layer_id
+ elif var_name.startswith('backbone.stages'):
+ stage_id = int(var_name.split('.')[2])
+ block_id = int(var_name.split('.')[3])
+ if stage_id == 0:
+ layer_id = 1
+ elif stage_id == 1:
+ layer_id = 2
+ elif stage_id == 2:
+ layer_id = 3 + block_id // 3
+ elif stage_id == 3:
+ layer_id = max_layer_id
+ return layer_id
+ else:
+ return max_layer_id + 1
+
+
+def get_stage_id_for_convnext(var_name, max_stage_id):
+ """Get the stage id to set the different learning rates in ``stage_wise``
+ decay_type.
+
+ Args:
+ var_name (str): The key of the model.
+ max_stage_id (int): Maximum stage id.
+
+ Returns:
+ int: The id number corresponding to different learning rate in
+ ``LearningRateDecayOptimizerConstructor``.
+ """
+
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
+ 'backbone.pos_embed'):
+ return 0
+ elif var_name.startswith('backbone.downsample_layers'):
+ return 0
+ elif var_name.startswith('backbone.stages'):
+ stage_id = int(var_name.split('.')[2])
+ return stage_id + 1
+ else:
+ return max_stage_id - 1
+
+
+@OPTIMIZER_BUILDERS.register_module()
+class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
+ # Different learning rates are set for different layers of backbone.
+ # Note: Currently, this optimizer constructor is built for ConvNeXt.
+
+ def add_params(self, params, module, **kwargs):
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+
+ Args:
+ params (list[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ """
+ logger = get_root_logger()
+
+ parameter_groups = {}
+ logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
+ num_layers = self.paramwise_cfg.get('num_layers') + 2
+ decay_rate = self.paramwise_cfg.get('decay_rate')
+ decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
+ logger.info('Build LearningRateDecayOptimizerConstructor '
+ f'{decay_type} {decay_rate} - {num_layers}')
+ weight_decay = self.base_wd
+ for name, param in module.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith('.bias') or name in (
+ 'pos_embed', 'cls_token'):
+ group_name = 'no_decay'
+ this_weight_decay = 0.
+ else:
+ group_name = 'decay'
+ this_weight_decay = weight_decay
+ if 'layer_wise' in decay_type:
+ if 'ConvNeXt' in module.backbone.__class__.__name__:
+ layer_id = get_layer_id_for_convnext(
+ name, self.paramwise_cfg.get('num_layers'))
+ logger.info(f'set param {name} as id {layer_id}')
+ else:
+ raise NotImplementedError()
+ elif decay_type == 'stage_wise':
+ if 'ConvNeXt' in module.backbone.__class__.__name__:
+ layer_id = get_stage_id_for_convnext(name, num_layers)
+ logger.info(f'set param {name} as id {layer_id}')
+ else:
+ raise NotImplementedError()
+ group_name = f'layer_{layer_id}_{group_name}'
+
+ if group_name not in parameter_groups:
+ scale = decay_rate**(num_layers - layer_id - 1)
+
+ parameter_groups[group_name] = {
+ 'weight_decay': this_weight_decay,
+ 'params': [],
+ 'param_names': [],
+ 'lr_scale': scale,
+ 'group_name': group_name,
+ 'lr': scale * self.base_lr,
+ }
+
+ parameter_groups[group_name]['params'].append(param)
+ parameter_groups[group_name]['param_names'].append(name)
+ rank, _ = get_dist_info()
+ if rank == 0:
+ to_display = {}
+ for key in parameter_groups:
+ to_display[key] = {
+ 'param_names': parameter_groups[key]['param_names'],
+ 'lr_scale': parameter_groups[key]['lr_scale'],
+ 'lr': parameter_groups[key]['lr'],
+ 'weight_decay': parameter_groups[key]['weight_decay'],
+ }
+ logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
+ params.extend(parameter_groups.values())
diff --git a/mmdet/core/post_processing/__init__.py b/mmdet/core/post_processing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..00376bd49ebf75d53c10a26ff810362917bae81c
--- /dev/null
+++ b/mmdet/core/post_processing/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .bbox_nms import fast_nms, multiclass_nms
+from .matrix_nms import mask_matrix_nms
+from .merge_augs import (merge_aug_bboxes, merge_aug_masks,
+ merge_aug_proposals, merge_aug_scores)
+
+__all__ = [
+ 'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
+ 'merge_aug_scores', 'merge_aug_masks', 'mask_matrix_nms', 'fast_nms'
+]
diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fcf57bb501de25adbba08d3fb5fe2cc8d00cd1c
--- /dev/null
+++ b/mmdet/core/post_processing/bbox_nms.py
@@ -0,0 +1,171 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.ops.nms import batched_nms
+
+from mmdet.core.bbox.iou_calculators import bbox_overlaps
+
+
+def multiclass_nms(multi_bboxes,
+ multi_scores,
+ score_thr,
+ nms_cfg,
+ max_num=-1,
+ score_factors=None,
+ return_inds=False):
+ """NMS for multi-class bboxes.
+
+ Args:
+ multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
+ multi_scores (Tensor): shape (n, #class), where the last column
+ contains scores of the background class, but this will be ignored.
+ score_thr (float): bbox threshold, bboxes with scores lower than it
+ will not be considered.
+ nms_cfg (dict): a dict that contains the arguments of nms operations
+ max_num (int, optional): if there are more than max_num bboxes after
+ NMS, only top max_num will be kept. Default to -1.
+ score_factors (Tensor, optional): The factors multiplied to scores
+ before applying NMS. Default to None.
+ return_inds (bool, optional): Whether return the indices of kept
+ bboxes. Default to False.
+
+ Returns:
+ tuple: (dets, labels, indices (optional)), tensors of shape (k, 5),
+ (k), and (k). Dets are boxes with scores. Labels are 0-based.
+ """
+ num_classes = multi_scores.size(1) - 1
+ # exclude background category
+ if multi_bboxes.shape[1] > 4:
+ bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
+ else:
+ bboxes = multi_bboxes[:, None].expand(
+ multi_scores.size(0), num_classes, 4)
+
+ scores = multi_scores[:, :-1]
+
+ labels = torch.arange(num_classes, dtype=torch.long, device=scores.device)
+ labels = labels.view(1, -1).expand_as(scores)
+
+ bboxes = bboxes.reshape(-1, 4)
+ scores = scores.reshape(-1)
+ labels = labels.reshape(-1)
+
+ if not torch.onnx.is_in_onnx_export():
+ # NonZero not supported in TensorRT
+ # remove low scoring boxes
+ valid_mask = scores > score_thr
+ # multiply score_factor after threshold to preserve more bboxes, improve
+ # mAP by 1% for YOLOv3
+ if score_factors is not None:
+ # expand the shape to match original shape of score
+ score_factors = score_factors.view(-1, 1).expand(
+ multi_scores.size(0), num_classes)
+ score_factors = score_factors.reshape(-1)
+ scores = scores * score_factors
+
+ if not torch.onnx.is_in_onnx_export():
+ # NonZero not supported in TensorRT
+ inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
+ bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
+ else:
+ # TensorRT NMS plugin has invalid output filled with -1
+ # add dummy data to make detection output correct.
+ bboxes = torch.cat([bboxes, bboxes.new_zeros(1, 4)], dim=0)
+ scores = torch.cat([scores, scores.new_zeros(1)], dim=0)
+ labels = torch.cat([labels, labels.new_zeros(1)], dim=0)
+
+ if bboxes.numel() == 0:
+ if torch.onnx.is_in_onnx_export():
+ raise RuntimeError('[ONNX Error] Can not record NMS '
+ 'as it has not been executed this time')
+ dets = torch.cat([bboxes, scores[:, None]], -1)
+ if return_inds:
+ return dets, labels, inds
+ else:
+ return dets, labels
+
+ dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
+
+ if max_num > 0:
+ dets = dets[:max_num]
+ keep = keep[:max_num]
+
+ if return_inds:
+ return dets, labels[keep], inds[keep]
+ else:
+ return dets, labels[keep]
+
+
+def fast_nms(multi_bboxes,
+ multi_scores,
+ multi_coeffs,
+ score_thr,
+ iou_thr,
+ top_k,
+ max_num=-1):
+ """Fast NMS in `YOLACT `_.
+
+ Fast NMS allows already-removed detections to suppress other detections so
+ that every instance can be decided to be kept or discarded in parallel,
+ which is not possible in traditional NMS. This relaxation allows us to
+ implement Fast NMS entirely in standard GPU-accelerated matrix operations.
+
+ Args:
+ multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
+ multi_scores (Tensor): shape (n, #class+1), where the last column
+ contains scores of the background class, but this will be ignored.
+ multi_coeffs (Tensor): shape (n, #class*coeffs_dim).
+ score_thr (float): bbox threshold, bboxes with scores lower than it
+ will not be considered.
+ iou_thr (float): IoU threshold to be considered as conflicted.
+ top_k (int): if there are more than top_k bboxes before NMS,
+ only top top_k will be kept.
+ max_num (int): if there are more than max_num bboxes after NMS,
+ only top max_num will be kept. If -1, keep all the bboxes.
+ Default: -1.
+
+ Returns:
+ tuple: (dets, labels, coefficients), tensors of shape (k, 5), (k, 1),
+ and (k, coeffs_dim). Dets are boxes with scores.
+ Labels are 0-based.
+ """
+
+ scores = multi_scores[:, :-1].t() # [#class, n]
+ scores, idx = scores.sort(1, descending=True)
+
+ idx = idx[:, :top_k].contiguous()
+ scores = scores[:, :top_k] # [#class, topk]
+ num_classes, num_dets = idx.size()
+ boxes = multi_bboxes[idx.view(-1), :].view(num_classes, num_dets, 4)
+ coeffs = multi_coeffs[idx.view(-1), :].view(num_classes, num_dets, -1)
+
+ iou = bbox_overlaps(boxes, boxes) # [#class, topk, topk]
+ iou.triu_(diagonal=1)
+ iou_max, _ = iou.max(dim=1)
+
+ # Now just filter out the ones higher than the threshold
+ keep = iou_max <= iou_thr
+
+ # Second thresholding introduces 0.2 mAP gain at negligible time cost
+ keep *= scores > score_thr
+
+ # Assign each kept detection to its corresponding class
+ classes = torch.arange(
+ num_classes, device=boxes.device)[:, None].expand_as(keep)
+ classes = classes[keep]
+
+ boxes = boxes[keep]
+ coeffs = coeffs[keep]
+ scores = scores[keep]
+
+ # Only keep the top max_num highest scores across all classes
+ scores, idx = scores.sort(0, descending=True)
+ if max_num > 0:
+ idx = idx[:max_num]
+ scores = scores[:max_num]
+
+ classes = classes[idx]
+ boxes = boxes[idx]
+ coeffs = coeffs[idx]
+
+ cls_dets = torch.cat([boxes, scores[:, None]], dim=1)
+ return cls_dets, classes, coeffs
diff --git a/mmdet/core/post_processing/matrix_nms.py b/mmdet/core/post_processing/matrix_nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dc8c4f74e28127fb69ccc684f0bdb2bd3943b20
--- /dev/null
+++ b/mmdet/core/post_processing/matrix_nms.py
@@ -0,0 +1,121 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+def mask_matrix_nms(masks,
+ labels,
+ scores,
+ filter_thr=-1,
+ nms_pre=-1,
+ max_num=-1,
+ kernel='gaussian',
+ sigma=2.0,
+ mask_area=None):
+ """Matrix NMS for multi-class masks.
+
+ Args:
+ masks (Tensor): Has shape (num_instances, h, w)
+ labels (Tensor): Labels of corresponding masks,
+ has shape (num_instances,).
+ scores (Tensor): Mask scores of corresponding masks,
+ has shape (num_instances).
+ filter_thr (float): Score threshold to filter the masks
+ after matrix nms. Default: -1, which means do not
+ use filter_thr.
+ nms_pre (int): The max number of instances to do the matrix nms.
+ Default: -1, which means do not use nms_pre.
+ max_num (int, optional): If there are more than max_num masks after
+ matrix, only top max_num will be kept. Default: -1, which means
+ do not use max_num.
+ kernel (str): 'linear' or 'gaussian'.
+ sigma (float): std in gaussian method.
+ mask_area (Tensor): The sum of seg_masks.
+
+ Returns:
+ tuple(Tensor): Processed mask results.
+
+ - scores (Tensor): Updated scores, has shape (n,).
+ - labels (Tensor): Remained labels, has shape (n,).
+ - masks (Tensor): Remained masks, has shape (n, w, h).
+ - keep_inds (Tensor): The indices number of
+ the remaining mask in the input mask, has shape (n,).
+ """
+ assert len(labels) == len(masks) == len(scores)
+ if len(labels) == 0:
+ return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
+ 0, *masks.shape[-2:]), labels.new_zeros(0)
+ if mask_area is None:
+ mask_area = masks.sum((1, 2)).float()
+ else:
+ assert len(masks) == len(mask_area)
+
+ # sort and keep top nms_pre
+ scores, sort_inds = torch.sort(scores, descending=True)
+
+ keep_inds = sort_inds
+ if nms_pre > 0 and len(sort_inds) > nms_pre:
+ sort_inds = sort_inds[:nms_pre]
+ keep_inds = keep_inds[:nms_pre]
+ scores = scores[:nms_pre]
+ masks = masks[sort_inds]
+ mask_area = mask_area[sort_inds]
+ labels = labels[sort_inds]
+
+ num_masks = len(labels)
+ flatten_masks = masks.reshape(num_masks, -1).float()
+ # inter.
+ inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))
+ expanded_mask_area = mask_area.expand(num_masks, num_masks)
+ # Upper triangle iou matrix.
+ iou_matrix = (inter_matrix /
+ (expanded_mask_area + expanded_mask_area.transpose(1, 0) -
+ inter_matrix)).triu(diagonal=1)
+ # label_specific matrix.
+ expanded_labels = labels.expand(num_masks, num_masks)
+ # Upper triangle label matrix.
+ label_matrix = (expanded_labels == expanded_labels.transpose(
+ 1, 0)).triu(diagonal=1)
+
+ # IoU compensation
+ compensate_iou, _ = (iou_matrix * label_matrix).max(0)
+ compensate_iou = compensate_iou.expand(num_masks,
+ num_masks).transpose(1, 0)
+
+ # IoU decay
+ decay_iou = iou_matrix * label_matrix
+
+ # Calculate the decay_coefficient
+ if kernel == 'gaussian':
+ decay_matrix = torch.exp(-1 * sigma * (decay_iou**2))
+ compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2))
+ decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
+ elif kernel == 'linear':
+ decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
+ decay_coefficient, _ = decay_matrix.min(0)
+ else:
+ raise NotImplementedError(
+ f'{kernel} kernel is not supported in matrix nms!')
+ # update the score.
+ scores = scores * decay_coefficient
+
+ if filter_thr > 0:
+ keep = scores >= filter_thr
+ keep_inds = keep_inds[keep]
+ if not keep.any():
+ return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
+ 0, *masks.shape[-2:]), labels.new_zeros(0)
+ masks = masks[keep]
+ scores = scores[keep]
+ labels = labels[keep]
+
+ # sort and keep top max_num
+ scores, sort_inds = torch.sort(scores, descending=True)
+ keep_inds = keep_inds[sort_inds]
+ if max_num > 0 and len(sort_inds) > max_num:
+ sort_inds = sort_inds[:max_num]
+ keep_inds = keep_inds[:max_num]
+ scores = scores[:max_num]
+ masks = masks[sort_inds]
+ labels = labels[sort_inds]
+
+ return scores, labels, masks, keep_inds
diff --git a/mmdet/core/post_processing/merge_augs.py b/mmdet/core/post_processing/merge_augs.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ac4603a1aea9e463e35d7041a0bf00bd3b13529
--- /dev/null
+++ b/mmdet/core/post_processing/merge_augs.py
@@ -0,0 +1,154 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+import numpy as np
+import torch
+from mmcv import ConfigDict
+from mmcv.ops import nms
+
+from ..bbox import bbox_mapping_back
+
+
+def merge_aug_proposals(aug_proposals, img_metas, cfg):
+ """Merge augmented proposals (multiscale, flip, etc.)
+
+ Args:
+ aug_proposals (list[Tensor]): proposals from different testing
+ schemes, shape (n, 5). Note that they are not rescaled to the
+ original image size.
+
+ img_metas (list[dict]): list of image info dict where each dict has:
+ 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+ cfg (dict): rpn test config.
+
+ Returns:
+ Tensor: shape (n, 4), proposals corresponding to original image scale.
+ """
+
+ cfg = copy.deepcopy(cfg)
+
+ # deprecate arguments warning
+ if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
+ warnings.warn(
+ 'In rpn_proposal or test_cfg, '
+ 'nms_thr has been moved to a dict named nms as '
+ 'iou_threshold, max_num has been renamed as max_per_img, '
+ 'name of original arguments and the way to specify '
+ 'iou_threshold of NMS will be deprecated.')
+ if 'nms' not in cfg:
+ cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
+ if 'max_num' in cfg:
+ if 'max_per_img' in cfg:
+ assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \
+ f'max_per_img at the same time, but get {cfg.max_num} ' \
+ f'and {cfg.max_per_img} respectively' \
+ f'Please delete max_num which will be deprecated.'
+ else:
+ cfg.max_per_img = cfg.max_num
+ if 'nms_thr' in cfg:
+ assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
+ f'iou_threshold in nms and ' \
+ f'nms_thr at the same time, but get ' \
+ f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
+ f' respectively. Please delete the nms_thr ' \
+ f'which will be deprecated.'
+
+ recovered_proposals = []
+ for proposals, img_info in zip(aug_proposals, img_metas):
+ img_shape = img_info['img_shape']
+ scale_factor = img_info['scale_factor']
+ flip = img_info['flip']
+ flip_direction = img_info['flip_direction']
+ _proposals = proposals.clone()
+ _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape,
+ scale_factor, flip,
+ flip_direction)
+ recovered_proposals.append(_proposals)
+ aug_proposals = torch.cat(recovered_proposals, dim=0)
+ merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(),
+ aug_proposals[:, -1].contiguous(),
+ cfg.nms.iou_threshold)
+ scores = merged_proposals[:, 4]
+ _, order = scores.sort(0, descending=True)
+ num = min(cfg.max_per_img, merged_proposals.shape[0])
+ order = order[:num]
+ merged_proposals = merged_proposals[order, :]
+ return merged_proposals
+
+
+def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg):
+ """Merge augmented detection bboxes and scores.
+
+ Args:
+ aug_bboxes (list[Tensor]): shape (n, 4*#class)
+ aug_scores (list[Tensor] or None): shape (n, #class)
+ img_shapes (list[Tensor]): shape (3, ).
+ rcnn_test_cfg (dict): rcnn test config.
+
+ Returns:
+ tuple: (bboxes, scores)
+ """
+ recovered_bboxes = []
+ for bboxes, img_info in zip(aug_bboxes, img_metas):
+ img_shape = img_info[0]['img_shape']
+ scale_factor = img_info[0]['scale_factor']
+ flip = img_info[0]['flip']
+ flip_direction = img_info[0]['flip_direction']
+ bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
+ flip_direction)
+ recovered_bboxes.append(bboxes)
+ bboxes = torch.stack(recovered_bboxes).mean(dim=0)
+ if aug_scores is None:
+ return bboxes
+ else:
+ scores = torch.stack(aug_scores).mean(dim=0)
+ return bboxes, scores
+
+
+def merge_aug_scores(aug_scores):
+ """Merge augmented bbox scores."""
+ if isinstance(aug_scores[0], torch.Tensor):
+ return torch.mean(torch.stack(aug_scores), dim=0)
+ else:
+ return np.mean(aug_scores, axis=0)
+
+
+def merge_aug_masks(aug_masks, img_metas, rcnn_test_cfg, weights=None):
+ """Merge augmented mask prediction.
+
+ Args:
+ aug_masks (list[ndarray]): shape (n, #class, h, w)
+ img_shapes (list[ndarray]): shape (3, ).
+ rcnn_test_cfg (dict): rcnn test config.
+
+ Returns:
+ tuple: (bboxes, scores)
+ """
+ recovered_masks = []
+ for mask, img_info in zip(aug_masks, img_metas):
+ flip = img_info[0]['flip']
+ if flip:
+ flip_direction = img_info[0]['flip_direction']
+ if flip_direction == 'horizontal':
+ mask = mask[:, :, :, ::-1]
+ elif flip_direction == 'vertical':
+ mask = mask[:, :, ::-1, :]
+ elif flip_direction == 'diagonal':
+ mask = mask[:, :, :, ::-1]
+ mask = mask[:, :, ::-1, :]
+ else:
+ raise ValueError(
+ f"Invalid flipping direction '{flip_direction}'")
+ recovered_masks.append(mask)
+
+ if weights is None:
+ merged_masks = np.mean(recovered_masks, axis=0)
+ else:
+ merged_masks = np.average(
+ np.array(recovered_masks), axis=0, weights=np.array(weights))
+ return merged_masks
diff --git a/mmdet/core/utils/__init__.py b/mmdet/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f0d07081a265d249d0ddb3a80ce39bf29e668e9
--- /dev/null
+++ b/mmdet/core/utils/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads,
+ reduce_mean, sync_random_seed)
+from .misc import (center_of_mass, filter_scores_and_topk, flip_tensor,
+ generate_coordinate, mask2ndarray, multi_apply,
+ select_single_mlvl, unmap)
+
+__all__ = [
+ 'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply',
+ 'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict',
+ 'center_of_mass', 'generate_coordinate', 'select_single_mlvl',
+ 'filter_scores_and_topk', 'sync_random_seed'
+]
diff --git a/mmdet/core/utils/dist_utils.py b/mmdet/core/utils/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8760774fd90e666c03ca4d553111363065a08426
--- /dev/null
+++ b/mmdet/core/utils/dist_utils.py
@@ -0,0 +1,193 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import pickle
+import warnings
+from collections import OrderedDict
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.runner import OptimizerHook, get_dist_info
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+
+
+def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
+ if bucket_size_mb > 0:
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
+ buckets = _take_tensors(tensors, bucket_size_bytes)
+ else:
+ buckets = OrderedDict()
+ for tensor in tensors:
+ tp = tensor.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(tensor)
+ buckets = buckets.values()
+
+ for bucket in buckets:
+ flat_tensors = _flatten_dense_tensors(bucket)
+ dist.all_reduce(flat_tensors)
+ flat_tensors.div_(world_size)
+ for tensor, synced in zip(
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
+ tensor.copy_(synced)
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce gradients.
+
+ Args:
+ params (list[torch.Parameters]): List of parameters of a model
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ grads = [
+ param.grad.data for param in params
+ if param.requires_grad and param.grad is not None
+ ]
+ world_size = dist.get_world_size()
+ if coalesce:
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
+ else:
+ for tensor in grads:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+class DistOptimizerHook(OptimizerHook):
+ """Deprecated optimizer hook for distributed training."""
+
+ def __init__(self, *args, **kwargs):
+ warnings.warn('"DistOptimizerHook" is deprecated, please switch to'
+ '"mmcv.runner.OptimizerHook".')
+ super().__init__(*args, **kwargs)
+
+
+def reduce_mean(tensor):
+ """"Obtain the mean of tensor on different GPUs."""
+ if not (dist.is_available() and dist.is_initialized()):
+ return tensor
+ tensor = tensor.clone()
+ dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+ return tensor
+
+
+def obj2tensor(pyobj, device='cuda'):
+ """Serialize picklable python object to tensor."""
+ storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
+ return torch.ByteTensor(storage).to(device=device)
+
+
+def tensor2obj(tensor):
+ """Deserialize tensor to picklable python object."""
+ return pickle.loads(tensor.cpu().numpy().tobytes())
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """Return a process group based on gloo backend, containing all the ranks
+ The result is cached."""
+ if dist.get_backend() == 'nccl':
+ return dist.new_group(backend='gloo')
+ else:
+ return dist.group.WORLD
+
+
+def all_reduce_dict(py_dict, op='sum', group=None, to_float=True):
+ """Apply all reduce function for python dict object.
+
+ The code is modified from https://github.com/Megvii-
+ BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py.
+
+ NOTE: make sure that py_dict in different ranks has the same keys and
+ the values should be in the same shape. Currently only supports
+ nccl backend.
+
+ Args:
+ py_dict (dict): Dict to be applied all reduce op.
+ op (str): Operator, could be 'sum' or 'mean'. Default: 'sum'
+ group (:obj:`torch.distributed.group`, optional): Distributed group,
+ Default: None.
+ to_float (bool): Whether to convert all values of dict to float.
+ Default: True.
+
+ Returns:
+ OrderedDict: reduced python dict object.
+ """
+ warnings.warn(
+ 'group` is deprecated. Currently only supports NCCL backend.')
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return py_dict
+
+ # all reduce logic across different devices.
+ py_key = list(py_dict.keys())
+ if not isinstance(py_dict, OrderedDict):
+ py_key_tensor = obj2tensor(py_key)
+ dist.broadcast(py_key_tensor, src=0)
+ py_key = tensor2obj(py_key_tensor)
+
+ tensor_shapes = [py_dict[k].shape for k in py_key]
+ tensor_numels = [py_dict[k].numel() for k in py_key]
+
+ if to_float:
+ warnings.warn('Note: the "to_float" is True, you need to '
+ 'ensure that the behavior is reasonable.')
+ flatten_tensor = torch.cat(
+ [py_dict[k].flatten().float() for k in py_key])
+ else:
+ flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
+
+ dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM)
+ if op == 'mean':
+ flatten_tensor /= world_size
+
+ split_tensors = [
+ x.reshape(shape) for x, shape in zip(
+ torch.split(flatten_tensor, tensor_numels), tensor_shapes)
+ ]
+ out_dict = {k: v for k, v in zip(py_key, split_tensors)}
+ if isinstance(py_dict, OrderedDict):
+ out_dict = OrderedDict(out_dict)
+ return out_dict
+
+
+def sync_random_seed(seed=None, device='cuda'):
+ """Make sure different ranks share the same seed.
+
+ All workers must call this function, otherwise it will deadlock.
+ This method is generally used in `DistributedSampler`,
+ because the seed should be identical across all processes
+ in the distributed group.
+
+ In distributed sampling, different ranks should sample non-overlapped
+ data in the dataset. Therefore, this function is used to make sure that
+ each rank shuffles the data indices in the same order based
+ on the same seed. Then different ranks could use different indices
+ to select non-overlapped data from the same data list.
+
+ Args:
+ seed (int, Optional): The seed. Default to None.
+ device (str): The device where the seed will be put on.
+ Default to 'cuda'.
+
+ Returns:
+ int: Seed to be used.
+ """
+ if seed is None:
+ seed = np.random.randint(2**31)
+ assert isinstance(seed, int)
+
+ rank, world_size = get_dist_info()
+
+ if world_size == 1:
+ return seed
+
+ if rank == 0:
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+ else:
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
+ dist.broadcast(random_num, src=0)
+ return random_num.item()
diff --git a/mmdet/core/utils/misc.py b/mmdet/core/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..14cb745e38e7f2a9c0fea43be926eb2f0dddd734
--- /dev/null
+++ b/mmdet/core/utils/misc.py
@@ -0,0 +1,208 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+
+import numpy as np
+import torch
+from six.moves import map, zip
+
+from ..mask.structures import BitmapMasks, PolygonMasks
+
+
+def multi_apply(func, *args, **kwargs):
+ """Apply function to a list of arguments.
+
+ Note:
+ This function applies the ``func`` to multiple inputs and
+ map the multiple outputs of the ``func`` into different
+ list. Each list contains the same type of outputs corresponding
+ to different inputs.
+
+ Args:
+ func (Function): A function that will be applied to a list of
+ arguments
+
+ Returns:
+ tuple(list): A tuple containing multiple list, each list contains \
+ a kind of returned results by the function
+ """
+ pfunc = partial(func, **kwargs) if kwargs else func
+ map_results = map(pfunc, *args)
+ return tuple(map(list, zip(*map_results)))
+
+
+def unmap(data, count, inds, fill=0):
+ """Unmap a subset of item (data) back to the original set of items (of size
+ count)"""
+ if data.dim() == 1:
+ ret = data.new_full((count, ), fill)
+ ret[inds.type(torch.bool)] = data
+ else:
+ new_size = (count, ) + data.size()[1:]
+ ret = data.new_full(new_size, fill)
+ ret[inds.type(torch.bool), :] = data
+ return ret
+
+
+def mask2ndarray(mask):
+ """Convert Mask to ndarray..
+
+ Args:
+ mask (:obj:`BitmapMasks` or :obj:`PolygonMasks` or
+ torch.Tensor or np.ndarray): The mask to be converted.
+
+ Returns:
+ np.ndarray: Ndarray mask of shape (n, h, w) that has been converted
+ """
+ if isinstance(mask, (BitmapMasks, PolygonMasks)):
+ mask = mask.to_ndarray()
+ elif isinstance(mask, torch.Tensor):
+ mask = mask.detach().cpu().numpy()
+ elif not isinstance(mask, np.ndarray):
+ raise TypeError(f'Unsupported {type(mask)} data type')
+ return mask
+
+
+def flip_tensor(src_tensor, flip_direction):
+ """flip tensor base on flip_direction.
+
+ Args:
+ src_tensor (Tensor): input feature map, shape (B, C, H, W).
+ flip_direction (str): The flipping direction. Options are
+ 'horizontal', 'vertical', 'diagonal'.
+
+ Returns:
+ out_tensor (Tensor): Flipped tensor.
+ """
+ assert src_tensor.ndim == 4
+ valid_directions = ['horizontal', 'vertical', 'diagonal']
+ assert flip_direction in valid_directions
+ if flip_direction == 'horizontal':
+ out_tensor = torch.flip(src_tensor, [3])
+ elif flip_direction == 'vertical':
+ out_tensor = torch.flip(src_tensor, [2])
+ else:
+ out_tensor = torch.flip(src_tensor, [2, 3])
+ return out_tensor
+
+
+def select_single_mlvl(mlvl_tensors, batch_id, detach=True):
+ """Extract a multi-scale single image tensor from a multi-scale batch
+ tensor based on batch index.
+
+ Note: The default value of detach is True, because the proposal gradient
+ needs to be detached during the training of the two-stage model. E.g
+ Cascade Mask R-CNN.
+
+ Args:
+ mlvl_tensors (list[Tensor]): Batch tensor for all scale levels,
+ each is a 4D-tensor.
+ batch_id (int): Batch index.
+ detach (bool): Whether detach gradient. Default True.
+
+ Returns:
+ list[Tensor]: Multi-scale single image tensor.
+ """
+ assert isinstance(mlvl_tensors, (list, tuple))
+ num_levels = len(mlvl_tensors)
+
+ if detach:
+ mlvl_tensor_list = [
+ mlvl_tensors[i][batch_id].detach() for i in range(num_levels)
+ ]
+ else:
+ mlvl_tensor_list = [
+ mlvl_tensors[i][batch_id] for i in range(num_levels)
+ ]
+ return mlvl_tensor_list
+
+
+def filter_scores_and_topk(scores, score_thr, topk, results=None):
+ """Filter results using score threshold and topk candidates.
+
+ Args:
+ scores (Tensor): The scores, shape (num_bboxes, K).
+ score_thr (float): The score filter threshold.
+ topk (int): The number of topk candidates.
+ results (dict or list or Tensor, Optional): The results to
+ which the filtering rule is to be applied. The shape
+ of each item is (num_bboxes, N).
+
+ Returns:
+ tuple: Filtered results
+
+ - scores (Tensor): The scores after being filtered, \
+ shape (num_bboxes_filtered, ).
+ - labels (Tensor): The class labels, shape \
+ (num_bboxes_filtered, ).
+ - anchor_idxs (Tensor): The anchor indexes, shape \
+ (num_bboxes_filtered, ).
+ - filtered_results (dict or list or Tensor, Optional): \
+ The filtered results. The shape of each item is \
+ (num_bboxes_filtered, N).
+ """
+ valid_mask = scores > score_thr
+ scores = scores[valid_mask]
+ valid_idxs = torch.nonzero(valid_mask)
+
+ num_topk = min(topk, valid_idxs.size(0))
+ # torch.sort is actually faster than .topk (at least on GPUs)
+ scores, idxs = scores.sort(descending=True)
+ scores = scores[:num_topk]
+ topk_idxs = valid_idxs[idxs[:num_topk]]
+ keep_idxs, labels = topk_idxs.unbind(dim=1)
+
+ filtered_results = None
+ if results is not None:
+ if isinstance(results, dict):
+ filtered_results = {k: v[keep_idxs] for k, v in results.items()}
+ elif isinstance(results, list):
+ filtered_results = [result[keep_idxs] for result in results]
+ elif isinstance(results, torch.Tensor):
+ filtered_results = results[keep_idxs]
+ else:
+ raise NotImplementedError(f'Only supports dict or list or Tensor, '
+ f'but get {type(results)}.')
+ return scores, labels, keep_idxs, filtered_results
+
+
+def center_of_mass(mask, esp=1e-6):
+ """Calculate the centroid coordinates of the mask.
+
+ Args:
+ mask (Tensor): The mask to be calculated, shape (h, w).
+ esp (float): Avoid dividing by zero. Default: 1e-6.
+
+ Returns:
+ tuple[Tensor]: the coordinates of the center point of the mask.
+
+ - center_h (Tensor): the center point of the height.
+ - center_w (Tensor): the center point of the width.
+ """
+ h, w = mask.shape
+ grid_h = torch.arange(h, device=mask.device)[:, None]
+ grid_w = torch.arange(w, device=mask.device)
+ normalizer = mask.sum().float().clamp(min=esp)
+ center_h = (mask * grid_h).sum() / normalizer
+ center_w = (mask * grid_w).sum() / normalizer
+ return center_h, center_w
+
+
+def generate_coordinate(featmap_sizes, device='cuda'):
+ """Generate the coordinate.
+
+ Args:
+ featmap_sizes (tuple): The feature to be calculated,
+ of shape (N, C, W, H).
+ device (str): The device where the feature will be put on.
+ Returns:
+ coord_feat (Tensor): The coordinate feature, of shape (N, 2, W, H).
+ """
+
+ x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
+ y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
+ y, x = torch.meshgrid(y_range, x_range)
+ y = y.expand([featmap_sizes[0], 1, -1, -1])
+ x = x.expand([featmap_sizes[0], 1, -1, -1])
+ coord_feat = torch.cat([x, y], 1)
+
+ return coord_feat
diff --git a/mmdet/core/visualization/__init__.py b/mmdet/core/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eb17c4b32bc0c5c76db31e22e995716ba718222
--- /dev/null
+++ b/mmdet/core/visualization/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .image import (color_val_matplotlib, imshow_det_bboxes,
+ imshow_gt_det_bboxes)
+from .palette import get_palette, palette_val
+
+__all__ = [
+ 'imshow_det_bboxes', 'imshow_gt_det_bboxes', 'color_val_matplotlib',
+ 'palette_val', 'get_palette'
+]
diff --git a/mmdet/core/visualization/image.py b/mmdet/core/visualization/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..63eae8a2846b78394f0ed554d182e04a0da36021
--- /dev/null
+++ b/mmdet/core/visualization/image.py
@@ -0,0 +1,563 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+
+import cv2
+import matplotlib.pyplot as plt
+import mmcv
+import numpy as np
+import pycocotools.mask as mask_util
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon
+
+from mmdet.core.evaluation.panoptic_utils import INSTANCE_OFFSET
+from ..mask.structures import bitmap_to_polygon
+from ..utils import mask2ndarray
+from .palette import get_palette, palette_val
+
+__all__ = [
+ 'color_val_matplotlib', 'draw_masks', 'draw_bboxes', 'draw_labels',
+ 'imshow_det_bboxes', 'imshow_gt_det_bboxes'
+]
+
+EPS = 1e-2
+
+
+def color_val_matplotlib(color):
+ """Convert various input in BGR order to normalized RGB matplotlib color
+ tuples.
+
+ Args:
+ color (:obj`Color` | str | tuple | int | ndarray): Color inputs.
+
+ Returns:
+ tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
+ """
+ color = mmcv.color_val(color)
+ color = [color / 255 for color in color[::-1]]
+ return tuple(color)
+
+
+def _get_adaptive_scales(areas, min_area=800, max_area=30000):
+ """Get adaptive scales according to areas.
+
+ The scale range is [0.5, 1.0]. When the area is less than
+ ``'min_area'``, the scale is 0.5 while the area is larger than
+ ``'max_area'``, the scale is 1.0.
+
+ Args:
+ areas (ndarray): The areas of bboxes or masks with the
+ shape of (n, ).
+ min_area (int): Lower bound areas for adaptive scales.
+ Default: 800.
+ max_area (int): Upper bound areas for adaptive scales.
+ Default: 30000.
+
+ Returns:
+ ndarray: The adaotive scales with the shape of (n, ).
+ """
+ scales = 0.5 + (areas - min_area) / (max_area - min_area)
+ scales = np.clip(scales, 0.5, 1.0)
+ return scales
+
+
+def _get_bias_color(base, max_dist=30):
+ """Get different colors for each masks.
+
+ Get different colors for each masks by adding a bias
+ color to the base category color.
+ Args:
+ base (ndarray): The base category color with the shape
+ of (3, ).
+ max_dist (int): The max distance of bias. Default: 30.
+
+ Returns:
+ ndarray: The new color for a mask with the shape of (3, ).
+ """
+ new_color = base + np.random.randint(
+ low=-max_dist, high=max_dist + 1, size=3)
+ return np.clip(new_color, 0, 255, new_color)
+
+
+def draw_bboxes(ax, bboxes, color='g', alpha=0.8, thickness=2):
+ """Draw bounding boxes on the axes.
+
+ Args:
+ ax (matplotlib.Axes): The input axes.
+ bboxes (ndarray): The input bounding boxes with the shape
+ of (n, 4).
+ color (list[tuple] | matplotlib.color): the colors for each
+ bounding boxes.
+ alpha (float): Transparency of bounding boxes. Default: 0.8.
+ thickness (int): Thickness of lines. Default: 2.
+
+ Returns:
+ matplotlib.Axes: The result axes.
+ """
+ polygons = []
+ for i, bbox in enumerate(bboxes):
+ bbox_int = bbox.astype(np.int32)
+ poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
+ [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
+ np_poly = np.array(poly).reshape((4, 2))
+ polygons.append(Polygon(np_poly))
+ p = PatchCollection(
+ polygons,
+ facecolor='none',
+ edgecolors=color,
+ linewidths=thickness,
+ alpha=alpha)
+ ax.add_collection(p)
+
+ return ax
+
+
+def draw_labels(ax,
+ labels,
+ positions,
+ scores=None,
+ class_names=None,
+ color='w',
+ font_size=8,
+ scales=None,
+ horizontal_alignment='left'):
+ """Draw labels on the axes.
+
+ Args:
+ ax (matplotlib.Axes): The input axes.
+ labels (ndarray): The labels with the shape of (n, ).
+ positions (ndarray): The positions to draw each labels.
+ scores (ndarray): The scores for each labels.
+ class_names (list[str]): The class names.
+ color (list[tuple] | matplotlib.color): The colors for labels.
+ font_size (int): Font size of texts. Default: 8.
+ scales (list[float]): Scales of texts. Default: None.
+ horizontal_alignment (str): The horizontal alignment method of
+ texts. Default: 'left'.
+
+ Returns:
+ matplotlib.Axes: The result axes.
+ """
+ for i, (pos, label) in enumerate(zip(positions, labels)):
+ label_text = class_names[
+ label] if class_names is not None else f'class {label}'
+ if scores is not None:
+ label_text += f'|{scores[i]:.02f}'
+ text_color = color[i] if isinstance(color, list) else color
+
+ font_size_mask = font_size if scales is None else font_size * scales[i]
+ ax.text(
+ pos[0],
+ pos[1],
+ f'{label_text}',
+ bbox={
+ 'facecolor': 'black',
+ 'alpha': 0.8,
+ 'pad': 0.7,
+ 'edgecolor': 'none'
+ },
+ color=text_color,
+ fontsize=font_size_mask,
+ verticalalignment='top',
+ horizontalalignment=horizontal_alignment)
+
+ return ax
+
+
+def draw_masks(ax, img, masks, color=None, with_edge=True, alpha=0.8):
+ """Draw masks on the image and their edges on the axes.
+
+ Args:
+ ax (matplotlib.Axes): The input axes.
+ img (ndarray): The image with the shape of (3, h, w).
+ masks (ndarray): The masks with the shape of (n, h, w).
+ color (ndarray): The colors for each masks with the shape
+ of (n, 3).
+ with_edge (bool): Whether to draw edges. Default: True.
+ alpha (float): Transparency of bounding boxes. Default: 0.8.
+
+ Returns:
+ matplotlib.Axes: The result axes.
+ ndarray: The result image.
+ """
+ taken_colors = set([0, 0, 0])
+ if color is None:
+ random_colors = np.random.randint(0, 255, (masks.size(0), 3))
+ color = [tuple(c) for c in random_colors]
+ color = np.array(color, dtype=np.uint8)
+ polygons = []
+ for i, mask in enumerate(masks):
+ if with_edge:
+ contours, _ = bitmap_to_polygon(mask)
+ polygons += [Polygon(c) for c in contours]
+
+ color_mask = color[i]
+ while tuple(color_mask) in taken_colors:
+ color_mask = _get_bias_color(color_mask)
+ taken_colors.add(tuple(color_mask))
+
+ mask = mask.astype(bool)
+ img[mask] = img[mask] * (1 - alpha) + color_mask * alpha
+
+ p = PatchCollection(
+ polygons, facecolor='none', edgecolors='w', linewidths=1, alpha=0.8)
+ ax.add_collection(p)
+
+ return ax, img
+
+
+def imshow_det_bboxes(img,
+ bboxes=None,
+ labels=None,
+ segms=None,
+ class_names=None,
+ score_thr=0,
+ bbox_color='green',
+ text_color='green',
+ mask_color=None,
+ thickness=2,
+ font_size=8,
+ win_name='',
+ show=True,
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes and class labels (with scores) on an image.
+
+ Args:
+ img (str | ndarray): The image to be displayed.
+ bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
+ (n, 5).
+ labels (ndarray): Labels of bboxes.
+ segms (ndarray | None): Masks, shaped (n,h,w) or None.
+ class_names (list[str]): Names of each classes.
+ score_thr (float): Minimum score of bboxes to be shown. Default: 0.
+ bbox_color (list[tuple] | tuple | str | None): Colors of bbox lines.
+ If a single color is given, it will be applied to all classes.
+ The tuple of color should be in RGB order. Default: 'green'.
+ text_color (list[tuple] | tuple | str | None): Colors of texts.
+ If a single color is given, it will be applied to all classes.
+ The tuple of color should be in RGB order. Default: 'green'.
+ mask_color (list[tuple] | tuple | str | None, optional): Colors of
+ masks. If a single color is given, it will be applied to all
+ classes. The tuple of color should be in RGB order.
+ Default: None.
+ thickness (int): Thickness of lines. Default: 2.
+ font_size (int): Font size of texts. Default: 13.
+ show (bool): Whether to show the image. Default: True.
+ win_name (str): The window name. Default: ''.
+ wait_time (float): Value of waitKey param. Default: 0.
+ out_file (str, optional): The filename to write the image.
+ Default: None.
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ assert bboxes is None or bboxes.ndim == 2, \
+ f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
+ assert labels.ndim == 1, \
+ f' labels ndim should be 1, but its ndim is {labels.ndim}.'
+ assert bboxes is None or bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \
+ f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'
+ assert bboxes is None or bboxes.shape[0] <= labels.shape[0], \
+ 'labels.shape[0] should not be less than bboxes.shape[0].'
+ assert segms is None or segms.shape[0] == labels.shape[0], \
+ 'segms.shape[0] and labels.shape[0] should have the same length.'
+ assert segms is not None or bboxes is not None, \
+ 'segms and bboxes should not be None at the same time.'
+
+ img = mmcv.imread(img).astype(np.uint8)
+
+ if score_thr > 0:
+ assert bboxes is not None and bboxes.shape[1] == 5
+ scores = bboxes[:, -1]
+ inds = scores > score_thr
+ bboxes = bboxes[inds, :]
+ labels = labels[inds]
+ if segms is not None:
+ segms = segms[inds, ...]
+
+ img = mmcv.bgr2rgb(img)
+ width, height = img.shape[1], img.shape[0]
+ img = np.ascontiguousarray(img)
+
+ fig = plt.figure(win_name, frameon=False)
+ plt.title(win_name)
+ canvas = fig.canvas
+ dpi = fig.get_dpi()
+ # add a small EPS to avoid precision lost due to matplotlib's truncation
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
+ fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
+
+ # remove white edges by set subplot margin
+ plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
+ ax = plt.gca()
+ ax.axis('off')
+
+ max_label = int(max(labels) if len(labels) > 0 else 0)
+ text_palette = palette_val(get_palette(text_color, max_label + 1))
+ text_colors = [text_palette[label] for label in labels]
+
+ num_bboxes = 0
+ if bboxes is not None:
+ num_bboxes = bboxes.shape[0]
+ bbox_palette = palette_val(get_palette(bbox_color, max_label + 1))
+ colors = [bbox_palette[label] for label in labels[:num_bboxes]]
+ draw_bboxes(ax, bboxes, colors, alpha=0.8, thickness=thickness)
+
+ horizontal_alignment = 'left'
+ positions = bboxes[:, :2].astype(np.int32) + thickness
+ areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
+ scales = _get_adaptive_scales(areas)
+ scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
+ draw_labels(
+ ax,
+ labels[:num_bboxes],
+ positions,
+ scores=scores,
+ class_names=class_names,
+ color=text_colors,
+ font_size=font_size,
+ scales=scales,
+ horizontal_alignment=horizontal_alignment)
+
+ if segms is not None:
+ mask_palette = get_palette(mask_color, max_label + 1)
+ colors = [mask_palette[label] for label in labels]
+ colors = np.array(colors, dtype=np.uint8)
+ draw_masks(ax, img, segms, colors, with_edge=True)
+
+ if num_bboxes < segms.shape[0]:
+ segms = segms[num_bboxes:]
+ horizontal_alignment = 'center'
+ areas = []
+ positions = []
+ for mask in segms:
+ _, _, stats, centroids = cv2.connectedComponentsWithStats(
+ mask.astype(np.uint8), connectivity=8)
+ largest_id = np.argmax(stats[1:, -1]) + 1
+ positions.append(centroids[largest_id])
+ areas.append(stats[largest_id, -1])
+ areas = np.stack(areas, axis=0)
+ scales = _get_adaptive_scales(areas)
+ draw_labels(
+ ax,
+ labels[num_bboxes:],
+ positions,
+ class_names=class_names,
+ color=text_colors,
+ font_size=font_size,
+ scales=scales,
+ horizontal_alignment=horizontal_alignment)
+
+ plt.imshow(img)
+
+ stream, _ = canvas.print_to_buffer()
+ buffer = np.frombuffer(stream, dtype='uint8')
+ if sys.platform == 'darwin':
+ width, height = canvas.get_width_height(physical=True)
+ img_rgba = buffer.reshape(height, width, 4)
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
+ img = rgb.astype('uint8')
+ img = mmcv.rgb2bgr(img)
+
+ if show:
+ # We do not use cv2 for display because in some cases, opencv will
+ # conflict with Qt, it will output a warning: Current thread
+ # is not the object's thread. You can refer to
+ # https://github.com/opencv/opencv-python/issues/46 for details
+ if wait_time == 0:
+ plt.show()
+ else:
+ plt.show(block=False)
+ plt.pause(wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+
+ plt.close()
+
+ return img
+
+
+def imshow_gt_det_bboxes(img,
+ annotation,
+ result,
+ class_names=None,
+ score_thr=0,
+ gt_bbox_color=(61, 102, 255),
+ gt_text_color=(200, 200, 200),
+ gt_mask_color=(61, 102, 255),
+ det_bbox_color=(241, 101, 72),
+ det_text_color=(200, 200, 200),
+ det_mask_color=(241, 101, 72),
+ thickness=2,
+ font_size=13,
+ win_name='',
+ show=True,
+ wait_time=0,
+ out_file=None,
+ overlay_gt_pred=True):
+ """General visualization GT and result function.
+
+ Args:
+ img (str | ndarray): The image to be displayed.
+ annotation (dict): Ground truth annotations where contain keys of
+ 'gt_bboxes' and 'gt_labels' or 'gt_masks'.
+ result (tuple[list] | list): The detection result, can be either
+ (bbox, segm) or just bbox.
+ class_names (list[str]): Names of each classes.
+ score_thr (float): Minimum score of bboxes to be shown. Default: 0.
+ gt_bbox_color (list[tuple] | tuple | str | None): Colors of bbox lines.
+ If a single color is given, it will be applied to all classes.
+ The tuple of color should be in RGB order. Default: (61, 102, 255).
+ gt_text_color (list[tuple] | tuple | str | None): Colors of texts.
+ If a single color is given, it will be applied to all classes.
+ The tuple of color should be in RGB order. Default: (200, 200, 200).
+ gt_mask_color (list[tuple] | tuple | str | None, optional): Colors of
+ masks. If a single color is given, it will be applied to all classes.
+ The tuple of color should be in RGB order. Default: (61, 102, 255).
+ det_bbox_color (list[tuple] | tuple | str | None):Colors of bbox lines.
+ If a single color is given, it will be applied to all classes.
+ The tuple of color should be in RGB order. Default: (241, 101, 72).
+ det_text_color (list[tuple] | tuple | str | None):Colors of texts.
+ If a single color is given, it will be applied to all classes.
+ The tuple of color should be in RGB order. Default: (200, 200, 200).
+ det_mask_color (list[tuple] | tuple | str | None, optional): Color of
+ masks. If a single color is given, it will be applied to all classes.
+ The tuple of color should be in RGB order. Default: (241, 101, 72).
+ thickness (int): Thickness of lines. Default: 2.
+ font_size (int): Font size of texts. Default: 13.
+ win_name (str): The window name. Default: ''.
+ show (bool): Whether to show the image. Default: True.
+ wait_time (float): Value of waitKey param. Default: 0.
+ out_file (str, optional): The filename to write the image.
+ Default: None.
+ overlay_gt_pred (bool): Whether to plot gts and predictions on the
+ same image. If False, predictions and gts will be plotted on two same
+ image which will be concatenated in vertical direction. The image
+ above is drawn with gt, and the image below is drawn with the
+ prediction result. Default: True.
+
+ Returns:
+ ndarray: The image with bboxes or masks drawn on it.
+ """
+ assert 'gt_bboxes' in annotation
+ assert 'gt_labels' in annotation
+ assert isinstance(result, (tuple, list, dict)), 'Expected ' \
+ f'tuple or list or dict, but get {type(result)}'
+
+ gt_bboxes = annotation['gt_bboxes']
+ gt_labels = annotation['gt_labels']
+ gt_masks = annotation.get('gt_masks', None)
+ if gt_masks is not None:
+ gt_masks = mask2ndarray(gt_masks)
+
+ gt_seg = annotation.get('gt_semantic_seg', None)
+ if gt_seg is not None:
+ pad_value = 255 # the padding value of gt_seg
+ sem_labels = np.unique(gt_seg)
+ all_labels = np.concatenate((gt_labels, sem_labels), axis=0)
+ all_labels, counts = np.unique(all_labels, return_counts=True)
+ stuff_labels = all_labels[np.logical_and(counts < 2,
+ all_labels != pad_value)]
+ stuff_masks = gt_seg[None] == stuff_labels[:, None, None]
+ gt_labels = np.concatenate((gt_labels, stuff_labels), axis=0)
+ gt_masks = np.concatenate((gt_masks, stuff_masks.astype(np.uint8)),
+ axis=0)
+ # If you need to show the bounding boxes,
+ # please comment the following line
+ # gt_bboxes = None
+
+ img = mmcv.imread(img)
+
+ img_with_gt = imshow_det_bboxes(
+ img,
+ gt_bboxes,
+ gt_labels,
+ gt_masks,
+ class_names=class_names,
+ bbox_color=gt_bbox_color,
+ text_color=gt_text_color,
+ mask_color=gt_mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=False)
+
+ if not isinstance(result, dict):
+ if isinstance(result, tuple):
+ bbox_result, segm_result = result
+ if isinstance(segm_result, tuple):
+ segm_result = segm_result[0] # ms rcnn
+ else:
+ bbox_result, segm_result = result, None
+
+ bboxes = np.vstack(bbox_result)
+ labels = [
+ np.full(bbox.shape[0], i, dtype=np.int32)
+ for i, bbox in enumerate(bbox_result)
+ ]
+ labels = np.concatenate(labels)
+
+ segms = None
+ if segm_result is not None and len(labels) > 0: # non empty
+ segms = mmcv.concat_list(segm_result)
+ segms = mask_util.decode(segms)
+ segms = segms.transpose(2, 0, 1)
+ else:
+ assert class_names is not None, 'We need to know the number ' \
+ 'of classes.'
+ VOID = len(class_names)
+ bboxes = None
+ pan_results = result['pan_results']
+ # keep objects ahead
+ ids = np.unique(pan_results)[::-1]
+ legal_indices = ids != VOID
+ ids = ids[legal_indices]
+ labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
+ segms = (pan_results[None] == ids[:, None, None])
+
+ if overlay_gt_pred:
+ img = imshow_det_bboxes(
+ img_with_gt,
+ bboxes,
+ labels,
+ segms=segms,
+ class_names=class_names,
+ score_thr=score_thr,
+ bbox_color=det_bbox_color,
+ text_color=det_text_color,
+ mask_color=det_mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file)
+ else:
+ img_with_det = imshow_det_bboxes(
+ img,
+ bboxes,
+ labels,
+ segms=segms,
+ class_names=class_names,
+ score_thr=score_thr,
+ bbox_color=det_bbox_color,
+ text_color=det_text_color,
+ mask_color=det_mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=False)
+ img = np.concatenate([img_with_gt, img_with_det], axis=0)
+
+ plt.imshow(img)
+ if show:
+ if wait_time == 0:
+ plt.show()
+ else:
+ plt.show(block=False)
+ plt.pause(wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+ plt.close()
+
+ return img
diff --git a/mmdet/core/visualization/palette.py b/mmdet/core/visualization/palette.py
new file mode 100644
index 0000000000000000000000000000000000000000..11692cdd086301d9d3be4a4702dc12881b8e8d6e
--- /dev/null
+++ b/mmdet/core/visualization/palette.py
@@ -0,0 +1,63 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import numpy as np
+
+
+def palette_val(palette):
+ """Convert palette to matplotlib palette.
+
+ Args:
+ palette List[tuple]: A list of color tuples.
+
+ Returns:
+ List[tuple[float]]: A list of RGB matplotlib color tuples.
+ """
+ new_palette = []
+ for color in palette:
+ color = [c / 255 for c in color]
+ new_palette.append(tuple(color))
+ return new_palette
+
+
+def get_palette(palette, num_classes):
+ """Get palette from various inputs.
+
+ Args:
+ palette (list[tuple] | str | tuple | :obj:`Color`): palette inputs.
+ num_classes (int): the number of classes.
+
+ Returns:
+ list[tuple[int]]: A list of color tuples.
+ """
+ assert isinstance(num_classes, int)
+
+ if isinstance(palette, list):
+ dataset_palette = palette
+ elif isinstance(palette, tuple):
+ dataset_palette = [palette] * num_classes
+ elif palette == 'random' or palette is None:
+ state = np.random.get_state()
+ # random color
+ np.random.seed(42)
+ palette = np.random.randint(0, 256, size=(num_classes, 3))
+ np.random.set_state(state)
+ dataset_palette = [tuple(c) for c in palette]
+ elif palette == 'coco':
+ from mmdet.datasets import CocoDataset, CocoPanopticDataset
+ dataset_palette = CocoDataset.PALETTE
+ if len(dataset_palette) < num_classes:
+ dataset_palette = CocoPanopticDataset.PALETTE
+ elif palette == 'citys':
+ from mmdet.datasets import CityscapesDataset
+ dataset_palette = CityscapesDataset.PALETTE
+ elif palette == 'voc':
+ from mmdet.datasets import VOCDataset
+ dataset_palette = VOCDataset.PALETTE
+ elif mmcv.is_str(palette):
+ dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes
+ else:
+ raise TypeError(f'Invalid type for palette: {type(palette)}')
+
+ assert len(dataset_palette) >= num_classes, \
+ 'The length of palette should not be less than `num_classes`.'
+ return dataset_palette
diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..46c49fd42c112a0d9058d6e9da9eecbcb1a475e7
--- /dev/null
+++ b/mmdet/datasets/__init__.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
+from .cityscapes import CityscapesDataset
+from .coco import CocoDataset
+from .coco_occluded import OccludedSeparatedCocoDataset
+from .coco_panoptic import CocoPanopticDataset
+from .custom import CustomDataset
+from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
+ MultiImageMixDataset, RepeatDataset)
+from .deepfashion import DeepFashionDataset
+from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
+from .objects365 import Objects365V1Dataset, Objects365V2Dataset
+from .openimages import OpenImagesChallengeDataset, OpenImagesDataset
+from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
+from .utils import (NumClassCheckHook, get_loading_pipeline,
+ replace_ImageToTensor)
+from .voc import VOCDataset
+from .wider_face import WIDERFaceDataset
+from .xml_style import XMLDataset
+
+__all__ = [
+ 'CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset',
+ 'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset',
+ 'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler',
+ 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
+ 'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
+ 'build_dataset', 'replace_ImageToTensor', 'get_loading_pipeline',
+ 'NumClassCheckHook', 'CocoPanopticDataset', 'MultiImageMixDataset',
+ 'OpenImagesDataset', 'OpenImagesChallengeDataset', 'Objects365V1Dataset',
+ 'Objects365V2Dataset', 'OccludedSeparatedCocoDataset'
+]
diff --git a/mmdet/datasets/api_wrappers/__init__.py b/mmdet/datasets/api_wrappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..af8557593b6a50541bba1198dc9361ab5382547f
--- /dev/null
+++ b/mmdet/datasets/api_wrappers/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .coco_api import COCO, COCOeval
+from .panoptic_evaluation import pq_compute_multi_core, pq_compute_single_core
+
+__all__ = [
+ 'COCO', 'COCOeval', 'pq_compute_multi_core', 'pq_compute_single_core'
+]
diff --git a/mmdet/datasets/api_wrappers/coco_api.py b/mmdet/datasets/api_wrappers/coco_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..eef6341ebbd33c222b5cda9c43c21bac1a9575da
--- /dev/null
+++ b/mmdet/datasets/api_wrappers/coco_api.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This file add snake case alias for coco api
+
+import warnings
+
+import pycocotools
+from pycocotools.coco import COCO as _COCO
+from pycocotools.cocoeval import COCOeval as _COCOeval
+
+
+class COCO(_COCO):
+ """This class is almost the same as official pycocotools package.
+
+ It implements some snake case function aliases. So that the COCO class has
+ the same interface as LVIS class.
+ """
+
+ def __init__(self, annotation_file=None):
+ if getattr(pycocotools, '__version__', '0') >= '12.0.2':
+ warnings.warn(
+ 'mmpycocotools is deprecated. Please install official pycocotools by "pip install pycocotools"', # noqa: E501
+ UserWarning)
+ super().__init__(annotation_file=annotation_file)
+ self.img_ann_map = self.imgToAnns
+ self.cat_img_map = self.catToImgs
+
+ def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
+ return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)
+
+ def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
+ return self.getCatIds(cat_names, sup_names, cat_ids)
+
+ def get_img_ids(self, img_ids=[], cat_ids=[]):
+ return self.getImgIds(img_ids, cat_ids)
+
+ def load_anns(self, ids):
+ return self.loadAnns(ids)
+
+ def load_cats(self, ids):
+ return self.loadCats(ids)
+
+ def load_imgs(self, ids):
+ return self.loadImgs(ids)
+
+
+# just for the ease of import
+COCOeval = _COCOeval
diff --git a/mmdet/datasets/api_wrappers/panoptic_evaluation.py b/mmdet/datasets/api_wrappers/panoptic_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..55f57bf4a4ca3554ab90ac768dc9ec06e9c878d2
--- /dev/null
+++ b/mmdet/datasets/api_wrappers/panoptic_evaluation.py
@@ -0,0 +1,228 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+# Copyright (c) 2018, Alexander Kirillov
+# This file supports `file_client` for `panopticapi`,
+# the source code is copied from `panopticapi`,
+# only the way to load the gt images is modified.
+import multiprocessing
+import os
+
+import mmcv
+import numpy as np
+
+try:
+ from panopticapi.evaluation import OFFSET, VOID, PQStat
+ from panopticapi.utils import rgb2id
+except ImportError:
+ PQStat = None
+ rgb2id = None
+ VOID = 0
+ OFFSET = 256 * 256 * 256
+
+
+def pq_compute_single_core(proc_id,
+ annotation_set,
+ gt_folder,
+ pred_folder,
+ categories,
+ file_client=None,
+ print_log=False):
+ """The single core function to evaluate the metric of Panoptic
+ Segmentation.
+
+ Same as the function with the same name in `panopticapi`. Only the function
+ to load the images is changed to use the file client.
+
+ Args:
+ proc_id (int): The id of the mini process.
+ gt_folder (str): The path of the ground truth images.
+ pred_folder (str): The path of the prediction images.
+ categories (str): The categories of the dataset.
+ file_client (object): The file client of the dataset. If None,
+ the backend will be set to `disk`.
+ print_log (bool): Whether to print the log. Defaults to False.
+ """
+ if PQStat is None:
+ raise RuntimeError(
+ 'panopticapi is not installed, please install it by: '
+ 'pip install git+https://github.com/cocodataset/'
+ 'panopticapi.git.')
+
+ if file_client is None:
+ file_client_args = dict(backend='disk')
+ file_client = mmcv.FileClient(**file_client_args)
+
+ pq_stat = PQStat()
+
+ idx = 0
+ for gt_ann, pred_ann in annotation_set:
+ if print_log and idx % 100 == 0:
+ print('Core: {}, {} from {} images processed'.format(
+ proc_id, idx, len(annotation_set)))
+ idx += 1
+ # The gt images can be on the local disk or `ceph`, so we use
+ # file_client here.
+ img_bytes = file_client.get(
+ os.path.join(gt_folder, gt_ann['file_name']))
+ pan_gt = mmcv.imfrombytes(img_bytes, flag='color', channel_order='rgb')
+ pan_gt = rgb2id(pan_gt)
+
+ # The predictions can only be on the local dist now.
+ pan_pred = mmcv.imread(
+ os.path.join(pred_folder, pred_ann['file_name']),
+ flag='color',
+ channel_order='rgb')
+ pan_pred = rgb2id(pan_pred)
+
+ gt_segms = {el['id']: el for el in gt_ann['segments_info']}
+ pred_segms = {el['id']: el for el in pred_ann['segments_info']}
+
+ # predicted segments area calculation + prediction sanity checks
+ pred_labels_set = set(el['id'] for el in pred_ann['segments_info'])
+ labels, labels_cnt = np.unique(pan_pred, return_counts=True)
+ for label, label_cnt in zip(labels, labels_cnt):
+ if label not in pred_segms:
+ if label == VOID:
+ continue
+ raise KeyError(
+ 'In the image with ID {} segment with ID {} is '
+ 'presented in PNG and not presented in JSON.'.format(
+ gt_ann['image_id'], label))
+ pred_segms[label]['area'] = label_cnt
+ pred_labels_set.remove(label)
+ if pred_segms[label]['category_id'] not in categories:
+ raise KeyError(
+ 'In the image with ID {} segment with ID {} has '
+ 'unknown category_id {}.'.format(
+ gt_ann['image_id'], label,
+ pred_segms[label]['category_id']))
+ if len(pred_labels_set) != 0:
+ raise KeyError(
+ 'In the image with ID {} the following segment IDs {} '
+ 'are presented in JSON and not presented in PNG.'.format(
+ gt_ann['image_id'], list(pred_labels_set)))
+
+ # confusion matrix calculation
+ pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype(
+ np.uint64)
+ gt_pred_map = {}
+ labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True)
+ for label, intersection in zip(labels, labels_cnt):
+ gt_id = label // OFFSET
+ pred_id = label % OFFSET
+ gt_pred_map[(gt_id, pred_id)] = intersection
+
+ # count all matched pairs
+ gt_matched = set()
+ pred_matched = set()
+ for label_tuple, intersection in gt_pred_map.items():
+ gt_label, pred_label = label_tuple
+ if gt_label not in gt_segms:
+ continue
+ if pred_label not in pred_segms:
+ continue
+ if gt_segms[gt_label]['iscrowd'] == 1:
+ continue
+ if gt_segms[gt_label]['category_id'] != pred_segms[pred_label][
+ 'category_id']:
+ continue
+
+ union = pred_segms[pred_label]['area'] + gt_segms[gt_label][
+ 'area'] - intersection - gt_pred_map.get((VOID, pred_label), 0)
+ iou = intersection / union
+ if iou > 0.5:
+ pq_stat[gt_segms[gt_label]['category_id']].tp += 1
+ pq_stat[gt_segms[gt_label]['category_id']].iou += iou
+ gt_matched.add(gt_label)
+ pred_matched.add(pred_label)
+
+ # count false positives
+ crowd_labels_dict = {}
+ for gt_label, gt_info in gt_segms.items():
+ if gt_label in gt_matched:
+ continue
+ # crowd segments are ignored
+ if gt_info['iscrowd'] == 1:
+ crowd_labels_dict[gt_info['category_id']] = gt_label
+ continue
+ pq_stat[gt_info['category_id']].fn += 1
+
+ # count false positives
+ for pred_label, pred_info in pred_segms.items():
+ if pred_label in pred_matched:
+ continue
+ # intersection of the segment with VOID
+ intersection = gt_pred_map.get((VOID, pred_label), 0)
+ # plus intersection with corresponding CROWD region if it exists
+ if pred_info['category_id'] in crowd_labels_dict:
+ intersection += gt_pred_map.get(
+ (crowd_labels_dict[pred_info['category_id']], pred_label),
+ 0)
+ # predicted segment is ignored if more than half of
+ # the segment correspond to VOID and CROWD regions
+ if intersection / pred_info['area'] > 0.5:
+ continue
+ pq_stat[pred_info['category_id']].fp += 1
+
+ if print_log:
+ print('Core: {}, all {} images processed'.format(
+ proc_id, len(annotation_set)))
+ return pq_stat
+
+
+def pq_compute_multi_core(matched_annotations_list,
+ gt_folder,
+ pred_folder,
+ categories,
+ file_client=None,
+ nproc=32):
+ """Evaluate the metrics of Panoptic Segmentation with multithreading.
+
+ Same as the function with the same name in `panopticapi`.
+
+ Args:
+ matched_annotations_list (list): The matched annotation list. Each
+ element is a tuple of annotations of the same image with the
+ format (gt_anns, pred_anns).
+ gt_folder (str): The path of the ground truth images.
+ pred_folder (str): The path of the prediction images.
+ categories (str): The categories of the dataset.
+ file_client (object): The file client of the dataset. If None,
+ the backend will be set to `disk`.
+ nproc (int): Number of processes for panoptic quality computing.
+ Defaults to 32. When `nproc` exceeds the number of cpu cores,
+ the number of cpu cores is used.
+ """
+ if PQStat is None:
+ raise RuntimeError(
+ 'panopticapi is not installed, please install it by: '
+ 'pip install git+https://github.com/cocodataset/'
+ 'panopticapi.git.')
+
+ if file_client is None:
+ file_client_args = dict(backend='disk')
+ file_client = mmcv.FileClient(**file_client_args)
+
+ cpu_num = min(nproc, multiprocessing.cpu_count())
+
+ annotations_split = np.array_split(matched_annotations_list, cpu_num)
+ print('Number of cores: {}, images per core: {}'.format(
+ cpu_num, len(annotations_split[0])))
+ workers = multiprocessing.Pool(processes=cpu_num)
+ processes = []
+ for proc_id, annotation_set in enumerate(annotations_split):
+ p = workers.apply_async(pq_compute_single_core,
+ (proc_id, annotation_set, gt_folder,
+ pred_folder, categories, file_client))
+ processes.append(p)
+
+ # Close the process pool, otherwise it will lead to memory
+ # leaking problems.
+ workers.close()
+ workers.join()
+
+ pq_stat = PQStat()
+ for p in processes:
+ pq_stat += p.get()
+
+ return pq_stat
diff --git a/mmdet/datasets/builder.py b/mmdet/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1936296a55824dc6f04e0fe0cc39a7a532724b59
--- /dev/null
+++ b/mmdet/datasets/builder.py
@@ -0,0 +1,215 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import platform
+import random
+import warnings
+from functools import partial
+
+import numpy as np
+import torch
+from mmcv.parallel import collate
+from mmcv.runner import get_dist_info
+from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version
+from torch.utils.data import DataLoader
+
+from .samplers import (ClassAwareSampler, DistributedGroupSampler,
+ DistributedSampler, GroupSampler, InfiniteBatchSampler,
+ InfiniteGroupBatchSampler)
+
+if platform.system() != 'Windows':
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ base_soft_limit = rlimit[0]
+ hard_limit = rlimit[1]
+ soft_limit = min(max(4096, base_soft_limit), hard_limit)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+
+DATASETS = Registry('dataset')
+PIPELINES = Registry('pipeline')
+
+
+def _concat_dataset(cfg, default_args=None):
+ from .dataset_wrappers import ConcatDataset
+ ann_files = cfg['ann_file']
+ img_prefixes = cfg.get('img_prefix', None)
+ seg_prefixes = cfg.get('seg_prefix', None)
+ proposal_files = cfg.get('proposal_file', None)
+ separate_eval = cfg.get('separate_eval', True)
+
+ datasets = []
+ num_dset = len(ann_files)
+ for i in range(num_dset):
+ data_cfg = copy.deepcopy(cfg)
+ # pop 'separate_eval' since it is not a valid key for common datasets.
+ if 'separate_eval' in data_cfg:
+ data_cfg.pop('separate_eval')
+ data_cfg['ann_file'] = ann_files[i]
+ if isinstance(img_prefixes, (list, tuple)):
+ data_cfg['img_prefix'] = img_prefixes[i]
+ if isinstance(seg_prefixes, (list, tuple)):
+ data_cfg['seg_prefix'] = seg_prefixes[i]
+ if isinstance(proposal_files, (list, tuple)):
+ data_cfg['proposal_file'] = proposal_files[i]
+ datasets.append(build_dataset(data_cfg, default_args))
+
+ return ConcatDataset(datasets, separate_eval)
+
+
+def build_dataset(cfg, default_args=None):
+ from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
+ MultiImageMixDataset, RepeatDataset)
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg['type'] == 'ConcatDataset':
+ dataset = ConcatDataset(
+ [build_dataset(c, default_args) for c in cfg['datasets']],
+ cfg.get('separate_eval', True))
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
+ elif cfg['type'] == 'ClassBalancedDataset':
+ dataset = ClassBalancedDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
+ elif cfg['type'] == 'MultiImageMixDataset':
+ cp_cfg = copy.deepcopy(cfg)
+ cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
+ cp_cfg.pop('type')
+ dataset = MultiImageMixDataset(**cp_cfg)
+ elif isinstance(cfg.get('ann_file'), (list, tuple)):
+ dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+ return dataset
+
+
+def build_dataloader(dataset,
+ samples_per_gpu,
+ workers_per_gpu,
+ num_gpus=1,
+ dist=True,
+ shuffle=True,
+ seed=None,
+ runner_type='EpochBasedRunner',
+ persistent_workers=False,
+ class_aware_sampler=None,
+ **kwargs):
+ """Build PyTorch DataLoader.
+
+ In distributed training, each GPU/process has a dataloader.
+ In non-distributed training, there is only one dataloader for all GPUs.
+
+ Args:
+ dataset (Dataset): A PyTorch dataset.
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+ batch size of each GPU.
+ workers_per_gpu (int): How many subprocesses to use for data loading
+ for each GPU.
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
+ dist (bool): Distributed training/test or not. Default: True.
+ shuffle (bool): Whether to shuffle the data at every epoch.
+ Default: True.
+ seed (int, Optional): Seed to be used. Default: None.
+ runner_type (str): Type of runner. Default: `EpochBasedRunner`
+ persistent_workers (bool): If True, the data loader will not shutdown
+ the worker processes after a dataset has been consumed once.
+ This allows to maintain the workers `Dataset` instances alive.
+ This argument is only valid when PyTorch>=1.7.0. Default: False.
+ class_aware_sampler (dict): Whether to use `ClassAwareSampler`
+ during training. Default: None.
+ kwargs: any keyword argument to be used to initialize DataLoader
+
+ Returns:
+ DataLoader: A PyTorch dataloader.
+ """
+ rank, world_size = get_dist_info()
+
+ if dist:
+ # When model is :obj:`DistributedDataParallel`,
+ # `batch_size` of :obj:`dataloader` is the
+ # number of training samples on each GPU.
+ batch_size = samples_per_gpu
+ num_workers = workers_per_gpu
+ else:
+ # When model is obj:`DataParallel`
+ # the batch size is samples on all the GPUS
+ batch_size = num_gpus * samples_per_gpu
+ num_workers = num_gpus * workers_per_gpu
+
+ if runner_type == 'IterBasedRunner':
+ # this is a batch sampler, which can yield
+ # a mini-batch indices each time.
+ # it can be used in both `DataParallel` and
+ # `DistributedDataParallel`
+ if shuffle:
+ batch_sampler = InfiniteGroupBatchSampler(
+ dataset, batch_size, world_size, rank, seed=seed)
+ else:
+ batch_sampler = InfiniteBatchSampler(
+ dataset,
+ batch_size,
+ world_size,
+ rank,
+ seed=seed,
+ shuffle=False)
+ batch_size = 1
+ sampler = None
+ else:
+ if class_aware_sampler is not None:
+ # ClassAwareSampler can be used in both distributed and
+ # non-distributed training.
+ num_sample_class = class_aware_sampler.get('num_sample_class', 1)
+ sampler = ClassAwareSampler(
+ dataset,
+ samples_per_gpu,
+ world_size,
+ rank,
+ seed=seed,
+ num_sample_class=num_sample_class)
+ elif dist:
+ # DistributedGroupSampler will definitely shuffle the data to
+ # satisfy that images on each GPU are in the same group
+ if shuffle:
+ sampler = DistributedGroupSampler(
+ dataset, samples_per_gpu, world_size, rank, seed=seed)
+ else:
+ sampler = DistributedSampler(
+ dataset, world_size, rank, shuffle=False, seed=seed)
+ else:
+ sampler = GroupSampler(dataset,
+ samples_per_gpu) if shuffle else None
+ batch_sampler = None
+
+ init_fn = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank,
+ seed=seed) if seed is not None else None
+
+ if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) >= digit_version('1.7.0')):
+ kwargs['persistent_workers'] = persistent_workers
+ elif persistent_workers is True:
+ warnings.warn('persistent_workers is invalid because your pytorch '
+ 'version is lower than 1.7.0')
+
+ data_loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ batch_sampler=batch_sampler,
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
+ pin_memory=kwargs.pop('pin_memory', False),
+ worker_init_fn=init_fn,
+ **kwargs)
+
+ return data_loader
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # The seed of each worker equals to
+ # num_worker * rank + worker_id + user_seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
diff --git a/mmdet/datasets/cityscapes.py b/mmdet/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..c998d1253fb89f9a6ae3835d3c3b22ef2912b29e
--- /dev/null
+++ b/mmdet/datasets/cityscapes.py
@@ -0,0 +1,339 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/cityscapes.py # noqa
+# and https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
+
+import glob
+import os
+import os.path as osp
+import tempfile
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+import pycocotools.mask as maskUtils
+from mmcv.utils import print_log
+
+from .builder import DATASETS
+from .coco import CocoDataset
+
+
+@DATASETS.register_module()
+class CityscapesDataset(CocoDataset):
+
+ CLASSES = ('person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle')
+
+ PALETTE = [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
+ (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)]
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without ground truths."""
+ valid_inds = []
+ # obtain images that contain annotation
+ ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
+ # obtain images that contain annotations of the required categories
+ ids_in_cat = set()
+ for i, class_id in enumerate(self.cat_ids):
+ ids_in_cat |= set(self.coco.cat_img_map[class_id])
+ # merge the image id sets of the two conditions and use the merged set
+ # to filter out images if self.filter_empty_gt=True
+ ids_in_cat &= ids_with_ann
+
+ valid_img_ids = []
+ for i, img_info in enumerate(self.data_infos):
+ img_id = img_info['id']
+ ann_ids = self.coco.getAnnIds(imgIds=[img_id])
+ ann_info = self.coco.loadAnns(ann_ids)
+ all_iscrowd = all([_['iscrowd'] for _ in ann_info])
+ if self.filter_empty_gt and (self.img_ids[i] not in ids_in_cat
+ or all_iscrowd):
+ continue
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ valid_img_ids.append(img_id)
+ self.img_ids = valid_img_ids
+ return valid_inds
+
+ def _parse_ann_info(self, img_info, ann_info):
+ """Parse bbox and mask annotation.
+
+ Args:
+ img_info (dict): Image info of an image.
+ ann_info (list[dict]): Annotation info of an image.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, \
+ bboxes_ignore, labels, masks, seg_map. \
+ "masks" are already decoded into binary masks.
+ """
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ann = []
+
+ for i, ann in enumerate(ann_info):
+ if ann.get('ignore', False):
+ continue
+ x1, y1, w, h = ann['bbox']
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ if ann['category_id'] not in self.cat_ids:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(bbox)
+ else:
+ gt_bboxes.append(bbox)
+ gt_labels.append(self.cat2label[ann['category_id']])
+ gt_masks_ann.append(ann['segmentation'])
+
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks=gt_masks_ann,
+ seg_map=img_info['segm_file'])
+
+ return ann
+
+ def results2txt(self, results, outfile_prefix):
+ """Dump the detection results to a txt file.
+
+ Args:
+ results (list[list | tuple]): Testing results of the
+ dataset.
+ outfile_prefix (str): The filename prefix of the json files.
+ If the prefix is "somepath/xxx",
+ the txt files will be named "somepath/xxx.txt".
+
+ Returns:
+ list[str]: Result txt files which contains corresponding \
+ instance segmentation images.
+ """
+ try:
+ import cityscapesscripts.helpers.labels as CSLabels
+ except ImportError:
+ raise ImportError('Please run "pip install citscapesscripts" to '
+ 'install cityscapesscripts first.')
+ result_files = []
+ os.makedirs(outfile_prefix, exist_ok=True)
+ prog_bar = mmcv.ProgressBar(len(self))
+ for idx in range(len(self)):
+ result = results[idx]
+ filename = self.data_infos[idx]['filename']
+ basename = osp.splitext(osp.basename(filename))[0]
+ pred_txt = osp.join(outfile_prefix, basename + '_pred.txt')
+
+ bbox_result, segm_result = result
+ bboxes = np.vstack(bbox_result)
+ # segm results
+ if isinstance(segm_result, tuple):
+ # Some detectors use different scores for bbox and mask,
+ # like Mask Scoring R-CNN. Score of segm will be used instead
+ # of bbox score.
+ segms = mmcv.concat_list(segm_result[0])
+ mask_score = segm_result[1]
+ else:
+ # use bbox score for mask score
+ segms = mmcv.concat_list(segm_result)
+ mask_score = [bbox[-1] for bbox in bboxes]
+ labels = [
+ np.full(bbox.shape[0], i, dtype=np.int32)
+ for i, bbox in enumerate(bbox_result)
+ ]
+ labels = np.concatenate(labels)
+
+ assert len(bboxes) == len(segms) == len(labels)
+ num_instances = len(bboxes)
+ prog_bar.update()
+ with open(pred_txt, 'w') as fout:
+ for i in range(num_instances):
+ pred_class = labels[i]
+ classes = self.CLASSES[pred_class]
+ class_id = CSLabels.name2label[classes].id
+ score = mask_score[i]
+ mask = maskUtils.decode(segms[i]).astype(np.uint8)
+ png_filename = osp.join(outfile_prefix,
+ basename + f'_{i}_{classes}.png')
+ mmcv.imwrite(mask, png_filename)
+ fout.write(f'{osp.basename(png_filename)} {class_id} '
+ f'{score}\n')
+ result_files.append(pred_txt)
+
+ return result_files
+
+ def format_results(self, results, txtfile_prefix=None):
+ """Format the results to txt (standard format for Cityscapes
+ evaluation).
+
+ Args:
+ results (list): Testing results of the dataset.
+ txtfile_prefix (str | None): The prefix of txt files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a dict containing \
+ the json filepaths, tmp_dir is the temporal directory created \
+ for saving txt/png files when txtfile_prefix is not specified.
+ """
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ if txtfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ txtfile_prefix = osp.join(tmp_dir.name, 'results')
+ else:
+ tmp_dir = None
+ result_files = self.results2txt(results, txtfile_prefix)
+
+ return result_files, tmp_dir
+
+ def evaluate(self,
+ results,
+ metric='bbox',
+ logger=None,
+ outfile_prefix=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=np.arange(0.5, 0.96, 0.05)):
+ """Evaluation in Cityscapes/COCO protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ outfile_prefix (str | None): The prefix of output file. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If results are evaluated with COCO protocol, it would be the
+ prefix of output json file. For example, the metric is 'bbox'
+ and 'segm', then json files would be "a/b/prefix.bbox.json" and
+ "a/b/prefix.segm.json".
+ If results are evaluated with cityscapes protocol, it would be
+ the prefix of output txt/png files. The output files would be
+ png images under folder "a/b/prefix/xxx/" and the file name of
+ images would be written into a txt file
+ "a/b/prefix/xxx_pred.txt", where "xxx" is the video name of
+ cityscapes. If not specified, a temp file will be created.
+ Default: None.
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float]): IoU threshold used for evaluating
+ recalls. If set to a list, the average recall of all IoUs will
+ also be computed. Default: 0.5.
+
+ Returns:
+ dict[str, float]: COCO style evaluation metric or cityscapes mAP \
+ and AP@50.
+ """
+ eval_results = dict()
+
+ metrics = metric.copy() if isinstance(metric, list) else [metric]
+
+ if 'cityscapes' in metrics:
+ eval_results.update(
+ self._evaluate_cityscapes(results, outfile_prefix, logger))
+ metrics.remove('cityscapes')
+
+ # left metrics are all coco metric
+ if len(metrics) > 0:
+ # create CocoDataset with CityscapesDataset annotation
+ self_coco = CocoDataset(self.ann_file, self.pipeline.transforms,
+ None, self.data_root, self.img_prefix,
+ self.seg_prefix, self.seg_suffix,
+ self.proposal_file, self.test_mode,
+ self.filter_empty_gt)
+ # TODO: remove this in the future
+ # reload annotations of correct class
+ self_coco.CLASSES = self.CLASSES
+ self_coco.data_infos = self_coco.load_annotations(self.ann_file)
+ eval_results.update(
+ self_coco.evaluate(results, metrics, logger, outfile_prefix,
+ classwise, proposal_nums, iou_thrs))
+
+ return eval_results
+
+ def _evaluate_cityscapes(self, results, txtfile_prefix, logger):
+ """Evaluation in Cityscapes protocol.
+
+ Args:
+ results (list): Testing results of the dataset.
+ txtfile_prefix (str | None): The prefix of output txt file
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+
+ Returns:
+ dict[str: float]: Cityscapes evaluation results, contains 'mAP' \
+ and 'AP@50'.
+ """
+
+ try:
+ import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa
+ except ImportError:
+ raise ImportError('Please run "pip install citscapesscripts" to '
+ 'install cityscapesscripts first.')
+ msg = 'Evaluating in Cityscapes style'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ result_files, tmp_dir = self.format_results(results, txtfile_prefix)
+
+ if tmp_dir is None:
+ result_dir = osp.join(txtfile_prefix, 'results')
+ else:
+ result_dir = osp.join(tmp_dir.name, 'results')
+
+ eval_results = OrderedDict()
+ print_log(f'Evaluating results under {result_dir} ...', logger=logger)
+
+ # set global states in cityscapes evaluation API
+ CSEval.args.cityscapesPath = os.path.join(self.img_prefix, '../..')
+ CSEval.args.predictionPath = os.path.abspath(result_dir)
+ CSEval.args.predictionWalk = None
+ CSEval.args.JSONOutput = False
+ CSEval.args.colorized = False
+ CSEval.args.gtInstancesFile = os.path.join(result_dir,
+ 'gtInstances.json')
+ CSEval.args.groundTruthSearch = os.path.join(
+ self.img_prefix.replace('leftImg8bit', 'gtFine'),
+ '*/*_gtFine_instanceIds.png')
+
+ groundTruthImgList = glob.glob(CSEval.args.groundTruthSearch)
+ assert len(groundTruthImgList), 'Cannot find ground truth images' \
+ f' in {CSEval.args.groundTruthSearch}.'
+ predictionImgList = []
+ for gt in groundTruthImgList:
+ predictionImgList.append(CSEval.getPrediction(gt, CSEval.args))
+ CSEval_results = CSEval.evaluateImgLists(predictionImgList,
+ groundTruthImgList,
+ CSEval.args)['averages']
+
+ eval_results['mAP'] = CSEval_results['allAp']
+ eval_results['AP@50'] = CSEval_results['allAp50%']
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..d20a121ca5a747d6930f02d7a2e35d02f942df1c
--- /dev/null
+++ b/mmdet/datasets/coco.py
@@ -0,0 +1,649 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import contextlib
+import io
+import itertools
+import logging
+import os.path as osp
+import tempfile
+import warnings
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+from mmcv.utils import print_log
+from terminaltables import AsciiTable
+
+from mmdet.core import eval_recalls
+from .api_wrappers import COCO, COCOeval
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class CocoDataset(CustomDataset):
+
+ CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
+ 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
+ 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+ 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
+ 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
+ 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
+ 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
+
+ PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
+ (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
+ (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
+ (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
+ (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
+ (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
+ (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
+ (0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
+ (78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
+ (134, 134, 103), (145, 148, 174), (255, 208, 186),
+ (197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
+ (151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
+ (166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149),
+ (179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205),
+ (147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0),
+ (119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88),
+ (95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118),
+ (219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15),
+ (127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0),
+ (95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
+ (191, 162, 208)]
+
+ def load_annotations(self, ann_file):
+ """Load annotation from COCO style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from COCO api.
+ """
+
+ self.coco = COCO(ann_file)
+ # The order of returned `cat_ids` will not
+ # change with the order of the CLASSES
+ self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
+
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ total_ann_ids = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ ann_ids = self.coco.get_ann_ids(img_ids=[i])
+ total_ann_ids.extend(ann_ids)
+ assert len(set(total_ann_ids)) == len(
+ total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
+ return data_infos
+
+ def get_ann_info(self, idx):
+ """Get COCO annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ ann_info = self.coco.load_anns(ann_ids)
+ return self._parse_ann_info(self.data_infos[idx], ann_info)
+
+ def get_cat_ids(self, idx):
+ """Get COCO category ids by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ ann_info = self.coco.load_anns(ann_ids)
+ return [ann['category_id'] for ann in ann_info]
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without ground truths."""
+ valid_inds = []
+ # obtain images that contain annotation
+ ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
+ # obtain images that contain annotations of the required categories
+ ids_in_cat = set()
+ for i, class_id in enumerate(self.cat_ids):
+ ids_in_cat |= set(self.coco.cat_img_map[class_id])
+ # merge the image id sets of the two conditions and use the merged set
+ # to filter out images if self.filter_empty_gt=True
+ ids_in_cat &= ids_with_ann
+
+ valid_img_ids = []
+ for i, img_info in enumerate(self.data_infos):
+ img_id = self.img_ids[i]
+ if self.filter_empty_gt and img_id not in ids_in_cat:
+ continue
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ valid_img_ids.append(img_id)
+ self.img_ids = valid_img_ids
+ return valid_inds
+
+ def _parse_ann_info(self, img_info, ann_info):
+ """Parse bbox and mask annotation.
+
+ Args:
+ ann_info (list[dict]): Annotation info of an image.
+ with_mask (bool): Whether to parse mask annotations.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,\
+ labels, masks, seg_map. "masks" are raw annotations and not \
+ decoded into binary masks.
+ """
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ann = []
+ for i, ann in enumerate(ann_info):
+ if ann.get('ignore', False):
+ continue
+ x1, y1, w, h = ann['bbox']
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
+ if inter_w * inter_h == 0:
+ continue
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ if ann['category_id'] not in self.cat_ids:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(bbox)
+ else:
+ gt_bboxes.append(bbox)
+ gt_labels.append(self.cat2label[ann['category_id']])
+ gt_masks_ann.append(ann.get('segmentation', None))
+
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ seg_map = img_info['filename'].rsplit('.', 1)[0] + self.seg_suffix
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks=gt_masks_ann,
+ seg_map=seg_map)
+
+ return ann
+
+ def xyxy2xywh(self, bbox):
+ """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
+ evaluation.
+
+ Args:
+ bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
+ ``xyxy`` order.
+
+ Returns:
+ list[float]: The converted bounding boxes, in ``xywh`` order.
+ """
+
+ _bbox = bbox.tolist()
+ return [
+ _bbox[0],
+ _bbox[1],
+ _bbox[2] - _bbox[0],
+ _bbox[3] - _bbox[1],
+ ]
+
+ def _proposal2json(self, results):
+ """Convert proposal results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ bboxes = results[idx]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = 1
+ json_results.append(data)
+ return json_results
+
+ def _det2json(self, results):
+ """Convert detection results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ result = results[idx]
+ for label in range(len(result)):
+ bboxes = result[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ json_results.append(data)
+ return json_results
+
+ def _segm2json(self, results):
+ """Convert instance segmentation results to COCO json style."""
+ bbox_json_results = []
+ segm_json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ det, seg = results[idx]
+ for label in range(len(det)):
+ # bbox results
+ bboxes = det[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ bbox_json_results.append(data)
+
+ # segm results
+ # some detectors use different scores for bbox and mask
+ if isinstance(seg, tuple):
+ segms = seg[0][label]
+ mask_score = seg[1][label]
+ else:
+ segms = seg[label]
+ mask_score = [bbox[4] for bbox in bboxes]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(mask_score[i])
+ data['category_id'] = self.cat_ids[label]
+ if isinstance(segms[i]['counts'], bytes):
+ segms[i]['counts'] = segms[i]['counts'].decode()
+ data['segmentation'] = segms[i]
+ segm_json_results.append(data)
+ return bbox_json_results, segm_json_results
+
+ def results2json(self, results, outfile_prefix):
+ """Dump the detection results to a COCO style json file.
+
+ There are 3 types of results: proposals, bbox predictions, mask
+ predictions, and they have different data types. This method will
+ automatically recognize the type, and dump them to json files.
+
+ Args:
+ results (list[list | tuple | ndarray]): Testing results of the
+ dataset.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
+ "somepath/xxx.proposal.json".
+
+ Returns:
+ dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
+ values are corresponding filenames.
+ """
+ result_files = dict()
+ if isinstance(results[0], list):
+ json_results = self._det2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ mmcv.dump(json_results, result_files['bbox'])
+ elif isinstance(results[0], tuple):
+ json_results = self._segm2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ result_files['segm'] = f'{outfile_prefix}.segm.json'
+ mmcv.dump(json_results[0], result_files['bbox'])
+ mmcv.dump(json_results[1], result_files['segm'])
+ elif isinstance(results[0], np.ndarray):
+ json_results = self._proposal2json(results)
+ result_files['proposal'] = f'{outfile_prefix}.proposal.json'
+ mmcv.dump(json_results, result_files['proposal'])
+ else:
+ raise TypeError('invalid type of results')
+ return result_files
+
+ def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
+ gt_bboxes = []
+ for i in range(len(self.img_ids)):
+ ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
+ ann_info = self.coco.load_anns(ann_ids)
+ if len(ann_info) == 0:
+ gt_bboxes.append(np.zeros((0, 4)))
+ continue
+ bboxes = []
+ for ann in ann_info:
+ if ann.get('ignore', False) or ann['iscrowd']:
+ continue
+ x1, y1, w, h = ann['bbox']
+ bboxes.append([x1, y1, x1 + w, y1 + h])
+ bboxes = np.array(bboxes, dtype=np.float32)
+ if bboxes.shape[0] == 0:
+ bboxes = np.zeros((0, 4))
+ gt_bboxes.append(bboxes)
+
+ recalls = eval_recalls(
+ gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
+ ar = recalls.mean(axis=1)
+ return ar
+
+ def format_results(self, results, jsonfile_prefix=None, **kwargs):
+ """Format the results to json (standard format for COCO evaluation).
+
+ Args:
+ results (list[tuple | numpy.ndarray]): Testing results of the
+ dataset.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a dict containing \
+ the json filepaths, tmp_dir is the temporal directory created \
+ for saving json files when jsonfile_prefix is not specified.
+ """
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ if jsonfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ jsonfile_prefix = osp.join(tmp_dir.name, 'results')
+ else:
+ tmp_dir = None
+ result_files = self.results2json(results, jsonfile_prefix)
+ return result_files, tmp_dir
+
+ def evaluate_det_segm(self,
+ results,
+ result_files,
+ coco_gt,
+ metrics,
+ logger=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=None,
+ metric_items=None):
+ """Instance segmentation and object detection evaluation in COCO
+ protocol.
+
+ Args:
+ results (list[list | tuple | dict]): Testing results of the
+ dataset.
+ result_files (dict[str, str]): a dict contains json file path.
+ coco_gt (COCO): COCO API object with ground truth annotation.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float], optional): IoU threshold used for
+ evaluating recalls/mAPs. If set to a list, the average of all
+ IoUs will also be computed. If not specified, [0.50, 0.55,
+ 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
+ Default: None.
+ metric_items (list[str] | str, optional): Metric items that will
+ be returned. If not specified, ``['AR@100', 'AR@300',
+ 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
+ used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
+ 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
+ ``metric=='bbox' or metric=='segm'``.
+
+ Returns:
+ dict[str, float]: COCO style evaluation metric.
+ """
+ if iou_thrs is None:
+ iou_thrs = np.linspace(
+ .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
+ if metric_items is not None:
+ if not isinstance(metric_items, list):
+ metric_items = [metric_items]
+
+ eval_results = OrderedDict()
+ for metric in metrics:
+ msg = f'Evaluating {metric}...'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ if metric == 'proposal_fast':
+ if isinstance(results[0], tuple):
+ raise KeyError('proposal_fast is not supported for '
+ 'instance segmentation result.')
+ ar = self.fast_eval_recall(
+ results, proposal_nums, iou_thrs, logger='silent')
+ log_msg = []
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
+ log_msg = ''.join(log_msg)
+ print_log(log_msg, logger=logger)
+ continue
+
+ iou_type = 'bbox' if metric == 'proposal' else metric
+ if metric not in result_files:
+ raise KeyError(f'{metric} is not in results')
+ try:
+ predictions = mmcv.load(result_files[metric])
+ if iou_type == 'segm':
+ # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa
+ # When evaluating mask AP, if the results contain bbox,
+ # cocoapi will use the box area instead of the mask area
+ # for calculating the instance area. Though the overall AP
+ # is not affected, this leads to different
+ # small/medium/large mask AP results.
+ for x in predictions:
+ x.pop('bbox')
+ warnings.simplefilter('once')
+ warnings.warn(
+ 'The key "bbox" is deleted for more accurate mask AP '
+ 'of small/medium/large instances since v2.12.0. This '
+ 'does not change the overall mAP calculation.',
+ UserWarning)
+ coco_det = coco_gt.loadRes(predictions)
+ except IndexError:
+ print_log(
+ 'The testing results of the whole dataset is empty.',
+ logger=logger,
+ level=logging.ERROR)
+ break
+
+ cocoEval = COCOeval(coco_gt, coco_det, iou_type)
+ cocoEval.params.catIds = self.cat_ids
+ cocoEval.params.imgIds = self.img_ids
+ cocoEval.params.maxDets = list(proposal_nums)
+ cocoEval.params.iouThrs = iou_thrs
+ # mapping of cocoEval.stats
+ coco_metric_names = {
+ 'mAP': 0,
+ 'mAP_50': 1,
+ 'mAP_75': 2,
+ 'mAP_s': 3,
+ 'mAP_m': 4,
+ 'mAP_l': 5,
+ 'AR@100': 6,
+ 'AR@300': 7,
+ 'AR@1000': 8,
+ 'AR_s@1000': 9,
+ 'AR_m@1000': 10,
+ 'AR_l@1000': 11
+ }
+ if metric_items is not None:
+ for metric_item in metric_items:
+ if metric_item not in coco_metric_names:
+ raise KeyError(
+ f'metric item {metric_item} is not supported')
+
+ if metric == 'proposal':
+ cocoEval.params.useCats = 0
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+
+ # Save coco summarize print information to logger
+ redirect_string = io.StringIO()
+ with contextlib.redirect_stdout(redirect_string):
+ cocoEval.summarize()
+ print_log('\n' + redirect_string.getvalue(), logger=logger)
+
+ if metric_items is None:
+ metric_items = [
+ 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
+ 'AR_m@1000', 'AR_l@1000'
+ ]
+
+ for item in metric_items:
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[item]]:.4f}')
+ eval_results[item] = val
+ else:
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+
+ # Save coco summarize print information to logger
+ redirect_string = io.StringIO()
+ with contextlib.redirect_stdout(redirect_string):
+ cocoEval.summarize()
+ print_log('\n' + redirect_string.getvalue(), logger=logger)
+
+ if classwise: # Compute per-category AP
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/detectron2/
+ precisions = cocoEval.eval['precision']
+ # precision: (iou, recall, cls, area range, max dets)
+ assert len(self.cat_ids) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, catId in enumerate(self.cat_ids):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ nm = self.coco.loadCats(catId)[0]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ results_per_category.append(
+ (f'{nm["name"]}', f'{float(ap):0.3f}'))
+
+ num_columns = min(6, len(results_per_category) * 2)
+ results_flatten = list(
+ itertools.chain(*results_per_category))
+ headers = ['category', 'AP'] * (num_columns // 2)
+ results_2d = itertools.zip_longest(*[
+ results_flatten[i::num_columns]
+ for i in range(num_columns)
+ ])
+ table_data = [headers]
+ table_data += [result for result in results_2d]
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table, logger=logger)
+
+ if metric_items is None:
+ metric_items = [
+ 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
+ ]
+
+ for metric_item in metric_items:
+ key = f'{metric}_{metric_item}'
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[metric_item]]:.4f}'
+ )
+ eval_results[key] = val
+ ap = cocoEval.stats[:6]
+ eval_results[f'{metric}_mAP_copypaste'] = (
+ f'{ap[0]:.4f} {ap[1]:.4f} {ap[2]:.4f} {ap[3]:.4f} '
+ f'{ap[4]:.4f} {ap[5]:.4f}')
+
+ return eval_results
+
+ def evaluate(self,
+ results,
+ metric='bbox',
+ logger=None,
+ jsonfile_prefix=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=None,
+ metric_items=None):
+ """Evaluation in COCO protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float], optional): IoU threshold used for
+ evaluating recalls/mAPs. If set to a list, the average of all
+ IoUs will also be computed. If not specified, [0.50, 0.55,
+ 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
+ Default: None.
+ metric_items (list[str] | str, optional): Metric items that will
+ be returned. If not specified, ``['AR@100', 'AR@300',
+ 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
+ used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
+ 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
+ ``metric=='bbox' or metric=='segm'``.
+
+ Returns:
+ dict[str, float]: COCO style evaluation metric.
+ """
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+
+ coco_gt = self.coco
+ self.cat_ids = coco_gt.get_cat_ids(cat_names=self.CLASSES)
+
+ result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
+ eval_results = self.evaluate_det_segm(results, result_files, coco_gt,
+ metrics, logger, classwise,
+ proposal_nums, iou_thrs,
+ metric_items)
+
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
diff --git a/mmdet/datasets/coco_occluded.py b/mmdet/datasets/coco_occluded.py
new file mode 100644
index 0000000000000000000000000000000000000000..96e439a222b80014d6fc475d3171a5daa4fc0f87
--- /dev/null
+++ b/mmdet/datasets/coco_occluded.py
@@ -0,0 +1,219 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+import mmcv
+import numpy as np
+from mmcv.fileio import load
+from mmcv.utils import print_log
+from pycocotools import mask as coco_mask
+from terminaltables import AsciiTable
+
+from .builder import DATASETS
+from .coco import CocoDataset
+
+
+@DATASETS.register_module()
+class OccludedSeparatedCocoDataset(CocoDataset):
+ """COCO dataset with evaluation on separated and occluded masks which
+ presented in paper `A Tri-Layer Plugin to Improve Occluded Detection.
+
+ `_.
+
+ Separated COCO and Occluded COCO are automatically generated subsets of
+ COCO val dataset, collecting separated objects and partially occluded
+ objects for a large variety of categories. In this way, we define
+ occlusion into two major categories: separated and partially occluded.
+
+ - Separation: target object segmentation mask is separated into distinct
+ regions by the occluder.
+ - Partial Occlusion: target object is partially occluded but the
+ segmentation mask is connected.
+
+ These two new scalable real-image datasets are to benchmark a model's
+ capability to detect occluded objects of 80 common categories.
+
+ Please cite the paper if you use this dataset:
+
+ @article{zhan2022triocc,
+ title={A Tri-Layer Plugin to Improve Occluded Detection},
+ author={Zhan, Guanqi and Xie, Weidi and Zisserman, Andrew},
+ journal={British Machine Vision Conference},
+ year={2022}
+ }
+
+ Args:
+ occluded_ann (str): Path to the occluded coco annotation file.
+ separated_ann (str): Path to the separated coco annotation file.
+ """ # noqa
+
+ def __init__(
+ self,
+ *args,
+ occluded_ann='https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/occluded_coco.pkl', # noqa
+ separated_ann='https://www.robots.ox.ac.uk/~vgg/research/tpod/datasets/separated_coco.pkl', # noqa
+ **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # load from local file
+ if osp.isfile(occluded_ann) and not osp.isabs(occluded_ann):
+ occluded_ann = osp.join(self.data_root, occluded_ann)
+ if osp.isfile(separated_ann) and not osp.isabs(separated_ann):
+ separated_ann = osp.join(self.data_root, separated_ann)
+
+ self.occluded_ann = load(occluded_ann)
+ self.separated_ann = load(separated_ann)
+
+ def evaluate(self,
+ results,
+ metric=[],
+ score_thr=0.3,
+ iou_thr=0.75,
+ **kwargs):
+ """Occluded and separated mask evaluation in COCO protocol.
+
+ Args:
+ results (list[tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'. Defaults to [].
+ score_thr (float): Score threshold of the detection masks.
+ Defaults to 0.3.
+ iou_thr (float): IoU threshold for the recall calculation.
+ Defaults to 0.75.
+ Returns:
+ dict[str, float]: The recall of occluded and separated masks and
+ COCO style evaluation metric.
+ """
+ coco_metric_res = super().evaluate(results, metric=metric, **kwargs)
+ eval_res = self.evaluate_occluded_separated(results, score_thr,
+ iou_thr)
+ coco_metric_res.update(eval_res)
+ return coco_metric_res
+
+ def evaluate_occluded_separated(self,
+ results,
+ score_thr=0.3,
+ iou_thr=0.75):
+ """Compute the recall of occluded and separated masks.
+
+ Args:
+ results (list[tuple]): Testing results of the dataset.
+ score_thr (float): Score threshold of the detection masks.
+ Defaults to 0.3.
+ iou_thr (float): IoU threshold for the recall calculation.
+ Defaults to 0.75.
+ Returns:
+ dict[str, float]: The recall of occluded and separated masks.
+ """
+ dict_det = {}
+ print_log('processing detection results...')
+ prog_bar = mmcv.ProgressBar(len(results))
+ for i in range(len(results)):
+ cur_img_name = self.data_infos[i]['filename']
+ if cur_img_name not in dict_det.keys():
+ dict_det[cur_img_name] = []
+ for cat_id in range(len(results[i][1])):
+ assert len(results[i][1][cat_id]) == len(results[i][0][cat_id])
+ for instance_id in range(len(results[i][1][cat_id])):
+ cur_binary_mask = coco_mask.decode(
+ results[i][1][cat_id][instance_id])
+ cur_det_bbox = results[i][0][cat_id][instance_id][:4]
+ dict_det[cur_img_name].append([
+ results[i][0][cat_id][instance_id][4],
+ self.CLASSES[cat_id], cur_binary_mask, cur_det_bbox
+ ])
+ dict_det[cur_img_name].sort(
+ key=lambda x: (-x[0], x[3][0], x[3][1])
+ ) # rank by confidence from high to low, avoid same confidence
+ prog_bar.update()
+ print_log('\ncomputing occluded mask recall...')
+ occluded_correct_num, occluded_recall = self.compute_recall(
+ dict_det,
+ gt_ann=self.occluded_ann,
+ score_thr=score_thr,
+ iou_thr=iou_thr,
+ is_occ=True)
+ print_log(f'\nCOCO occluded mask recall: {occluded_recall:.2f}%')
+ print_log(f'COCO occluded mask success num: {occluded_correct_num}')
+ print_log('computing separated mask recall...')
+ separated_correct_num, separated_recall = self.compute_recall(
+ dict_det,
+ gt_ann=self.separated_ann,
+ score_thr=score_thr,
+ iou_thr=iou_thr,
+ is_occ=False)
+ print_log(f'\nCOCO separated mask recall: {separated_recall:.2f}%')
+ print_log(f'COCO separated mask success num: {separated_correct_num}')
+ table_data = [
+ ['mask type', 'recall', 'num correct'],
+ ['occluded', f'{occluded_recall:.2f}%', occluded_correct_num],
+ ['separated', f'{separated_recall:.2f}%', separated_correct_num]
+ ]
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table)
+ return dict(
+ occluded_recall=occluded_recall, separated_recall=separated_recall)
+
+ def compute_recall(self,
+ result_dict,
+ gt_ann,
+ score_thr=0.3,
+ iou_thr=0.75,
+ is_occ=True):
+ """Compute the recall of occluded or separated masks.
+
+ Args:
+ results (list[tuple]): Testing results of the dataset.
+ gt_ann (list): Occluded or separated coco annotations.
+ score_thr (float): Score threshold of the detection masks.
+ Defaults to 0.3.
+ iou_thr (float): IoU threshold for the recall calculation.
+ Defaults to 0.75.
+ is_occ (bool): Whether the annotation is occluded mask.
+ Defaults to True.
+ Returns:
+ tuple: number of correct masks and the recall.
+ """
+ correct = 0
+ prog_bar = mmcv.ProgressBar(len(gt_ann))
+ for iter_i in range(len(gt_ann)):
+ cur_item = gt_ann[iter_i]
+ cur_img_name = cur_item[0]
+ cur_gt_bbox = cur_item[3]
+ if is_occ:
+ cur_gt_bbox = [
+ cur_gt_bbox[0], cur_gt_bbox[1],
+ cur_gt_bbox[0] + cur_gt_bbox[2],
+ cur_gt_bbox[1] + cur_gt_bbox[3]
+ ]
+ cur_gt_class = cur_item[1]
+ cur_gt_mask = coco_mask.decode(cur_item[4])
+
+ assert cur_img_name in result_dict.keys()
+ cur_detections = result_dict[cur_img_name]
+
+ correct_flag = False
+ for i in range(len(cur_detections)):
+ cur_det_confidence = cur_detections[i][0]
+ if cur_det_confidence < score_thr:
+ break
+ cur_det_class = cur_detections[i][1]
+ if cur_det_class != cur_gt_class:
+ continue
+ cur_det_mask = cur_detections[i][2]
+ cur_iou = self.mask_iou(cur_det_mask, cur_gt_mask)
+ if cur_iou >= iou_thr:
+ correct_flag = True
+ break
+ if correct_flag:
+ correct += 1
+ prog_bar.update()
+ recall = correct / len(gt_ann) * 100
+ return correct, recall
+
+ def mask_iou(self, mask1, mask2):
+ """Compute IoU between two masks."""
+ mask1_area = np.count_nonzero(mask1 == 1)
+ mask2_area = np.count_nonzero(mask2 == 1)
+ intersection = np.count_nonzero(np.logical_and(mask1 == 1, mask2 == 1))
+ iou = intersection / (mask1_area + mask2_area - intersection)
+ return iou
diff --git a/mmdet/datasets/coco_panoptic.py b/mmdet/datasets/coco_panoptic.py
new file mode 100644
index 0000000000000000000000000000000000000000..53ef5947d1e723dbd19b4fd1fbdeba672414e378
--- /dev/null
+++ b/mmdet/datasets/coco_panoptic.py
@@ -0,0 +1,692 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
+import os
+from collections import defaultdict
+
+import mmcv
+import numpy as np
+from mmcv.utils import print_log
+from terminaltables import AsciiTable
+
+from mmdet.core import INSTANCE_OFFSET
+from .api_wrappers import COCO, pq_compute_multi_core
+from .builder import DATASETS
+from .coco import CocoDataset
+
+try:
+ import panopticapi
+ from panopticapi.evaluation import VOID
+ from panopticapi.utils import id2rgb
+except ImportError:
+ panopticapi = None
+ id2rgb = None
+ VOID = None
+
+__all__ = ['CocoPanopticDataset']
+
+
+class COCOPanoptic(COCO):
+ """This wrapper is for loading the panoptic style annotation file.
+
+ The format is shown in the CocoPanopticDataset class.
+
+ Args:
+ annotation_file (str): Path of annotation file.
+ """
+
+ def __init__(self, annotation_file=None):
+ if panopticapi is None:
+ raise RuntimeError(
+ 'panopticapi is not installed, please install it by: '
+ 'pip install git+https://github.com/cocodataset/'
+ 'panopticapi.git.')
+
+ super(COCOPanoptic, self).__init__(annotation_file)
+
+ def createIndex(self):
+ # create index
+ print('creating index...')
+ # anns stores 'segment_id -> annotation'
+ anns, cats, imgs = {}, {}, {}
+ img_to_anns, cat_to_imgs = defaultdict(list), defaultdict(list)
+ if 'annotations' in self.dataset:
+ for ann, img_info in zip(self.dataset['annotations'],
+ self.dataset['images']):
+ img_info['segm_file'] = ann['file_name']
+ for seg_ann in ann['segments_info']:
+ # to match with instance.json
+ seg_ann['image_id'] = ann['image_id']
+ seg_ann['height'] = img_info['height']
+ seg_ann['width'] = img_info['width']
+ img_to_anns[ann['image_id']].append(seg_ann)
+ # segment_id is not unique in coco dataset orz...
+ if seg_ann['id'] in anns.keys():
+ anns[seg_ann['id']].append(seg_ann)
+ else:
+ anns[seg_ann['id']] = [seg_ann]
+
+ if 'images' in self.dataset:
+ for img in self.dataset['images']:
+ imgs[img['id']] = img
+
+ if 'categories' in self.dataset:
+ for cat in self.dataset['categories']:
+ cats[cat['id']] = cat
+
+ if 'annotations' in self.dataset and 'categories' in self.dataset:
+ for ann in self.dataset['annotations']:
+ for seg_ann in ann['segments_info']:
+ cat_to_imgs[seg_ann['category_id']].append(ann['image_id'])
+
+ print('index created!')
+
+ self.anns = anns
+ self.imgToAnns = img_to_anns
+ self.catToImgs = cat_to_imgs
+ self.imgs = imgs
+ self.cats = cats
+
+ def load_anns(self, ids=[]):
+ """Load anns with the specified ids.
+
+ self.anns is a list of annotation lists instead of a
+ list of annotations.
+
+ Args:
+ ids (int array): integer ids specifying anns
+
+ Returns:
+ anns (object array): loaded ann objects
+ """
+ anns = []
+
+ if hasattr(ids, '__iter__') and hasattr(ids, '__len__'):
+ # self.anns is a list of annotation lists instead of
+ # a list of annotations
+ for id in ids:
+ anns += self.anns[id]
+ return anns
+ elif type(ids) == int:
+ return self.anns[ids]
+
+
+@DATASETS.register_module()
+class CocoPanopticDataset(CocoDataset):
+ """Coco dataset for Panoptic segmentation.
+
+ The annotation format is shown as follows. The `ann` field is optional
+ for testing.
+
+ .. code-block:: none
+
+ [
+ {
+ 'filename': f'{image_id:012}.png',
+ 'image_id':9
+ 'segments_info': {
+ [
+ {
+ 'id': 8345037, (segment_id in panoptic png,
+ convert from rgb)
+ 'category_id': 51,
+ 'iscrowd': 0,
+ 'bbox': (x1, y1, w, h),
+ 'area': 24315,
+ 'segmentation': list,(coded mask)
+ },
+ ...
+ }
+ }
+ },
+ ...
+ ]
+
+ Args:
+ ann_file (str): Panoptic segmentation annotation file path.
+ pipeline (list[dict]): Processing pipeline.
+ ins_ann_file (str): Instance segmentation annotation file path.
+ Defaults to None.
+ classes (str | Sequence[str], optional): Specify classes to load.
+ If is None, ``cls.CLASSES`` will be used. Defaults to None.
+ data_root (str, optional): Data root for ``ann_file``,
+ ``ins_ann_file`` ``img_prefix``, ``seg_prefix``, ``proposal_file``
+ if specified. Defaults to None.
+ img_prefix (str, optional): Prefix of path to images. Defaults to ''.
+ seg_prefix (str, optional): Prefix of path to segmentation files.
+ Defaults to None.
+ proposal_file (str, optional): Path to proposal file. Defaults to None.
+ test_mode (bool, optional): If set True, annotation will not be loaded.
+ Defaults to False.
+ filter_empty_gt (bool, optional): If set true, images without bounding
+ boxes of the dataset's classes will be filtered out. This option
+ only works when `test_mode=False`, i.e., we never filter images
+ during tests. Defaults to True.
+ file_client_args (:obj:`mmcv.ConfigDict` | dict): file client args.
+ Defaults to dict(backend='disk').
+ """
+ CLASSES = [
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
+ ' truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',
+ 'blanket', 'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff',
+ 'floor-wood', 'flower', 'fruit', 'gravel', 'house', 'light',
+ 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
+ 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
+ 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
+ 'wall-wood', 'water-other', 'window-blind', 'window-other',
+ 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
+ 'cabinet-merged', 'table-merged', 'floor-other-merged',
+ 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
+ 'paper-merged', 'food-other-merged', 'building-other-merged',
+ 'rock-merged', 'wall-other-merged', 'rug-merged'
+ ]
+ THING_CLASSES = [
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
+ 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
+ ]
+ STUFF_CLASSES = [
+ 'banner', 'blanket', 'bridge', 'cardboard', 'counter', 'curtain',
+ 'door-stuff', 'floor-wood', 'flower', 'fruit', 'gravel', 'house',
+ 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield',
+ 'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow',
+ 'stairs', 'tent', 'towel', 'wall-brick', 'wall-stone', 'wall-tile',
+ 'wall-wood', 'water-other', 'window-blind', 'window-other',
+ 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged',
+ 'cabinet-merged', 'table-merged', 'floor-other-merged',
+ 'pavement-merged', 'mountain-merged', 'grass-merged', 'dirt-merged',
+ 'paper-merged', 'food-other-merged', 'building-other-merged',
+ 'rock-merged', 'wall-other-merged', 'rug-merged'
+ ]
+
+ PALETTE = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
+ (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
+ (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
+ (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
+ (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
+ (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
+ (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
+ (0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
+ (78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
+ (134, 134, 103), (145, 148, 174), (255, 208, 186),
+ (197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
+ (151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
+ (166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149),
+ (179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205),
+ (147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0),
+ (119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88),
+ (95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118),
+ (219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15),
+ (127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0),
+ (95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
+ (191, 162, 208), (255, 255, 128), (147, 211, 203),
+ (150, 100, 100), (168, 171, 172), (146, 112, 198),
+ (210, 170, 100), (92, 136, 89), (218, 88, 184), (241, 129, 0),
+ (217, 17, 255), (124, 74, 181), (70, 70, 70), (255, 228, 255),
+ (154, 208, 0), (193, 0, 92), (76, 91, 113), (255, 180, 195),
+ (106, 154, 176),
+ (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
+ (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
+ (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
+ (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
+ (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
+ (146, 139, 141),
+ (70, 130, 180), (134, 199, 156), (209, 226, 140), (96, 36, 108),
+ (96, 96, 96), (64, 170, 64), (152, 251, 152), (208, 229, 228),
+ (206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143),
+ (102, 102, 156), (250, 141, 255)]
+
+ def __init__(self,
+ ann_file,
+ pipeline,
+ ins_ann_file=None,
+ classes=None,
+ data_root=None,
+ img_prefix='',
+ seg_prefix=None,
+ proposal_file=None,
+ test_mode=False,
+ filter_empty_gt=True,
+ file_client_args=dict(backend='disk')):
+ super().__init__(
+ ann_file,
+ pipeline,
+ classes=classes,
+ data_root=data_root,
+ img_prefix=img_prefix,
+ seg_prefix=seg_prefix,
+ proposal_file=proposal_file,
+ test_mode=test_mode,
+ filter_empty_gt=filter_empty_gt,
+ file_client_args=file_client_args)
+ self.ins_ann_file = ins_ann_file
+
+ def load_annotations(self, ann_file):
+ """Load annotation from COCO Panoptic style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from COCO api.
+ """
+ self.coco = COCOPanoptic(ann_file)
+ self.cat_ids = self.coco.get_cat_ids()
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.categories = self.coco.cats
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ info['filename'] = info['file_name']
+ info['segm_file'] = info['filename'].replace('jpg', 'png')
+ data_infos.append(info)
+ return data_infos
+
+ def get_ann_info(self, idx):
+ """Get COCO annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+ img_id = self.data_infos[idx]['id']
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ ann_info = self.coco.load_anns(ann_ids)
+ # filter out unmatched images
+ ann_info = [i for i in ann_info if i['image_id'] == img_id]
+ return self._parse_ann_info(self.data_infos[idx], ann_info)
+
+ def _parse_ann_info(self, img_info, ann_info):
+ """Parse annotations and load panoptic ground truths.
+
+ Args:
+ img_info (int): Image info of an image.
+ ann_info (list[dict]): Annotation info of an image.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,
+ labels, masks, seg_map.
+ """
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_mask_infos = []
+
+ for i, ann in enumerate(ann_info):
+ x1, y1, w, h = ann['bbox']
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+
+ category_id = ann['category_id']
+ contiguous_cat_id = self.cat2label[category_id]
+
+ is_thing = self.coco.load_cats(ids=category_id)[0]['isthing']
+ if is_thing:
+ is_crowd = ann.get('iscrowd', False)
+ if not is_crowd:
+ gt_bboxes.append(bbox)
+ gt_labels.append(contiguous_cat_id)
+ else:
+ gt_bboxes_ignore.append(bbox)
+ is_thing = False
+
+ mask_info = {
+ 'id': ann['id'],
+ 'category': contiguous_cat_id,
+ 'is_thing': is_thing
+ }
+ gt_mask_infos.append(mask_info)
+
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks=gt_mask_infos,
+ seg_map=img_info['segm_file'])
+
+ return ann
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without ground truths."""
+ ids_with_ann = []
+ # check whether images have legal thing annotations.
+ for lists in self.coco.anns.values():
+ for item in lists:
+ category_id = item['category_id']
+ is_thing = self.coco.load_cats(ids=category_id)[0]['isthing']
+ if not is_thing:
+ continue
+ ids_with_ann.append(item['image_id'])
+ ids_with_ann = set(ids_with_ann)
+
+ valid_inds = []
+ valid_img_ids = []
+ for i, img_info in enumerate(self.data_infos):
+ img_id = self.img_ids[i]
+ if self.filter_empty_gt and img_id not in ids_with_ann:
+ continue
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ valid_img_ids.append(img_id)
+ self.img_ids = valid_img_ids
+ return valid_inds
+
+ def _pan2json(self, results, outfile_prefix):
+ """Convert panoptic results to COCO panoptic json style."""
+ label2cat = dict((v, k) for (k, v) in self.cat2label.items())
+ pred_annotations = []
+ outdir = os.path.join(os.path.dirname(outfile_prefix), 'panoptic')
+
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ segm_file = self.data_infos[idx]['segm_file']
+ pan = results[idx]
+
+ pan_labels = np.unique(pan)
+ segm_info = []
+ for pan_label in pan_labels:
+ sem_label = pan_label % INSTANCE_OFFSET
+ # We reserve the length of self.CLASSES for VOID label
+ if sem_label == len(self.CLASSES):
+ continue
+ # convert sem_label to json label
+ cat_id = label2cat[sem_label]
+ is_thing = self.categories[cat_id]['isthing']
+ mask = pan == pan_label
+ area = mask.sum()
+ segm_info.append({
+ 'id': int(pan_label),
+ 'category_id': cat_id,
+ 'isthing': is_thing,
+ 'area': int(area)
+ })
+ # evaluation script uses 0 for VOID label.
+ pan[pan % INSTANCE_OFFSET == len(self.CLASSES)] = VOID
+ pan = id2rgb(pan).astype(np.uint8)
+ mmcv.imwrite(pan[:, :, ::-1], os.path.join(outdir, segm_file))
+ record = {
+ 'image_id': img_id,
+ 'segments_info': segm_info,
+ 'file_name': segm_file
+ }
+ pred_annotations.append(record)
+ pan_json_results = dict(annotations=pred_annotations)
+ return pan_json_results
+
+ def results2json(self, results, outfile_prefix):
+ """Dump the results to a COCO style json file.
+
+ There are 4 types of results: proposals, bbox predictions, mask
+ predictions, panoptic segmentation predictions, and they have
+ different data types. This method will automatically recognize
+ the type, and dump them to json files.
+
+ .. code-block:: none
+
+ [
+ {
+ 'pan_results': np.array, # shape (h, w)
+ # ins_results which includes bboxes and RLE encoded masks
+ # is optional.
+ 'ins_results': (list[np.array], list[list[str]])
+ },
+ ...
+ ]
+
+ Args:
+ results (list[dict]): Testing results of the dataset.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.panoptic.json", "somepath/xxx.bbox.json",
+ "somepath/xxx.segm.json"
+
+ Returns:
+ dict[str: str]: Possible keys are "panoptic", "bbox", "segm", \
+ "proposal", and values are corresponding filenames.
+ """
+ result_files = dict()
+ # panoptic segmentation results
+ if 'pan_results' in results[0]:
+ pan_results = [result['pan_results'] for result in results]
+ pan_json_results = self._pan2json(pan_results, outfile_prefix)
+ result_files['panoptic'] = f'{outfile_prefix}.panoptic.json'
+ mmcv.dump(pan_json_results, result_files['panoptic'])
+
+ # instance segmentation results
+ if 'ins_results' in results[0]:
+ ins_results = [result['ins_results'] for result in results]
+ bbox_json_results, segm_json_results = self._segm2json(ins_results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ result_files['segm'] = f'{outfile_prefix}.segm.json'
+ mmcv.dump(bbox_json_results, result_files['bbox'])
+ mmcv.dump(segm_json_results, result_files['segm'])
+
+ return result_files
+
+ def evaluate_pan_json(self,
+ result_files,
+ outfile_prefix,
+ logger=None,
+ classwise=False,
+ nproc=32):
+ """Evaluate PQ according to the panoptic results json file."""
+ imgs = self.coco.imgs
+ gt_json = self.coco.img_ann_map # image to annotations
+ gt_json = [{
+ 'image_id': k,
+ 'segments_info': v,
+ 'file_name': imgs[k]['segm_file']
+ } for k, v in gt_json.items()]
+ pred_json = mmcv.load(result_files['panoptic'])
+ pred_json = dict(
+ (el['image_id'], el) for el in pred_json['annotations'])
+
+ # match the gt_anns and pred_anns in the same image
+ matched_annotations_list = []
+ for gt_ann in gt_json:
+ img_id = gt_ann['image_id']
+ if img_id not in pred_json.keys():
+ raise Exception('no prediction for the image'
+ ' with id: {}'.format(img_id))
+ matched_annotations_list.append((gt_ann, pred_json[img_id]))
+
+ gt_folder = self.seg_prefix
+ pred_folder = os.path.join(os.path.dirname(outfile_prefix), 'panoptic')
+
+ pq_stat = pq_compute_multi_core(
+ matched_annotations_list,
+ gt_folder,
+ pred_folder,
+ self.categories,
+ self.file_client,
+ nproc=nproc)
+
+ metrics = [('All', None), ('Things', True), ('Stuff', False)]
+ pq_results = {}
+
+ for name, isthing in metrics:
+ pq_results[name], classwise_results = pq_stat.pq_average(
+ self.categories, isthing=isthing)
+ if name == 'All':
+ pq_results['classwise'] = classwise_results
+
+ classwise_results = None
+ if classwise:
+ classwise_results = {
+ k: v
+ for k, v in zip(self.CLASSES, pq_results['classwise'].values())
+ }
+ print_panoptic_table(pq_results, classwise_results, logger=logger)
+ results = parse_pq_results(pq_results)
+ results['PQ_copypaste'] = (
+ f'{results["PQ"]:.3f} {results["SQ"]:.3f} '
+ f'{results["RQ"]:.3f} '
+ f'{results["PQ_th"]:.3f} {results["SQ_th"]:.3f} '
+ f'{results["RQ_th"]:.3f} '
+ f'{results["PQ_st"]:.3f} {results["SQ_st"]:.3f} '
+ f'{results["RQ_st"]:.3f}')
+
+ return results
+
+ def evaluate(self,
+ results,
+ metric='PQ',
+ logger=None,
+ jsonfile_prefix=None,
+ classwise=False,
+ nproc=32,
+ **kwargs):
+ """Evaluation in COCO Panoptic protocol.
+
+ Args:
+ results (list[dict]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. 'PQ', 'bbox',
+ 'segm', 'proposal' are supported. 'pq' will be regarded as 'PQ.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+ classwise (bool): Whether to print classwise evaluation results.
+ Default: False.
+ nproc (int): Number of processes for panoptic quality computing.
+ Defaults to 32. When `nproc` exceeds the number of cpu cores,
+ the number of cpu cores is used.
+
+ Returns:
+ dict[str, float]: COCO Panoptic style evaluation metric.
+ """
+ metrics = metric if isinstance(metric, list) else [metric]
+ # Compatible with lowercase 'pq'
+ metrics = ['PQ' if metric == 'pq' else metric for metric in metrics]
+ allowed_metrics = ['PQ', 'bbox', 'segm', 'proposal']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+
+ result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
+ eval_results = {}
+
+ outfile_prefix = os.path.join(tmp_dir.name, 'results') \
+ if tmp_dir is not None else jsonfile_prefix
+ if 'PQ' in metrics:
+ eval_pan_results = self.evaluate_pan_json(
+ result_files, outfile_prefix, logger, classwise, nproc=nproc)
+
+ eval_results.update(eval_pan_results)
+ metrics.remove('PQ')
+
+ if (('bbox' in metrics) or ('segm' in metrics)
+ or ('proposal' in metrics)):
+
+ assert 'ins_results' in results[0], 'instance segmentation' \
+ 'results are absent from results'
+
+ assert self.ins_ann_file is not None, 'Annotation '\
+ 'file for instance segmentation or object detection ' \
+ 'shuold not be None'
+
+ coco_gt = COCO(self.ins_ann_file)
+ panoptic_cat_ids = self.cat_ids
+ self.cat_ids = coco_gt.get_cat_ids(cat_names=self.THING_CLASSES)
+
+ eval_ins_results = self.evaluate_det_segm(results, result_files,
+ coco_gt, metrics, logger,
+ classwise, **kwargs)
+ self.cat_ids = panoptic_cat_ids
+ eval_results.update(eval_ins_results)
+
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
+
+
+def parse_pq_results(pq_results):
+ """Parse the Panoptic Quality results."""
+ result = dict()
+ result['PQ'] = 100 * pq_results['All']['pq']
+ result['SQ'] = 100 * pq_results['All']['sq']
+ result['RQ'] = 100 * pq_results['All']['rq']
+ result['PQ_th'] = 100 * pq_results['Things']['pq']
+ result['SQ_th'] = 100 * pq_results['Things']['sq']
+ result['RQ_th'] = 100 * pq_results['Things']['rq']
+ result['PQ_st'] = 100 * pq_results['Stuff']['pq']
+ result['SQ_st'] = 100 * pq_results['Stuff']['sq']
+ result['RQ_st'] = 100 * pq_results['Stuff']['rq']
+ return result
+
+
+def print_panoptic_table(pq_results, classwise_results=None, logger=None):
+ """Print the panoptic evaluation results table.
+
+ Args:
+ pq_results(dict): The Panoptic Quality results.
+ classwise_results(dict | None): The classwise Panoptic Quality results.
+ The keys are class names and the values are metrics.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ """
+
+ headers = ['', 'PQ', 'SQ', 'RQ', 'categories']
+ data = [headers]
+ for name in ['All', 'Things', 'Stuff']:
+ numbers = [
+ f'{(pq_results[name][k] * 100):0.3f}' for k in ['pq', 'sq', 'rq']
+ ]
+ row = [name] + numbers + [pq_results[name]['n']]
+ data.append(row)
+ table = AsciiTable(data)
+ print_log('Panoptic Evaluation Results:\n' + table.table, logger=logger)
+
+ if classwise_results is not None:
+ class_metrics = [(name, ) + tuple(f'{(metrics[k] * 100):0.3f}'
+ for k in ['pq', 'sq', 'rq'])
+ for name, metrics in classwise_results.items()]
+ num_columns = min(8, len(class_metrics) * 4)
+ results_flatten = list(itertools.chain(*class_metrics))
+ headers = ['category', 'PQ', 'SQ', 'RQ'] * (num_columns // 4)
+ results_2d = itertools.zip_longest(
+ *[results_flatten[i::num_columns] for i in range(num_columns)])
+ data = [headers]
+ data += [result for result in results_2d]
+ table = AsciiTable(data)
+ print_log(
+ 'Classwise Panoptic Evaluation Results:\n' + table.table,
+ logger=logger)
diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b97685bff49b5c4291a4ebe459e829cce2e54d0
--- /dev/null
+++ b/mmdet/datasets/custom.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+from mmcv.utils import print_log
+from terminaltables import AsciiTable
+from torch.utils.data import Dataset
+
+from mmdet.core import eval_map, eval_recalls
+from .builder import DATASETS
+from .pipelines import Compose
+
+
+@DATASETS.register_module()
+class CustomDataset(Dataset):
+ """Custom dataset for detection.
+
+ The annotation format is shown as follows. The `ann` field is optional for
+ testing.
+
+ .. code-block:: none
+
+ [
+ {
+ 'filename': 'a.jpg',
+ 'width': 1280,
+ 'height': 720,
+ 'ann': {
+ 'bboxes': (n, 4) in (x1, y1, x2, y2) order.
+ 'labels': (n, ),
+ 'bboxes_ignore': (k, 4), (optional field)
+ 'labels_ignore': (k, 4) (optional field)
+ }
+ },
+ ...
+ ]
+
+ Args:
+ ann_file (str): Annotation file path.
+ pipeline (list[dict]): Processing pipeline.
+ classes (str | Sequence[str], optional): Specify classes to load.
+ If is None, ``cls.CLASSES`` will be used. Default: None.
+ data_root (str, optional): Data root for ``ann_file``,
+ ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
+ test_mode (bool, optional): If set True, annotation will not be loaded.
+ filter_empty_gt (bool, optional): If set true, images without bounding
+ boxes of the dataset's classes will be filtered out. This option
+ only works when `test_mode=False`, i.e., we never filter images
+ during tests.
+ """
+
+ CLASSES = None
+
+ PALETTE = None
+
+ def __init__(self,
+ ann_file,
+ pipeline,
+ classes=None,
+ data_root=None,
+ img_prefix='',
+ seg_prefix=None,
+ seg_suffix='.png',
+ proposal_file=None,
+ test_mode=False,
+ filter_empty_gt=True,
+ file_client_args=dict(backend='disk')):
+ self.ann_file = ann_file
+ self.data_root = data_root
+ self.img_prefix = img_prefix
+ self.seg_prefix = seg_prefix
+ self.seg_suffix = seg_suffix
+ self.proposal_file = proposal_file
+ self.test_mode = test_mode
+ self.filter_empty_gt = filter_empty_gt
+ self.file_client = mmcv.FileClient(**file_client_args)
+ self.CLASSES = self.get_classes(classes)
+
+ # join paths if data_root is specified
+ if self.data_root is not None:
+ if not osp.isabs(self.ann_file):
+ self.ann_file = osp.join(self.data_root, self.ann_file)
+ if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
+ self.img_prefix = osp.join(self.data_root, self.img_prefix)
+ if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
+ self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
+ if not (self.proposal_file is None
+ or osp.isabs(self.proposal_file)):
+ self.proposal_file = osp.join(self.data_root,
+ self.proposal_file)
+ # load annotations (and proposals)
+ if hasattr(self.file_client, 'get_local_path'):
+ with self.file_client.get_local_path(self.ann_file) as local_path:
+ self.data_infos = self.load_annotations(local_path)
+ else:
+ warnings.warn(
+ 'The used MMCV version does not have get_local_path. '
+ f'We treat the {self.ann_file} as local paths and it '
+ 'might cause errors if the path is not a local path. '
+ 'Please use MMCV>= 1.3.16 if you meet errors.')
+ self.data_infos = self.load_annotations(self.ann_file)
+
+ if self.proposal_file is not None:
+ if hasattr(self.file_client, 'get_local_path'):
+ with self.file_client.get_local_path(
+ self.proposal_file) as local_path:
+ self.proposals = self.load_proposals(local_path)
+ else:
+ warnings.warn(
+ 'The used MMCV version does not have get_local_path. '
+ f'We treat the {self.ann_file} as local paths and it '
+ 'might cause errors if the path is not a local path. '
+ 'Please use MMCV>= 1.3.16 if you meet errors.')
+ self.proposals = self.load_proposals(self.proposal_file)
+ else:
+ self.proposals = None
+
+ # filter images too small and containing no annotations
+ if not test_mode:
+ valid_inds = self._filter_imgs()
+ self.data_infos = [self.data_infos[i] for i in valid_inds]
+ if self.proposals is not None:
+ self.proposals = [self.proposals[i] for i in valid_inds]
+ # set group flag for the sampler
+ self._set_group_flag()
+
+ # processing pipeline
+ self.pipeline = Compose(pipeline)
+
+ def __len__(self):
+ """Total number of samples of data."""
+ return len(self.data_infos)
+
+ def load_annotations(self, ann_file):
+ """Load annotation from annotation file."""
+ return mmcv.load(ann_file)
+
+ def load_proposals(self, proposal_file):
+ """Load proposal from proposal file."""
+ return mmcv.load(proposal_file)
+
+ def get_ann_info(self, idx):
+ """Get annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ return self.data_infos[idx]['ann']
+
+ def get_cat_ids(self, idx):
+ """Get category ids by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist()
+
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['img_prefix'] = self.img_prefix
+ results['seg_prefix'] = self.seg_prefix
+ results['proposal_file'] = self.proposal_file
+ results['bbox_fields'] = []
+ results['mask_fields'] = []
+ results['seg_fields'] = []
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small."""
+ if self.filter_empty_gt:
+ warnings.warn(
+ 'CustomDataset does not support filtering empty gt images.')
+ valid_inds = []
+ for i, img_info in enumerate(self.data_infos):
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ return valid_inds
+
+ def _set_group_flag(self):
+ """Set flag according to image aspect ratio.
+
+ Images with aspect ratio greater than 1 will be set as group 1,
+ otherwise group 0.
+ """
+ self.flag = np.zeros(len(self), dtype=np.uint8)
+ for i in range(len(self)):
+ img_info = self.data_infos[i]
+ if img_info['width'] / img_info['height'] > 1:
+ self.flag[i] = 1
+
+ def _rand_another(self, idx):
+ """Get another random index from the same group as the given index."""
+ pool = np.where(self.flag == self.flag[idx])[0]
+ return np.random.choice(pool)
+
+ def __getitem__(self, idx):
+ """Get training/test data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training/test data (with annotation if `test_mode` is set \
+ True).
+ """
+
+ if self.test_mode:
+ return self.prepare_test_img(idx)
+ while True:
+ data = self.prepare_train_img(idx)
+ if data is None:
+ idx = self._rand_another(idx)
+ continue
+ return data
+
+ def prepare_train_img(self, idx):
+ """Get training data and annotations after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys \
+ introduced by pipeline.
+ """
+
+ img_info = self.data_infos[idx]
+ ann_info = self.get_ann_info(idx)
+ results = dict(img_info=img_info, ann_info=ann_info)
+ if self.proposals is not None:
+ results['proposals'] = self.proposals[idx]
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, idx):
+ """Get testing data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by \
+ pipeline.
+ """
+
+ img_info = self.data_infos[idx]
+ results = dict(img_info=img_info)
+ if self.proposals is not None:
+ results['proposals'] = self.proposals[idx]
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ @classmethod
+ def get_classes(cls, classes=None):
+ """Get class names of current dataset.
+
+ Args:
+ classes (Sequence[str] | str | None): If classes is None, use
+ default CLASSES defined by builtin dataset. If classes is a
+ string, take it as a file name. The file contains the name of
+ classes where each line contains one class name. If classes is
+ a tuple or list, override the CLASSES defined by the dataset.
+
+ Returns:
+ tuple[str] or list[str]: Names of categories of the dataset.
+ """
+ if classes is None:
+ return cls.CLASSES
+
+ if isinstance(classes, str):
+ # take it as a file path
+ class_names = mmcv.list_from_file(classes)
+ elif isinstance(classes, (tuple, list)):
+ class_names = classes
+ else:
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
+
+ return class_names
+
+ def get_cat2imgs(self):
+ """Get a dict with class as key and img_ids as values, which will be
+ used in :class:`ClassAwareSampler`.
+
+ Returns:
+ dict[list]: A dict of per-label image list,
+ the item of the dict indicates a label index,
+ corresponds to the image index that contains the label.
+ """
+ if self.CLASSES is None:
+ raise ValueError('self.CLASSES can not be None')
+ # sort the label index
+ cat2imgs = {i: [] for i in range(len(self.CLASSES))}
+ for i in range(len(self)):
+ cat_ids = set(self.get_cat_ids(i))
+ for cat in cat_ids:
+ cat2imgs[cat].append(i)
+ return cat2imgs
+
+ def format_results(self, results, **kwargs):
+ """Place holder to format result to dataset specific output."""
+
+ def evaluate(self,
+ results,
+ metric='mAP',
+ logger=None,
+ proposal_nums=(100, 300, 1000),
+ iou_thr=0.5,
+ scale_ranges=None):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thr (float | list[float]): IoU threshold. Default: 0.5.
+ scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
+ Default: None.
+ """
+
+ if not isinstance(metric, str):
+ assert len(metric) == 1
+ metric = metric[0]
+ allowed_metrics = ['mAP', 'recall']
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ annotations = [self.get_ann_info(i) for i in range(len(self))]
+ eval_results = OrderedDict()
+ iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
+ if metric == 'mAP':
+ assert isinstance(iou_thrs, list)
+ mean_aps = []
+ for iou_thr in iou_thrs:
+ print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
+ mean_ap, _ = eval_map(
+ results,
+ annotations,
+ scale_ranges=scale_ranges,
+ iou_thr=iou_thr,
+ dataset=self.CLASSES,
+ logger=logger)
+ mean_aps.append(mean_ap)
+ eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
+ eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
+ elif metric == 'recall':
+ gt_bboxes = [ann['bboxes'] for ann in annotations]
+ recalls = eval_recalls(
+ gt_bboxes, results, proposal_nums, iou_thr, logger=logger)
+ for i, num in enumerate(proposal_nums):
+ for j, iou in enumerate(iou_thrs):
+ eval_results[f'recall@{num}@{iou}'] = recalls[i, j]
+ if recalls.shape[1] > 1:
+ ar = recalls.mean(axis=1)
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ return eval_results
+
+ def __repr__(self):
+ """Print the number of instance number."""
+ dataset_type = 'Test' if self.test_mode else 'Train'
+ result = (f'\n{self.__class__.__name__} {dataset_type} dataset '
+ f'with number of images {len(self)}, '
+ f'and instance counts: \n')
+ if self.CLASSES is None:
+ result += 'Category names are not provided. \n'
+ return result
+ instance_count = np.zeros(len(self.CLASSES) + 1).astype(int)
+ # count the instance number in each image
+ for idx in range(len(self)):
+ label = self.get_ann_info(idx)['labels']
+ unique, counts = np.unique(label, return_counts=True)
+ if len(unique) > 0:
+ # add the occurrence number to each class
+ instance_count[unique] += counts
+ else:
+ # background is the last index
+ instance_count[-1] += 1
+ # create a table with category count
+ table_data = [['category', 'count'] * 5]
+ row_data = []
+ for cls, count in enumerate(instance_count):
+ if cls < len(self.CLASSES):
+ row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}']
+ else:
+ # add the background number
+ row_data += ['-1 background', f'{count}']
+ if len(row_data) == 10:
+ table_data.append(row_data)
+ row_data = []
+ if len(row_data) >= 2:
+ if row_data[-1] == '0':
+ row_data = row_data[:-2]
+ if len(row_data) >= 2:
+ table_data.append([])
+ table_data.append(row_data)
+
+ table = AsciiTable(table_data)
+ result += table.table
+ return result
diff --git a/mmdet/datasets/dataset_wrappers.py b/mmdet/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6ceffb8105dd05586a1a88c49de71777159f1f1
--- /dev/null
+++ b/mmdet/datasets/dataset_wrappers.py
@@ -0,0 +1,456 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import bisect
+import collections
+import copy
+import math
+from collections import defaultdict
+
+import numpy as np
+from mmcv.utils import build_from_cfg, print_log
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+from .builder import DATASETS, PIPELINES
+from .coco import CocoDataset
+
+
+@DATASETS.register_module()
+class ConcatDataset(_ConcatDataset):
+ """A wrapper of concatenated dataset.
+
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+ concat the group flag for image aspect ratio.
+
+ Args:
+ datasets (list[:obj:`Dataset`]): A list of datasets.
+ separate_eval (bool): Whether to evaluate the results
+ separately if it is used as validation dataset.
+ Defaults to True.
+ """
+
+ def __init__(self, datasets, separate_eval=True):
+ super(ConcatDataset, self).__init__(datasets)
+ self.CLASSES = datasets[0].CLASSES
+ self.PALETTE = getattr(datasets[0], 'PALETTE', None)
+ self.separate_eval = separate_eval
+ if not separate_eval:
+ if any([isinstance(ds, CocoDataset) for ds in datasets]):
+ raise NotImplementedError(
+ 'Evaluating concatenated CocoDataset as a whole is not'
+ ' supported! Please set "separate_eval=True"')
+ elif len(set([type(ds) for ds in datasets])) != 1:
+ raise NotImplementedError(
+ 'All the datasets should have same types')
+
+ if hasattr(datasets[0], 'flag'):
+ flags = []
+ for i in range(0, len(datasets)):
+ flags.append(datasets[i].flag)
+ self.flag = np.concatenate(flags)
+
+ def get_cat_ids(self, idx):
+ """Get category ids of concatenated dataset by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError(
+ 'absolute value of index should not exceed dataset length')
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx].get_cat_ids(sample_idx)
+
+ def get_ann_info(self, idx):
+ """Get annotation of concatenated dataset by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError(
+ 'absolute value of index should not exceed dataset length')
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx].get_ann_info(sample_idx)
+
+ def evaluate(self, results, logger=None, **kwargs):
+ """Evaluate the results.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+
+ Returns:
+ dict[str: float]: AP results of the total dataset or each separate
+ dataset if `self.separate_eval=True`.
+ """
+ assert len(results) == self.cumulative_sizes[-1], \
+ ('Dataset and results have different sizes: '
+ f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
+
+ # Check whether all the datasets support evaluation
+ for dataset in self.datasets:
+ assert hasattr(dataset, 'evaluate'), \
+ f'{type(dataset)} does not implement evaluate function'
+
+ if self.separate_eval:
+ dataset_idx = -1
+ total_eval_results = dict()
+ for size, dataset in zip(self.cumulative_sizes, self.datasets):
+ start_idx = 0 if dataset_idx == -1 else \
+ self.cumulative_sizes[dataset_idx]
+ end_idx = self.cumulative_sizes[dataset_idx + 1]
+
+ results_per_dataset = results[start_idx:end_idx]
+ print_log(
+ f'\nEvaluating {dataset.ann_file} with '
+ f'{len(results_per_dataset)} images now',
+ logger=logger)
+
+ eval_results_per_dataset = dataset.evaluate(
+ results_per_dataset, logger=logger, **kwargs)
+ dataset_idx += 1
+ for k, v in eval_results_per_dataset.items():
+ total_eval_results.update({f'{dataset_idx}_{k}': v})
+
+ return total_eval_results
+ elif any([isinstance(ds, CocoDataset) for ds in self.datasets]):
+ raise NotImplementedError(
+ 'Evaluating concatenated CocoDataset as a whole is not'
+ ' supported! Please set "separate_eval=True"')
+ elif len(set([type(ds) for ds in self.datasets])) != 1:
+ raise NotImplementedError(
+ 'All the datasets should have same types')
+ else:
+ original_data_infos = self.datasets[0].data_infos
+ self.datasets[0].data_infos = sum(
+ [dataset.data_infos for dataset in self.datasets], [])
+ eval_results = self.datasets[0].evaluate(
+ results, logger=logger, **kwargs)
+ self.datasets[0].data_infos = original_data_infos
+ return eval_results
+
+
+@DATASETS.register_module()
+class RepeatDataset:
+ """A wrapper of repeated dataset.
+
+ The length of repeated dataset will be `times` larger than the original
+ dataset. This is useful when the data loading time is long but the dataset
+ is small. Using RepeatDataset can reduce the data loading time between
+ epochs.
+
+ Args:
+ dataset (:obj:`Dataset`): The dataset to be repeated.
+ times (int): Repeat times.
+ """
+
+ def __init__(self, dataset, times):
+ self.dataset = dataset
+ self.times = times
+ self.CLASSES = dataset.CLASSES
+ self.PALETTE = getattr(dataset, 'PALETTE', None)
+ if hasattr(self.dataset, 'flag'):
+ self.flag = np.tile(self.dataset.flag, times)
+
+ self._ori_len = len(self.dataset)
+
+ def __getitem__(self, idx):
+ return self.dataset[idx % self._ori_len]
+
+ def get_cat_ids(self, idx):
+ """Get category ids of repeat dataset by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ return self.dataset.get_cat_ids(idx % self._ori_len)
+
+ def get_ann_info(self, idx):
+ """Get annotation of repeat dataset by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ return self.dataset.get_ann_info(idx % self._ori_len)
+
+ def __len__(self):
+ """Length after repetition."""
+ return self.times * self._ori_len
+
+
+# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
+@DATASETS.register_module()
+class ClassBalancedDataset:
+ """A wrapper of repeated dataset with repeat factor.
+
+ Suitable for training on class imbalanced datasets like LVIS. Following
+ the sampling strategy in the `paper `_,
+ in each epoch, an image may appear multiple times based on its
+ "repeat factor".
+ The repeat factor for an image is a function of the frequency the rarest
+ category labeled in that image. The "frequency of category c" in [0, 1]
+ is defined by the fraction of images in the training set (without repeats)
+ in which category c appears.
+ The dataset needs to instantiate :func:`self.get_cat_ids` to support
+ ClassBalancedDataset.
+
+ The repeat factor is computed as followed.
+
+ 1. For each category c, compute the fraction # of images
+ that contain it: :math:`f(c)`
+ 2. For each category c, compute the category-level repeat factor:
+ :math:`r(c) = max(1, sqrt(t/f(c)))`
+ 3. For each image I, compute the image-level repeat factor:
+ :math:`r(I) = max_{c in I} r(c)`
+
+ Args:
+ dataset (:obj:`CustomDataset`): The dataset to be repeated.
+ oversample_thr (float): frequency threshold below which data is
+ repeated. For categories with ``f_c >= oversample_thr``, there is
+ no oversampling. For categories with ``f_c < oversample_thr``, the
+ degree of oversampling following the square-root inverse frequency
+ heuristic above.
+ filter_empty_gt (bool, optional): If set true, images without bounding
+ boxes will not be oversampled. Otherwise, they will be categorized
+ as the pure background class and involved into the oversampling.
+ Default: True.
+ """
+
+ def __init__(self, dataset, oversample_thr, filter_empty_gt=True):
+ self.dataset = dataset
+ self.oversample_thr = oversample_thr
+ self.filter_empty_gt = filter_empty_gt
+ self.CLASSES = dataset.CLASSES
+ self.PALETTE = getattr(dataset, 'PALETTE', None)
+
+ repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
+ repeat_indices = []
+ for dataset_idx, repeat_factor in enumerate(repeat_factors):
+ repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor))
+ self.repeat_indices = repeat_indices
+
+ flags = []
+ if hasattr(self.dataset, 'flag'):
+ for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
+ flags.extend([flag] * int(math.ceil(repeat_factor)))
+ assert len(flags) == len(repeat_indices)
+ self.flag = np.asarray(flags, dtype=np.uint8)
+
+ def _get_repeat_factors(self, dataset, repeat_thr):
+ """Get repeat factor for each images in the dataset.
+
+ Args:
+ dataset (:obj:`CustomDataset`): The dataset
+ repeat_thr (float): The threshold of frequency. If an image
+ contains the categories whose frequency below the threshold,
+ it would be repeated.
+
+ Returns:
+ list[float]: The repeat factors for each images in the dataset.
+ """
+
+ # 1. For each category c, compute the fraction # of images
+ # that contain it: f(c)
+ category_freq = defaultdict(int)
+ num_images = len(dataset)
+ for idx in range(num_images):
+ cat_ids = set(self.dataset.get_cat_ids(idx))
+ if len(cat_ids) == 0 and not self.filter_empty_gt:
+ cat_ids = set([len(self.CLASSES)])
+ for cat_id in cat_ids:
+ category_freq[cat_id] += 1
+ for k, v in category_freq.items():
+ category_freq[k] = v / num_images
+
+ # 2. For each category c, compute the category-level repeat factor:
+ # r(c) = max(1, sqrt(t/f(c)))
+ category_repeat = {
+ cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
+ for cat_id, cat_freq in category_freq.items()
+ }
+
+ # 3. For each image I, compute the image-level repeat factor:
+ # r(I) = max_{c in I} r(c)
+ repeat_factors = []
+ for idx in range(num_images):
+ cat_ids = set(self.dataset.get_cat_ids(idx))
+ if len(cat_ids) == 0 and not self.filter_empty_gt:
+ cat_ids = set([len(self.CLASSES)])
+ repeat_factor = 1
+ if len(cat_ids) > 0:
+ repeat_factor = max(
+ {category_repeat[cat_id]
+ for cat_id in cat_ids})
+ repeat_factors.append(repeat_factor)
+
+ return repeat_factors
+
+ def __getitem__(self, idx):
+ ori_index = self.repeat_indices[idx]
+ return self.dataset[ori_index]
+
+ def get_ann_info(self, idx):
+ """Get annotation of dataset by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+ ori_index = self.repeat_indices[idx]
+ return self.dataset.get_ann_info(ori_index)
+
+ def __len__(self):
+ """Length after repetition."""
+ return len(self.repeat_indices)
+
+
+@DATASETS.register_module()
+class MultiImageMixDataset:
+ """A wrapper of multiple images mixed dataset.
+
+ Suitable for training on multiple images mixed data augmentation like
+ mosaic and mixup. For the augmentation pipeline of mixed image data,
+ the `get_indexes` method needs to be provided to obtain the image
+ indexes, and you can set `skip_flags` to change the pipeline running
+ process. At the same time, we provide the `dynamic_scale` parameter
+ to dynamically change the output image size.
+
+ Args:
+ dataset (:obj:`CustomDataset`): The dataset to be mixed.
+ pipeline (Sequence[dict]): Sequence of transform object or
+ config dict to be composed.
+ dynamic_scale (tuple[int], optional): The image scale can be changed
+ dynamically. Default to None. It is deprecated.
+ skip_type_keys (list[str], optional): Sequence of type string to
+ be skip pipeline. Default to None.
+ max_refetch (int): The maximum number of retry iterations for getting
+ valid results from the pipeline. If the number of iterations is
+ greater than `max_refetch`, but results is still None, then the
+ iteration is terminated and raise the error. Default: 15.
+ """
+
+ def __init__(self,
+ dataset,
+ pipeline,
+ dynamic_scale=None,
+ skip_type_keys=None,
+ max_refetch=15):
+ if dynamic_scale is not None:
+ raise RuntimeError(
+ 'dynamic_scale is deprecated. Please use Resize pipeline '
+ 'to achieve similar functions')
+ assert isinstance(pipeline, collections.abc.Sequence)
+ if skip_type_keys is not None:
+ assert all([
+ isinstance(skip_type_key, str)
+ for skip_type_key in skip_type_keys
+ ])
+ self._skip_type_keys = skip_type_keys
+
+ self.pipeline = []
+ self.pipeline_types = []
+ for transform in pipeline:
+ if isinstance(transform, dict):
+ self.pipeline_types.append(transform['type'])
+ transform = build_from_cfg(transform, PIPELINES)
+ self.pipeline.append(transform)
+ else:
+ raise TypeError('pipeline must be a dict')
+
+ self.dataset = dataset
+ self.CLASSES = dataset.CLASSES
+ self.PALETTE = getattr(dataset, 'PALETTE', None)
+ if hasattr(self.dataset, 'flag'):
+ self.flag = dataset.flag
+ self.num_samples = len(dataset)
+ self.max_refetch = max_refetch
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ results = copy.deepcopy(self.dataset[idx])
+ for (transform, transform_type) in zip(self.pipeline,
+ self.pipeline_types):
+ if self._skip_type_keys is not None and \
+ transform_type in self._skip_type_keys:
+ continue
+
+ if hasattr(transform, 'get_indexes'):
+ for i in range(self.max_refetch):
+ # Make sure the results passed the loading pipeline
+ # of the original dataset is not None.
+ indexes = transform.get_indexes(self.dataset)
+ if not isinstance(indexes, collections.abc.Sequence):
+ indexes = [indexes]
+ mix_results = [
+ copy.deepcopy(self.dataset[index]) for index in indexes
+ ]
+ if None not in mix_results:
+ results['mix_results'] = mix_results
+ break
+ else:
+ raise RuntimeError(
+ 'The loading pipeline of the original dataset'
+ ' always return None. Please check the correctness '
+ 'of the dataset and its pipeline.')
+
+ for i in range(self.max_refetch):
+ # To confirm the results passed the training pipeline
+ # of the wrapper is not None.
+ updated_results = transform(copy.deepcopy(results))
+ if updated_results is not None:
+ results = updated_results
+ break
+ else:
+ raise RuntimeError(
+ 'The training pipeline of the dataset wrapper'
+ ' always return None.Please check the correctness '
+ 'of the dataset and its pipeline.')
+
+ if 'mix_results' in results:
+ results.pop('mix_results')
+
+ return results
+
+ def update_skip_type_keys(self, skip_type_keys):
+ """Update skip_type_keys. It is called by an external hook.
+
+ Args:
+ skip_type_keys (list[str], optional): Sequence of type
+ string to be skip pipeline.
+ """
+ assert all([
+ isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
+ ])
+ self._skip_type_keys = skip_type_keys
diff --git a/mmdet/datasets/deepfashion.py b/mmdet/datasets/deepfashion.py
new file mode 100644
index 0000000000000000000000000000000000000000..609f80913b4ac63a80359dc25fdd49293a29aa7e
--- /dev/null
+++ b/mmdet/datasets/deepfashion.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import DATASETS
+from .coco import CocoDataset
+
+
+@DATASETS.register_module()
+class DeepFashionDataset(CocoDataset):
+
+ CLASSES = ('top', 'skirt', 'leggings', 'dress', 'outer', 'pants', 'bag',
+ 'neckwear', 'headwear', 'eyeglass', 'belt', 'footwear', 'hair',
+ 'skin', 'face')
+
+ PALETTE = [(0, 192, 64), (0, 64, 96), (128, 192, 192), (0, 64, 64),
+ (0, 192, 224), (0, 192, 192), (128, 192, 64), (0, 192, 96),
+ (128, 32, 192), (0, 0, 224), (0, 0, 64), (0, 160, 192),
+ (128, 0, 96), (128, 0, 192), (0, 32, 192)]
diff --git a/mmdet/datasets/lvis.py b/mmdet/datasets/lvis.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f6196eee59393b3a7733c76694d21dbb1279e68
--- /dev/null
+++ b/mmdet/datasets/lvis.py
@@ -0,0 +1,742 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
+import logging
+import os.path as osp
+import tempfile
+import warnings
+from collections import OrderedDict
+
+import numpy as np
+from mmcv.utils import print_log
+from terminaltables import AsciiTable
+
+from .builder import DATASETS
+from .coco import CocoDataset
+
+
+@DATASETS.register_module()
+class LVISV05Dataset(CocoDataset):
+
+ CLASSES = (
+ 'acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
+ 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
+ 'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron',
+ 'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke',
+ 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award',
+ 'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack',
+ 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball',
+ 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage',
+ 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel',
+ 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat',
+ 'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop',
+ 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel',
+ 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead',
+ 'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed',
+ 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can',
+ 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench',
+ 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder', 'binoculars',
+ 'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse',
+ 'birthday_cake', 'birthday_card', 'biscuit_(bread)', 'pirate_flag',
+ 'black_sheep', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp',
+ 'blinker', 'blueberry', 'boar', 'gameboard', 'boat', 'bobbin',
+ 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet',
+ 'book', 'book_bag', 'bookcase', 'booklet', 'bookmark',
+ 'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet',
+ 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl',
+ 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin',
+ 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
+ 'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase',
+ 'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie',
+ 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull',
+ 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board',
+ 'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed',
+ 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife',
+ 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
+ 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
+ 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
+ 'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder',
+ 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon',
+ 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap',
+ 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)',
+ 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan',
+ 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag',
+ 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast',
+ 'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player',
+ 'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
+ 'champagne', 'chandelier', 'chap', 'checkbook', 'checkerboard',
+ 'cherry', 'chessboard', 'chest_of_drawers_(furniture)',
+ 'chicken_(animal)', 'chicken_wire', 'chickpea', 'Chihuahua',
+ 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
+ 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
+ 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
+ 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
+ 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
+ 'clementine', 'clip', 'clipboard', 'clock', 'clock_tower',
+ 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
+ 'coat_hanger', 'coatrack', 'cock', 'coconut', 'coffee_filter',
+ 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin',
+ 'colander', 'coleslaw', 'coloring_material', 'combination_lock',
+ 'pacifier', 'comic_book', 'computer_keyboard', 'concrete_mixer',
+ 'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cookie',
+ 'cookie_jar', 'cooking_utensil', 'cooler_(for_food)',
+ 'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn',
+ 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset',
+ 'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell',
+ 'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon',
+ 'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot',
+ 'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship',
+ 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube',
+ 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler',
+ 'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool',
+ 'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard',
+ 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
+ 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux',
+ 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
+ 'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog',
+ 'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask',
+ 'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
+ 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
+ 'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper',
+ 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
+ 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan',
+ 'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel',
+ 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
+ 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
+ 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
+ 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
+ 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm',
+ 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace',
+ 'fireplug', 'fish', 'fish_(food)', 'fishbowl', 'fishing_boat',
+ 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flash',
+ 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
+ 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
+ 'food_processor', 'football_(American)', 'football_helmet',
+ 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
+ 'freshener', 'frisbee', 'frog', 'fruit_juice', 'fruit_salad',
+ 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage',
+ 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic',
+ 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda',
+ 'gift_wrap', 'ginger', 'giraffe', 'cincture',
+ 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
+ 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
+ 'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater',
+ 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
+ 'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag',
+ 'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush',
+ 'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock',
+ 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
+ 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
+ 'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil',
+ 'headband', 'headboard', 'headlight', 'headscarf', 'headset',
+ 'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater',
+ 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus',
+ 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood',
+ 'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
+ 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
+ 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
+ 'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod',
+ 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean',
+ 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick',
+ 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard',
+ 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten',
+ 'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)',
+ 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat',
+ 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp',
+ 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer',
+ 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)',
+ 'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy',
+ 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine',
+ 'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard',
+ 'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion',
+ 'speaker_(stereo_equipment)', 'loveseat', 'machine_gun', 'magazine',
+ 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth',
+ 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini',
+ 'mascot', 'mashed_potato', 'masher', 'mask', 'mast',
+ 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup',
+ 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone',
+ 'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan',
+ 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money',
+ 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
+ 'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle',
+ 'mound_(baseball)', 'mouse_(animal_rodent)',
+ 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
+ 'music_stool', 'musical_instrument', 'nailfile', 'nameplate', 'napkin',
+ 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newsstand',
+ 'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)',
+ 'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
+ 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
+ 'orange_(fruit)', 'orange_juice', 'oregano', 'ostrich', 'ottoman',
+ 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
+ 'padlock', 'paintbox', 'paintbrush', 'painting', 'pajamas', 'palette',
+ 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
+ 'papaya', 'paperclip', 'paper_plate', 'paper_towel', 'paperback_book',
+ 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
+ 'parchment', 'parka', 'parking_meter', 'parrot',
+ 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
+ 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
+ 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard',
+ 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener',
+ 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper',
+ 'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood',
+ 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
+ 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
+ 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
+ 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
+ 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
+ 'plate', 'platter', 'playing_card', 'playpen', 'pliers',
+ 'plow_(farm_equipment)', 'pocket_watch', 'pocketknife',
+ 'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt',
+ 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait',
+ 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato',
+ 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'printer',
+ 'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding',
+ 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet',
+ 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car',
+ 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft',
+ 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
+ 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
+ 'recliner', 'record_player', 'red_cabbage', 'reflector',
+ 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring',
+ 'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate',
+ 'Rollerblade', 'rolling_pin', 'root_beer',
+ 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)',
+ 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag',
+ 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami',
+ 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker',
+ 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer',
+ 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)',
+ 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard',
+ 'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver',
+ 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
+ 'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker',
+ 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)',
+ 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog',
+ 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart',
+ 'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head',
+ 'shower_curtain', 'shredder_(for_paper)', 'sieve', 'signboard', 'silo',
+ 'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka',
+ 'ski_pole', 'skirt', 'sled', 'sleeping_bag', 'sling_(bandage)',
+ 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
+ 'snowmobile', 'soap', 'soccer_ball', 'sock', 'soda_fountain',
+ 'carbonated_water', 'sofa', 'softball', 'solar_array', 'sombrero',
+ 'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk',
+ 'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear',
+ 'spectacles', 'spice_rack', 'spider', 'sponge', 'spoon', 'sportswear',
+ 'spotlight', 'squirrel', 'stapler_(stapling_machine)', 'starfish',
+ 'statue_(sculpture)', 'steak_(food)', 'steak_knife',
+ 'steamer_(kitchen_appliance)', 'steering_wheel', 'stencil',
+ 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer',
+ 'stirrup', 'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light',
+ 'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
+ 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
+ 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
+ 'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop',
+ 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato',
+ 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table',
+ 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag',
+ 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)',
+ 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
+ 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
+ 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
+ 'telephone_pole', 'telephoto_lens', 'television_camera',
+ 'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
+ 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
+ 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil',
+ 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven',
+ 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush',
+ 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel',
+ 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light',
+ 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline',
+ 'tray', 'tree_house', 'trench_coat', 'triangle_(musical_instrument)',
+ 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)',
+ 'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip',
+ 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella',
+ 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve',
+ 'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin',
+ 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon',
+ 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet',
+ 'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch',
+ 'water_bottle', 'water_cooler', 'water_faucet', 'water_filter',
+ 'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski',
+ 'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam',
+ 'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair',
+ 'whipped_cream', 'whiskey', 'whistle', 'wick', 'wig', 'wind_chime',
+ 'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock',
+ 'wine_bottle', 'wine_bucket', 'wineglass', 'wing_chair',
+ 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath',
+ 'wrench', 'wristband', 'wristlet', 'yacht', 'yak', 'yogurt',
+ 'yoke_(animal_equipment)', 'zebra', 'zucchini')
+
+ PALETTE = None
+
+ def load_annotations(self, ann_file):
+ """Load annotation from lvis style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from LVIS api.
+ """
+
+ try:
+ import lvis
+ if getattr(lvis, '__version__', '0') >= '10.5.3':
+ warnings.warn(
+ 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
+ UserWarning)
+ from lvis import LVIS
+ except ImportError:
+ raise ImportError(
+ 'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
+ )
+ self.coco = LVIS(ann_file)
+ self.cat_ids = self.coco.get_cat_ids()
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ if info['file_name'].startswith('COCO'):
+ # Convert form the COCO 2014 file naming convention of
+ # COCO_[train/val/test]2014_000000000000.jpg to the 2017
+ # naming convention of 000000000000.jpg
+ # (LVIS v1 will fix this naming issue)
+ info['filename'] = info['file_name'][-16:]
+ else:
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ return data_infos
+
+ def evaluate(self,
+ results,
+ metric='bbox',
+ logger=None,
+ jsonfile_prefix=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=np.arange(0.5, 0.96, 0.05)):
+ """Evaluation in LVIS protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ jsonfile_prefix (str | None):
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float]): IoU threshold used for evaluating
+ recalls. If set to a list, the average recall of all IoUs will
+ also be computed. Default: 0.5.
+
+ Returns:
+ dict[str, float]: LVIS style metrics.
+ """
+
+ try:
+ import lvis
+ if getattr(lvis, '__version__', '0') >= '10.5.3':
+ warnings.warn(
+ 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
+ UserWarning)
+ from lvis import LVISEval, LVISResults
+ except ImportError:
+ raise ImportError(
+ 'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
+ )
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError('metric {} is not supported'.format(metric))
+
+ if jsonfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ jsonfile_prefix = osp.join(tmp_dir.name, 'results')
+ else:
+ tmp_dir = None
+ result_files = self.results2json(results, jsonfile_prefix)
+
+ eval_results = OrderedDict()
+ # get original api
+ lvis_gt = self.coco
+ for metric in metrics:
+ msg = 'Evaluating {}...'.format(metric)
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ if metric == 'proposal_fast':
+ ar = self.fast_eval_recall(
+ results, proposal_nums, iou_thrs, logger='silent')
+ log_msg = []
+ for i, num in enumerate(proposal_nums):
+ eval_results['AR@{}'.format(num)] = ar[i]
+ log_msg.append('\nAR@{}\t{:.4f}'.format(num, ar[i]))
+ log_msg = ''.join(log_msg)
+ print_log(log_msg, logger=logger)
+ continue
+
+ if metric not in result_files:
+ raise KeyError('{} is not in results'.format(metric))
+ try:
+ lvis_dt = LVISResults(lvis_gt, result_files[metric])
+ except IndexError:
+ print_log(
+ 'The testing results of the whole dataset is empty.',
+ logger=logger,
+ level=logging.ERROR)
+ break
+
+ iou_type = 'bbox' if metric == 'proposal' else metric
+ lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type)
+ lvis_eval.params.imgIds = self.img_ids
+ if metric == 'proposal':
+ lvis_eval.params.useCats = 0
+ lvis_eval.params.maxDets = list(proposal_nums)
+ lvis_eval.evaluate()
+ lvis_eval.accumulate()
+ lvis_eval.summarize()
+ for k, v in lvis_eval.get_results().items():
+ if k.startswith('AR'):
+ val = float('{:.4f}'.format(float(v)))
+ eval_results[k] = val
+ else:
+ lvis_eval.evaluate()
+ lvis_eval.accumulate()
+ lvis_eval.summarize()
+ lvis_results = lvis_eval.get_results()
+ if classwise: # Compute per-category AP
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/detectron2/
+ precisions = lvis_eval.eval['precision']
+ # precision: (iou, recall, cls, area range, max dets)
+ assert len(self.cat_ids) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, catId in enumerate(self.cat_ids):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ # the dimensions of precisions are
+ # [num_thrs, num_recalls, num_cats, num_area_rngs]
+ nm = self.coco.load_cats([catId])[0]
+ precision = precisions[:, :, idx, 0]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ results_per_category.append(
+ (f'{nm["name"]}', f'{float(ap):0.3f}'))
+
+ num_columns = min(6, len(results_per_category) * 2)
+ results_flatten = list(
+ itertools.chain(*results_per_category))
+ headers = ['category', 'AP'] * (num_columns // 2)
+ results_2d = itertools.zip_longest(*[
+ results_flatten[i::num_columns]
+ for i in range(num_columns)
+ ])
+ table_data = [headers]
+ table_data += [result for result in results_2d]
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table, logger=logger)
+
+ for k, v in lvis_results.items():
+ if k.startswith('AP'):
+ key = '{}_{}'.format(metric, k)
+ val = float('{:.4f}'.format(float(v)))
+ eval_results[key] = val
+ ap_summary = ' '.join([
+ '{}:{:.4f}'.format(k, float(v))
+ for k, v in lvis_results.items() if k.startswith('AP')
+ ])
+ eval_results['{}_mAP_copypaste'.format(metric)] = ap_summary
+ lvis_eval.print_results()
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
+
+
+LVISDataset = LVISV05Dataset
+DATASETS.register_module(name='LVISDataset', module=LVISDataset)
+
+
+@DATASETS.register_module()
+class LVISV1Dataset(LVISDataset):
+
+ CLASSES = (
+ 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol',
+ 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna',
+ 'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
+ 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
+ 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
+ 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
+ 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
+ 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
+ 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
+ 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
+ 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
+ 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
+ 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
+ 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
+ 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
+ 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
+ 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
+ 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
+ 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
+ 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
+ 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
+ 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
+ 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
+ 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
+ 'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)',
+ 'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box',
+ 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
+ 'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase',
+ 'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts',
+ 'bubble_gum', 'bucket', 'horse_buggy', 'bull', 'bulldog', 'bulldozer',
+ 'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn',
+ 'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card',
+ 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
+ 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
+ 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
+ 'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar',
+ 'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup',
+ 'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
+ 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
+ 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
+ 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
+ 'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower',
+ 'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone',
+ 'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier',
+ 'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard',
+ 'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime',
+ 'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar',
+ 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker',
+ 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider',
+ 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet',
+ 'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine',
+ 'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock',
+ 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster',
+ 'coat', 'coat_hanger', 'coatrack', 'cock', 'cockroach',
+ 'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table',
+ 'coffeepot', 'coil', 'coin', 'colander', 'coleslaw',
+ 'coloring_material', 'combination_lock', 'pacifier', 'comic_book',
+ 'compass', 'computer_keyboard', 'condiment', 'cone', 'control',
+ 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
+ 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
+ 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
+ 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
+ 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
+ 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
+ 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
+ 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
+ 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
+ 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
+ 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
+ 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
+ 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux',
+ 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
+ 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
+ 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
+ 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
+ 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
+ 'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)',
+ 'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell',
+ 'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring',
+ 'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
+ 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
+ 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
+ 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
+ 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm',
+ 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace',
+ 'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl',
+ 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap',
+ 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)',
+ 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal',
+ 'folding_chair', 'food_processor', 'football_(American)',
+ 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car',
+ 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice',
+ 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage',
+ 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic',
+ 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator',
+ 'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture',
+ 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
+ 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
+ 'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat',
+ 'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly',
+ 'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet',
+ 'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock',
+ 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
+ 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
+ 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband',
+ 'headboard', 'headlight', 'headscarf', 'headset',
+ 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
+ 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
+ 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
+ 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
+ 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
+ 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
+ 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
+ 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
+ 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
+ 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
+ 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
+ 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
+ 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
+ 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
+ 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
+ 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce',
+ 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
+ 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
+ 'lizard', 'log', 'lollipop', 'speaker_(stereo_equipment)', 'loveseat',
+ 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
+ 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger',
+ 'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato',
+ 'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox',
+ 'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine',
+ 'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone',
+ 'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror',
+ 'mitten', 'mixer_(kitchen_tool)', 'money',
+ 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
+ 'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)',
+ 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
+ 'music_stool', 'musical_instrument', 'nailfile', 'napkin',
+ 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper',
+ 'newsstand', 'nightshirt', 'nosebag_(for_animals)',
+ 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker',
+ 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil',
+ 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich',
+ 'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad',
+ 'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas',
+ 'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake',
+ 'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
+ 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol',
+ 'parchment', 'parka', 'parking_meter', 'parrot',
+ 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
+ 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
+ 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
+ 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
+ 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
+ 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
+ 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
+ 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
+ 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
+ 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
+ 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
+ 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
+ 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
+ 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
+ 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato',
+ 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel',
+ 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune',
+ 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher',
+ 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit',
+ 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish',
+ 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
+ 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
+ 'recliner', 'record_player', 'reflector', 'remote_control',
+ 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
+ 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
+ 'rolling_pin', 'root_beer', 'router_(computer_equipment)',
+ 'rubber_band', 'runner_(carpet)', 'plastic_bag',
+ 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
+ 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
+ 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
+ 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
+ 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
+ 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
+ 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
+ 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
+ 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
+ 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
+ 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
+ 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
+ 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
+ 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
+ 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
+ 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
+ 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
+ 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
+ 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
+ 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
+ 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
+ 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
+ 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
+ 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer',
+ 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer',
+ 'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign',
+ 'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl',
+ 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses',
+ 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband',
+ 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword',
+ 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
+ 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
+ 'tambourine', 'army_tank', 'tank_(storage_vessel)',
+ 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
+ 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
+ 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
+ 'telephone_pole', 'telephoto_lens', 'television_camera',
+ 'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
+ 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
+ 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil',
+ 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven',
+ 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush',
+ 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel',
+ 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light',
+ 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline',
+ 'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle',
+ 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat',
+ 'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
+ 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
+ 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
+ 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
+ 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
+ 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
+ 'washbasin', 'automatic_washer', 'watch', 'water_bottle',
+ 'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
+ 'water_gun', 'water_scooter', 'water_ski', 'water_tower',
+ 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
+ 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
+ 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
+ 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
+ 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
+ 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
+ 'yoke_(animal_equipment)', 'zebra', 'zucchini')
+
+ def load_annotations(self, ann_file):
+ try:
+ import lvis
+ if getattr(lvis, '__version__', '0') >= '10.5.3':
+ warnings.warn(
+ 'mmlvis is deprecated, please install official lvis-api by "pip install git+https://github.com/lvis-dataset/lvis-api.git"', # noqa: E501
+ UserWarning)
+ from lvis import LVIS
+ except ImportError:
+ raise ImportError(
+ 'Package lvis is not installed. Please run "pip install git+https://github.com/lvis-dataset/lvis-api.git".' # noqa: E501
+ )
+ self.coco = LVIS(ann_file)
+ self.cat_ids = self.coco.get_cat_ids()
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ # coco_url is used in LVISv1 instead of file_name
+ # e.g. http://images.cocodataset.org/train2017/000000391895.jpg
+ # train/val split in specified in url
+ info['filename'] = info['coco_url'].replace(
+ 'http://images.cocodataset.org/', '')
+ data_infos.append(info)
+ return data_infos
diff --git a/mmdet/datasets/objects365.py b/mmdet/datasets/objects365.py
new file mode 100644
index 0000000000000000000000000000000000000000..930f470f5b0772b7315f4a2fa661e95105915186
--- /dev/null
+++ b/mmdet/datasets/objects365.py
@@ -0,0 +1,232 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+from .api_wrappers import COCO
+from .builder import DATASETS
+from .coco import CocoDataset
+
+# images exist in annotations but not in image folder.
+objv2_ignore_list = [
+ osp.join('patch16', 'objects365_v2_00908726.jpg'),
+ osp.join('patch6', 'objects365_v1_00320532.jpg'),
+ osp.join('patch6', 'objects365_v1_00320534.jpg'),
+]
+
+
+@DATASETS.register_module()
+class Objects365V1Dataset(CocoDataset):
+ """Objects365 v1 dataset for detection."""
+ CLASSES = (
+ 'person', 'sneakers', 'chair', 'hat', 'lamp', 'bottle',
+ 'cabinet/shelf', 'cup', 'car', 'glasses', 'picture/frame', 'desk',
+ 'handbag', 'street lights', 'book', 'plate', 'helmet', 'leather shoes',
+ 'pillow', 'glove', 'potted plant', 'bracelet', 'flower', 'tv',
+ 'storage box', 'vase', 'bench', 'wine glass', 'boots', 'bowl',
+ 'dining table', 'umbrella', 'boat', 'flag', 'speaker', 'trash bin/can',
+ 'stool', 'backpack', 'couch', 'belt', 'carpet', 'basket',
+ 'towel/napkin', 'slippers', 'barrel/bucket', 'coffee table', 'suv',
+ 'toy', 'tie', 'bed', 'traffic light', 'pen/pencil', 'microphone',
+ 'sandals', 'canned', 'necklace', 'mirror', 'faucet', 'bicycle',
+ 'bread', 'high heels', 'ring', 'van', 'watch', 'sink', 'horse', 'fish',
+ 'apple', 'camera', 'candle', 'teddy bear', 'cake', 'motorcycle',
+ 'wild bird', 'laptop', 'knife', 'traffic sign', 'cell phone', 'paddle',
+ 'truck', 'cow', 'power outlet', 'clock', 'drum', 'fork', 'bus',
+ 'hanger', 'nightstand', 'pot/pan', 'sheep', 'guitar', 'traffic cone',
+ 'tea pot', 'keyboard', 'tripod', 'hockey', 'fan', 'dog', 'spoon',
+ 'blackboard/whiteboard', 'balloon', 'air conditioner', 'cymbal',
+ 'mouse', 'telephone', 'pickup truck', 'orange', 'banana', 'airplane',
+ 'luggage', 'skis', 'soccer', 'trolley', 'oven', 'remote',
+ 'baseball glove', 'paper towel', 'refrigerator', 'train', 'tomato',
+ 'machinery vehicle', 'tent', 'shampoo/shower gel', 'head phone',
+ 'lantern', 'donut', 'cleaning products', 'sailboat', 'tangerine',
+ 'pizza', 'kite', 'computer box', 'elephant', 'toiletries', 'gas stove',
+ 'broccoli', 'toilet', 'stroller', 'shovel', 'baseball bat',
+ 'microwave', 'skateboard', 'surfboard', 'surveillance camera', 'gun',
+ 'life saver', 'cat', 'lemon', 'liquid soap', 'zebra', 'duck',
+ 'sports car', 'giraffe', 'pumpkin', 'piano', 'stop sign', 'radiator',
+ 'converter', 'tissue ', 'carrot', 'washing machine', 'vent', 'cookies',
+ 'cutting/chopping board', 'tennis racket', 'candy',
+ 'skating and skiing shoes', 'scissors', 'folder', 'baseball',
+ 'strawberry', 'bow tie', 'pigeon', 'pepper', 'coffee machine',
+ 'bathtub', 'snowboard', 'suitcase', 'grapes', 'ladder', 'pear',
+ 'american football', 'basketball', 'potato', 'paint brush', 'printer',
+ 'billiards', 'fire hydrant', 'goose', 'projector', 'sausage',
+ 'fire extinguisher', 'extension cord', 'facial mask', 'tennis ball',
+ 'chopsticks', 'electronic stove and gas stove', 'pie', 'frisbee',
+ 'kettle', 'hamburger', 'golf club', 'cucumber', 'clutch', 'blender',
+ 'tong', 'slide', 'hot dog', 'toothbrush', 'facial cleanser', 'mango',
+ 'deer', 'egg', 'violin', 'marker', 'ship', 'chicken', 'onion',
+ 'ice cream', 'tape', 'wheelchair', 'plum', 'bar soap', 'scale',
+ 'watermelon', 'cabbage', 'router/modem', 'golf ball', 'pine apple',
+ 'crane', 'fire truck', 'peach', 'cello', 'notepaper', 'tricycle',
+ 'toaster', 'helicopter', 'green beans', 'brush', 'carriage', 'cigar',
+ 'earphone', 'penguin', 'hurdle', 'swing', 'radio', 'CD',
+ 'parking meter', 'swan', 'garlic', 'french fries', 'horn', 'avocado',
+ 'saxophone', 'trumpet', 'sandwich', 'cue', 'kiwi fruit', 'bear',
+ 'fishing rod', 'cherry', 'tablet', 'green vegetables', 'nuts', 'corn',
+ 'key', 'screwdriver', 'globe', 'broom', 'pliers', 'volleyball',
+ 'hammer', 'eggplant', 'trophy', 'dates', 'board eraser', 'rice',
+ 'tape measure/ruler', 'dumbbell', 'hamimelon', 'stapler', 'camel',
+ 'lettuce', 'goldfish', 'meat balls', 'medal', 'toothpaste', 'antelope',
+ 'shrimp', 'rickshaw', 'trombone', 'pomegranate', 'coconut',
+ 'jellyfish', 'mushroom', 'calculator', 'treadmill', 'butterfly',
+ 'egg tart', 'cheese', 'pig', 'pomelo', 'race car', 'rice cooker',
+ 'tuba', 'crosswalk sign', 'papaya', 'hair drier', 'green onion',
+ 'chips', 'dolphin', 'sushi', 'urinal', 'donkey', 'electric drill',
+ 'spring rolls', 'tortoise/turtle', 'parrot', 'flute', 'measuring cup',
+ 'shark', 'steak', 'poker card', 'binoculars', 'llama', 'radish',
+ 'noodles', 'yak', 'mop', 'crab', 'microscope', 'barbell', 'bread/bun',
+ 'baozi', 'lion', 'red cabbage', 'polar bear', 'lighter', 'seal',
+ 'mangosteen', 'comb', 'eraser', 'pitaya', 'scallop', 'pencil case',
+ 'saw', 'table tennis paddle', 'okra', 'starfish', 'eagle', 'monkey',
+ 'durian', 'game board', 'rabbit', 'french horn', 'ambulance',
+ 'asparagus', 'hoverboard', 'pasta', 'target', 'hotair balloon',
+ 'chainsaw', 'lobster', 'iron', 'flashlight')
+
+ PALETTE = None
+
+ def load_annotations(self, ann_file):
+ """Load annotation from COCO style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from COCO api.
+ """
+
+ self.coco = COCO(ann_file)
+ # 'categories' list in objects365_train.json and objects365_val.
+ # json is inconsistent, need sorted list(or dict) before get cat_ids.
+ cats = self.coco.cats
+ sorted_cats = {i: cats[i] for i in sorted(cats)}
+ self.coco.cats = sorted_cats
+ categories = self.coco.dataset['categories']
+ sorted_categories = sorted(categories, key=lambda i: i['id'])
+ self.coco.dataset['categories'] = sorted_categories
+ # The order of returned `cat_ids` will not
+ # change with the order of the CLASSES
+ self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
+
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ total_ann_ids = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ ann_ids = self.coco.get_ann_ids(img_ids=[i])
+ total_ann_ids.extend(ann_ids)
+ assert len(set(total_ann_ids)) == len(
+ total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
+ return data_infos
+
+
+@DATASETS.register_module()
+class Objects365V2Dataset(CocoDataset):
+ """Objects365 v2 dataset for detection."""
+
+ CLASSES = (
+ 'Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp',
+ 'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf',
+ 'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet',
+ 'Book', 'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower',
+ 'Bench', 'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots',
+ 'Vase', 'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt',
+ 'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker',
+ 'Watch', 'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool',
+ 'Barrel/bucket', 'Van', 'Couch', 'Sandals', 'Bakset', 'Drum',
+ 'Pen/Pencil', 'Bus', 'Wild Bird', 'High Heels', 'Motorcycle', 'Guitar',
+ 'Carpet', 'Cell Phone', 'Bread', 'Camera', 'Canned', 'Truck',
+ 'Traffic cone', 'Cymbal', 'Lifesaver', 'Towel', 'Stuffed Toy',
+ 'Candle', 'Sailboat', 'Laptop', 'Awning', 'Bed', 'Faucet', 'Tent',
+ 'Horse', 'Mirror', 'Power outlet', 'Sink', 'Apple', 'Air Conditioner',
+ 'Knife', 'Hockey Stick', 'Paddle', 'Pickup Truck', 'Fork',
+ 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon',
+ 'Clock', 'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger',
+ 'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine',
+ 'Toiletry', 'Keyboard', 'Tomato', 'Lantern',
+ 'Machinery Vehicle', 'Fan', 'Green Vegetables', 'Banana',
+ 'Baseball Glove', 'Airplane', 'Mouse', 'Train', 'Pumpkin', 'Soccer',
+ 'Skiboard', 'Luggage', 'Nightstand', 'Tea pot', 'Telephone', 'Trolley',
+ 'Head Phone', 'Sports Car', 'Stop Sign', 'Dessert', 'Scooter',
+ 'Stroller', 'Crane', 'Remote', 'Refrigerator', 'Oven', 'Lemon', 'Duck',
+ 'Baseball Bat', 'Surveillance Camera', 'Cat', 'Jug', 'Broccoli',
+ 'Piano', 'Pizza', 'Elephant', 'Skateboard', 'Surfboard', 'Gun',
+ 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie', 'Carrot',
+ 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel', 'Pepper',
+ 'Computer Box', 'Toilet Paper', 'Cleaning Products', 'Chopsticks',
+ 'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board',
+ 'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder',
+ 'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball',
+ 'Zebra', 'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin',
+ 'Egg', 'Fire Extinguisher', 'Candy', 'Fire Truck', 'Billards',
+ 'Converter', 'Bathtub', 'Wheelchair', 'Golf Club', 'Briefcase',
+ 'Cucumber', 'Cigar/Cigarette ', 'Paint Brush', 'Pear', 'Heavy Truck',
+ 'Hamburger', 'Extractor', 'Extention Cord', 'Tong', 'Tennis Racket',
+ 'Folder', 'American Football', 'earphone', 'Mask', 'Kettle', 'Tennis',
+ 'Ship', 'Swing', 'Coffee Machine', 'Slide', 'Carriage', 'Onion',
+ 'Green beans', 'Projector', 'Frisbee',
+ 'Washing Machine/Drying Machine', 'Chicken', 'Printer', 'Watermelon',
+ 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream', 'Hotair ballon',
+ 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage', 'Hot dog',
+ 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball', 'Deer',
+ 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple',
+ 'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle',
+ 'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone',
+ 'Corn', 'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion',
+ 'Sandwich', 'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom',
+ 'Trombone', 'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit',
+ 'Router/modem', 'Poker Card', 'Toaster', 'Shrimp', 'Sushi', 'Cheese',
+ 'Notepaper', 'Cherry', 'Pliers', 'CD', 'Pasta', 'Hammer', 'Cue',
+ 'Avocado', 'Hamimelon', 'Flask', 'Mushroon', 'Screwdriver', 'Soap',
+ 'Recorder', 'Bear', 'Eggplant', 'Board Eraser', 'Coconut',
+ 'Tape Measur/ Ruler', 'Pig', 'Showerhead', 'Globe', 'Chips', 'Steak',
+ 'Crosswalk Sign', 'Stapler', 'Campel', 'Formula 1 ', 'Pomegranate',
+ 'Dishwasher', 'Crab', 'Hoverboard', 'Meat ball', 'Rice Cooker', 'Tuba',
+ 'Calculator', 'Papaya', 'Antelope', 'Parrot', 'Seal', 'Buttefly',
+ 'Dumbbell', 'Donkey', 'Lion', 'Urinal', 'Dolphin', 'Electric Drill',
+ 'Hair Dryer', 'Egg tart', 'Jellyfish', 'Treadmill', 'Lighter',
+ 'Grapefruit', 'Game board', 'Mop', 'Radish', 'Baozi', 'Target',
+ 'French', 'Spring Rolls', 'Monkey', 'Rabbit', 'Pencil Case', 'Yak',
+ 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell', 'Scallop',
+ 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle',
+ 'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster',
+ 'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling',
+ 'Table Tennis ')
+
+ def load_annotations(self, ann_file):
+ """Load annotation from COCO style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from COCO api.
+ """
+
+ self.coco = COCO(ann_file)
+ # The order of returned `cat_ids` will not
+ # change with the order of the CLASSES
+ self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
+
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ total_ann_ids = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ file_name = osp.join(
+ osp.split(osp.split(info['file_name'])[0])[-1],
+ osp.split(info['file_name'])[-1])
+ info['file_name'] = file_name
+ if info['file_name'] in objv2_ignore_list:
+ continue
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ ann_ids = self.coco.get_ann_ids(img_ids=[i])
+ total_ann_ids.extend(ann_ids)
+ assert len(set(total_ann_ids)) == len(
+ total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
+ return data_infos
diff --git a/mmdet/datasets/openimages.py b/mmdet/datasets/openimages.py
new file mode 100644
index 0000000000000000000000000000000000000000..13153495126040810abda3dcbf3dc74b6c502c3f
--- /dev/null
+++ b/mmdet/datasets/openimages.py
@@ -0,0 +1,891 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import csv
+import json
+import os.path as osp
+import warnings
+from collections import OrderedDict, defaultdict
+
+import mmcv
+import numpy as np
+import torch.distributed as dist
+from mmcv.runner import get_dist_info
+from mmcv.utils import print_log
+
+from mmdet.core import eval_map
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class OpenImagesDataset(CustomDataset):
+ """Open Images dataset for detection.
+
+ Args:
+ ann_file (str): Annotation file path.
+ label_file (str): File path of the label description file that
+ maps the classes names in MID format to their short
+ descriptions.
+ image_level_ann_file (str): Image level annotation, which is used
+ in evaluation.
+ get_supercategory (bool): Whether to get parent class of the
+ current class. Default: True.
+ hierarchy_file (str): The file path of the class hierarchy.
+ Default: None.
+ get_metas (bool): Whether to get image metas in testing or
+ validation time. This should be `True` during evaluation.
+ Default: True. The OpenImages annotations do not have image
+ metas (width and height of the image), which will be used
+ during evaluation. We provide two ways to get image metas
+ in `OpenImagesDataset`:
+
+ - 1. `load from file`: Load image metas from pkl file, which
+ is suggested to use. We provided a script to get image metas:
+ `tools/misc/get_image_metas.py`, which need to run
+ this script before training/testing. Please refer to
+ `config/openimages/README.md` for more details.
+
+ - 2. `load from pipeline`, which will get image metas during
+ test time. However, this may reduce the inference speed,
+ especially when using distribution.
+
+ load_from_file (bool): Whether to get image metas from pkl file.
+ meta_file (str): File path to get image metas.
+ filter_labels (bool): Whether filter unannotated classes.
+ Default: True.
+ load_image_level_labels (bool): Whether load and consider image
+ level labels during evaluation. Default: True.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ ann_file,
+ label_file='',
+ image_level_ann_file='',
+ get_supercategory=True,
+ hierarchy_file=None,
+ get_metas=True,
+ load_from_file=True,
+ meta_file='',
+ filter_labels=True,
+ load_image_level_labels=True,
+ file_client_args=dict(backend='disk'),
+ **kwargs):
+ # may get error if use other file_client
+ self.file_client_args = file_client_args
+
+ self.cat2label = defaultdict(str)
+ self.index_dict = {}
+
+ # Although it will init file_client in `CustomDataset`,
+ # it needs to be init here.
+ file_client = mmcv.FileClient(**file_client_args)
+ # need get `index_dict` before load annotations
+ assert label_file.endswith('csv')
+ if hasattr(file_client, 'get_local_path'):
+ with file_client.get_local_path(label_file) as local_path:
+ class_names = self.get_classes_from_csv(local_path)
+ else:
+ class_names = self.get_classes_from_csv(label_file)
+ super(OpenImagesDataset, self).__init__(
+ ann_file=ann_file, file_client_args=file_client_args, **kwargs)
+ self.CLASSES = class_names
+ self.image_level_ann_file = image_level_ann_file
+ self.load_image_level_labels = load_image_level_labels
+ if get_supercategory is True:
+ assert hierarchy_file is not None
+ if self.__class__.__name__ == 'OpenImagesDataset':
+ assert hierarchy_file.endswith('json')
+ elif self.__class__.__name__ == 'OpenImagesChallengeDataset':
+ assert hierarchy_file.endswith('np')
+ else:
+ raise NotImplementedError
+ if hasattr(self.file_client, 'get_local_path'):
+ with self.file_client.get_local_path(
+ hierarchy_file) as local_path:
+ self.class_label_tree = self.get_relation_matrix(
+ local_path)
+ else:
+ self.class_label_tree = self.get_relation_matrix(
+ hierarchy_file)
+ self.get_supercategory = get_supercategory
+ self.get_metas = get_metas
+ self.load_from_file = load_from_file
+ self.meta_file = meta_file
+ if self.data_root is not None:
+ if not osp.isabs(self.meta_file):
+ self.meta_file = osp.join(self.data_root, self.meta_file)
+ self.filter_labels = filter_labels
+ self.rank, self.world_size = get_dist_info()
+ self.temp_img_metas = []
+ self.test_img_metas = []
+ self.test_img_shapes = []
+ self.load_from_pipeline = False if load_from_file else True
+
+ def get_classes_from_csv(self, label_file):
+ """Get classes name from file.
+
+ Args:
+ label_file (str): File path of the label description file that
+ maps the classes names in MID format to their short
+ descriptions.
+
+ Returns:
+ list[str]: Class name of OpenImages.
+ """
+
+ index_list = []
+ classes_names = []
+ with open(label_file, 'r') as f:
+ reader = csv.reader(f)
+ for line in reader:
+ self.cat2label[line[0]] = line[1]
+ classes_names.append(line[1])
+ index_list.append(line[0])
+ self.index_dict = {index: i for i, index in enumerate(index_list)}
+ return classes_names
+
+ def load_annotations(self, ann_file):
+ """Load annotation from annotation file.
+
+ Special described `self.data_infos` (defaultdict[list[dict]])
+ in this function: Annotations where item of the defaultdict
+ indicates an image, each of which has (n) dicts. Keys of dicts are:
+
+ - `bbox` (list): coordinates of the box, in normalized image
+ coordinates, of shape 4.
+ - `label` (int): the label id.
+ - `is_group_of` (bool): Indicates that the box spans a group
+ of objects (e.g., a bed of flowers or a crowd of people).
+ - `is_occluded` (bool): Indicates that the object is occluded
+ by another object in the image.
+ - `is_truncated` (bool): Indicates that the object extends
+ beyond the boundary of the image.
+ - `is_depiction` (bool): Indicates that the object is a
+ depiction.
+ - `is_inside` (bool): Indicates a picture taken from the
+ inside of the object.
+
+ Args:
+ ann_file (str): CSV style annotation file path.
+
+ Returns:
+ list[dict]: Data infos where each item of the list
+ indicates an image. Keys of annotations are:
+
+ - `img_id` (str): Image name.
+ - `filename` (str): Image name with suffix.
+ """
+ self.ann_infos = defaultdict(list)
+ data_infos = []
+ cp_filename = None
+ with open(ann_file, 'r') as f:
+ reader = csv.reader(f)
+ for i, line in enumerate(reader):
+ if i == 0:
+ continue
+ img_id = line[0]
+ filename = f'{img_id}.jpg'
+ label_id = line[2]
+ assert label_id in self.index_dict
+ label = int(self.index_dict[label_id])
+ bbox = [
+ float(line[4]), # xmin
+ float(line[6]), # ymin
+ float(line[5]), # xmax
+ float(line[7]) # ymax
+ ]
+ is_occluded = True if int(line[8]) == 1 else False
+ is_truncated = True if int(line[9]) == 1 else False
+ is_group_of = True if int(line[10]) == 1 else False
+ is_depiction = True if int(line[11]) == 1 else False
+ is_inside = True if int(line[12]) == 1 else False
+
+ self.ann_infos[img_id].append(
+ dict(
+ bbox=bbox,
+ label=label,
+ is_occluded=is_occluded,
+ is_truncated=is_truncated,
+ is_group_of=is_group_of,
+ is_depiction=is_depiction,
+ is_inside=is_inside))
+ if filename != cp_filename:
+ data_infos.append(dict(img_id=img_id, filename=filename))
+ cp_filename = filename
+ return data_infos
+
+ def get_ann_info(self, idx):
+ """Get OpenImages annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+ img_id = self.data_infos[idx]['img_id']
+ bboxes = []
+ labels = []
+ bboxes_ignore = []
+ labels_ignore = []
+ is_occludeds = []
+ is_truncateds = []
+ is_group_ofs = []
+ is_depictions = []
+ is_insides = []
+ for obj in self.ann_infos[img_id]:
+ label = int(obj['label'])
+ bbox = [
+ float(obj['bbox'][0]),
+ float(obj['bbox'][1]),
+ float(obj['bbox'][2]),
+ float(obj['bbox'][3])
+ ]
+ bboxes.append(bbox)
+ labels.append(label)
+
+ # Other parameters
+ is_occludeds.append(obj['is_occluded'])
+ is_truncateds.append(obj['is_truncated'])
+ is_group_ofs.append(obj['is_group_of'])
+ is_depictions.append(obj['is_depiction'])
+ is_insides.append(obj['is_inside'])
+ if not bboxes:
+ bboxes = np.zeros((0, 4))
+ labels = np.zeros((0, ))
+ else:
+ bboxes = np.array(bboxes)
+ labels = np.array(labels)
+ if not bboxes_ignore:
+ bboxes_ignore = np.zeros((0, 4))
+ labels_ignore = np.zeros((0, ))
+ else:
+ bboxes_ignore = np.array(bboxes_ignore)
+ labels_ignore = np.array(labels_ignore)
+
+ assert len(is_group_ofs) == len(labels) == len(bboxes)
+ gt_is_group_ofs = np.array(is_group_ofs, dtype=bool)
+
+ # These parameters is not used yet.
+ is_occludeds = np.array(is_occludeds, dtype=bool)
+ is_truncateds = np.array(is_truncateds, dtype=bool)
+ is_depictions = np.array(is_depictions, dtype=bool)
+ is_insides = np.array(is_insides, dtype=bool)
+
+ ann = dict(
+ bboxes=bboxes.astype(np.float32),
+ labels=labels.astype(np.int64),
+ bboxes_ignore=bboxes_ignore.astype(np.float32),
+ labels_ignore=labels_ignore.astype(np.int64),
+ gt_is_group_ofs=gt_is_group_ofs,
+ is_occludeds=is_occludeds,
+ is_truncateds=is_truncateds,
+ is_depictions=is_depictions,
+ is_insides=is_insides)
+
+ return ann
+
+ def get_meta_from_file(self, meta_file=''):
+ """Get image metas from pkl file."""
+ metas = mmcv.load(
+ meta_file,
+ file_format='pkl',
+ file_client_args=self.file_client_args)
+ assert len(metas) == len(self)
+ for i in range(len(metas)):
+ file_name = osp.split(metas[i]['filename'])[-1]
+ img_info = self.data_infos[i].get('img_info', None)
+ if img_info is not None:
+ assert file_name == osp.split(img_info['filename'])[-1]
+ else:
+ assert file_name == self.data_infos[i]['filename']
+ hw = metas[i]['ori_shape'][:2]
+ self.test_img_shapes.append(hw)
+
+ def get_meta_from_pipeline(self, results):
+ """Get image metas from pipeline."""
+ self.temp_img_metas.extend(results['img_metas'])
+ if dist.is_available() and self.world_size > 1:
+ from mmdet.apis.test import collect_results_cpu
+
+ self.test_img_metas = collect_results_cpu(self.temp_img_metas,
+ len(self))
+ else:
+ self.test_img_metas = self.temp_img_metas
+
+ def get_img_shape(self, metas):
+ """Set images original shape into data_infos."""
+ assert len(metas) == len(self)
+ for i in range(len(metas)):
+ file_name = osp.split(metas[i].data['ori_filename'])[-1]
+ img_info = self.data_infos[i].get('img_info', None)
+ if img_info is not None:
+ assert file_name == osp.split(img_info['filename'])[-1]
+ else:
+ assert file_name == self.data_infos[i]['filename']
+ hw = metas[i].data['ori_shape'][:2]
+ self.test_img_shapes.append(hw)
+
+ def prepare_test_img(self, idx):
+ """Get testing data after pipeline."""
+ img_info = self.data_infos[idx]
+ results = dict(img_info=img_info)
+ if self.proposals is not None:
+ results['proposals'] = self.proposals[idx]
+ self.pre_pipeline(results)
+ results = self.pipeline(results)
+ if self.get_metas and self.load_from_pipeline:
+ self.get_meta_from_pipeline(results)
+ return results
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small."""
+ if self.filter_empty_gt:
+ warnings.warn('OpenImageDatasets does not support '
+ 'filtering empty gt images.')
+ valid_inds = [i for i in range(len(self))]
+ return valid_inds
+
+ def _set_group_flag(self):
+ """Set flag according to image aspect ratio."""
+ self.flag = np.zeros(len(self), dtype=np.uint8)
+ # TODO: set flag without width and height
+
+ def get_relation_matrix(self, hierarchy_file):
+ """Get hierarchy for classes.
+
+ Args:
+ hierarchy_file (sty): File path to the hierarchy for classes.
+
+ Returns:
+ ndarray: The matrix of the corresponding relationship between
+ the parent class and the child class, of shape
+ (class_num, class_num).
+ """
+
+ if self.data_root is not None:
+ if not osp.isabs(hierarchy_file):
+ hierarchy_file = osp.join(self.data_root, hierarchy_file)
+ with open(hierarchy_file, 'r') as f:
+ hierarchy = json.load(f)
+ class_num = len(self.CLASSES)
+ class_label_tree = np.eye(class_num, class_num)
+ class_label_tree = self._convert_hierarchy_tree(
+ hierarchy, class_label_tree)
+ return class_label_tree
+
+ def _convert_hierarchy_tree(self,
+ hierarchy_map,
+ class_label_tree,
+ parents=[],
+ get_all_parents=True):
+ """Get matrix of the corresponding relationship between the parent
+ class and the child class.
+
+ Args:
+ hierarchy_map (dict): Including label name and corresponding
+ subcategory. Keys of dicts are:
+
+ - `LabeName` (str): Name of the label.
+ - `Subcategory` (dict | list): Corresponding subcategory(ies).
+ class_label_tree (ndarray): The matrix of the corresponding
+ relationship between the parent class and the child class,
+ of shape (class_num, class_num).
+ parents (list): Corresponding parent class.
+ get_all_parents (bool): Whether get all parent names.
+ Default: True
+
+ Returns:
+ ndarray: The matrix of the corresponding relationship between
+ the parent class and the child class, of shape
+ (class_num, class_num).
+ """
+
+ if 'Subcategory' in hierarchy_map:
+ for node in hierarchy_map['Subcategory']:
+ if 'LabelName' in node:
+ children_name = node['LabelName']
+ children_index = self.index_dict[children_name]
+ children = [children_index]
+ else:
+ continue
+ if len(parents) > 0:
+ for parent_index in parents:
+ if get_all_parents:
+ children.append(parent_index)
+ class_label_tree[children_index, parent_index] = 1
+
+ class_label_tree = self._convert_hierarchy_tree(
+ node, class_label_tree, parents=children)
+
+ return class_label_tree
+
+ def add_supercategory_ann(self, annotations):
+ """Add parent classes of the corresponding class of the ground truth
+ bboxes."""
+ for i, ann in enumerate(annotations):
+ assert len(ann['labels']) == len(ann['bboxes']) == \
+ len(ann['gt_is_group_ofs'])
+ gt_bboxes = []
+ gt_is_group_ofs = []
+ gt_labels = []
+ for j in range(len(ann['labels'])):
+ label = ann['labels'][j]
+ bbox = ann['bboxes'][j]
+ is_group = ann['gt_is_group_ofs'][j]
+ label = np.where(self.class_label_tree[label])[0]
+ if len(label) > 1:
+ for k in range(len(label)):
+ gt_bboxes.append(bbox)
+ gt_is_group_ofs.append(is_group)
+ gt_labels.append(label[k])
+ else:
+ gt_bboxes.append(bbox)
+ gt_is_group_ofs.append(is_group)
+ gt_labels.append(label[0])
+ annotations[i] = dict(
+ bboxes=np.array(gt_bboxes).astype(np.float32),
+ labels=np.array(gt_labels).astype(np.int64),
+ bboxes_ignore=ann['bboxes_ignore'],
+ gt_is_group_ofs=np.array(gt_is_group_ofs).astype(bool))
+
+ return annotations
+
+ def process_results(self, det_results, annotations,
+ image_level_annotations):
+ """Process results of the corresponding class of the detection bboxes.
+
+ Note: It will choose to do the following two processing according to
+ the parameters:
+
+ 1. Whether to add parent classes of the corresponding class of the
+ detection bboxes.
+
+ 2. Whether to ignore the classes that unannotated on that image.
+ """
+ if image_level_annotations is not None:
+ assert len(annotations) == \
+ len(image_level_annotations) == \
+ len(det_results)
+ else:
+ assert len(annotations) == len(det_results)
+ for i in range(len(det_results)):
+ results = copy.deepcopy(det_results[i])
+ valid_classes = np.where(
+ np.array([[bbox.shape[0]] for bbox in det_results[i]]) != 0)[0]
+ if image_level_annotations is not None:
+ labels = annotations[i]['labels']
+ image_level_labels = \
+ image_level_annotations[i]['image_level_labels']
+ allowed_labeles = np.unique(
+ np.append(labels, image_level_labels))
+ else:
+ allowed_labeles = np.unique(annotations[i]['labels'])
+
+ for valid_class in valid_classes:
+ det_cls = np.where(self.class_label_tree[valid_class])[0]
+ for index in det_cls:
+ if index in allowed_labeles and \
+ index != valid_class and \
+ self.get_supercategory:
+ det_results[i][index] = \
+ np.concatenate((det_results[i][index],
+ results[valid_class]))
+ elif index not in allowed_labeles and self.filter_labels:
+ # Remove useless parts
+ det_results[i][index] = np.empty(
+ (0, 5)).astype(np.float32)
+ return det_results
+
+ def load_image_label_from_csv(self, image_level_ann_file):
+ """Load image level annotations from csv style ann_file.
+
+ Args:
+ image_level_ann_file (str): CSV style image level annotation
+ file path.
+
+ Returns:
+ defaultdict[list[dict]]: Annotations where item of the defaultdict
+ indicates an image, each of which has (n) dicts.
+ Keys of dicts are:
+
+ - `image_level_label` (int): Label id.
+ - `confidence` (float): Labels that are human-verified to be
+ present in an image have confidence = 1 (positive labels).
+ Labels that are human-verified to be absent from an image
+ have confidence = 0 (negative labels). Machine-generated
+ labels have fractional confidences, generally >= 0.5.
+ The higher the confidence, the smaller the chance for
+ the label to be a false positive.
+ """
+
+ item_lists = defaultdict(list)
+ with open(image_level_ann_file, 'r') as f:
+ reader = csv.reader(f)
+ for i, line in enumerate(reader):
+ if i == 0:
+ continue
+ img_id = line[0]
+ item_lists[img_id].append(
+ dict(
+ image_level_label=int(self.index_dict[line[2]]),
+ confidence=float(line[3])))
+ return item_lists
+
+ def get_image_level_ann(self, image_level_ann_file):
+ """Get OpenImages annotation by index.
+
+ Args:
+ image_level_ann_file (str): CSV style image level annotation
+ file path.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ if hasattr(self.file_client, 'get_local_path'):
+ with self.file_client.get_local_path(image_level_ann_file) \
+ as local_path:
+ item_lists = self.load_image_label_from_csv(local_path)
+ else:
+ item_lists = self.load_image_label_from_csv(image_level_ann_file)
+ image_level_annotations = []
+ for i in range(len(self)):
+ img_info = self.data_infos[i].get('img_info', None)
+ if img_info is not None:
+ # for Open Images Challenges
+ img_id = osp.split(img_info['filename'])[-1][:-4]
+ else:
+ # for Open Images v6
+ img_id = self.data_infos[i]['img_id']
+ item_list = item_lists.get(img_id, None)
+ if item_list is not None:
+ image_level_labels = []
+ confidences = []
+ for obj in item_list:
+ image_level_label = int(obj['image_level_label'])
+ confidence = float(obj['confidence'])
+
+ image_level_labels.append(image_level_label)
+ confidences.append(confidence)
+
+ if not image_level_labels:
+ image_level_labels = np.zeros((0, ))
+ confidences = np.zeros((0, ))
+ else:
+ image_level_labels = np.array(image_level_labels)
+ confidences = np.array(confidences)
+ else:
+ image_level_labels = np.zeros((0, ))
+ confidences = np.zeros((0, ))
+ ann = dict(
+ image_level_labels=image_level_labels.astype(np.int64),
+ confidences=confidences.astype(np.float32))
+ image_level_annotations.append(ann)
+
+ return image_level_annotations
+
+ def denormalize_gt_bboxes(self, annotations):
+ """Convert ground truth bboxes from relative position to absolute
+ position.
+
+ Only used in evaluating time.
+ """
+ assert len(self.test_img_shapes) == len(annotations)
+ for i in range(len(annotations)):
+ h, w = self.test_img_shapes[i]
+ annotations[i]['bboxes'][:, 0::2] *= w
+ annotations[i]['bboxes'][:, 1::2] *= h
+ return annotations
+
+ def get_cat_ids(self, idx):
+ """Get category ids by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+ return self.get_ann_info(idx)['labels'].astype(np.int).tolist()
+
+ def evaluate(self,
+ results,
+ metric='mAP',
+ logger=None,
+ iou_thr=0.5,
+ ioa_thr=0.5,
+ scale_ranges=None,
+ denorm_gt_bbox=True,
+ use_group_of=True):
+ """Evaluate in OpenImages.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Option is
+ 'mAP'. Default: 'mAP'.
+ logger (logging.Logger | str, optional): Logger used for printing
+ related information during evaluation. Default: None.
+ iou_thr (float | list[float]): IoU threshold. Default: 0.5.
+ ioa_thr (float | list[float]): IoA threshold. Default: 0.5.
+ scale_ranges (list[tuple], optional): Scale ranges for evaluating
+ mAP. If not specified, all bounding boxes would be included in
+ evaluation. Default: None
+ denorm_gt_bbox (bool): Whether to denorm ground truth bboxes from
+ relative position to absolute position. Default: True
+ use_group_of (bool): Whether consider group of groud truth bboxes
+ during evaluating. Default: True.
+
+ Returns:
+ dict[str, float]: AP metrics.
+ """
+
+ if not isinstance(metric, str):
+ assert len(metric) == 1
+ metric = metric[0]
+ allowed_metrics = ['mAP']
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ annotations = [self.get_ann_info(i) for i in range(len(self))]
+
+ if self.load_image_level_labels:
+ image_level_annotations = \
+ self.get_image_level_ann(self.image_level_ann_file)
+ else:
+ image_level_annotations = None
+
+ # load metas from file
+ if self.get_metas and self.load_from_file:
+ assert self.meta_file.endswith(
+ 'pkl'), 'File name must be pkl suffix'
+ self.get_meta_from_file(self.meta_file)
+ # load metas from pipeline
+ else:
+ self.get_img_shape(self.test_img_metas)
+
+ if len(self.test_img_shapes) > len(self):
+ self.test_img_shapes = self.test_img_shapes[:len(self)]
+
+ if denorm_gt_bbox:
+ annotations = self.denormalize_gt_bboxes(annotations)
+
+ # Reset test_image_metas, temp_image_metas and test_img_shapes
+ # to avoid potential error
+ self.temp_img_metas = []
+ self.test_img_shapes = []
+ self.test_img_metas = []
+ if self.get_supercategory:
+ annotations = self.add_supercategory_ann(annotations)
+
+ results = self.process_results(results, annotations,
+ image_level_annotations)
+ if use_group_of:
+ assert ioa_thr is not None, \
+ 'ioa_thr must have value when using group_of in evaluation.'
+
+ eval_results = OrderedDict()
+ iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
+ ioa_thrs = [ioa_thr] if isinstance(ioa_thr, float) or ioa_thr is None \
+ else ioa_thr
+
+ # get dataset type
+ if len(self.CLASSES) == 500:
+ ds_name = 'oid_challenge'
+ elif len(self.CLASSES) == 601:
+ ds_name = 'oid_v6'
+ else:
+ ds_name = self.CLASSES
+ warnings.warn('Cannot infer dataset type from the length of the '
+ 'classes. Set `oid_v6` as dataset type.')
+
+ if metric == 'mAP':
+ assert isinstance(iou_thrs, list) and isinstance(ioa_thrs, list)
+ assert len(ioa_thrs) == len(iou_thrs)
+ mean_aps = []
+ for iou_thr, ioa_thr in zip(iou_thrs, ioa_thrs):
+ print_log(f'\n{"-" * 15}iou_thr, ioa_thr: {iou_thr}, {ioa_thr}'
+ f'{"-" * 15}')
+ mean_ap, _ = eval_map(
+ results,
+ annotations,
+ scale_ranges=scale_ranges,
+ iou_thr=iou_thr,
+ ioa_thr=ioa_thr,
+ dataset=ds_name,
+ logger=logger,
+ use_group_of=use_group_of)
+ mean_aps.append(mean_ap)
+ eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
+ eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
+ return eval_results
+
+
+@DATASETS.register_module()
+class OpenImagesChallengeDataset(OpenImagesDataset):
+ """Open Images Challenge dataset for detection."""
+
+ def __init__(self, ann_file, **kwargs):
+ assert ann_file.endswith('txt')
+ super(OpenImagesChallengeDataset, self).__init__(
+ ann_file=ann_file, **kwargs)
+
+ def get_classes_from_csv(self, label_file):
+ """Get classes name from file.
+
+ Args:
+ label_file (str): File path of the label description file that
+ maps the classes names in MID format to their short
+ descriptions.
+
+ Returns:
+ list: Class name of OpenImages.
+ """
+
+ label_list = []
+ id_list = []
+ with open(label_file, 'r') as f:
+ reader = csv.reader(f)
+ for line in reader:
+ label_name = line[0]
+ label_id = int(line[2])
+
+ label_list.append(line[1])
+ id_list.append(label_id)
+ self.index_dict[label_name] = label_id - 1
+
+ indexes = np.argsort(id_list)
+ classes_names = []
+ for index in indexes:
+ classes_names.append(label_list[index])
+ return classes_names
+
+ def load_annotations(self, ann_file):
+ """Load annotation from annotation file."""
+ with open(ann_file) as f:
+ lines = f.readlines()
+ i = 0
+ ann_infos = []
+ while i < len(lines):
+ bboxes = []
+ labels = []
+ is_group_ofs = []
+ filename = lines[i].rstrip()
+ i += 2
+ img_gt_size = int(lines[i])
+ i += 1
+ for j in range(img_gt_size):
+ sp = lines[i + j].split()
+ bboxes.append(
+ [float(sp[1]),
+ float(sp[2]),
+ float(sp[3]),
+ float(sp[4])])
+ labels.append(int(sp[0]) - 1) # labels begin from 1
+ is_group_ofs.append(True if int(sp[5]) == 1 else False)
+ i += img_gt_size
+
+ gt_bboxes = np.array(bboxes, dtype=np.float32)
+ gt_labels = np.array(labels, dtype=np.int64)
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+ gt_is_group_ofs = np.array(is_group_ofs, dtype=bool)
+
+ img_info = dict(filename=filename)
+ ann_info = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ gt_is_group_ofs=gt_is_group_ofs)
+ ann_infos.append(dict(img_info=img_info, ann_info=ann_info))
+
+ return ann_infos
+
+ def prepare_train_img(self, idx):
+ """Get training data and annotations after pipeline."""
+ ann_info = self.data_infos[idx]
+ results = dict(
+ img_info=ann_info['img_info'],
+ ann_info=ann_info['ann_info'],
+ )
+ if self.proposals is not None:
+ results['proposals'] = self.proposals[idx]
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, idx):
+ """Get testing data after pipeline."""
+ ann_info = self.data_infos[idx]
+ results = dict(img_info=ann_info['img_info'])
+ if self.proposals is not None:
+ results['proposals'] = self.proposals[idx]
+ self.pre_pipeline(results)
+
+ results = self.pipeline(results)
+ if self.get_metas and self.load_from_pipeline:
+ self.get_meta_from_pipeline(results)
+ return results
+
+ def get_relation_matrix(self, hierarchy_file):
+ """Get hierarchy for classes.
+
+ Args:
+ hierarchy_file (str): File path to the hierarchy for classes.
+
+ Returns:
+ ndarray: The matrix of the corresponding
+ relationship between the parent class and the child class,
+ of shape (class_num, class_num).
+ """
+ class_label_tree = np.load(hierarchy_file, allow_pickle=True)
+ return class_label_tree[1:, 1:]
+
+ def get_ann_info(self, idx):
+ """Get OpenImages annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+ # avoid some potential error
+ data_infos = copy.deepcopy(self.data_infos[idx]['ann_info'])
+ return data_infos
+
+ def load_image_label_from_csv(self, image_level_ann_file):
+ """Load image level annotations from csv style ann_file.
+
+ Args:
+ image_level_ann_file (str): CSV style image level annotation
+ file path.
+
+ Returns:
+ defaultdict[list[dict]]: Annotations where item of the defaultdict
+ indicates an image, each of which has (n) dicts.
+ Keys of dicts are:
+
+ - `image_level_label` (int): of shape 1.
+ - `confidence` (float): of shape 1.
+ """
+
+ item_lists = defaultdict(list)
+ with open(image_level_ann_file, 'r') as f:
+ reader = csv.reader(f)
+ i = -1
+ for line in reader:
+ i += 1
+ if i == 0:
+ continue
+ else:
+ img_id = line[0]
+ label_id = line[1]
+ assert label_id in self.index_dict
+ image_level_label = int(self.index_dict[label_id])
+ confidence = float(line[2])
+ item_lists[img_id].append(
+ dict(
+ image_level_label=image_level_label,
+ confidence=confidence))
+ return item_lists
diff --git a/mmdet/datasets/pipelines/__init__.py b/mmdet/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8260da642682e3ea509c544170b0b4d1f5f23199
--- /dev/null
+++ b/mmdet/datasets/pipelines/__init__.py
@@ -0,0 +1,31 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .auto_augment import (AutoAugment, BrightnessTransform, ColorTransform,
+ ContrastTransform, EqualizeTransform, Rotate, Shear,
+ Translate)
+from .compose import Compose
+from .formatting import (Collect, DefaultFormatBundle, ImageToTensor,
+ ToDataContainer, ToTensor, Transpose, to_tensor)
+from .instaboost import InstaBoost
+from .loading import (FilterAnnotations, LoadAnnotations, LoadImageFromFile,
+ LoadImageFromWebcam, LoadMultiChannelImageFromFiles,
+ LoadPanopticAnnotations, LoadProposals)
+from .test_time_aug import MultiScaleFlipAug
+from .transforms import (Albu, CopyPaste, CutOut, Expand, MinIoURandomCrop,
+ MixUp, Mosaic, Normalize, Pad, PhotoMetricDistortion,
+ RandomAffine, RandomCenterCropPad, RandomCrop,
+ RandomFlip, RandomShift, Resize, SegRescale,
+ YOLOXHSVRandomAug)
+
+__all__ = [
+ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
+ 'Transpose', 'Collect', 'DefaultFormatBundle', 'LoadAnnotations',
+ 'LoadImageFromFile', 'LoadImageFromWebcam', 'LoadPanopticAnnotations',
+ 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'FilterAnnotations',
+ 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
+ 'Normalize', 'SegRescale', 'MinIoURandomCrop', 'Expand',
+ 'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad',
+ 'AutoAugment', 'CutOut', 'Shear', 'Rotate', 'ColorTransform',
+ 'EqualizeTransform', 'BrightnessTransform', 'ContrastTransform',
+ 'Translate', 'RandomShift', 'Mosaic', 'MixUp', 'RandomAffine',
+ 'YOLOXHSVRandomAug', 'CopyPaste'
+]
diff --git a/mmdet/datasets/pipelines/auto_augment.py b/mmdet/datasets/pipelines/auto_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0ff67dbdd99c1889c424b59a9f0f12cfb216ba4
--- /dev/null
+++ b/mmdet/datasets/pipelines/auto_augment.py
@@ -0,0 +1,894 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import cv2
+import mmcv
+import numpy as np
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+_MAX_LEVEL = 10
+
+
+def level_to_value(level, max_value):
+ """Map from level to values based on max_value."""
+ return (level / _MAX_LEVEL) * max_value
+
+
+def enhance_level_to_value(level, a=1.8, b=0.1):
+ """Map from level to values."""
+ return (level / _MAX_LEVEL) * a + b
+
+
+def random_negative(value, random_negative_prob):
+ """Randomly negate value based on random_negative_prob."""
+ return -value if np.random.rand() < random_negative_prob else value
+
+
+def bbox2fields():
+ """The key correspondence from bboxes to labels, masks and
+ segmentations."""
+ bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+ bbox2mask = {
+ 'gt_bboxes': 'gt_masks',
+ 'gt_bboxes_ignore': 'gt_masks_ignore'
+ }
+ bbox2seg = {
+ 'gt_bboxes': 'gt_semantic_seg',
+ }
+ return bbox2label, bbox2mask, bbox2seg
+
+
+@PIPELINES.register_module()
+class AutoAugment:
+ """Auto augmentation.
+
+ This data augmentation is proposed in `Learning Data Augmentation
+ Strategies for Object Detection `_.
+
+ TODO: Implement 'Shear', 'Sharpness' and 'Rotate' transforms
+
+ Args:
+ policies (list[list[dict]]): The policies of auto augmentation. Each
+ policy in ``policies`` is a specific augmentation policy, and is
+ composed by several augmentations (dict). When AutoAugment is
+ called, a random policy in ``policies`` will be selected to
+ augment images.
+
+ Examples:
+ >>> replace = (104, 116, 124)
+ >>> policies = [
+ >>> [
+ >>> dict(type='Sharpness', prob=0.0, level=8),
+ >>> dict(
+ >>> type='Shear',
+ >>> prob=0.4,
+ >>> level=0,
+ >>> replace=replace,
+ >>> axis='x')
+ >>> ],
+ >>> [
+ >>> dict(
+ >>> type='Rotate',
+ >>> prob=0.6,
+ >>> level=10,
+ >>> replace=replace),
+ >>> dict(type='Color', prob=1.0, level=6)
+ >>> ]
+ >>> ]
+ >>> augmentation = AutoAugment(policies)
+ >>> img = np.ones(100, 100, 3)
+ >>> gt_bboxes = np.ones(10, 4)
+ >>> results = dict(img=img, gt_bboxes=gt_bboxes)
+ >>> results = augmentation(results)
+ """
+
+ def __init__(self, policies):
+ assert isinstance(policies, list) and len(policies) > 0, \
+ 'Policies must be a non-empty list.'
+ for policy in policies:
+ assert isinstance(policy, list) and len(policy) > 0, \
+ 'Each policy in policies must be a non-empty list.'
+ for augment in policy:
+ assert isinstance(augment, dict) and 'type' in augment, \
+ 'Each specific augmentation must be a dict with key' \
+ ' "type".'
+
+ self.policies = copy.deepcopy(policies)
+ self.transforms = [Compose(policy) for policy in self.policies]
+
+ def __call__(self, results):
+ transform = np.random.choice(self.transforms)
+ return transform(results)
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(policies={self.policies})'
+
+
+@PIPELINES.register_module()
+class Shear:
+ """Apply Shear Transformation to image (and its corresponding bbox, mask,
+ segmentation).
+
+ Args:
+ level (int | float): The level should be in range [0,_MAX_LEVEL].
+ img_fill_val (int | float | tuple): The filled values for image border.
+ If float, the same fill value will be used for all the three
+ channels of image. If tuple, the should be 3 elements.
+ seg_ignore_label (int): The fill value used for segmentation map.
+ Note this value must equals ``ignore_label`` in ``semantic_head``
+ of the corresponding config. Default 255.
+ prob (float): The probability for performing Shear and should be in
+ range [0, 1].
+ direction (str): The direction for shear, either "horizontal"
+ or "vertical".
+ max_shear_magnitude (float): The maximum magnitude for Shear
+ transformation.
+ random_negative_prob (float): The probability that turns the
+ offset negative. Should be in range [0,1]
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+ """
+
+ def __init__(self,
+ level,
+ img_fill_val=128,
+ seg_ignore_label=255,
+ prob=0.5,
+ direction='horizontal',
+ max_shear_magnitude=0.3,
+ random_negative_prob=0.5,
+ interpolation='bilinear'):
+ assert isinstance(level, (int, float)), 'The level must be type ' \
+ f'int or float, got {type(level)}.'
+ assert 0 <= level <= _MAX_LEVEL, 'The level should be in range ' \
+ f'[0,{_MAX_LEVEL}], got {level}.'
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, 'img_fill_val as tuple must ' \
+ f'have 3 elements. got {len(img_fill_val)}.'
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError(
+ 'img_fill_val must be float or tuple with 3 elements.')
+ assert np.all([0 <= val <= 255 for val in img_fill_val]), 'all ' \
+ 'elements of img_fill_val should between range [0,255].' \
+ f'got {img_fill_val}.'
+ assert 0 <= prob <= 1.0, 'The probability of shear should be in ' \
+ f'range [0,1]. got {prob}.'
+ assert direction in ('horizontal', 'vertical'), 'direction must ' \
+ f'in be either "horizontal" or "vertical". got {direction}.'
+ assert isinstance(max_shear_magnitude, float), 'max_shear_magnitude ' \
+ f'should be type float. got {type(max_shear_magnitude)}.'
+ assert 0. <= max_shear_magnitude <= 1., 'Defaultly ' \
+ 'max_shear_magnitude should be in range [0,1]. ' \
+ f'got {max_shear_magnitude}.'
+ self.level = level
+ self.magnitude = level_to_value(level, max_shear_magnitude)
+ self.img_fill_val = img_fill_val
+ self.seg_ignore_label = seg_ignore_label
+ self.prob = prob
+ self.direction = direction
+ self.max_shear_magnitude = max_shear_magnitude
+ self.random_negative_prob = random_negative_prob
+ self.interpolation = interpolation
+
+ def _shear_img(self,
+ results,
+ magnitude,
+ direction='horizontal',
+ interpolation='bilinear'):
+ """Shear the image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The direction for shear, either "horizontal"
+ or "vertical".
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+ """
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ img_sheared = mmcv.imshear(
+ img,
+ magnitude,
+ direction,
+ border_value=self.img_fill_val,
+ interpolation=interpolation)
+ results[key] = img_sheared.astype(img.dtype)
+ results['img_shape'] = results[key].shape
+
+ def _shear_bboxes(self, results, magnitude):
+ """Shear the bboxes."""
+ h, w, c = results['img_shape']
+ if self.direction == 'horizontal':
+ shear_matrix = np.stack([[1, magnitude],
+ [0, 1]]).astype(np.float32) # [2, 2]
+ else:
+ shear_matrix = np.stack([[1, 0], [magnitude,
+ 1]]).astype(np.float32)
+ for key in results.get('bbox_fields', []):
+ min_x, min_y, max_x, max_y = np.split(
+ results[key], results[key].shape[-1], axis=-1)
+ coordinates = np.stack([[min_x, min_y], [max_x, min_y],
+ [min_x, max_y],
+ [max_x, max_y]]) # [4, 2, nb_box, 1]
+ coordinates = coordinates[..., 0].transpose(
+ (2, 1, 0)).astype(np.float32) # [nb_box, 2, 4]
+ new_coords = np.matmul(shear_matrix[None, :, :],
+ coordinates) # [nb_box, 2, 4]
+ min_x = np.min(new_coords[:, 0, :], axis=-1)
+ min_y = np.min(new_coords[:, 1, :], axis=-1)
+ max_x = np.max(new_coords[:, 0, :], axis=-1)
+ max_y = np.max(new_coords[:, 1, :], axis=-1)
+ min_x = np.clip(min_x, a_min=0, a_max=w)
+ min_y = np.clip(min_y, a_min=0, a_max=h)
+ max_x = np.clip(max_x, a_min=min_x, a_max=w)
+ max_y = np.clip(max_y, a_min=min_y, a_max=h)
+ results[key] = np.stack([min_x, min_y, max_x, max_y],
+ axis=-1).astype(results[key].dtype)
+
+ def _shear_masks(self,
+ results,
+ magnitude,
+ direction='horizontal',
+ fill_val=0,
+ interpolation='bilinear'):
+ """Shear the masks."""
+ h, w, c = results['img_shape']
+ for key in results.get('mask_fields', []):
+ masks = results[key]
+ results[key] = masks.shear((h, w),
+ magnitude,
+ direction,
+ border_value=fill_val,
+ interpolation=interpolation)
+
+ def _shear_seg(self,
+ results,
+ magnitude,
+ direction='horizontal',
+ fill_val=255,
+ interpolation='bilinear'):
+ """Shear the segmentation maps."""
+ for key in results.get('seg_fields', []):
+ seg = results[key]
+ results[key] = mmcv.imshear(
+ seg,
+ magnitude,
+ direction,
+ border_value=fill_val,
+ interpolation=interpolation).astype(seg.dtype)
+
+ def _filter_invalid(self, results, min_bbox_size=0):
+ """Filter bboxes and corresponding masks too small after shear
+ augmentation."""
+ bbox2label, bbox2mask, _ = bbox2fields()
+ for key in results.get('bbox_fields', []):
+ bbox_w = results[key][:, 2] - results[key][:, 0]
+ bbox_h = results[key][:, 3] - results[key][:, 1]
+ valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ results[key] = results[key][valid_inds]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][valid_inds]
+
+ def __call__(self, results):
+ """Call function to shear images, bounding boxes, masks and semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Sheared results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ magnitude = random_negative(self.magnitude, self.random_negative_prob)
+ self._shear_img(results, magnitude, self.direction, self.interpolation)
+ self._shear_bboxes(results, magnitude)
+ # fill_val set to 0 for background of mask.
+ self._shear_masks(
+ results,
+ magnitude,
+ self.direction,
+ fill_val=0,
+ interpolation=self.interpolation)
+ self._shear_seg(
+ results,
+ magnitude,
+ self.direction,
+ fill_val=self.seg_ignore_label,
+ interpolation=self.interpolation)
+ self._filter_invalid(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'img_fill_val={self.img_fill_val}, '
+ repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'direction={self.direction}, '
+ repr_str += f'max_shear_magnitude={self.max_shear_magnitude}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}, '
+ repr_str += f'interpolation={self.interpolation})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Rotate:
+ """Apply Rotate Transformation to image (and its corresponding bbox, mask,
+ segmentation).
+
+ Args:
+ level (int | float): The level should be in range (0,_MAX_LEVEL].
+ scale (int | float): Isotropic scale factor. Same in
+ ``mmcv.imrotate``.
+ center (int | float | tuple[float]): Center point (w, h) of the
+ rotation in the source image. If None, the center of the
+ image will be used. Same in ``mmcv.imrotate``.
+ img_fill_val (int | float | tuple): The fill value for image border.
+ If float, the same value will be used for all the three
+ channels of image. If tuple, the should be 3 elements (e.g.
+ equals the number of channels for image).
+ seg_ignore_label (int): The fill value used for segmentation map.
+ Note this value must equals ``ignore_label`` in ``semantic_head``
+ of the corresponding config. Default 255.
+ prob (float): The probability for perform transformation and
+ should be in range 0 to 1.
+ max_rotate_angle (int | float): The maximum angles for rotate
+ transformation.
+ random_negative_prob (float): The probability that turns the
+ offset negative.
+ """
+
+ def __init__(self,
+ level,
+ scale=1,
+ center=None,
+ img_fill_val=128,
+ seg_ignore_label=255,
+ prob=0.5,
+ max_rotate_angle=30,
+ random_negative_prob=0.5):
+ assert isinstance(level, (int, float)), \
+ f'The level must be type int or float. got {type(level)}.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.'
+ assert isinstance(scale, (int, float)), \
+ f'The scale must be type int or float. got type {type(scale)}.'
+ if isinstance(center, (int, float)):
+ center = (center, center)
+ elif isinstance(center, tuple):
+ assert len(center) == 2, 'center with type tuple must have '\
+ f'2 elements. got {len(center)} elements.'
+ else:
+ assert center is None, 'center must be None or type int, '\
+ f'float or tuple, got type {type(center)}.'
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\
+ f'have 3 elements. got {len(img_fill_val)}.'
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError(
+ 'img_fill_val must be float or tuple with 3 elements.')
+ assert np.all([0 <= val <= 255 for val in img_fill_val]), \
+ 'all elements of img_fill_val should between range [0,255]. '\
+ f'got {img_fill_val}.'
+ assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\
+ f'got {prob}.'
+ assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\
+ f'should be type int or float. got type {type(max_rotate_angle)}.'
+ self.level = level
+ self.scale = scale
+ # Rotation angle in degrees. Positive values mean
+ # clockwise rotation.
+ self.angle = level_to_value(level, max_rotate_angle)
+ self.center = center
+ self.img_fill_val = img_fill_val
+ self.seg_ignore_label = seg_ignore_label
+ self.prob = prob
+ self.max_rotate_angle = max_rotate_angle
+ self.random_negative_prob = random_negative_prob
+
+ def _rotate_img(self, results, angle, center=None, scale=1.0):
+ """Rotate the image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ angle (float): Rotation angle in degrees, positive values
+ mean clockwise rotation. Same in ``mmcv.imrotate``.
+ center (tuple[float], optional): Center point (w, h) of the
+ rotation. Same in ``mmcv.imrotate``.
+ scale (int | float): Isotropic scale factor. Same in
+ ``mmcv.imrotate``.
+ """
+ for key in results.get('img_fields', ['img']):
+ img = results[key].copy()
+ img_rotated = mmcv.imrotate(
+ img, angle, center, scale, border_value=self.img_fill_val)
+ results[key] = img_rotated.astype(img.dtype)
+ results['img_shape'] = results[key].shape
+
+ def _rotate_bboxes(self, results, rotate_matrix):
+ """Rotate the bboxes."""
+ h, w, c = results['img_shape']
+ for key in results.get('bbox_fields', []):
+ min_x, min_y, max_x, max_y = np.split(
+ results[key], results[key].shape[-1], axis=-1)
+ coordinates = np.stack([[min_x, min_y], [max_x, min_y],
+ [min_x, max_y],
+ [max_x, max_y]]) # [4, 2, nb_bbox, 1]
+ # pad 1 to convert from format [x, y] to homogeneous
+ # coordinates format [x, y, 1]
+ coordinates = np.concatenate(
+ (coordinates,
+ np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)),
+ axis=1) # [4, 3, nb_bbox, 1]
+ coordinates = coordinates.transpose(
+ (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1]
+ rotated_coords = np.matmul(rotate_matrix,
+ coordinates) # [nb_bbox, 4, 2, 1]
+ rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2]
+ min_x, min_y = np.min(
+ rotated_coords[:, :, 0], axis=1), np.min(
+ rotated_coords[:, :, 1], axis=1)
+ max_x, max_y = np.max(
+ rotated_coords[:, :, 0], axis=1), np.max(
+ rotated_coords[:, :, 1], axis=1)
+ min_x, min_y = np.clip(
+ min_x, a_min=0, a_max=w), np.clip(
+ min_y, a_min=0, a_max=h)
+ max_x, max_y = np.clip(
+ max_x, a_min=min_x, a_max=w), np.clip(
+ max_y, a_min=min_y, a_max=h)
+ results[key] = np.stack([min_x, min_y, max_x, max_y],
+ axis=-1).astype(results[key].dtype)
+
+ def _rotate_masks(self,
+ results,
+ angle,
+ center=None,
+ scale=1.0,
+ fill_val=0):
+ """Rotate the masks."""
+ h, w, c = results['img_shape']
+ for key in results.get('mask_fields', []):
+ masks = results[key]
+ results[key] = masks.rotate((h, w), angle, center, scale, fill_val)
+
+ def _rotate_seg(self,
+ results,
+ angle,
+ center=None,
+ scale=1.0,
+ fill_val=255):
+ """Rotate the segmentation map."""
+ for key in results.get('seg_fields', []):
+ seg = results[key].copy()
+ results[key] = mmcv.imrotate(
+ seg, angle, center, scale,
+ border_value=fill_val).astype(seg.dtype)
+
+ def _filter_invalid(self, results, min_bbox_size=0):
+ """Filter bboxes and corresponding masks too small after rotate
+ augmentation."""
+ bbox2label, bbox2mask, _ = bbox2fields()
+ for key in results.get('bbox_fields', []):
+ bbox_w = results[key][:, 2] - results[key][:, 0]
+ bbox_h = results[key][:, 3] - results[key][:, 1]
+ valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ results[key] = results[key][valid_inds]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][valid_inds]
+
+ def __call__(self, results):
+ """Call function to rotate images, bounding boxes, masks and semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Rotated results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ h, w = results['img'].shape[:2]
+ center = self.center
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ angle = random_negative(self.angle, self.random_negative_prob)
+ self._rotate_img(results, angle, center, self.scale)
+ rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale)
+ self._rotate_bboxes(results, rotate_matrix)
+ self._rotate_masks(results, angle, center, self.scale, fill_val=0)
+ self._rotate_seg(
+ results, angle, center, self.scale, fill_val=self.seg_ignore_label)
+ self._filter_invalid(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'scale={self.scale}, '
+ repr_str += f'center={self.center}, '
+ repr_str += f'img_fill_val={self.img_fill_val}, '
+ repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'max_rotate_angle={self.max_rotate_angle}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Translate:
+ """Translate the images, bboxes, masks and segmentation maps horizontally
+ or vertically.
+
+ Args:
+ level (int | float): The level for Translate and should be in
+ range [0,_MAX_LEVEL].
+ prob (float): The probability for performing translation and
+ should be in range [0, 1].
+ img_fill_val (int | float | tuple): The filled value for image
+ border. If float, the same fill value will be used for all
+ the three channels of image. If tuple, the should be 3
+ elements (e.g. equals the number of channels for image).
+ seg_ignore_label (int): The fill value used for segmentation map.
+ Note this value must equals ``ignore_label`` in ``semantic_head``
+ of the corresponding config. Default 255.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ max_translate_offset (int | float): The maximum pixel's offset for
+ Translate.
+ random_negative_prob (float): The probability that turns the
+ offset negative.
+ min_size (int | float): The minimum pixel for filtering
+ invalid bboxes after the translation.
+ """
+
+ def __init__(self,
+ level,
+ prob=0.5,
+ img_fill_val=128,
+ seg_ignore_label=255,
+ direction='horizontal',
+ max_translate_offset=250.,
+ random_negative_prob=0.5,
+ min_size=0):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level used for calculating Translate\'s offset should be ' \
+ 'in range [0,_MAX_LEVEL]'
+ assert 0 <= prob <= 1.0, \
+ 'The probability of translation should be in range [0, 1].'
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, \
+ 'img_fill_val as tuple must have 3 elements.'
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError('img_fill_val must be type float or tuple.')
+ assert np.all([0 <= val <= 255 for val in img_fill_val]), \
+ 'all elements of img_fill_val should between range [0,255].'
+ assert direction in ('horizontal', 'vertical'), \
+ 'direction should be "horizontal" or "vertical".'
+ assert isinstance(max_translate_offset, (int, float)), \
+ 'The max_translate_offset must be type int or float.'
+ # the offset used for translation
+ self.offset = int(level_to_value(level, max_translate_offset))
+ self.level = level
+ self.prob = prob
+ self.img_fill_val = img_fill_val
+ self.seg_ignore_label = seg_ignore_label
+ self.direction = direction
+ self.max_translate_offset = max_translate_offset
+ self.random_negative_prob = random_negative_prob
+ self.min_size = min_size
+
+ def _translate_img(self, results, offset, direction='horizontal'):
+ """Translate the image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ offset (int | float): The offset for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ """
+ for key in results.get('img_fields', ['img']):
+ img = results[key].copy()
+ results[key] = mmcv.imtranslate(
+ img, offset, direction, self.img_fill_val).astype(img.dtype)
+ results['img_shape'] = results[key].shape
+
+ def _translate_bboxes(self, results, offset):
+ """Shift bboxes horizontally or vertically, according to offset."""
+ h, w, c = results['img_shape']
+ for key in results.get('bbox_fields', []):
+ min_x, min_y, max_x, max_y = np.split(
+ results[key], results[key].shape[-1], axis=-1)
+ if self.direction == 'horizontal':
+ min_x = np.maximum(0, min_x + offset)
+ max_x = np.minimum(w, max_x + offset)
+ elif self.direction == 'vertical':
+ min_y = np.maximum(0, min_y + offset)
+ max_y = np.minimum(h, max_y + offset)
+
+ # the boxes translated outside of image will be filtered along with
+ # the corresponding masks, by invoking ``_filter_invalid``.
+ results[key] = np.concatenate([min_x, min_y, max_x, max_y],
+ axis=-1)
+
+ def _translate_masks(self,
+ results,
+ offset,
+ direction='horizontal',
+ fill_val=0):
+ """Translate masks horizontally or vertically."""
+ h, w, c = results['img_shape']
+ for key in results.get('mask_fields', []):
+ masks = results[key]
+ results[key] = masks.translate((h, w), offset, direction, fill_val)
+
+ def _translate_seg(self,
+ results,
+ offset,
+ direction='horizontal',
+ fill_val=255):
+ """Translate segmentation maps horizontally or vertically."""
+ for key in results.get('seg_fields', []):
+ seg = results[key].copy()
+ results[key] = mmcv.imtranslate(seg, offset, direction,
+ fill_val).astype(seg.dtype)
+
+ def _filter_invalid(self, results, min_size=0):
+ """Filter bboxes and masks too small or translated out of image."""
+ bbox2label, bbox2mask, _ = bbox2fields()
+ for key in results.get('bbox_fields', []):
+ bbox_w = results[key][:, 2] - results[key][:, 0]
+ bbox_h = results[key][:, 3] - results[key][:, 1]
+ valid_inds = (bbox_w > min_size) & (bbox_h > min_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ results[key] = results[key][valid_inds]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][valid_inds]
+ return results
+
+ def __call__(self, results):
+ """Call function to translate images, bounding boxes, masks and
+ semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Translated results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ offset = random_negative(self.offset, self.random_negative_prob)
+ self._translate_img(results, offset, self.direction)
+ self._translate_bboxes(results, offset)
+ # fill_val defaultly 0 for BitmapMasks and None for PolygonMasks.
+ self._translate_masks(results, offset, self.direction)
+ # fill_val set to ``seg_ignore_label`` for the ignored value
+ # of segmentation map.
+ self._translate_seg(
+ results, offset, self.direction, fill_val=self.seg_ignore_label)
+ self._filter_invalid(results, min_size=self.min_size)
+ return results
+
+
+@PIPELINES.register_module()
+class ColorTransform:
+ """Apply Color transformation to image. The bboxes, masks, and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Color transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level should be in range [0,_MAX_LEVEL].'
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.level = level
+ self.prob = prob
+ self.factor = enhance_level_to_value(level)
+
+ def _adjust_color_img(self, results, factor=1.0):
+ """Apply Color transformation to image."""
+ for key in results.get('img_fields', ['img']):
+ # NOTE defaultly the image should be BGR format
+ img = results[key]
+ results[key] = mmcv.adjust_color(img, factor).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Color transformation.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Colored results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._adjust_color_img(results, self.factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'prob={self.prob})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class EqualizeTransform:
+ """Apply Equalize transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ prob (float): The probability for performing Equalize transformation.
+ """
+
+ def __init__(self, prob=0.5):
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.prob = prob
+
+ def _imequalize(self, results):
+ """Equalizes the histogram of one image."""
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.imequalize(img).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Equalize transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._imequalize(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob})'
+
+
+@PIPELINES.register_module()
+class BrightnessTransform:
+ """Apply Brightness transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Brightness transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level should be in range [0,_MAX_LEVEL].'
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.level = level
+ self.prob = prob
+ self.factor = enhance_level_to_value(level)
+
+ def _adjust_brightness_img(self, results, factor=1.0):
+ """Adjust the brightness of image."""
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.adjust_brightness(img,
+ factor).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Brightness transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._adjust_brightness_img(results, self.factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'prob={self.prob})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class ContrastTransform:
+ """Apply Contrast transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Contrast transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level should be in range [0,_MAX_LEVEL].'
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.level = level
+ self.prob = prob
+ self.factor = enhance_level_to_value(level)
+
+ def _adjust_contrast_img(self, results, factor=1.0):
+ """Adjust the image contrast."""
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.adjust_contrast(img, factor).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Contrast transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._adjust_contrast_img(results, self.factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'prob={self.prob})'
+ return repr_str
diff --git a/mmdet/datasets/pipelines/compose.py b/mmdet/datasets/pipelines/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..d759220098440c769b8f53c1e3b902c046450ff4
--- /dev/null
+++ b/mmdet/datasets/pipelines/compose.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import collections
+
+from mmcv.utils import build_from_cfg
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Compose:
+ """Compose multiple transforms sequentially.
+
+ Args:
+ transforms (Sequence[dict | callable]): Sequence of transform object or
+ config dict to be composed.
+ """
+
+ def __init__(self, transforms):
+ assert isinstance(transforms, collections.abc.Sequence)
+ self.transforms = []
+ for transform in transforms:
+ if isinstance(transform, dict):
+ transform = build_from_cfg(transform, PIPELINES)
+ self.transforms.append(transform)
+ elif callable(transform):
+ self.transforms.append(transform)
+ else:
+ raise TypeError('transform must be callable or a dict')
+
+ def __call__(self, data):
+ """Call function to apply transforms sequentially.
+
+ Args:
+ data (dict): A result dict contains the data to transform.
+
+ Returns:
+ dict: Transformed data.
+ """
+
+ for t in self.transforms:
+ data = t(data)
+ if data is None:
+ return None
+ return data
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ str_ = t.__repr__()
+ if 'Compose(' in str_:
+ str_ = str_.replace('\n', '\n ')
+ format_string += '\n'
+ format_string += f' {str_}'
+ format_string += '\n)'
+ return format_string
diff --git a/mmdet/datasets/pipelines/formating.py b/mmdet/datasets/pipelines/formating.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b3e45abbb0714db18700ba9a12618a5aaa638d8
--- /dev/null
+++ b/mmdet/datasets/pipelines/formating.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# flake8: noqa
+import warnings
+
+from .formatting import *
+
+warnings.warn('DeprecationWarning: mmdet.datasets.pipelines.formating will be '
+ 'deprecated, please replace it with '
+ 'mmdet.datasets.pipelines.formatting.')
diff --git a/mmdet/datasets/pipelines/formatting.py b/mmdet/datasets/pipelines/formatting.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e07f3894f0e7ab9703acd9b790135cd1f878672
--- /dev/null
+++ b/mmdet/datasets/pipelines/formatting.py
@@ -0,0 +1,403 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections.abc import Sequence
+
+import mmcv
+import numpy as np
+import torch
+from mmcv.parallel import DataContainer as DC
+
+from ..builder import PIPELINES
+
+
+def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+
+ Args:
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+ be converted.
+ """
+
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
+
+
+@PIPELINES.register_module()
+class ToTensor:
+ """Convert some results to :obj:`torch.Tensor` by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys that need to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert data in results to :obj:`torch.Tensor`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted
+ to :obj:`torch.Tensor`.
+ """
+ for key in self.keys:
+ results[key] = to_tensor(results[key])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class ImageToTensor:
+ """Convert image to :obj:`torch.Tensor` by given keys.
+
+ The dimension order of input image is (H, W, C). The pipeline will convert
+ it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
+ (1, H, W).
+
+ Args:
+ keys (Sequence[str]): Key of images to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ permute the channel order.
+
+ Args:
+ results (dict): Result dict contains the image data to convert.
+
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and permuted to (C, H, W) order.
+ """
+ for key in self.keys:
+ img = results[key]
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ results[key] = to_tensor(img).permute(2, 0, 1).contiguous()
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class Transpose:
+ """Transpose some results by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys of results to be transposed.
+ order (Sequence[int]): Order of transpose.
+ """
+
+ def __init__(self, keys, order):
+ self.keys = keys
+ self.order = order
+
+ def __call__(self, results):
+ """Call function to transpose the channel order of data in results.
+
+ Args:
+ results (dict): Result dict contains the data to transpose.
+
+ Returns:
+ dict: The result dict contains the data transposed to \
+ ``self.order``.
+ """
+ for key in self.keys:
+ results[key] = results[key].transpose(self.order)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, order={self.order})'
+
+
+@PIPELINES.register_module()
+class ToDataContainer:
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
+
+ Args:
+ fields (Sequence[dict]): Each field is a dict like
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
+ Default: ``(dict(key='img', stack=True), dict(key='gt_bboxes'),
+ dict(key='gt_labels'))``.
+ """
+
+ def __init__(self,
+ fields=(dict(key='img', stack=True), dict(key='gt_bboxes'),
+ dict(key='gt_labels'))):
+ self.fields = fields
+
+ def __call__(self, results):
+ """Call function to convert data in results to
+ :obj:`mmcv.DataContainer`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted to \
+ :obj:`mmcv.DataContainer`.
+ """
+
+ for field in self.fields:
+ field = field.copy()
+ key = field.pop('key')
+ results[key] = DC(results[key], **field)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(fields={self.fields})'
+
+
+@PIPELINES.register_module()
+class DefaultFormatBundle:
+ """Default formatting bundle.
+
+ It simplifies the pipeline of formatting common fields, including "img",
+ "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg".
+ These fields are formatted as follows.
+
+ - img: (1)transpose & to tensor, (2)to DataContainer (stack=True)
+ - proposals: (1)to tensor, (2)to DataContainer
+ - gt_bboxes: (1)to tensor, (2)to DataContainer
+ - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
+ - gt_labels: (1)to tensor, (2)to DataContainer
+ - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \
+ (3)to DataContainer (stack=True)
+
+ Args:
+ img_to_float (bool): Whether to force the image to be converted to
+ float type. Default: True.
+ pad_val (dict): A dict for padding value in batch collating,
+ the default value is `dict(img=0, masks=0, seg=255)`.
+ Without this argument, the padding value of "gt_semantic_seg"
+ will be set to 0 by default, which should be 255.
+ """
+
+ def __init__(self,
+ img_to_float=True,
+ pad_val=dict(img=0, masks=0, seg=255)):
+ self.img_to_float = img_to_float
+ self.pad_val = pad_val
+
+ def __call__(self, results):
+ """Call function to transform and format common fields in results.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with \
+ default bundle.
+ """
+
+ if 'img' in results:
+ img = results['img']
+ if self.img_to_float is True and img.dtype == np.uint8:
+ # Normally, image is of uint8 type without normalization.
+ # At this time, it needs to be forced to be converted to
+ # flot32, otherwise the model training and inference
+ # will be wrong. Only used for YOLOX currently .
+ img = img.astype(np.float32)
+ # add default meta keys
+ results = self._add_default_meta_keys(results)
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ # To improve the computational speed by by 3-5 times, apply:
+ # If image is not contiguous, use
+ # `numpy.transpose()` followed by `numpy.ascontiguousarray()`
+ # If image is already contiguous, use
+ # `torch.permute()` followed by `torch.contiguous()`
+ # Refer to https://github.com/open-mmlab/mmdetection/pull/9533
+ # for more details
+ if not img.flags.c_contiguous:
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
+ img = to_tensor(img)
+ else:
+ img = to_tensor(img).permute(2, 0, 1).contiguous()
+ results['img'] = DC(
+ img, padding_value=self.pad_val['img'], stack=True)
+ for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
+ if key not in results:
+ continue
+ results[key] = DC(to_tensor(results[key]))
+ if 'gt_masks' in results:
+ results['gt_masks'] = DC(
+ results['gt_masks'],
+ padding_value=self.pad_val['masks'],
+ cpu_only=True)
+ if 'gt_semantic_seg' in results:
+ results['gt_semantic_seg'] = DC(
+ to_tensor(results['gt_semantic_seg'][None, ...]),
+ padding_value=self.pad_val['seg'],
+ stack=True)
+ return results
+
+ def _add_default_meta_keys(self, results):
+ """Add default meta keys.
+
+ We set default meta keys including `pad_shape`, `scale_factor` and
+ `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and
+ `Pad` are implemented during the whole pipeline.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ results (dict): Updated result dict contains the data to convert.
+ """
+ img = results['img']
+ results.setdefault('pad_shape', img.shape)
+ results.setdefault('scale_factor', 1.0)
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results.setdefault(
+ 'img_norm_cfg',
+ dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(img_to_float={self.img_to_float})'
+
+
+@PIPELINES.register_module()
+class Collect:
+ """Collect data from the loader relevant to the specific task.
+
+ This is usually the last stage of the data loader pipeline. Typically keys
+ is set to some subset of "img", "proposals", "gt_bboxes",
+ "gt_bboxes_ignore", "gt_labels", and/or "gt_masks".
+
+ The "img_meta" item is always populated. The contents of the "img_meta"
+ dictionary depends on "meta_keys". By default this includes:
+
+ - "img_shape": shape of the image input to the network as a tuple \
+ (h, w, c). Note that images may be zero padded on the \
+ bottom/right if the batch tensor is larger than this shape.
+
+ - "scale_factor": a float indicating the preprocessing scale
+
+ - "flip": a boolean indicating if image flip transform was used
+
+ - "filename": path to the image file
+
+ - "ori_shape": original shape of the image as a tuple (h, w, c)
+
+ - "pad_shape": image shape after padding
+
+ - "img_norm_cfg": a dict of normalization information:
+
+ - mean - per channel mean subtraction
+ - std - per channel std divisor
+ - to_rgb - bool indicating if bgr was converted to rgb
+
+ Args:
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
+ Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'img_norm_cfg')``
+ """
+
+ def __init__(self,
+ keys,
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
+ 'img_shape', 'pad_shape', 'scale_factor', 'flip',
+ 'flip_direction', 'img_norm_cfg')):
+ self.keys = keys
+ self.meta_keys = meta_keys
+
+ def __call__(self, results):
+ """Call function to collect keys in results. The keys in ``meta_keys``
+ will be converted to :obj:mmcv.DataContainer.
+
+ Args:
+ results (dict): Result dict contains the data to collect.
+
+ Returns:
+ dict: The result dict contains the following keys
+
+ - keys in``self.keys``
+ - ``img_metas``
+ """
+
+ data = {}
+ img_meta = {}
+ for key in self.meta_keys:
+ img_meta[key] = results[key]
+ data['img_metas'] = DC(img_meta, cpu_only=True)
+ for key in self.keys:
+ data[key] = results[key]
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'
+
+
+@PIPELINES.register_module()
+class WrapFieldsToLists:
+ """Wrap fields of the data dictionary into lists for evaluation.
+
+ This class can be used as a last step of a test or validation
+ pipeline for single image evaluation or inference.
+
+ Example:
+ >>> test_pipeline = [
+ >>> dict(type='LoadImageFromFile'),
+ >>> dict(type='Normalize',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ to_rgb=True),
+ >>> dict(type='Pad', size_divisor=32),
+ >>> dict(type='ImageToTensor', keys=['img']),
+ >>> dict(type='Collect', keys=['img']),
+ >>> dict(type='WrapFieldsToLists')
+ >>> ]
+ """
+
+ def __call__(self, results):
+ """Call function to wrap fields into lists.
+
+ Args:
+ results (dict): Result dict contains the data to wrap.
+
+ Returns:
+ dict: The result dict where value of ``self.keys`` are wrapped \
+ into list.
+ """
+
+ # Wrap dict fields into lists
+ for key, val in results.items():
+ results[key] = [val]
+ return results
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
diff --git a/mmdet/datasets/pipelines/instaboost.py b/mmdet/datasets/pipelines/instaboost.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca10c4c751f5309e37822fbe61ea3c7ed5de1b83
--- /dev/null
+++ b/mmdet/datasets/pipelines/instaboost.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class InstaBoost:
+ r"""Data augmentation method in `InstaBoost: Boosting Instance
+ Segmentation Via Probability Map Guided Copy-Pasting
+ `_.
+
+ Refer to https://github.com/GothicAi/Instaboost for implementation details.
+
+ Args:
+ action_candidate (tuple): Action candidates. "normal", "horizontal", \
+ "vertical", "skip" are supported. Default: ('normal', \
+ 'horizontal', 'skip').
+ action_prob (tuple): Corresponding action probabilities. Should be \
+ the same length as action_candidate. Default: (1, 0, 0).
+ scale (tuple): (min scale, max scale). Default: (0.8, 1.2).
+ dx (int): The maximum x-axis shift will be (instance width) / dx.
+ Default 15.
+ dy (int): The maximum y-axis shift will be (instance height) / dy.
+ Default 15.
+ theta (tuple): (min rotation degree, max rotation degree). \
+ Default: (-1, 1).
+ color_prob (float): Probability of images for color augmentation.
+ Default 0.5.
+ heatmap_flag (bool): Whether to use heatmap guided. Default False.
+ aug_ratio (float): Probability of applying this transformation. \
+ Default 0.5.
+ """
+
+ def __init__(self,
+ action_candidate=('normal', 'horizontal', 'skip'),
+ action_prob=(1, 0, 0),
+ scale=(0.8, 1.2),
+ dx=15,
+ dy=15,
+ theta=(-1, 1),
+ color_prob=0.5,
+ hflag=False,
+ aug_ratio=0.5):
+ try:
+ import instaboostfast as instaboost
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install instaboostfast" '
+ 'to install instaboostfast first for instaboost augmentation.')
+ self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob,
+ scale, dx, dy, theta,
+ color_prob, hflag)
+ self.aug_ratio = aug_ratio
+
+ def _load_anns(self, results):
+ labels = results['ann_info']['labels']
+ masks = results['ann_info']['masks']
+ bboxes = results['ann_info']['bboxes']
+ n = len(labels)
+
+ anns = []
+ for i in range(n):
+ label = labels[i]
+ bbox = bboxes[i]
+ mask = masks[i]
+ x1, y1, x2, y2 = bbox
+ # assert (x2 - x1) >= 1 and (y2 - y1) >= 1
+ bbox = [x1, y1, x2 - x1, y2 - y1]
+ anns.append({
+ 'category_id': label,
+ 'segmentation': mask,
+ 'bbox': bbox
+ })
+
+ return anns
+
+ def _parse_anns(self, results, anns, img):
+ gt_bboxes = []
+ gt_labels = []
+ gt_masks_ann = []
+ for ann in anns:
+ x1, y1, w, h = ann['bbox']
+ # TODO: more essential bug need to be fixed in instaboost
+ if w <= 0 or h <= 0:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ gt_bboxes.append(bbox)
+ gt_labels.append(ann['category_id'])
+ gt_masks_ann.append(ann['segmentation'])
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ results['ann_info']['labels'] = gt_labels
+ results['ann_info']['bboxes'] = gt_bboxes
+ results['ann_info']['masks'] = gt_masks_ann
+ results['img'] = img
+ return results
+
+ def __call__(self, results):
+ img = results['img']
+ ori_type = img.dtype
+ anns = self._load_anns(results)
+ if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
+ try:
+ import instaboostfast as instaboost
+ except ImportError:
+ raise ImportError('Please run "pip install instaboostfast" '
+ 'to install instaboostfast first.')
+ anns, img = instaboost.get_new_data(
+ anns, img.astype(np.uint8), self.cfg, background=None)
+
+ results = self._parse_anns(results, anns, img.astype(ori_type))
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(cfg={self.cfg}, aug_ratio={self.aug_ratio})'
+ return repr_str
diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..8af8cf352ca4298fca4d50f0f5760daa869a6aeb
--- /dev/null
+++ b/mmdet/datasets/pipelines/loading.py
@@ -0,0 +1,645 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+import mmcv
+import numpy as np
+import pycocotools.mask as maskUtils
+
+from mmdet.core import BitmapMasks, PolygonMasks
+from ..builder import PIPELINES
+
+try:
+ from panopticapi.utils import rgb2id
+except ImportError:
+ rgb2id = None
+
+
+@PIPELINES.register_module()
+class LoadImageFromFile:
+ """Load an image from file.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename"). Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='color',
+ channel_order='bgr',
+ file_client_args=dict(backend='disk')):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.channel_order = channel_order
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def __call__(self, results):
+ """Call functions to load image and get image meta information.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results['img_prefix'] is not None:
+ filename = osp.join(results['img_prefix'],
+ results['img_info']['filename'])
+ else:
+ filename = results['img_info']['filename']
+
+ img_bytes = self.file_client.get(filename)
+ img = mmcv.imfrombytes(
+ img_bytes, flag=self.color_type, channel_order=self.channel_order)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ results['img_fields'] = ['img']
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32}, '
+ f"color_type='{self.color_type}', "
+ f"channel_order='{self.channel_order}', "
+ f'file_client_args={self.file_client_args})')
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadImageFromWebcam(LoadImageFromFile):
+ """Load an image from webcam.
+
+ Similar with :obj:`LoadImageFromFile`, but the image read from webcam is in
+ ``results['img']``.
+ """
+
+ def __call__(self, results):
+ """Call functions to add image meta information.
+
+ Args:
+ results (dict): Result dict with Webcam read image in
+ ``results['img']``.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ img = results['img']
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = None
+ results['ori_filename'] = None
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ results['img_fields'] = ['img']
+ return results
+
+
+@PIPELINES.register_module()
+class LoadMultiChannelImageFromFiles:
+ """Load multi-channel images from a list of separate channel files.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename", which is expected to be a list of filenames).
+ Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='unchanged',
+ file_client_args=dict(backend='disk')):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def __call__(self, results):
+ """Call functions to load multiple images and get images meta
+ information.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded images and meta information.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results['img_prefix'] is not None:
+ filename = [
+ osp.join(results['img_prefix'], fname)
+ for fname in results['img_info']['filename']
+ ]
+ else:
+ filename = results['img_info']['filename']
+
+ img = []
+ for name in filename:
+ img_bytes = self.file_client.get(name)
+ img.append(mmcv.imfrombytes(img_bytes, flag=self.color_type))
+ img = np.stack(img, axis=-1)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ # Set initial values for default meta_keys
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = 1.0
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results['img_norm_cfg'] = dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False)
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32}, '
+ f"color_type='{self.color_type}', "
+ f'file_client_args={self.file_client_args})')
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadAnnotations:
+ """Load multiple types of annotations.
+
+ Args:
+ with_bbox (bool): Whether to parse and load the bbox annotation.
+ Default: True.
+ with_label (bool): Whether to parse and load the label annotation.
+ Default: True.
+ with_mask (bool): Whether to parse and load the mask annotation.
+ Default: False.
+ with_seg (bool): Whether to parse and load the semantic segmentation
+ annotation. Default: False.
+ poly2mask (bool): Whether to convert the instance masks from polygons
+ to bitmaps. Default: True.
+ denorm_bbox (bool): Whether to convert bbox from relative value to
+ absolute value. Only used in OpenImage Dataset.
+ Default: False.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ with_bbox=True,
+ with_label=True,
+ with_mask=False,
+ with_seg=False,
+ poly2mask=True,
+ denorm_bbox=False,
+ file_client_args=dict(backend='disk')):
+ self.with_bbox = with_bbox
+ self.with_label = with_label
+ self.with_mask = with_mask
+ self.with_seg = with_seg
+ self.poly2mask = poly2mask
+ self.denorm_bbox = denorm_bbox
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def _load_bboxes(self, results):
+ """Private function to load bounding box annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded bounding box annotations.
+ """
+
+ ann_info = results['ann_info']
+ results['gt_bboxes'] = ann_info['bboxes'].copy()
+
+ if self.denorm_bbox:
+ bbox_num = results['gt_bboxes'].shape[0]
+ if bbox_num != 0:
+ h, w = results['img_shape'][:2]
+ results['gt_bboxes'][:, 0::2] *= w
+ results['gt_bboxes'][:, 1::2] *= h
+
+ gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
+ if gt_bboxes_ignore is not None:
+ results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy()
+ results['bbox_fields'].append('gt_bboxes_ignore')
+ results['bbox_fields'].append('gt_bboxes')
+
+ gt_is_group_ofs = ann_info.get('gt_is_group_ofs', None)
+ if gt_is_group_ofs is not None:
+ results['gt_is_group_ofs'] = gt_is_group_ofs.copy()
+
+ return results
+
+ def _load_labels(self, results):
+ """Private function to load label annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded label annotations.
+ """
+
+ results['gt_labels'] = results['ann_info']['labels'].copy()
+ return results
+
+ def _poly2mask(self, mask_ann, img_h, img_w):
+ """Private function to convert masks represented with polygon to
+ bitmaps.
+
+ Args:
+ mask_ann (list | dict): Polygon mask annotation input.
+ img_h (int): The height of output mask.
+ img_w (int): The width of output mask.
+
+ Returns:
+ numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
+ """
+
+ if isinstance(mask_ann, list):
+ # polygon -- a single object might consist of multiple parts
+ # we merge all parts into one mask rle code
+ rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
+ rle = maskUtils.merge(rles)
+ elif isinstance(mask_ann['counts'], list):
+ # uncompressed RLE
+ rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
+ else:
+ # rle
+ rle = mask_ann
+ mask = maskUtils.decode(rle)
+ return mask
+
+ def process_polygons(self, polygons):
+ """Convert polygons to list of ndarray and filter invalid polygons.
+
+ Args:
+ polygons (list[list]): Polygons of one instance.
+
+ Returns:
+ list[numpy.ndarray]: Processed polygons.
+ """
+
+ polygons = [np.array(p) for p in polygons]
+ valid_polygons = []
+ for polygon in polygons:
+ if len(polygon) % 2 == 0 and len(polygon) >= 6:
+ valid_polygons.append(polygon)
+ return valid_polygons
+
+ def _load_masks(self, results):
+ """Private function to load mask annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded mask annotations.
+ If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
+ :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
+ """
+
+ h, w = results['img_info']['height'], results['img_info']['width']
+ gt_masks = results['ann_info']['masks']
+ if self.poly2mask:
+ gt_masks = BitmapMasks(
+ [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
+ else:
+ gt_masks = PolygonMasks(
+ [self.process_polygons(polygons) for polygons in gt_masks], h,
+ w)
+ results['gt_masks'] = gt_masks
+ results['mask_fields'].append('gt_masks')
+ return results
+
+ def _load_semantic_seg(self, results):
+ """Private function to load semantic segmentation annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+
+ Returns:
+ dict: The dict contains loaded semantic segmentation annotations.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ filename = osp.join(results['seg_prefix'],
+ results['ann_info']['seg_map'])
+ img_bytes = self.file_client.get(filename)
+ results['gt_semantic_seg'] = mmcv.imfrombytes(
+ img_bytes, flag='unchanged').squeeze()
+ results['seg_fields'].append('gt_semantic_seg')
+ return results
+
+ def __call__(self, results):
+ """Call function to load multiple types annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded bounding box, label, mask and
+ semantic segmentation annotations.
+ """
+
+ if self.with_bbox:
+ results = self._load_bboxes(results)
+ if results is None:
+ return None
+ if self.with_label:
+ results = self._load_labels(results)
+ if self.with_mask:
+ results = self._load_masks(results)
+ if self.with_seg:
+ results = self._load_semantic_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(with_bbox={self.with_bbox}, '
+ repr_str += f'with_label={self.with_label}, '
+ repr_str += f'with_mask={self.with_mask}, '
+ repr_str += f'with_seg={self.with_seg}, '
+ repr_str += f'poly2mask={self.poly2mask}, '
+ repr_str += f'file_client_args={self.file_client_args})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadPanopticAnnotations(LoadAnnotations):
+ """Load multiple types of panoptic annotations.
+
+ Args:
+ with_bbox (bool): Whether to parse and load the bbox annotation.
+ Default: True.
+ with_label (bool): Whether to parse and load the label annotation.
+ Default: True.
+ with_mask (bool): Whether to parse and load the mask annotation.
+ Default: True.
+ with_seg (bool): Whether to parse and load the semantic segmentation
+ annotation. Default: True.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ with_bbox=True,
+ with_label=True,
+ with_mask=True,
+ with_seg=True,
+ file_client_args=dict(backend='disk')):
+ if rgb2id is None:
+ raise RuntimeError(
+ 'panopticapi is not installed, please install it by: '
+ 'pip install git+https://github.com/cocodataset/'
+ 'panopticapi.git.')
+
+ super(LoadPanopticAnnotations, self).__init__(
+ with_bbox=with_bbox,
+ with_label=with_label,
+ with_mask=with_mask,
+ with_seg=with_seg,
+ poly2mask=True,
+ denorm_bbox=False,
+ file_client_args=file_client_args)
+
+ def _load_masks_and_semantic_segs(self, results):
+ """Private function to load mask and semantic segmentation annotations.
+
+ In gt_semantic_seg, the foreground label is from `0` to
+ `num_things - 1`, the background label is from `num_things` to
+ `num_things + num_stuff - 1`, 255 means the ignored label (`VOID`).
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded mask and semantic segmentation
+ annotations. `BitmapMasks` is used for mask annotations.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ filename = osp.join(results['seg_prefix'],
+ results['ann_info']['seg_map'])
+ img_bytes = self.file_client.get(filename)
+ pan_png = mmcv.imfrombytes(
+ img_bytes, flag='color', channel_order='rgb').squeeze()
+ pan_png = rgb2id(pan_png)
+
+ gt_masks = []
+ gt_seg = np.zeros_like(pan_png) + 255 # 255 as ignore
+
+ for mask_info in results['ann_info']['masks']:
+ mask = (pan_png == mask_info['id'])
+ gt_seg = np.where(mask, mask_info['category'], gt_seg)
+
+ # The legal thing masks
+ if mask_info.get('is_thing'):
+ gt_masks.append(mask.astype(np.uint8))
+
+ if self.with_mask:
+ h, w = results['img_info']['height'], results['img_info']['width']
+ gt_masks = BitmapMasks(gt_masks, h, w)
+ results['gt_masks'] = gt_masks
+ results['mask_fields'].append('gt_masks')
+
+ if self.with_seg:
+ results['gt_semantic_seg'] = gt_seg
+ results['seg_fields'].append('gt_semantic_seg')
+ return results
+
+ def __call__(self, results):
+ """Call function to load multiple types panoptic annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded bounding box, label, mask and
+ semantic segmentation annotations.
+ """
+
+ if self.with_bbox:
+ results = self._load_bboxes(results)
+ if results is None:
+ return None
+ if self.with_label:
+ results = self._load_labels(results)
+ if self.with_mask or self.with_seg:
+ # The tasks completed by '_load_masks' and '_load_semantic_segs'
+ # in LoadAnnotations are merged to one function.
+ results = self._load_masks_and_semantic_segs(results)
+
+ return results
+
+
+@PIPELINES.register_module()
+class LoadProposals:
+ """Load proposal pipeline.
+
+ Required key is "proposals". Updated keys are "proposals", "bbox_fields".
+
+ Args:
+ num_max_proposals (int, optional): Maximum number of proposals to load.
+ If not specified, all proposals will be loaded.
+ """
+
+ def __init__(self, num_max_proposals=None):
+ self.num_max_proposals = num_max_proposals
+
+ def __call__(self, results):
+ """Call function to load proposals from file.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded proposal annotations.
+ """
+
+ proposals = results['proposals']
+ if proposals.shape[1] not in (4, 5):
+ raise AssertionError(
+ 'proposals should have shapes (n, 4) or (n, 5), '
+ f'but found {proposals.shape}')
+ proposals = proposals[:, :4]
+
+ if self.num_max_proposals is not None:
+ proposals = proposals[:self.num_max_proposals]
+
+ if len(proposals) == 0:
+ proposals = np.array([[0, 0, 0, 0]], dtype=np.float32)
+ results['proposals'] = proposals
+ results['bbox_fields'].append('proposals')
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(num_max_proposals={self.num_max_proposals})'
+
+
+@PIPELINES.register_module()
+class FilterAnnotations:
+ """Filter invalid annotations.
+
+ Args:
+ min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth
+ boxes. Default: (1., 1.)
+ min_gt_mask_area (int): Minimum foreground area of ground truth masks.
+ Default: 1
+ by_box (bool): Filter instances with bounding boxes not meeting the
+ min_gt_bbox_wh threshold. Default: True
+ by_mask (bool): Filter instances with masks not meeting
+ min_gt_mask_area threshold. Default: False
+ keep_empty (bool): Whether to return None when it
+ becomes an empty bbox after filtering. Default: True
+ """
+
+ def __init__(self,
+ min_gt_bbox_wh=(1., 1.),
+ min_gt_mask_area=1,
+ by_box=True,
+ by_mask=False,
+ keep_empty=True):
+ # TODO: add more filter options
+ assert by_box or by_mask
+ self.min_gt_bbox_wh = min_gt_bbox_wh
+ self.min_gt_mask_area = min_gt_mask_area
+ self.by_box = by_box
+ self.by_mask = by_mask
+ self.keep_empty = keep_empty
+
+ def __call__(self, results):
+ if self.by_box:
+ assert 'gt_bboxes' in results
+ gt_bboxes = results['gt_bboxes']
+ instance_num = gt_bboxes.shape[0]
+ if self.by_mask:
+ assert 'gt_masks' in results
+ gt_masks = results['gt_masks']
+ instance_num = len(gt_masks)
+
+ if instance_num == 0:
+ return results
+
+ tests = []
+ if self.by_box:
+ w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
+ h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
+ tests.append((w > self.min_gt_bbox_wh[0])
+ & (h > self.min_gt_bbox_wh[1]))
+ if self.by_mask:
+ gt_masks = results['gt_masks']
+ tests.append(gt_masks.areas >= self.min_gt_mask_area)
+
+ keep = tests[0]
+ for t in tests[1:]:
+ keep = keep & t
+
+ keep = keep.nonzero()[0]
+
+ keys = ('gt_bboxes', 'gt_labels', 'gt_masks')
+ for key in keys:
+ if key in results:
+ results[key] = results[key][keep]
+ if keep.size == 0:
+ if self.keep_empty:
+ return None
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(min_gt_bbox_wh={self.min_gt_bbox_wh},' \
+ f'min_gt_mask_area={self.min_gt_mask_area},' \
+ f'by_box={self.by_box},' \
+ f'by_mask={self.by_mask},' \
+ f'always_keep={self.always_keep})'
diff --git a/mmdet/datasets/pipelines/test_time_aug.py b/mmdet/datasets/pipelines/test_time_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f1ab7b7cc81891dd14d136a24cec5228495d2f0
--- /dev/null
+++ b/mmdet/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,121 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import mmcv
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+
+@PIPELINES.register_module()
+class MultiScaleFlipAug:
+ """Test-time augmentation with multiple scales and flipping.
+
+ An example configuration is as followed:
+
+ .. code-block::
+
+ img_scale=[(1333, 400), (1333, 800)],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ]
+
+ After MultiScaleFLipAug with above configuration, the results are wrapped
+ into lists of the same length as followed:
+
+ .. code-block::
+
+ dict(
+ img=[...],
+ img_shape=[...],
+ scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
+ flip=[False, True, False, True]
+ ...
+ )
+
+ Args:
+ transforms (list[dict]): Transforms to apply in each augmentation.
+ img_scale (tuple | list[tuple] | None): Images scales for resizing.
+ scale_factor (float | list[float] | None): Scale factors for resizing.
+ flip (bool): Whether apply flip augmentation. Default: False.
+ flip_direction (str | list[str]): Flip augmentation directions,
+ options are "horizontal", "vertical" and "diagonal". If
+ flip_direction is a list, multiple flip augmentations will be
+ applied. It has no effect when flip == False. Default:
+ "horizontal".
+ """
+
+ def __init__(self,
+ transforms,
+ img_scale=None,
+ scale_factor=None,
+ flip=False,
+ flip_direction='horizontal'):
+ self.transforms = Compose(transforms)
+ assert (img_scale is None) ^ (scale_factor is None), (
+ 'Must have but only one variable can be set')
+ if img_scale is not None:
+ self.img_scale = img_scale if isinstance(img_scale,
+ list) else [img_scale]
+ self.scale_key = 'scale'
+ assert mmcv.is_list_of(self.img_scale, tuple)
+ else:
+ self.img_scale = scale_factor if isinstance(
+ scale_factor, list) else [scale_factor]
+ self.scale_key = 'scale_factor'
+
+ self.flip = flip
+ self.flip_direction = flip_direction if isinstance(
+ flip_direction, list) else [flip_direction]
+ assert mmcv.is_list_of(self.flip_direction, str)
+ if not self.flip and self.flip_direction != ['horizontal']:
+ warnings.warn(
+ 'flip_direction has no effect when flip is set to False')
+ if (self.flip
+ and not any([t['type'] == 'RandomFlip' for t in transforms])):
+ warnings.warn(
+ 'flip has no effect when RandomFlip is not in transforms')
+
+ def __call__(self, results):
+ """Call function to apply test time augment transforms on results.
+
+ Args:
+ results (dict): Result dict contains the data to transform.
+
+ Returns:
+ dict[str: list]: The augmented data, where each value is wrapped
+ into a list.
+ """
+
+ aug_data = []
+ flip_args = [(False, None)]
+ if self.flip:
+ flip_args += [(True, direction)
+ for direction in self.flip_direction]
+ for scale in self.img_scale:
+ for flip, direction in flip_args:
+ _results = results.copy()
+ _results[self.scale_key] = scale
+ _results['flip'] = flip
+ _results['flip_direction'] = direction
+ data = self.transforms(_results)
+ aug_data.append(data)
+ # list of dict to dict of list
+ aug_data_dict = {key: [] for key in aug_data[0]}
+ for data in aug_data:
+ for key, val in data.items():
+ aug_data_dict[key].append(val)
+ return aug_data_dict
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
+ repr_str += f'flip_direction={self.flip_direction})'
+ return repr_str
diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c9ef72c76f6a0eb3ab5eb5ca85286de432ee34b
--- /dev/null
+++ b/mmdet/datasets/pipelines/transforms.py
@@ -0,0 +1,2968 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import inspect
+import math
+import warnings
+
+import cv2
+import mmcv
+import numpy as np
+from numpy import random
+
+from mmdet.core import BitmapMasks, PolygonMasks, find_inside_bboxes
+from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
+from mmdet.utils import log_img_scale
+from ..builder import PIPELINES
+
+try:
+ from imagecorruptions import corrupt
+except ImportError:
+ corrupt = None
+
+try:
+ import albumentations
+ from albumentations import Compose
+except ImportError:
+ albumentations = None
+ Compose = None
+
+
+@PIPELINES.register_module()
+class Resize:
+ """Resize images & bbox & mask.
+
+ This transform resizes the input image to some scale. Bboxes and masks are
+ then resized with the same scale factor. If the input dict contains the key
+ "scale", then the scale in the input dict is used, otherwise the specified
+ scale in the init method is used. If the input dict contains the key
+ "scale_factor" (if MultiScaleFlipAug does not give img_scale but
+ scale_factor), the actual scale will be computed by image shape and
+ scale_factor.
+
+ `img_scale` can either be a tuple (single-scale) or a list of tuple
+ (multi-scale). There are 3 multiscale modes:
+
+ - ``ratio_range is not None``: randomly sample a ratio from the ratio \
+ range and multiply it with the image scale.
+ - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \
+ sample a scale from the multiscale range.
+ - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \
+ sample a scale from multiple scales.
+
+ Args:
+ img_scale (tuple or list[tuple]): Images scales for resizing.
+ multiscale_mode (str): Either "range" or "value".
+ ratio_range (tuple[float]): (min_ratio, max_ratio)
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image.
+ bbox_clip_border (bool, optional): Whether to clip the objects outside
+ the border of the image. In some dataset like MOT17, the gt bboxes
+ are allowed to cross the border of images. Therefore, we don't
+ need to clip the gt bboxes in these cases. Defaults to True.
+ backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
+ These two backends generates slightly different results. Defaults
+ to 'cv2'.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ override (bool, optional): Whether to override `scale` and
+ `scale_factor` so as to call resize twice. Default False. If True,
+ after the first resizing, the existed `scale` and `scale_factor`
+ will be ignored so the second resizing can be allowed.
+ This option is a work-around for multiple times of resize in DETR.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ img_scale=None,
+ multiscale_mode='range',
+ ratio_range=None,
+ keep_ratio=True,
+ bbox_clip_border=True,
+ backend='cv2',
+ interpolation='bilinear',
+ override=False):
+ if img_scale is None:
+ self.img_scale = None
+ else:
+ if isinstance(img_scale, list):
+ self.img_scale = img_scale
+ else:
+ self.img_scale = [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple)
+
+ if ratio_range is not None:
+ # mode 1: given a scale and a range of image ratio
+ assert len(self.img_scale) == 1
+ else:
+ # mode 2: given multiple scales or a range of scales
+ assert multiscale_mode in ['value', 'range']
+
+ self.backend = backend
+ self.multiscale_mode = multiscale_mode
+ self.ratio_range = ratio_range
+ self.keep_ratio = keep_ratio
+ # TODO: refactor the override option in Resize
+ self.interpolation = interpolation
+ self.override = override
+ self.bbox_clip_border = bbox_clip_border
+
+ @staticmethod
+ def random_select(img_scales):
+ """Randomly select an img_scale from given candidates.
+
+ Args:
+ img_scales (list[tuple]): Images scales for selection.
+
+ Returns:
+ (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \
+ where ``img_scale`` is the selected image scale and \
+ ``scale_idx`` is the selected index in the given candidates.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple)
+ scale_idx = np.random.randint(len(img_scales))
+ img_scale = img_scales[scale_idx]
+ return img_scale, scale_idx
+
+ @staticmethod
+ def random_sample(img_scales):
+ """Randomly sample an img_scale when ``multiscale_mode=='range'``.
+
+ Args:
+ img_scales (list[tuple]): Images scale range for sampling.
+ There must be two tuples in img_scales, which specify the lower
+ and upper bound of image scales.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(img_scale, None)``, where \
+ ``img_scale`` is sampled scale and None is just a placeholder \
+ to be consistent with :func:`random_select`.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
+ img_scale_long = [max(s) for s in img_scales]
+ img_scale_short = [min(s) for s in img_scales]
+ long_edge = np.random.randint(
+ min(img_scale_long),
+ max(img_scale_long) + 1)
+ short_edge = np.random.randint(
+ min(img_scale_short),
+ max(img_scale_short) + 1)
+ img_scale = (long_edge, short_edge)
+ return img_scale, None
+
+ @staticmethod
+ def random_sample_ratio(img_scale, ratio_range):
+ """Randomly sample an img_scale when ``ratio_range`` is specified.
+
+ A ratio will be randomly sampled from the range specified by
+ ``ratio_range``. Then it would be multiplied with ``img_scale`` to
+ generate sampled scale.
+
+ Args:
+ img_scale (tuple): Images scale base to multiply with ratio.
+ ratio_range (tuple[float]): The minimum and maximum ratio to scale
+ the ``img_scale``.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(scale, None)``, where \
+ ``scale`` is sampled ratio multiplied with ``img_scale`` and \
+ None is just a placeholder to be consistent with \
+ :func:`random_select`.
+ """
+
+ assert isinstance(img_scale, tuple) and len(img_scale) == 2
+ min_ratio, max_ratio = ratio_range
+ assert min_ratio <= max_ratio
+ ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
+ scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
+ return scale, None
+
+ def _random_scale(self, results):
+ """Randomly sample an img_scale according to ``ratio_range`` and
+ ``multiscale_mode``.
+
+ If ``ratio_range`` is specified, a ratio will be sampled and be
+ multiplied with ``img_scale``.
+ If multiple scales are specified by ``img_scale``, a scale will be
+ sampled according to ``multiscale_mode``.
+ Otherwise, single scale will be used.
+
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+
+ Returns:
+ dict: Two new keys 'scale` and 'scale_idx` are added into \
+ ``results``, which would be used by subsequent pipelines.
+ """
+
+ if self.ratio_range is not None:
+ scale, scale_idx = self.random_sample_ratio(
+ self.img_scale[0], self.ratio_range)
+ elif len(self.img_scale) == 1:
+ scale, scale_idx = self.img_scale[0], 0
+ elif self.multiscale_mode == 'range':
+ scale, scale_idx = self.random_sample(self.img_scale)
+ elif self.multiscale_mode == 'value':
+ scale, scale_idx = self.random_select(self.img_scale)
+ else:
+ raise NotImplementedError
+
+ results['scale'] = scale
+ results['scale_idx'] = scale_idx
+
+ def _resize_img(self, results):
+ """Resize images with ``results['scale']``."""
+ for key in results.get('img_fields', ['img']):
+ if self.keep_ratio:
+ img, scale_factor = mmcv.imrescale(
+ results[key],
+ results['scale'],
+ return_scale=True,
+ interpolation=self.interpolation,
+ backend=self.backend)
+ # the w_scale and h_scale has minor difference
+ # a real fix should be done in the mmcv.imrescale in the future
+ new_h, new_w = img.shape[:2]
+ h, w = results[key].shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ else:
+ img, w_scale, h_scale = mmcv.imresize(
+ results[key],
+ results['scale'],
+ return_scale=True,
+ interpolation=self.interpolation,
+ backend=self.backend)
+ results[key] = img
+
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
+ dtype=np.float32)
+ results['img_shape'] = img.shape
+ # in case that there is no padding
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = scale_factor
+ results['keep_ratio'] = self.keep_ratio
+
+ def _resize_bboxes(self, results):
+ """Resize bounding boxes with ``results['scale_factor']``."""
+ for key in results.get('bbox_fields', []):
+ bboxes = results[key] * results['scale_factor']
+ if self.bbox_clip_border:
+ img_shape = results['img_shape']
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
+ results[key] = bboxes
+
+ def _resize_masks(self, results):
+ """Resize masks with ``results['scale']``"""
+ for key in results.get('mask_fields', []):
+ if results[key] is None:
+ continue
+ if self.keep_ratio:
+ results[key] = results[key].rescale(results['scale'])
+ else:
+ results[key] = results[key].resize(results['img_shape'][:2])
+
+ def _resize_seg(self, results):
+ """Resize semantic segmentation map with ``results['scale']``."""
+ for key in results.get('seg_fields', []):
+ if self.keep_ratio:
+ gt_seg = mmcv.imrescale(
+ results[key],
+ results['scale'],
+ interpolation='nearest',
+ backend=self.backend)
+ else:
+ gt_seg = mmcv.imresize(
+ results[key],
+ results['scale'],
+ interpolation='nearest',
+ backend=self.backend)
+ results[key] = gt_seg
+
+ def __call__(self, results):
+ """Call function to resize images, bounding boxes, masks, semantic
+ segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \
+ 'keep_ratio' keys are added into result dict.
+ """
+
+ if 'scale' not in results:
+ if 'scale_factor' in results:
+ img_shape = results['img'].shape[:2]
+ scale_factor = results['scale_factor']
+ assert isinstance(scale_factor, float)
+ results['scale'] = tuple(
+ [int(x * scale_factor) for x in img_shape][::-1])
+ else:
+ self._random_scale(results)
+ else:
+ if not self.override:
+ assert 'scale_factor' not in results, (
+ 'scale and scale_factor cannot be both set.')
+ else:
+ results.pop('scale')
+ if 'scale_factor' in results:
+ results.pop('scale_factor')
+ self._random_scale(results)
+
+ self._resize_img(results)
+ self._resize_bboxes(results)
+ self._resize_masks(results)
+ self._resize_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(img_scale={self.img_scale}, '
+ repr_str += f'multiscale_mode={self.multiscale_mode}, '
+ repr_str += f'ratio_range={self.ratio_range}, '
+ repr_str += f'keep_ratio={self.keep_ratio}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomFlip:
+ """Flip the image & bbox & mask.
+
+ If the input dict contains the key "flip", then the flag will be used,
+ otherwise it will be randomly decided by a ratio specified in the init
+ method.
+
+ When random flip is enabled, ``flip_ratio``/``direction`` can either be a
+ float/string or tuple of float/string. There are 3 flip modes:
+
+ - ``flip_ratio`` is float, ``direction`` is string: the image will be
+ ``direction``ly flipped with probability of ``flip_ratio`` .
+ E.g., ``flip_ratio=0.5``, ``direction='horizontal'``,
+ then image will be horizontally flipped with probability of 0.5.
+ - ``flip_ratio`` is float, ``direction`` is list of string: the image will
+ be ``direction[i]``ly flipped with probability of
+ ``flip_ratio/len(direction)``.
+ E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``,
+ then image will be horizontally flipped with probability of 0.25,
+ vertically with probability of 0.25.
+ - ``flip_ratio`` is list of float, ``direction`` is list of string:
+ given ``len(flip_ratio) == len(direction)``, the image will
+ be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``.
+ E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal',
+ 'vertical']``, then image will be horizontally flipped with probability
+ of 0.3, vertically with probability of 0.5.
+
+ Args:
+ flip_ratio (float | list[float], optional): The flipping probability.
+ Default: None.
+ direction(str | list[str], optional): The flipping direction. Options
+ are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'.
+ If input is a list, the length must equal ``flip_ratio``. Each
+ element in ``flip_ratio`` indicates the flip probability of
+ corresponding direction.
+ """
+
+ def __init__(self, flip_ratio=None, direction='horizontal'):
+ if isinstance(flip_ratio, list):
+ assert mmcv.is_list_of(flip_ratio, float)
+ assert 0 <= sum(flip_ratio) <= 1
+ elif isinstance(flip_ratio, float):
+ assert 0 <= flip_ratio <= 1
+ elif flip_ratio is None:
+ pass
+ else:
+ raise ValueError('flip_ratios must be None, float, '
+ 'or list of float')
+ self.flip_ratio = flip_ratio
+
+ valid_directions = ['horizontal', 'vertical', 'diagonal']
+ if isinstance(direction, str):
+ assert direction in valid_directions
+ elif isinstance(direction, list):
+ assert mmcv.is_list_of(direction, str)
+ assert set(direction).issubset(set(valid_directions))
+ else:
+ raise ValueError('direction must be either str or list of str')
+ self.direction = direction
+
+ if isinstance(flip_ratio, list):
+ assert len(self.flip_ratio) == len(self.direction)
+
+ def bbox_flip(self, bboxes, img_shape, direction):
+ """Flip bboxes horizontally.
+
+ Args:
+ bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
+ img_shape (tuple[int]): Image shape (height, width)
+ direction (str): Flip direction. Options are 'horizontal',
+ 'vertical'.
+
+ Returns:
+ numpy.ndarray: Flipped bounding boxes.
+ """
+
+ assert bboxes.shape[-1] % 4 == 0
+ flipped = bboxes.copy()
+ if direction == 'horizontal':
+ w = img_shape[1]
+ flipped[..., 0::4] = w - bboxes[..., 2::4]
+ flipped[..., 2::4] = w - bboxes[..., 0::4]
+ elif direction == 'vertical':
+ h = img_shape[0]
+ flipped[..., 1::4] = h - bboxes[..., 3::4]
+ flipped[..., 3::4] = h - bboxes[..., 1::4]
+ elif direction == 'diagonal':
+ w = img_shape[1]
+ h = img_shape[0]
+ flipped[..., 0::4] = w - bboxes[..., 2::4]
+ flipped[..., 1::4] = h - bboxes[..., 3::4]
+ flipped[..., 2::4] = w - bboxes[..., 0::4]
+ flipped[..., 3::4] = h - bboxes[..., 1::4]
+ else:
+ raise ValueError(f"Invalid flipping direction '{direction}'")
+ return flipped
+
+ def __call__(self, results):
+ """Call function to flip bounding boxes, masks, semantic segmentation
+ maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Flipped results, 'flip', 'flip_direction' keys are added \
+ into result dict.
+ """
+
+ if 'flip' not in results:
+ if isinstance(self.direction, list):
+ # None means non-flip
+ direction_list = self.direction + [None]
+ else:
+ # None means non-flip
+ direction_list = [self.direction, None]
+
+ if isinstance(self.flip_ratio, list):
+ non_flip_ratio = 1 - sum(self.flip_ratio)
+ flip_ratio_list = self.flip_ratio + [non_flip_ratio]
+ else:
+ non_flip_ratio = 1 - self.flip_ratio
+ # exclude non-flip
+ single_ratio = self.flip_ratio / (len(direction_list) - 1)
+ flip_ratio_list = [single_ratio] * (len(direction_list) -
+ 1) + [non_flip_ratio]
+
+ cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
+
+ results['flip'] = cur_dir is not None
+ if 'flip_direction' not in results:
+ results['flip_direction'] = cur_dir
+ if results['flip']:
+ # flip image
+ for key in results.get('img_fields', ['img']):
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction'])
+ # flip bboxes
+ for key in results.get('bbox_fields', []):
+ results[key] = self.bbox_flip(results[key],
+ results['img_shape'],
+ results['flip_direction'])
+ # flip masks
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].flip(results['flip_direction'])
+
+ # flip segs
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction'])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
+
+
+@PIPELINES.register_module()
+class RandomShift:
+ """Shift the image and box given shift pixels and probability.
+
+ Args:
+ shift_ratio (float): Probability of shifts. Default 0.5.
+ max_shift_px (int): The max pixels for shifting. Default 32.
+ filter_thr_px (int): The width and height threshold for filtering.
+ The bbox and the rest of the targets below the width and
+ height threshold will be filtered. Default 1.
+ """
+
+ def __init__(self, shift_ratio=0.5, max_shift_px=32, filter_thr_px=1):
+ assert 0 <= shift_ratio <= 1
+ assert max_shift_px >= 0
+ self.shift_ratio = shift_ratio
+ self.max_shift_px = max_shift_px
+ self.filter_thr_px = int(filter_thr_px)
+ # The key correspondence from bboxes to labels.
+ self.bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+
+ def __call__(self, results):
+ """Call function to random shift images, bounding boxes.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Shift results.
+ """
+ if random.random() < self.shift_ratio:
+ img_shape = results['img'].shape[:2]
+
+ random_shift_x = random.randint(-self.max_shift_px,
+ self.max_shift_px)
+ random_shift_y = random.randint(-self.max_shift_px,
+ self.max_shift_px)
+ new_x = max(0, random_shift_x)
+ ori_x = max(0, -random_shift_x)
+ new_y = max(0, random_shift_y)
+ ori_y = max(0, -random_shift_y)
+
+ # TODO: support mask and semantic segmentation maps.
+ for key in results.get('bbox_fields', []):
+ bboxes = results[key].copy()
+ bboxes[..., 0::2] += random_shift_x
+ bboxes[..., 1::2] += random_shift_y
+
+ # clip border
+ bboxes[..., 0::2] = np.clip(bboxes[..., 0::2], 0, img_shape[1])
+ bboxes[..., 1::2] = np.clip(bboxes[..., 1::2], 0, img_shape[0])
+
+ # remove invalid bboxes
+ bbox_w = bboxes[..., 2] - bboxes[..., 0]
+ bbox_h = bboxes[..., 3] - bboxes[..., 1]
+ valid_inds = (bbox_w > self.filter_thr_px) & (
+ bbox_h > self.filter_thr_px)
+ # If the shift does not contain any gt-bbox area, skip this
+ # image.
+ if key == 'gt_bboxes' and not valid_inds.any():
+ return results
+ bboxes = bboxes[valid_inds]
+ results[key] = bboxes
+
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = self.bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ new_img = np.zeros_like(img)
+ img_h, img_w = img.shape[:2]
+ new_h = img_h - np.abs(random_shift_y)
+ new_w = img_w - np.abs(random_shift_x)
+ new_img[new_y:new_y + new_h, new_x:new_x + new_w] \
+ = img[ori_y:ori_y + new_h, ori_x:ori_x + new_w]
+ results[key] = new_img
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(max_shift_px={self.max_shift_px}, '
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Pad:
+ """Pad the image & masks & segmentation map.
+
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
+ minimum size that is divisible by some number.
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
+
+ Args:
+ size (tuple, optional): Fixed padding size.
+ size_divisor (int, optional): The divisor of padded size.
+ pad_to_square (bool): Whether to pad the image into a square.
+ Currently only used for YOLOX. Default: False.
+ pad_val (dict, optional): A dict for padding value, the default
+ value is `dict(img=0, masks=0, seg=255)`.
+ """
+
+ def __init__(self,
+ size=None,
+ size_divisor=None,
+ pad_to_square=False,
+ pad_val=dict(img=0, masks=0, seg=255)):
+ self.size = size
+ self.size_divisor = size_divisor
+ if isinstance(pad_val, float) or isinstance(pad_val, int):
+ warnings.warn(
+ 'pad_val of float type is deprecated now, '
+ f'please use pad_val=dict(img={pad_val}, '
+ f'masks={pad_val}, seg=255) instead.', DeprecationWarning)
+ pad_val = dict(img=pad_val, masks=pad_val, seg=255)
+ assert isinstance(pad_val, dict)
+ self.pad_val = pad_val
+ self.pad_to_square = pad_to_square
+
+ if pad_to_square:
+ assert size is None and size_divisor is None, \
+ 'The size and size_divisor must be None ' \
+ 'when pad2square is True'
+ else:
+ assert size is not None or size_divisor is not None, \
+ 'only one of size and size_divisor should be valid'
+ assert size is None or size_divisor is None
+
+ def _pad_img(self, results):
+ """Pad images according to ``self.size``."""
+ pad_val = self.pad_val.get('img', 0)
+ for key in results.get('img_fields', ['img']):
+ if self.pad_to_square:
+ max_size = max(results[key].shape[:2])
+ self.size = (max_size, max_size)
+ if self.size is not None:
+ padded_img = mmcv.impad(
+ results[key], shape=self.size, pad_val=pad_val)
+ elif self.size_divisor is not None:
+ padded_img = mmcv.impad_to_multiple(
+ results[key], self.size_divisor, pad_val=pad_val)
+ results[key] = padded_img
+ results['pad_shape'] = padded_img.shape
+ results['pad_fixed_size'] = self.size
+ results['pad_size_divisor'] = self.size_divisor
+
+ def _pad_masks(self, results):
+ """Pad masks according to ``results['pad_shape']``."""
+ pad_shape = results['pad_shape'][:2]
+ pad_val = self.pad_val.get('masks', 0)
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].pad(pad_shape, pad_val=pad_val)
+
+ def _pad_seg(self, results):
+ """Pad semantic segmentation map according to
+ ``results['pad_shape']``."""
+ pad_val = self.pad_val.get('seg', 255)
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.impad(
+ results[key], shape=results['pad_shape'][:2], pad_val=pad_val)
+
+ def __call__(self, results):
+ """Call function to pad images, masks, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Updated result dict.
+ """
+ self._pad_img(results)
+ self._pad_masks(results)
+ self._pad_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(size={self.size}, '
+ repr_str += f'size_divisor={self.size_divisor}, '
+ repr_str += f'pad_to_square={self.pad_to_square}, '
+ repr_str += f'pad_val={self.pad_val})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Normalize:
+ """Normalize the image.
+
+ Added key is "img_norm_cfg".
+
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+
+ def __call__(self, results):
+ """Call function to normalize images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Normalized results, 'img_norm_cfg' key is added into
+ result dict.
+ """
+ for key in results.get('img_fields', ['img']):
+ results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
+ self.to_rgb)
+ results['img_norm_cfg'] = dict(
+ mean=self.mean, std=self.std, to_rgb=self.to_rgb)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCrop:
+ """Random crop the image & bboxes & masks.
+
+ The absolute `crop_size` is sampled based on `crop_type` and `image_size`,
+ then the cropped results are generated.
+
+ Args:
+ crop_size (tuple): The relative ratio or absolute pixels of
+ height and width.
+ crop_type (str, optional): one of "relative_range", "relative",
+ "absolute", "absolute_range". "relative" randomly crops
+ (h * crop_size[0], w * crop_size[1]) part from an input of size
+ (h, w). "relative_range" uniformly samples relative crop size from
+ range [crop_size[0], 1] and [crop_size[1], 1] for height and width
+ respectively. "absolute" crops from an input with absolute size
+ (crop_size[0], crop_size[1]). "absolute_range" uniformly samples
+ crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w
+ in range [crop_size[0], min(w, crop_size[1])]. Default "absolute".
+ allow_negative_crop (bool, optional): Whether to allow a crop that does
+ not contain any bbox area. Default False.
+ recompute_bbox (bool, optional): Whether to re-compute the boxes based
+ on cropped instance masks. Default False.
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+
+ Note:
+ - If the image is smaller than the absolute crop size, return the
+ original image.
+ - The keys for bboxes, labels and masks must be aligned. That is,
+ `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and
+ `gt_bboxes_ignore` corresponds to `gt_labels_ignore` and
+ `gt_masks_ignore`.
+ - If the crop does not contain any gt-bbox region and
+ `allow_negative_crop` is set to False, skip this image.
+ """
+
+ def __init__(self,
+ crop_size,
+ crop_type='absolute',
+ allow_negative_crop=False,
+ recompute_bbox=False,
+ bbox_clip_border=True):
+ if crop_type not in [
+ 'relative_range', 'relative', 'absolute', 'absolute_range'
+ ]:
+ raise ValueError(f'Invalid crop_type {crop_type}.')
+ if crop_type in ['absolute', 'absolute_range']:
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ assert isinstance(crop_size[0], int) and isinstance(
+ crop_size[1], int)
+ else:
+ assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
+ self.crop_size = crop_size
+ self.crop_type = crop_type
+ self.allow_negative_crop = allow_negative_crop
+ self.bbox_clip_border = bbox_clip_border
+ self.recompute_bbox = recompute_bbox
+ # The key correspondence from bboxes to labels and masks.
+ self.bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+ self.bbox2mask = {
+ 'gt_bboxes': 'gt_masks',
+ 'gt_bboxes_ignore': 'gt_masks_ignore'
+ }
+
+ def _crop_data(self, results, crop_size, allow_negative_crop):
+ """Function to randomly crop images, bounding boxes, masks, semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ crop_size (tuple): Expected absolute size after cropping, (h, w).
+ allow_negative_crop (bool): Whether to allow a crop that does not
+ contain any bbox area. Default to False.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ margin_h = max(img.shape[0] - crop_size[0], 0)
+ margin_w = max(img.shape[1] - crop_size[1], 0)
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
+
+ # crop the image
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+ img_shape = img.shape
+ results[key] = img
+ results['img_shape'] = img_shape
+
+ # crop bboxes accordingly and clip to the image boundary
+ for key in results.get('bbox_fields', []):
+ # e.g. gt_bboxes and gt_bboxes_ignore
+ bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],
+ dtype=np.float32)
+ bboxes = results[key] - bbox_offset
+ if self.bbox_clip_border:
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
+ valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (
+ bboxes[:, 3] > bboxes[:, 1])
+ # If the crop does not contain any gt-bbox area and
+ # allow_negative_crop is False, skip this image.
+ if (key == 'gt_bboxes' and not valid_inds.any()
+ and not allow_negative_crop):
+ return None
+ results[key] = bboxes[valid_inds, :]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = self.bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = self.bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][
+ valid_inds.nonzero()[0]].crop(
+ np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
+ if self.recompute_bbox:
+ results[key] = results[mask_key].get_bboxes()
+
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]
+
+ return results
+
+ def _get_crop_size(self, image_size):
+ """Randomly generates the absolute crop size based on `crop_type` and
+ `image_size`.
+
+ Args:
+ image_size (tuple): (h, w).
+
+ Returns:
+ crop_size (tuple): (crop_h, crop_w) in absolute pixels.
+ """
+ h, w = image_size
+ if self.crop_type == 'absolute':
+ return (min(self.crop_size[0], h), min(self.crop_size[1], w))
+ elif self.crop_type == 'absolute_range':
+ assert self.crop_size[0] <= self.crop_size[1]
+ crop_h = np.random.randint(
+ min(h, self.crop_size[0]),
+ min(h, self.crop_size[1]) + 1)
+ crop_w = np.random.randint(
+ min(w, self.crop_size[0]),
+ min(w, self.crop_size[1]) + 1)
+ return crop_h, crop_w
+ elif self.crop_type == 'relative':
+ crop_h, crop_w = self.crop_size
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+ elif self.crop_type == 'relative_range':
+ crop_size = np.asarray(self.crop_size, dtype=np.float32)
+ crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+
+ def __call__(self, results):
+ """Call function to randomly crop images, bounding boxes, masks,
+ semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+ image_size = results['img'].shape[:2]
+ crop_size = self._get_crop_size(image_size)
+ results = self._crop_data(results, crop_size, self.allow_negative_crop)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(crop_size={self.crop_size}, '
+ repr_str += f'crop_type={self.crop_type}, '
+ repr_str += f'allow_negative_crop={self.allow_negative_crop}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class SegRescale:
+ """Rescale semantic segmentation maps.
+
+ Args:
+ scale_factor (float): The scale factor of the final output.
+ backend (str): Image rescale backend, choices are 'cv2' and 'pillow'.
+ These two backends generates slightly different results. Defaults
+ to 'cv2'.
+ """
+
+ def __init__(self, scale_factor=1, backend='cv2'):
+ self.scale_factor = scale_factor
+ self.backend = backend
+
+ def __call__(self, results):
+ """Call function to scale the semantic segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with semantic segmentation map scaled.
+ """
+
+ for key in results.get('seg_fields', []):
+ if self.scale_factor != 1:
+ results[key] = mmcv.imrescale(
+ results[key],
+ self.scale_factor,
+ interpolation='nearest',
+ backend=self.backend)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
+
+
+@PIPELINES.register_module()
+class PhotoMetricDistortion:
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+ 8. randomly swap channels
+
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+
+ def __init__(self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+
+ def __call__(self, results):
+ """Call function to perform photometric distortion on images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ img = results['img']
+ img = img.astype(np.float32)
+ # random brightness
+ if random.randint(2):
+ delta = random.uniform(-self.brightness_delta,
+ self.brightness_delta)
+ img += delta
+
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = random.randint(2)
+ if mode == 1:
+ if random.randint(2):
+ alpha = random.uniform(self.contrast_lower,
+ self.contrast_upper)
+ img *= alpha
+
+ # convert color from BGR to HSV
+ img = mmcv.bgr2hsv(img)
+
+ # random saturation
+ if random.randint(2):
+ img[..., 1] *= random.uniform(self.saturation_lower,
+ self.saturation_upper)
+
+ # random hue
+ if random.randint(2):
+ img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
+ img[..., 0][img[..., 0] > 360] -= 360
+ img[..., 0][img[..., 0] < 0] += 360
+
+ # convert color from HSV to BGR
+ img = mmcv.hsv2bgr(img)
+
+ # random contrast
+ if mode == 0:
+ if random.randint(2):
+ alpha = random.uniform(self.contrast_lower,
+ self.contrast_upper)
+ img *= alpha
+
+ # randomly swap channels
+ if random.randint(2):
+ img = img[..., random.permutation(3)]
+
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(\nbrightness_delta={self.brightness_delta},\n'
+ repr_str += 'contrast_range='
+ repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n'
+ repr_str += 'saturation_range='
+ repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n'
+ repr_str += f'hue_delta={self.hue_delta})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Expand:
+ """Random expand the image & bboxes.
+
+ Randomly place the original image on a canvas of 'ratio' x original image
+ size filled with mean values. The ratio is in the range of ratio_range.
+
+ Args:
+ mean (tuple): mean value of dataset.
+ to_rgb (bool): if need to convert the order of mean to align with RGB.
+ ratio_range (tuple): range of expand ratio.
+ prob (float): probability of applying this transformation
+ """
+
+ def __init__(self,
+ mean=(0, 0, 0),
+ to_rgb=True,
+ ratio_range=(1, 4),
+ seg_ignore_label=None,
+ prob=0.5):
+ self.to_rgb = to_rgb
+ self.ratio_range = ratio_range
+ if to_rgb:
+ self.mean = mean[::-1]
+ else:
+ self.mean = mean
+ self.min_ratio, self.max_ratio = ratio_range
+ self.seg_ignore_label = seg_ignore_label
+ self.prob = prob
+
+ def __call__(self, results):
+ """Call function to expand images, bounding boxes.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images, bounding boxes expanded
+ """
+
+ if random.uniform(0, 1) > self.prob:
+ return results
+
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ img = results['img']
+
+ h, w, c = img.shape
+ ratio = random.uniform(self.min_ratio, self.max_ratio)
+ # speedup expand when meets large image
+ if np.all(self.mean == self.mean[0]):
+ expand_img = np.empty((int(h * ratio), int(w * ratio), c),
+ img.dtype)
+ expand_img.fill(self.mean[0])
+ else:
+ expand_img = np.full((int(h * ratio), int(w * ratio), c),
+ self.mean,
+ dtype=img.dtype)
+ left = int(random.uniform(0, w * ratio - w))
+ top = int(random.uniform(0, h * ratio - h))
+ expand_img[top:top + h, left:left + w] = img
+
+ results['img'] = expand_img
+ # expand bboxes
+ for key in results.get('bbox_fields', []):
+ results[key] = results[key] + np.tile(
+ (left, top), 2).astype(results[key].dtype)
+
+ # expand masks
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].expand(
+ int(h * ratio), int(w * ratio), top, left)
+
+ # expand segs
+ for key in results.get('seg_fields', []):
+ gt_seg = results[key]
+ expand_gt_seg = np.full((int(h * ratio), int(w * ratio)),
+ self.seg_ignore_label,
+ dtype=gt_seg.dtype)
+ expand_gt_seg[top:top + h, left:left + w] = gt_seg
+ results[key] = expand_gt_seg
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, to_rgb={self.to_rgb}, '
+ repr_str += f'ratio_range={self.ratio_range}, '
+ repr_str += f'seg_ignore_label={self.seg_ignore_label})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class MinIoURandomCrop:
+ """Random crop the image & bboxes, the cropped patches have minimum IoU
+ requirement with original image & bboxes, the IoU threshold is randomly
+ selected from min_ious.
+
+ Args:
+ min_ious (tuple): minimum IoU threshold for all intersections with
+ bounding boxes
+ min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
+ where a >= min_crop_size).
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+
+ Note:
+ The keys for bboxes, labels and masks should be paired. That is, \
+ `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \
+ `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`.
+ """
+
+ def __init__(self,
+ min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
+ min_crop_size=0.3,
+ bbox_clip_border=True):
+ # 1: return ori img
+ self.min_ious = min_ious
+ self.sample_mode = (1, *min_ious, 0)
+ self.min_crop_size = min_crop_size
+ self.bbox_clip_border = bbox_clip_border
+ self.bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+ self.bbox2mask = {
+ 'gt_bboxes': 'gt_masks',
+ 'gt_bboxes_ignore': 'gt_masks_ignore'
+ }
+
+ def __call__(self, results):
+ """Call function to crop images and bounding boxes with minimum IoU
+ constraint.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images and bounding boxes cropped, \
+ 'img_shape' key is updated.
+ """
+
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ img = results['img']
+ assert 'bbox_fields' in results
+ boxes = [results[key] for key in results['bbox_fields']]
+ boxes = np.concatenate(boxes, 0)
+ h, w, c = img.shape
+ while True:
+ mode = random.choice(self.sample_mode)
+ self.mode = mode
+ if mode == 1:
+ return results
+
+ min_iou = mode
+ for i in range(50):
+ new_w = random.uniform(self.min_crop_size * w, w)
+ new_h = random.uniform(self.min_crop_size * h, h)
+
+ # h / w in [0.5, 2]
+ if new_h / new_w < 0.5 or new_h / new_w > 2:
+ continue
+
+ left = random.uniform(w - new_w)
+ top = random.uniform(h - new_h)
+
+ patch = np.array(
+ (int(left), int(top), int(left + new_w), int(top + new_h)))
+ # Line or point crop is not allowed
+ if patch[2] == patch[0] or patch[3] == patch[1]:
+ continue
+ overlaps = bbox_overlaps(
+ patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
+ if len(overlaps) > 0 and overlaps.min() < min_iou:
+ continue
+
+ # center of boxes should inside the crop img
+ # only adjust boxes and instance masks when the gt is not empty
+ if len(overlaps) > 0:
+ # adjust boxes
+ def is_center_of_bboxes_in_patch(boxes, patch):
+ center = (boxes[:, :2] + boxes[:, 2:]) / 2
+ mask = ((center[:, 0] > patch[0]) *
+ (center[:, 1] > patch[1]) *
+ (center[:, 0] < patch[2]) *
+ (center[:, 1] < patch[3]))
+ return mask
+
+ mask = is_center_of_bboxes_in_patch(boxes, patch)
+ if not mask.any():
+ continue
+ for key in results.get('bbox_fields', []):
+ boxes = results[key].copy()
+ mask = is_center_of_bboxes_in_patch(boxes, patch)
+ boxes = boxes[mask]
+ if self.bbox_clip_border:
+ boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
+ boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
+ boxes -= np.tile(patch[:2], 2)
+
+ results[key] = boxes
+ # labels
+ label_key = self.bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][mask]
+
+ # mask fields
+ mask_key = self.bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][
+ mask.nonzero()[0]].crop(patch)
+ # adjust the img no matter whether the gt is empty before crop
+ img = img[patch[1]:patch[3], patch[0]:patch[2]]
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ # seg fields
+ for key in results.get('seg_fields', []):
+ results[key] = results[key][patch[1]:patch[3],
+ patch[0]:patch[2]]
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(min_ious={self.min_ious}, '
+ repr_str += f'min_crop_size={self.min_crop_size}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Corrupt:
+ """Corruption augmentation.
+
+ Corruption transforms implemented based on
+ `imagecorruptions `_.
+
+ Args:
+ corruption (str): Corruption name.
+ severity (int, optional): The severity of corruption. Default: 1.
+ """
+
+ def __init__(self, corruption, severity=1):
+ self.corruption = corruption
+ self.severity = severity
+
+ def __call__(self, results):
+ """Call function to corrupt image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images corrupted.
+ """
+
+ if corrupt is None:
+ raise RuntimeError('imagecorruptions is not installed')
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ results['img'] = corrupt(
+ results['img'].astype(np.uint8),
+ corruption_name=self.corruption,
+ severity=self.severity)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(corruption={self.corruption}, '
+ repr_str += f'severity={self.severity})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Albu:
+ """Albumentation augmentation.
+
+ Adds custom transformations from Albumentations library.
+ Please, visit `https://albumentations.readthedocs.io`
+ to get more information.
+
+ An example of ``transforms`` is as followed:
+
+ .. code-block::
+
+ [
+ dict(
+ type='ShiftScaleRotate',
+ shift_limit=0.0625,
+ scale_limit=0.0,
+ rotate_limit=0,
+ interpolation=1,
+ p=0.5),
+ dict(
+ type='RandomBrightnessContrast',
+ brightness_limit=[0.1, 0.3],
+ contrast_limit=[0.1, 0.3],
+ p=0.2),
+ dict(type='ChannelShuffle', p=0.1),
+ dict(
+ type='OneOf',
+ transforms=[
+ dict(type='Blur', blur_limit=3, p=1.0),
+ dict(type='MedianBlur', blur_limit=3, p=1.0)
+ ],
+ p=0.1),
+ ]
+
+ Args:
+ transforms (list[dict]): A list of albu transformations
+ bbox_params (dict): Bbox_params for albumentation `Compose`
+ keymap (dict): Contains {'input key':'albumentation-style key'}
+ skip_img_without_anno (bool): Whether to skip the image if no ann left
+ after aug
+ """
+
+ def __init__(self,
+ transforms,
+ bbox_params=None,
+ keymap=None,
+ update_pad_shape=False,
+ skip_img_without_anno=False):
+ if Compose is None:
+ raise RuntimeError('albumentations is not installed')
+
+ # Args will be modified later, copying it will be safer
+ transforms = copy.deepcopy(transforms)
+ if bbox_params is not None:
+ bbox_params = copy.deepcopy(bbox_params)
+ if keymap is not None:
+ keymap = copy.deepcopy(keymap)
+ self.transforms = transforms
+ self.filter_lost_elements = False
+ self.update_pad_shape = update_pad_shape
+ self.skip_img_without_anno = skip_img_without_anno
+
+ # A simple workaround to remove masks without boxes
+ if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params
+ and 'filter_lost_elements' in bbox_params):
+ self.filter_lost_elements = True
+ self.origin_label_fields = bbox_params['label_fields']
+ bbox_params['label_fields'] = ['idx_mapper']
+ del bbox_params['filter_lost_elements']
+
+ self.bbox_params = (
+ self.albu_builder(bbox_params) if bbox_params else None)
+ self.aug = Compose([self.albu_builder(t) for t in self.transforms],
+ bbox_params=self.bbox_params)
+
+ if not keymap:
+ self.keymap_to_albu = {
+ 'img': 'image',
+ 'gt_masks': 'masks',
+ 'gt_bboxes': 'bboxes'
+ }
+ else:
+ self.keymap_to_albu = keymap
+ self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
+
+ def albu_builder(self, cfg):
+ """Import a module from albumentations.
+
+ It inherits some of :func:`build_from_cfg` logic.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+
+ Returns:
+ obj: The constructed object.
+ """
+
+ assert isinstance(cfg, dict) and 'type' in cfg
+ args = cfg.copy()
+
+ obj_type = args.pop('type')
+ if mmcv.is_str(obj_type):
+ if albumentations is None:
+ raise RuntimeError('albumentations is not installed')
+ obj_cls = getattr(albumentations, obj_type)
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(obj_type)}')
+
+ if 'transforms' in args:
+ args['transforms'] = [
+ self.albu_builder(transform)
+ for transform in args['transforms']
+ ]
+
+ return obj_cls(**args)
+
+ @staticmethod
+ def mapper(d, keymap):
+ """Dictionary mapper. Renames keys according to keymap provided.
+
+ Args:
+ d (dict): old dict
+ keymap (dict): {'old_key':'new_key'}
+ Returns:
+ dict: new dict.
+ """
+
+ updated_dict = {}
+ for k, v in zip(d.keys(), d.values()):
+ new_k = keymap.get(k, k)
+ updated_dict[new_k] = d[k]
+ return updated_dict
+
+ def __call__(self, results):
+ # dict to albumentations format
+ results = self.mapper(results, self.keymap_to_albu)
+ # TODO: add bbox_fields
+ if 'bboxes' in results:
+ # to list of boxes
+ if isinstance(results['bboxes'], np.ndarray):
+ results['bboxes'] = [x for x in results['bboxes']]
+ # add pseudo-field for filtration
+ if self.filter_lost_elements:
+ results['idx_mapper'] = np.arange(len(results['bboxes']))
+
+ # TODO: Support mask structure in albu
+ if 'masks' in results:
+ if isinstance(results['masks'], PolygonMasks):
+ raise NotImplementedError(
+ 'Albu only supports BitMap masks now')
+ ori_masks = results['masks']
+ if albumentations.__version__ < '0.5':
+ results['masks'] = results['masks'].masks
+ else:
+ results['masks'] = [mask for mask in results['masks'].masks]
+
+ results = self.aug(**results)
+
+ if 'bboxes' in results:
+ if isinstance(results['bboxes'], list):
+ results['bboxes'] = np.array(
+ results['bboxes'], dtype=np.float32)
+ results['bboxes'] = results['bboxes'].reshape(-1, 4)
+
+ # filter label_fields
+ if self.filter_lost_elements:
+
+ for label in self.origin_label_fields:
+ results[label] = np.array(
+ [results[label][i] for i in results['idx_mapper']])
+ if 'masks' in results:
+ results['masks'] = np.array(
+ [results['masks'][i] for i in results['idx_mapper']])
+ results['masks'] = ori_masks.__class__(
+ results['masks'], results['image'].shape[0],
+ results['image'].shape[1])
+
+ if (not len(results['idx_mapper'])
+ and self.skip_img_without_anno):
+ return None
+
+ if 'gt_labels' in results:
+ if isinstance(results['gt_labels'], list):
+ results['gt_labels'] = np.array(results['gt_labels'])
+ results['gt_labels'] = results['gt_labels'].astype(np.int64)
+
+ # back to the original format
+ results = self.mapper(results, self.keymap_back)
+
+ # update final shape
+ if self.update_pad_shape:
+ results['pad_shape'] = results['img'].shape
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCenterCropPad:
+ """Random center crop and random around padding for CornerNet.
+
+ This operation generates randomly cropped image from the original image and
+ pads it simultaneously. Different from :class:`RandomCrop`, the output
+ shape may not equal to ``crop_size`` strictly. We choose a random value
+ from ``ratios`` and the output shape could be larger or smaller than
+ ``crop_size``. The padding operation is also different from :class:`Pad`,
+ here we use around padding instead of right-bottom padding.
+
+ The relation between output image (padding image) and original image:
+
+ .. code:: text
+
+ output image
+
+ +----------------------------+
+ | padded area |
+ +------|----------------------------|----------+
+ | | cropped area | |
+ | | +---------------+ | |
+ | | | . center | | | original image
+ | | | range | | |
+ | | +---------------+ | |
+ +------|----------------------------|----------+
+ | padded area |
+ +----------------------------+
+
+ There are 5 main areas in the figure:
+
+ - output image: output image of this operation, also called padding
+ image in following instruction.
+ - original image: input image of this operation.
+ - padded area: non-intersect area of output image and original image.
+ - cropped area: the overlap of output image and original image.
+ - center range: a smaller area where random center chosen from.
+ center range is computed by ``border`` and original image's shape
+ to avoid our random center is too close to original image's border.
+
+ Also this operation act differently in train and test mode, the summary
+ pipeline is listed below.
+
+ Train pipeline:
+
+ 1. Choose a ``random_ratio`` from ``ratios``, the shape of padding image
+ will be ``random_ratio * crop_size``.
+ 2. Choose a ``random_center`` in center range.
+ 3. Generate padding image with center matches the ``random_center``.
+ 4. Initialize the padding image with pixel value equals to ``mean``.
+ 5. Copy the cropped area to padding image.
+ 6. Refine annotations.
+
+ Test pipeline:
+
+ 1. Compute output shape according to ``test_pad_mode``.
+ 2. Generate padding image with center matches the original image
+ center.
+ 3. Initialize the padding image with pixel value equals to ``mean``.
+ 4. Copy the ``cropped area`` to padding image.
+
+ Args:
+ crop_size (tuple | None): expected size after crop, final size will
+ computed according to ratio. Requires (h, w) in train mode, and
+ None in test mode.
+ ratios (tuple): random select a ratio from tuple and crop image to
+ (crop_size[0] * ratio) * (crop_size[1] * ratio).
+ Only available in train mode.
+ border (int): max distance from center select area to image border.
+ Only available in train mode.
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB.
+ test_mode (bool): whether involve random variables in transform.
+ In train mode, crop_size is fixed, center coords and ratio is
+ random selected from predefined lists. In test mode, crop_size
+ is image's original shape, center coords and ratio is fixed.
+ test_pad_mode (tuple): padding method and padding shape value, only
+ available in test mode. Default is using 'logical_or' with
+ 127 as padding shape value.
+
+ - 'logical_or': final_shape = input_shape | padding_shape_value
+ - 'size_divisor': final_shape = int(
+ ceil(input_shape / padding_shape_value) * padding_shape_value)
+ test_pad_add_pix (int): Extra padding pixel in test mode. Default 0.
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+ """
+
+ def __init__(self,
+ crop_size=None,
+ ratios=(0.9, 1.0, 1.1),
+ border=128,
+ mean=None,
+ std=None,
+ to_rgb=None,
+ test_mode=False,
+ test_pad_mode=('logical_or', 127),
+ test_pad_add_pix=0,
+ bbox_clip_border=True):
+ if test_mode:
+ assert crop_size is None, 'crop_size must be None in test mode'
+ assert ratios is None, 'ratios must be None in test mode'
+ assert border is None, 'border must be None in test mode'
+ assert isinstance(test_pad_mode, (list, tuple))
+ assert test_pad_mode[0] in ['logical_or', 'size_divisor']
+ else:
+ assert isinstance(crop_size, (list, tuple))
+ assert crop_size[0] > 0 and crop_size[1] > 0, (
+ 'crop_size must > 0 in train mode')
+ assert isinstance(ratios, (list, tuple))
+ assert test_pad_mode is None, (
+ 'test_pad_mode must be None in train mode')
+
+ self.crop_size = crop_size
+ self.ratios = ratios
+ self.border = border
+ # We do not set default value to mean, std and to_rgb because these
+ # hyper-parameters are easy to forget but could affect the performance.
+ # Please use the same setting as Normalize for performance assurance.
+ assert mean is not None and std is not None and to_rgb is not None
+ self.to_rgb = to_rgb
+ self.input_mean = mean
+ self.input_std = std
+ if to_rgb:
+ self.mean = mean[::-1]
+ self.std = std[::-1]
+ else:
+ self.mean = mean
+ self.std = std
+ self.test_mode = test_mode
+ self.test_pad_mode = test_pad_mode
+ self.test_pad_add_pix = test_pad_add_pix
+ self.bbox_clip_border = bbox_clip_border
+
+ def _get_border(self, border, size):
+ """Get final border for the target size.
+
+ This function generates a ``final_border`` according to image's shape.
+ The area between ``final_border`` and ``size - final_border`` is the
+ ``center range``. We randomly choose center from the ``center range``
+ to avoid our random center is too close to original image's border.
+ Also ``center range`` should be larger than 0.
+
+ Args:
+ border (int): The initial border, default is 128.
+ size (int): The width or height of original image.
+ Returns:
+ int: The final border.
+ """
+ k = 2 * border / size
+ i = pow(2, np.ceil(np.log2(np.ceil(k))) + (k == int(k)))
+ return border // i
+
+ def _filter_boxes(self, patch, boxes):
+ """Check whether the center of each box is in the patch.
+
+ Args:
+ patch (list[int]): The cropped area, [left, top, right, bottom].
+ boxes (numpy array, (N x 4)): Ground truth boxes.
+
+ Returns:
+ mask (numpy array, (N,)): Each box is inside or outside the patch.
+ """
+ center = (boxes[:, :2] + boxes[:, 2:]) / 2
+ mask = (center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (
+ center[:, 0] < patch[2]) * (
+ center[:, 1] < patch[3])
+ return mask
+
+ def _crop_image_and_paste(self, image, center, size):
+ """Crop image with a given center and size, then paste the cropped
+ image to a blank image with two centers align.
+
+ This function is equivalent to generating a blank image with ``size``
+ as its shape. Then cover it on the original image with two centers (
+ the center of blank image and the random center of original image)
+ aligned. The overlap area is paste from the original image and the
+ outside area is filled with ``mean pixel``.
+
+ Args:
+ image (np array, H x W x C): Original image.
+ center (list[int]): Target crop center coord.
+ size (list[int]): Target crop size. [target_h, target_w]
+
+ Returns:
+ cropped_img (np array, target_h x target_w x C): Cropped image.
+ border (np array, 4): The distance of four border of
+ ``cropped_img`` to the original image area, [top, bottom,
+ left, right]
+ patch (list[int]): The cropped area, [left, top, right, bottom].
+ """
+ center_y, center_x = center
+ target_h, target_w = size
+ img_h, img_w, img_c = image.shape
+
+ x0 = max(0, center_x - target_w // 2)
+ x1 = min(center_x + target_w // 2, img_w)
+ y0 = max(0, center_y - target_h // 2)
+ y1 = min(center_y + target_h // 2, img_h)
+ patch = np.array((int(x0), int(y0), int(x1), int(y1)))
+
+ left, right = center_x - x0, x1 - center_x
+ top, bottom = center_y - y0, y1 - center_y
+
+ cropped_center_y, cropped_center_x = target_h // 2, target_w // 2
+ cropped_img = np.zeros((target_h, target_w, img_c), dtype=image.dtype)
+ for i in range(img_c):
+ cropped_img[:, :, i] += self.mean[i]
+ y_slice = slice(cropped_center_y - top, cropped_center_y + bottom)
+ x_slice = slice(cropped_center_x - left, cropped_center_x + right)
+ cropped_img[y_slice, x_slice, :] = image[y0:y1, x0:x1, :]
+
+ border = np.array([
+ cropped_center_y - top, cropped_center_y + bottom,
+ cropped_center_x - left, cropped_center_x + right
+ ],
+ dtype=np.float32)
+
+ return cropped_img, border, patch
+
+ def _train_aug(self, results):
+ """Random crop and around padding the original image.
+
+ Args:
+ results (dict): Image infomations in the augment pipeline.
+
+ Returns:
+ results (dict): The updated dict.
+ """
+ img = results['img']
+ h, w, c = img.shape
+ boxes = results['gt_bboxes']
+ while True:
+ scale = random.choice(self.ratios)
+ new_h = int(self.crop_size[0] * scale)
+ new_w = int(self.crop_size[1] * scale)
+ h_border = self._get_border(self.border, h)
+ w_border = self._get_border(self.border, w)
+
+ for i in range(50):
+ center_x = random.randint(low=w_border, high=w - w_border)
+ center_y = random.randint(low=h_border, high=h - h_border)
+
+ cropped_img, border, patch = self._crop_image_and_paste(
+ img, [center_y, center_x], [new_h, new_w])
+
+ mask = self._filter_boxes(patch, boxes)
+ # if image do not have valid bbox, any crop patch is valid.
+ if not mask.any() and len(boxes) > 0:
+ continue
+
+ results['img'] = cropped_img
+ results['img_shape'] = cropped_img.shape
+ results['pad_shape'] = cropped_img.shape
+
+ x0, y0, x1, y1 = patch
+
+ left_w, top_h = center_x - x0, center_y - y0
+ cropped_center_x, cropped_center_y = new_w // 2, new_h // 2
+
+ # crop bboxes accordingly and clip to the image boundary
+ for key in results.get('bbox_fields', []):
+ mask = self._filter_boxes(patch, results[key])
+ bboxes = results[key][mask]
+ bboxes[:, 0:4:2] += cropped_center_x - left_w - x0
+ bboxes[:, 1:4:2] += cropped_center_y - top_h - y0
+ if self.bbox_clip_border:
+ bboxes[:, 0:4:2] = np.clip(bboxes[:, 0:4:2], 0, new_w)
+ bboxes[:, 1:4:2] = np.clip(bboxes[:, 1:4:2], 0, new_h)
+ keep = (bboxes[:, 2] > bboxes[:, 0]) & (
+ bboxes[:, 3] > bboxes[:, 1])
+ bboxes = bboxes[keep]
+ results[key] = bboxes
+ if key in ['gt_bboxes']:
+ if 'gt_labels' in results:
+ labels = results['gt_labels'][mask]
+ labels = labels[keep]
+ results['gt_labels'] = labels
+ if 'gt_masks' in results:
+ raise NotImplementedError(
+ 'RandomCenterCropPad only supports bbox.')
+
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ raise NotImplementedError(
+ 'RandomCenterCropPad only supports bbox.')
+ return results
+
+ def _test_aug(self, results):
+ """Around padding the original image without cropping.
+
+ The padding mode and value are from ``test_pad_mode``.
+
+ Args:
+ results (dict): Image infomations in the augment pipeline.
+
+ Returns:
+ results (dict): The updated dict.
+ """
+ img = results['img']
+ h, w, c = img.shape
+ results['img_shape'] = img.shape
+ if self.test_pad_mode[0] in ['logical_or']:
+ # self.test_pad_add_pix is only used for centernet
+ target_h = (h | self.test_pad_mode[1]) + self.test_pad_add_pix
+ target_w = (w | self.test_pad_mode[1]) + self.test_pad_add_pix
+ elif self.test_pad_mode[0] in ['size_divisor']:
+ divisor = self.test_pad_mode[1]
+ target_h = int(np.ceil(h / divisor)) * divisor
+ target_w = int(np.ceil(w / divisor)) * divisor
+ else:
+ raise NotImplementedError(
+ 'RandomCenterCropPad only support two testing pad mode:'
+ 'logical-or and size_divisor.')
+
+ cropped_img, border, _ = self._crop_image_and_paste(
+ img, [h // 2, w // 2], [target_h, target_w])
+ results['img'] = cropped_img
+ results['pad_shape'] = cropped_img.shape
+ results['border'] = border
+ return results
+
+ def __call__(self, results):
+ img = results['img']
+ assert img.dtype == np.float32, (
+ 'RandomCenterCropPad needs the input image of dtype np.float32,'
+ ' please set "to_float32=True" in "LoadImageFromFile" pipeline')
+ h, w, c = img.shape
+ assert c == len(self.mean)
+ if self.test_mode:
+ return self._test_aug(results)
+ else:
+ return self._train_aug(results)
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(crop_size={self.crop_size}, '
+ repr_str += f'ratios={self.ratios}, '
+ repr_str += f'border={self.border}, '
+ repr_str += f'mean={self.input_mean}, '
+ repr_str += f'std={self.input_std}, '
+ repr_str += f'to_rgb={self.to_rgb}, '
+ repr_str += f'test_mode={self.test_mode}, '
+ repr_str += f'test_pad_mode={self.test_pad_mode}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class CutOut:
+ """CutOut operation.
+
+ Randomly drop some regions of image used in
+ `Cutout `_.
+
+ Args:
+ n_holes (int | tuple[int, int]): Number of regions to be dropped.
+ If it is given as a list, number of holes will be randomly
+ selected from the closed interval [`n_holes[0]`, `n_holes[1]`].
+ cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate
+ shape of dropped regions. It can be `tuple[int, int]` to use a
+ fixed cutout shape, or `list[tuple[int, int]]` to randomly choose
+ shape from the list.
+ cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The
+ candidate ratio of dropped regions. It can be `tuple[float, float]`
+ to use a fixed ratio or `list[tuple[float, float]]` to randomly
+ choose ratio from the list. Please note that `cutout_shape`
+ and `cutout_ratio` cannot be both given at the same time.
+ fill_in (tuple[float, float, float] | tuple[int, int, int]): The value
+ of pixel to fill in the dropped regions. Default: (0, 0, 0).
+ """
+
+ def __init__(self,
+ n_holes,
+ cutout_shape=None,
+ cutout_ratio=None,
+ fill_in=(0, 0, 0)):
+
+ assert (cutout_shape is None) ^ (cutout_ratio is None), \
+ 'Either cutout_shape or cutout_ratio should be specified.'
+ assert (isinstance(cutout_shape, (list, tuple))
+ or isinstance(cutout_ratio, (list, tuple)))
+ if isinstance(n_holes, tuple):
+ assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
+ else:
+ n_holes = (n_holes, n_holes)
+ self.n_holes = n_holes
+ self.fill_in = fill_in
+ self.with_ratio = cutout_ratio is not None
+ self.candidates = cutout_ratio if self.with_ratio else cutout_shape
+ if not isinstance(self.candidates, list):
+ self.candidates = [self.candidates]
+
+ def __call__(self, results):
+ """Call function to drop some regions of image."""
+ h, w, c = results['img'].shape
+ n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
+ for _ in range(n_holes):
+ x1 = np.random.randint(0, w)
+ y1 = np.random.randint(0, h)
+ index = np.random.randint(0, len(self.candidates))
+ if not self.with_ratio:
+ cutout_w, cutout_h = self.candidates[index]
+ else:
+ cutout_w = int(self.candidates[index][0] * w)
+ cutout_h = int(self.candidates[index][1] * h)
+
+ x2 = np.clip(x1 + cutout_w, 0, w)
+ y2 = np.clip(y1 + cutout_h, 0, h)
+ results['img'][y1:y2, x1:x2, :] = self.fill_in
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(n_holes={self.n_holes}, '
+ repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
+ else f'cutout_shape={self.candidates}, ')
+ repr_str += f'fill_in={self.fill_in})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Mosaic:
+ """Mosaic augmentation.
+
+ Given 4 images, mosaic transform combines them into
+ one output image. The output image is composed of the parts from each sub-
+ image.
+
+ .. code:: text
+
+ mosaic transform
+ center_x
+ +------------------------------+
+ | pad | pad |
+ | +-----------+ |
+ | | | |
+ | | image1 |--------+ |
+ | | | | |
+ | | | image2 | |
+ center_y |----+-------------+-----------|
+ | | cropped | |
+ |pad | image3 | image4 |
+ | | | |
+ +----|-------------+-----------+
+ | |
+ +-------------+
+
+ The mosaic transform steps are as follows:
+
+ 1. Choose the mosaic center as the intersections of 4 images
+ 2. Get the left top image according to the index, and randomly
+ sample another 3 images from the custom dataset.
+ 3. Sub image will be cropped if image is larger than mosaic patch
+
+ Args:
+ img_scale (Sequence[int]): Image size after mosaic pipeline of single
+ image. The shape order should be (height, width).
+ Default to (640, 640).
+ center_ratio_range (Sequence[float]): Center ratio range of mosaic
+ output. Default to (0.5, 1.5).
+ min_bbox_size (int | float): The minimum pixel for filtering
+ invalid bboxes after the mosaic pipeline. Default to 0.
+ bbox_clip_border (bool, optional): Whether to clip the objects outside
+ the border of the image. In some dataset like MOT17, the gt bboxes
+ are allowed to cross the border of images. Therefore, we don't
+ need to clip the gt bboxes in these cases. Defaults to True.
+ skip_filter (bool): Whether to skip filtering rules. If it
+ is True, the filter rule will not be applied, and the
+ `min_bbox_size` is invalid. Default to True.
+ pad_val (int): Pad value. Default to 114.
+ prob (float): Probability of applying this transformation.
+ Default to 1.0.
+ """
+
+ def __init__(self,
+ img_scale=(640, 640),
+ center_ratio_range=(0.5, 1.5),
+ min_bbox_size=0,
+ bbox_clip_border=True,
+ skip_filter=True,
+ pad_val=114,
+ prob=1.0):
+ assert isinstance(img_scale, tuple)
+ assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\
+ f'got {prob}.'
+
+ log_img_scale(img_scale, skip_square=True)
+ self.img_scale = img_scale
+ self.center_ratio_range = center_ratio_range
+ self.min_bbox_size = min_bbox_size
+ self.bbox_clip_border = bbox_clip_border
+ self.skip_filter = skip_filter
+ self.pad_val = pad_val
+ self.prob = prob
+
+ def __call__(self, results):
+ """Call function to make a mosaic of image.
+
+ Args:
+ results (dict): Result dict.
+
+ Returns:
+ dict: Result dict with mosaic transformed.
+ """
+
+ if random.uniform(0, 1) > self.prob:
+ return results
+
+ results = self._mosaic_transform(results)
+ return results
+
+ def get_indexes(self, dataset):
+ """Call function to collect indexes.
+
+ Args:
+ dataset (:obj:`MultiImageMixDataset`): The dataset.
+
+ Returns:
+ list: indexes.
+ """
+
+ indexes = [random.randint(0, len(dataset)) for _ in range(3)]
+ return indexes
+
+ def _mosaic_transform(self, results):
+ """Mosaic transform function.
+
+ Args:
+ results (dict): Result dict.
+
+ Returns:
+ dict: Updated result dict.
+ """
+
+ assert 'mix_results' in results
+ mosaic_labels = []
+ mosaic_bboxes = []
+ if len(results['img'].shape) == 3:
+ mosaic_img = np.full(
+ (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3),
+ self.pad_val,
+ dtype=results['img'].dtype)
+ else:
+ mosaic_img = np.full(
+ (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
+ self.pad_val,
+ dtype=results['img'].dtype)
+
+ # mosaic center x, y
+ center_x = int(
+ random.uniform(*self.center_ratio_range) * self.img_scale[1])
+ center_y = int(
+ random.uniform(*self.center_ratio_range) * self.img_scale[0])
+ center_position = (center_x, center_y)
+
+ loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
+ for i, loc in enumerate(loc_strs):
+ if loc == 'top_left':
+ results_patch = copy.deepcopy(results)
+ else:
+ results_patch = copy.deepcopy(results['mix_results'][i - 1])
+
+ img_i = results_patch['img']
+ h_i, w_i = img_i.shape[:2]
+ # keep_ratio resize
+ scale_ratio_i = min(self.img_scale[0] / h_i,
+ self.img_scale[1] / w_i)
+ img_i = mmcv.imresize(
+ img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
+
+ # compute the combine parameters
+ paste_coord, crop_coord = self._mosaic_combine(
+ loc, center_position, img_i.shape[:2][::-1])
+ x1_p, y1_p, x2_p, y2_p = paste_coord
+ x1_c, y1_c, x2_c, y2_c = crop_coord
+
+ # crop and paste image
+ mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
+
+ # adjust coordinate
+ gt_bboxes_i = results_patch['gt_bboxes']
+ gt_labels_i = results_patch['gt_labels']
+
+ if gt_bboxes_i.shape[0] > 0:
+ padw = x1_p - x1_c
+ padh = y1_p - y1_c
+ gt_bboxes_i[:, 0::2] = \
+ scale_ratio_i * gt_bboxes_i[:, 0::2] + padw
+ gt_bboxes_i[:, 1::2] = \
+ scale_ratio_i * gt_bboxes_i[:, 1::2] + padh
+
+ mosaic_bboxes.append(gt_bboxes_i)
+ mosaic_labels.append(gt_labels_i)
+
+ if len(mosaic_labels) > 0:
+ mosaic_bboxes = np.concatenate(mosaic_bboxes, 0)
+ mosaic_labels = np.concatenate(mosaic_labels, 0)
+
+ if self.bbox_clip_border:
+ mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0,
+ 2 * self.img_scale[1])
+ mosaic_bboxes[:, 1::2] = np.clip(mosaic_bboxes[:, 1::2], 0,
+ 2 * self.img_scale[0])
+
+ if not self.skip_filter:
+ mosaic_bboxes, mosaic_labels = \
+ self._filter_box_candidates(mosaic_bboxes, mosaic_labels)
+
+ # remove outside bboxes
+ inside_inds = find_inside_bboxes(mosaic_bboxes, 2 * self.img_scale[0],
+ 2 * self.img_scale[1])
+ mosaic_bboxes = mosaic_bboxes[inside_inds]
+ mosaic_labels = mosaic_labels[inside_inds]
+
+ results['img'] = mosaic_img
+ results['img_shape'] = mosaic_img.shape
+ results['gt_bboxes'] = mosaic_bboxes
+ results['gt_labels'] = mosaic_labels
+
+ return results
+
+ def _mosaic_combine(self, loc, center_position_xy, img_shape_wh):
+ """Calculate global coordinate of mosaic image and local coordinate of
+ cropped sub-image.
+
+ Args:
+ loc (str): Index for the sub-image, loc in ('top_left',
+ 'top_right', 'bottom_left', 'bottom_right').
+ center_position_xy (Sequence[float]): Mixing center for 4 images,
+ (x, y).
+ img_shape_wh (Sequence[int]): Width and height of sub-image
+
+ Returns:
+ tuple[tuple[float]]: Corresponding coordinate of pasting and
+ cropping
+ - paste_coord (tuple): paste corner coordinate in mosaic image.
+ - crop_coord (tuple): crop corner coordinate in mosaic image.
+ """
+ assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
+ if loc == 'top_left':
+ # index0 to top left part of image
+ x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
+ max(center_position_xy[1] - img_shape_wh[1], 0), \
+ center_position_xy[0], \
+ center_position_xy[1]
+ crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
+ y2 - y1), img_shape_wh[0], img_shape_wh[1]
+
+ elif loc == 'top_right':
+ # index1 to top right part of image
+ x1, y1, x2, y2 = center_position_xy[0], \
+ max(center_position_xy[1] - img_shape_wh[1], 0), \
+ min(center_position_xy[0] + img_shape_wh[0],
+ self.img_scale[1] * 2), \
+ center_position_xy[1]
+ crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
+ img_shape_wh[0], x2 - x1), img_shape_wh[1]
+
+ elif loc == 'bottom_left':
+ # index2 to bottom left part of image
+ x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
+ center_position_xy[1], \
+ center_position_xy[0], \
+ min(self.img_scale[0] * 2, center_position_xy[1] +
+ img_shape_wh[1])
+ crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
+ y2 - y1, img_shape_wh[1])
+
+ else:
+ # index3 to bottom right part of image
+ x1, y1, x2, y2 = center_position_xy[0], \
+ center_position_xy[1], \
+ min(center_position_xy[0] + img_shape_wh[0],
+ self.img_scale[1] * 2), \
+ min(self.img_scale[0] * 2, center_position_xy[1] +
+ img_shape_wh[1])
+ crop_coord = 0, 0, min(img_shape_wh[0],
+ x2 - x1), min(y2 - y1, img_shape_wh[1])
+
+ paste_coord = x1, y1, x2, y2
+ return paste_coord, crop_coord
+
+ def _filter_box_candidates(self, bboxes, labels):
+ """Filter out bboxes too small after Mosaic."""
+ bbox_w = bboxes[:, 2] - bboxes[:, 0]
+ bbox_h = bboxes[:, 3] - bboxes[:, 1]
+ valid_inds = (bbox_w > self.min_bbox_size) & \
+ (bbox_h > self.min_bbox_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ return bboxes[valid_inds], labels[valid_inds]
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'img_scale={self.img_scale}, '
+ repr_str += f'center_ratio_range={self.center_ratio_range}, '
+ repr_str += f'pad_val={self.pad_val}, '
+ repr_str += f'min_bbox_size={self.min_bbox_size}, '
+ repr_str += f'skip_filter={self.skip_filter})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class MixUp:
+ """MixUp data augmentation.
+
+ .. code:: text
+
+ mixup transform
+ +------------------------------+
+ | mixup image | |
+ | +--------|--------+ |
+ | | | | |
+ |---------------+ | |
+ | | | |
+ | | image | |
+ | | | |
+ | | | |
+ | |-----------------+ |
+ | pad |
+ +------------------------------+
+
+ The mixup transform steps are as follows:
+
+ 1. Another random image is picked by dataset and embedded in
+ the top left patch(after padding and resizing)
+ 2. The target of mixup transform is the weighted average of mixup
+ image and origin image.
+
+ Args:
+ img_scale (Sequence[int]): Image output size after mixup pipeline.
+ The shape order should be (height, width). Default: (640, 640).
+ ratio_range (Sequence[float]): Scale ratio of mixup image.
+ Default: (0.5, 1.5).
+ flip_ratio (float): Horizontal flip ratio of mixup image.
+ Default: 0.5.
+ pad_val (int): Pad value. Default: 114.
+ max_iters (int): The maximum number of iterations. If the number of
+ iterations is greater than `max_iters`, but gt_bbox is still
+ empty, then the iteration is terminated. Default: 15.
+ min_bbox_size (float): Width and height threshold to filter bboxes.
+ If the height or width of a box is smaller than this value, it
+ will be removed. Default: 5.
+ min_area_ratio (float): Threshold of area ratio between
+ original bboxes and wrapped bboxes. If smaller than this value,
+ the box will be removed. Default: 0.2.
+ max_aspect_ratio (float): Aspect ratio of width and height
+ threshold to filter bboxes. If max(h/w, w/h) larger than this
+ value, the box will be removed. Default: 20.
+ bbox_clip_border (bool, optional): Whether to clip the objects outside
+ the border of the image. In some dataset like MOT17, the gt bboxes
+ are allowed to cross the border of images. Therefore, we don't
+ need to clip the gt bboxes in these cases. Defaults to True.
+ skip_filter (bool): Whether to skip filtering rules. If it
+ is True, the filter rule will not be applied, and the
+ `min_bbox_size` and `min_area_ratio` and `max_aspect_ratio`
+ is invalid. Default to True.
+ """
+
+ def __init__(self,
+ img_scale=(640, 640),
+ ratio_range=(0.5, 1.5),
+ flip_ratio=0.5,
+ pad_val=114,
+ max_iters=15,
+ min_bbox_size=5,
+ min_area_ratio=0.2,
+ max_aspect_ratio=20,
+ bbox_clip_border=True,
+ skip_filter=True):
+ assert isinstance(img_scale, tuple)
+ log_img_scale(img_scale, skip_square=True)
+ self.dynamic_scale = img_scale
+ self.ratio_range = ratio_range
+ self.flip_ratio = flip_ratio
+ self.pad_val = pad_val
+ self.max_iters = max_iters
+ self.min_bbox_size = min_bbox_size
+ self.min_area_ratio = min_area_ratio
+ self.max_aspect_ratio = max_aspect_ratio
+ self.bbox_clip_border = bbox_clip_border
+ self.skip_filter = skip_filter
+
+ def __call__(self, results):
+ """Call function to make a mixup of image.
+
+ Args:
+ results (dict): Result dict.
+
+ Returns:
+ dict: Result dict with mixup transformed.
+ """
+
+ results = self._mixup_transform(results)
+ return results
+
+ def get_indexes(self, dataset):
+ """Call function to collect indexes.
+
+ Args:
+ dataset (:obj:`MultiImageMixDataset`): The dataset.
+
+ Returns:
+ list: indexes.
+ """
+
+ for i in range(self.max_iters):
+ index = random.randint(0, len(dataset))
+ gt_bboxes_i = dataset.get_ann_info(index)['bboxes']
+ if len(gt_bboxes_i) != 0:
+ break
+
+ return index
+
+ def _mixup_transform(self, results):
+ """MixUp transform function.
+
+ Args:
+ results (dict): Result dict.
+
+ Returns:
+ dict: Updated result dict.
+ """
+
+ assert 'mix_results' in results
+ assert len(
+ results['mix_results']) == 1, 'MixUp only support 2 images now !'
+
+ if results['mix_results'][0]['gt_bboxes'].shape[0] == 0:
+ # empty bbox
+ return results
+
+ retrieve_results = results['mix_results'][0]
+ retrieve_img = retrieve_results['img']
+
+ jit_factor = random.uniform(*self.ratio_range)
+ is_filp = random.uniform(0, 1) < self.flip_ratio
+
+ if len(retrieve_img.shape) == 3:
+ out_img = np.ones(
+ (self.dynamic_scale[0], self.dynamic_scale[1], 3),
+ dtype=retrieve_img.dtype) * self.pad_val
+ else:
+ out_img = np.ones(
+ self.dynamic_scale, dtype=retrieve_img.dtype) * self.pad_val
+
+ # 1. keep_ratio resize
+ scale_ratio = min(self.dynamic_scale[0] / retrieve_img.shape[0],
+ self.dynamic_scale[1] / retrieve_img.shape[1])
+ retrieve_img = mmcv.imresize(
+ retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
+ int(retrieve_img.shape[0] * scale_ratio)))
+
+ # 2. paste
+ out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
+
+ # 3. scale jit
+ scale_ratio *= jit_factor
+ out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
+ int(out_img.shape[0] * jit_factor)))
+
+ # 4. flip
+ if is_filp:
+ out_img = out_img[:, ::-1, :]
+
+ # 5. random crop
+ ori_img = results['img']
+ origin_h, origin_w = out_img.shape[:2]
+ target_h, target_w = ori_img.shape[:2]
+ padded_img = np.zeros(
+ (max(origin_h, target_h), max(origin_w,
+ target_w), 3)).astype(np.uint8)
+ padded_img[:origin_h, :origin_w] = out_img
+
+ x_offset, y_offset = 0, 0
+ if padded_img.shape[0] > target_h:
+ y_offset = random.randint(0, padded_img.shape[0] - target_h)
+ if padded_img.shape[1] > target_w:
+ x_offset = random.randint(0, padded_img.shape[1] - target_w)
+ padded_cropped_img = padded_img[y_offset:y_offset + target_h,
+ x_offset:x_offset + target_w]
+
+ # 6. adjust bbox
+ retrieve_gt_bboxes = retrieve_results['gt_bboxes']
+ retrieve_gt_bboxes[:, 0::2] = retrieve_gt_bboxes[:, 0::2] * scale_ratio
+ retrieve_gt_bboxes[:, 1::2] = retrieve_gt_bboxes[:, 1::2] * scale_ratio
+ if self.bbox_clip_border:
+ retrieve_gt_bboxes[:, 0::2] = np.clip(retrieve_gt_bboxes[:, 0::2],
+ 0, origin_w)
+ retrieve_gt_bboxes[:, 1::2] = np.clip(retrieve_gt_bboxes[:, 1::2],
+ 0, origin_h)
+
+ if is_filp:
+ retrieve_gt_bboxes[:, 0::2] = (
+ origin_w - retrieve_gt_bboxes[:, 0::2][:, ::-1])
+
+ # 7. filter
+ cp_retrieve_gt_bboxes = retrieve_gt_bboxes.copy()
+ cp_retrieve_gt_bboxes[:, 0::2] = \
+ cp_retrieve_gt_bboxes[:, 0::2] - x_offset
+ cp_retrieve_gt_bboxes[:, 1::2] = \
+ cp_retrieve_gt_bboxes[:, 1::2] - y_offset
+ if self.bbox_clip_border:
+ cp_retrieve_gt_bboxes[:, 0::2] = np.clip(
+ cp_retrieve_gt_bboxes[:, 0::2], 0, target_w)
+ cp_retrieve_gt_bboxes[:, 1::2] = np.clip(
+ cp_retrieve_gt_bboxes[:, 1::2], 0, target_h)
+
+ # 8. mix up
+ ori_img = ori_img.astype(np.float32)
+ mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(np.float32)
+
+ retrieve_gt_labels = retrieve_results['gt_labels']
+ if not self.skip_filter:
+ keep_list = self._filter_box_candidates(retrieve_gt_bboxes.T,
+ cp_retrieve_gt_bboxes.T)
+
+ retrieve_gt_labels = retrieve_gt_labels[keep_list]
+ cp_retrieve_gt_bboxes = cp_retrieve_gt_bboxes[keep_list]
+
+ mixup_gt_bboxes = np.concatenate(
+ (results['gt_bboxes'], cp_retrieve_gt_bboxes), axis=0)
+ mixup_gt_labels = np.concatenate(
+ (results['gt_labels'], retrieve_gt_labels), axis=0)
+
+ # remove outside bbox
+ inside_inds = find_inside_bboxes(mixup_gt_bboxes, target_h, target_w)
+ mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
+ mixup_gt_labels = mixup_gt_labels[inside_inds]
+
+ results['img'] = mixup_img.astype(np.uint8)
+ results['img_shape'] = mixup_img.shape
+ results['gt_bboxes'] = mixup_gt_bboxes
+ results['gt_labels'] = mixup_gt_labels
+
+ return results
+
+ def _filter_box_candidates(self, bbox1, bbox2):
+ """Compute candidate boxes which include following 5 things:
+
+ bbox1 before augment, bbox2 after augment, min_bbox_size (pixels),
+ min_area_ratio, max_aspect_ratio.
+ """
+
+ w1, h1 = bbox1[2] - bbox1[0], bbox1[3] - bbox1[1]
+ w2, h2 = bbox2[2] - bbox2[0], bbox2[3] - bbox2[1]
+ ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16))
+ return ((w2 > self.min_bbox_size)
+ & (h2 > self.min_bbox_size)
+ & (w2 * h2 / (w1 * h1 + 1e-16) > self.min_area_ratio)
+ & (ar < self.max_aspect_ratio))
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'dynamic_scale={self.dynamic_scale}, '
+ repr_str += f'ratio_range={self.ratio_range}, '
+ repr_str += f'flip_ratio={self.flip_ratio}, '
+ repr_str += f'pad_val={self.pad_val}, '
+ repr_str += f'max_iters={self.max_iters}, '
+ repr_str += f'min_bbox_size={self.min_bbox_size}, '
+ repr_str += f'min_area_ratio={self.min_area_ratio}, '
+ repr_str += f'max_aspect_ratio={self.max_aspect_ratio}, '
+ repr_str += f'skip_filter={self.skip_filter})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomAffine:
+ """Random affine transform data augmentation.
+
+ This operation randomly generates affine transform matrix which including
+ rotation, translation, shear and scaling transforms.
+
+ Args:
+ max_rotate_degree (float): Maximum degrees of rotation transform.
+ Default: 10.
+ max_translate_ratio (float): Maximum ratio of translation.
+ Default: 0.1.
+ scaling_ratio_range (tuple[float]): Min and max ratio of
+ scaling transform. Default: (0.5, 1.5).
+ max_shear_degree (float): Maximum degrees of shear
+ transform. Default: 2.
+ border (tuple[int]): Distance from height and width sides of input
+ image to adjust output shape. Only used in mosaic dataset.
+ Default: (0, 0).
+ border_val (tuple[int]): Border padding values of 3 channels.
+ Default: (114, 114, 114).
+ min_bbox_size (float): Width and height threshold to filter bboxes.
+ If the height or width of a box is smaller than this value, it
+ will be removed. Default: 2.
+ min_area_ratio (float): Threshold of area ratio between
+ original bboxes and wrapped bboxes. If smaller than this value,
+ the box will be removed. Default: 0.2.
+ max_aspect_ratio (float): Aspect ratio of width and height
+ threshold to filter bboxes. If max(h/w, w/h) larger than this
+ value, the box will be removed.
+ bbox_clip_border (bool, optional): Whether to clip the objects outside
+ the border of the image. In some dataset like MOT17, the gt bboxes
+ are allowed to cross the border of images. Therefore, we don't
+ need to clip the gt bboxes in these cases. Defaults to True.
+ skip_filter (bool): Whether to skip filtering rules. If it
+ is True, the filter rule will not be applied, and the
+ `min_bbox_size` and `min_area_ratio` and `max_aspect_ratio`
+ is invalid. Default to True.
+ """
+
+ def __init__(self,
+ max_rotate_degree=10.0,
+ max_translate_ratio=0.1,
+ scaling_ratio_range=(0.5, 1.5),
+ max_shear_degree=2.0,
+ border=(0, 0),
+ border_val=(114, 114, 114),
+ min_bbox_size=2,
+ min_area_ratio=0.2,
+ max_aspect_ratio=20,
+ bbox_clip_border=True,
+ skip_filter=True):
+ assert 0 <= max_translate_ratio <= 1
+ assert scaling_ratio_range[0] <= scaling_ratio_range[1]
+ assert scaling_ratio_range[0] > 0
+ self.max_rotate_degree = max_rotate_degree
+ self.max_translate_ratio = max_translate_ratio
+ self.scaling_ratio_range = scaling_ratio_range
+ self.max_shear_degree = max_shear_degree
+ self.border = border
+ self.border_val = border_val
+ self.min_bbox_size = min_bbox_size
+ self.min_area_ratio = min_area_ratio
+ self.max_aspect_ratio = max_aspect_ratio
+ self.bbox_clip_border = bbox_clip_border
+ self.skip_filter = skip_filter
+
+ def __call__(self, results):
+ img = results['img']
+ height = img.shape[0] + self.border[0] * 2
+ width = img.shape[1] + self.border[1] * 2
+
+ # Rotation
+ rotation_degree = random.uniform(-self.max_rotate_degree,
+ self.max_rotate_degree)
+ rotation_matrix = self._get_rotation_matrix(rotation_degree)
+
+ # Scaling
+ scaling_ratio = random.uniform(self.scaling_ratio_range[0],
+ self.scaling_ratio_range[1])
+ scaling_matrix = self._get_scaling_matrix(scaling_ratio)
+
+ # Shear
+ x_degree = random.uniform(-self.max_shear_degree,
+ self.max_shear_degree)
+ y_degree = random.uniform(-self.max_shear_degree,
+ self.max_shear_degree)
+ shear_matrix = self._get_shear_matrix(x_degree, y_degree)
+
+ # Translation
+ trans_x = random.uniform(-self.max_translate_ratio,
+ self.max_translate_ratio) * width
+ trans_y = random.uniform(-self.max_translate_ratio,
+ self.max_translate_ratio) * height
+ translate_matrix = self._get_translation_matrix(trans_x, trans_y)
+
+ warp_matrix = (
+ translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix)
+
+ img = cv2.warpPerspective(
+ img,
+ warp_matrix,
+ dsize=(width, height),
+ borderValue=self.border_val)
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ for key in results.get('bbox_fields', []):
+ bboxes = results[key]
+ num_bboxes = len(bboxes)
+ if num_bboxes:
+ # homogeneous coordinates
+ xs = bboxes[:, [0, 0, 2, 2]].reshape(num_bboxes * 4)
+ ys = bboxes[:, [1, 3, 3, 1]].reshape(num_bboxes * 4)
+ ones = np.ones_like(xs)
+ points = np.vstack([xs, ys, ones])
+
+ warp_points = warp_matrix @ points
+ warp_points = warp_points[:2] / warp_points[2]
+ xs = warp_points[0].reshape(num_bboxes, 4)
+ ys = warp_points[1].reshape(num_bboxes, 4)
+
+ warp_bboxes = np.vstack(
+ (xs.min(1), ys.min(1), xs.max(1), ys.max(1))).T
+
+ if self.bbox_clip_border:
+ warp_bboxes[:, [0, 2]] = \
+ warp_bboxes[:, [0, 2]].clip(0, width)
+ warp_bboxes[:, [1, 3]] = \
+ warp_bboxes[:, [1, 3]].clip(0, height)
+
+ # remove outside bbox
+ valid_index = find_inside_bboxes(warp_bboxes, height, width)
+ if not self.skip_filter:
+ # filter bboxes
+ filter_index = self.filter_gt_bboxes(
+ bboxes * scaling_ratio, warp_bboxes)
+ valid_index = valid_index & filter_index
+
+ results[key] = warp_bboxes[valid_index]
+ if key in ['gt_bboxes']:
+ if 'gt_labels' in results:
+ results['gt_labels'] = results['gt_labels'][
+ valid_index]
+
+ if 'gt_masks' in results:
+ raise NotImplementedError(
+ 'RandomAffine only supports bbox.')
+ return results
+
+ def filter_gt_bboxes(self, origin_bboxes, wrapped_bboxes):
+ origin_w = origin_bboxes[:, 2] - origin_bboxes[:, 0]
+ origin_h = origin_bboxes[:, 3] - origin_bboxes[:, 1]
+ wrapped_w = wrapped_bboxes[:, 2] - wrapped_bboxes[:, 0]
+ wrapped_h = wrapped_bboxes[:, 3] - wrapped_bboxes[:, 1]
+ aspect_ratio = np.maximum(wrapped_w / (wrapped_h + 1e-16),
+ wrapped_h / (wrapped_w + 1e-16))
+
+ wh_valid_idx = (wrapped_w > self.min_bbox_size) & \
+ (wrapped_h > self.min_bbox_size)
+ area_valid_idx = wrapped_w * wrapped_h / (origin_w * origin_h +
+ 1e-16) > self.min_area_ratio
+ aspect_ratio_valid_idx = aspect_ratio < self.max_aspect_ratio
+ return wh_valid_idx & area_valid_idx & aspect_ratio_valid_idx
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(max_rotate_degree={self.max_rotate_degree}, '
+ repr_str += f'max_translate_ratio={self.max_translate_ratio}, '
+ repr_str += f'scaling_ratio={self.scaling_ratio_range}, '
+ repr_str += f'max_shear_degree={self.max_shear_degree}, '
+ repr_str += f'border={self.border}, '
+ repr_str += f'border_val={self.border_val}, '
+ repr_str += f'min_bbox_size={self.min_bbox_size}, '
+ repr_str += f'min_area_ratio={self.min_area_ratio}, '
+ repr_str += f'max_aspect_ratio={self.max_aspect_ratio}, '
+ repr_str += f'skip_filter={self.skip_filter})'
+ return repr_str
+
+ @staticmethod
+ def _get_rotation_matrix(rotate_degrees):
+ radian = math.radians(rotate_degrees)
+ rotation_matrix = np.array(
+ [[np.cos(radian), -np.sin(radian), 0.],
+ [np.sin(radian), np.cos(radian), 0.], [0., 0., 1.]],
+ dtype=np.float32)
+ return rotation_matrix
+
+ @staticmethod
+ def _get_scaling_matrix(scale_ratio):
+ scaling_matrix = np.array(
+ [[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]],
+ dtype=np.float32)
+ return scaling_matrix
+
+ @staticmethod
+ def _get_share_matrix(scale_ratio):
+ scaling_matrix = np.array(
+ [[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]],
+ dtype=np.float32)
+ return scaling_matrix
+
+ @staticmethod
+ def _get_shear_matrix(x_shear_degrees, y_shear_degrees):
+ x_radian = math.radians(x_shear_degrees)
+ y_radian = math.radians(y_shear_degrees)
+ shear_matrix = np.array([[1, np.tan(x_radian), 0.],
+ [np.tan(y_radian), 1, 0.], [0., 0., 1.]],
+ dtype=np.float32)
+ return shear_matrix
+
+ @staticmethod
+ def _get_translation_matrix(x, y):
+ translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
+ dtype=np.float32)
+ return translation_matrix
+
+
+@PIPELINES.register_module()
+class YOLOXHSVRandomAug:
+ """Apply HSV augmentation to image sequentially. It is referenced from
+ https://github.com/Megvii-
+ BaseDetection/YOLOX/blob/main/yolox/data/data_augment.py#L21.
+
+ Args:
+ hue_delta (int): delta of hue. Default: 5.
+ saturation_delta (int): delta of saturation. Default: 30.
+ value_delta (int): delat of value. Default: 30.
+ """
+
+ def __init__(self, hue_delta=5, saturation_delta=30, value_delta=30):
+ self.hue_delta = hue_delta
+ self.saturation_delta = saturation_delta
+ self.value_delta = value_delta
+
+ def __call__(self, results):
+ img = results['img']
+ hsv_gains = np.random.uniform(-1, 1, 3) * [
+ self.hue_delta, self.saturation_delta, self.value_delta
+ ]
+ # random selection of h, s, v
+ hsv_gains *= np.random.randint(0, 2, 3)
+ # prevent overflow
+ hsv_gains = hsv_gains.astype(np.int16)
+ img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16)
+
+ img_hsv[..., 0] = (img_hsv[..., 0] + hsv_gains[0]) % 180
+ img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_gains[1], 0, 255)
+ img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_gains[2], 0, 255)
+ cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img)
+
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(hue_delta={self.hue_delta}, '
+ repr_str += f'saturation_delta={self.saturation_delta}, '
+ repr_str += f'value_delta={self.value_delta})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class CopyPaste:
+ """Simple Copy-Paste is a Strong Data Augmentation Method for Instance
+ Segmentation The simple copy-paste transform steps are as follows:
+
+ 1. The destination image is already resized with aspect ratio kept,
+ cropped and padded.
+ 2. Randomly select a source image, which is also already resized
+ with aspect ratio kept, cropped and padded in a similar way
+ as the destination image.
+ 3. Randomly select some objects from the source image.
+ 4. Paste these source objects to the destination image directly,
+ due to the source and destination image have the same size.
+ 5. Update object masks of the destination image, for some origin objects
+ may be occluded.
+ 6. Generate bboxes from the updated destination masks and
+ filter some objects which are totally occluded, and adjust bboxes
+ which are partly occluded.
+ 7. Append selected source bboxes, masks, and labels.
+
+ Args:
+ max_num_pasted (int): The maximum number of pasted objects.
+ Default: 100.
+ bbox_occluded_thr (int): The threshold of occluded bbox.
+ Default: 10.
+ mask_occluded_thr (int): The threshold of occluded mask.
+ Default: 300.
+ selected (bool): Whether select objects or not. If select is False,
+ all objects of the source image will be pasted to the
+ destination image.
+ Default: True.
+ """
+
+ def __init__(
+ self,
+ max_num_pasted=100,
+ bbox_occluded_thr=10,
+ mask_occluded_thr=300,
+ selected=True,
+ ):
+ self.max_num_pasted = max_num_pasted
+ self.bbox_occluded_thr = bbox_occluded_thr
+ self.mask_occluded_thr = mask_occluded_thr
+ self.selected = selected
+ self.paste_by_box = False
+
+ def get_indexes(self, dataset):
+ """Call function to collect indexes.s.
+
+ Args:
+ dataset (:obj:`MultiImageMixDataset`): The dataset.
+ Returns:
+ list: Indexes.
+ """
+ return random.randint(0, len(dataset))
+
+ def gen_masks_from_bboxes(self, bboxes, img_shape):
+ """Generate gt_masks based on gt_bboxes.
+
+ Args:
+ bboxes (list): The bboxes's list.
+ img_shape (tuple): The shape of image.
+ Returns:
+ BitmapMasks
+ """
+ self.paste_by_box = True
+ img_h, img_w = img_shape[:2]
+ xmin, ymin = bboxes[:, 0:1], bboxes[:, 1:2]
+ xmax, ymax = bboxes[:, 2:3], bboxes[:, 3:4]
+ gt_masks = np.zeros((len(bboxes), img_h, img_w), dtype=np.uint8)
+ for i in range(len(bboxes)):
+ gt_masks[i,
+ int(ymin[i]):int(ymax[i]),
+ int(xmin[i]):int(xmax[i])] = 1
+ return BitmapMasks(gt_masks, img_h, img_w)
+
+ def get_gt_masks(self, results):
+ """Get gt_masks originally or generated based on bboxes.
+
+ If gt_masks is not contained in results,
+ it will be generated based on gt_bboxes.
+ Args:
+ results (dict): Result dict.
+ Returns:
+ BitmapMasks: gt_masks, originally or generated based on bboxes.
+ """
+ if results.get('gt_masks', None) is not None:
+ return results['gt_masks']
+ else:
+ return self.gen_masks_from_bboxes(
+ results.get('gt_bboxes', []), results['img'].shape)
+
+ def __call__(self, results):
+ """Call function to make a copy-paste of image.
+
+ Args:
+ results (dict): Result dict.
+ Returns:
+ dict: Result dict with copy-paste transformed.
+ """
+
+ assert 'mix_results' in results
+ num_images = len(results['mix_results'])
+ assert num_images == 1, \
+ f'CopyPaste only supports processing 2 images, got {num_images}'
+
+ # Get gt_masks originally or generated based on bboxes.
+ results['gt_masks'] = self.get_gt_masks(results)
+ # only one mix picture
+ results['mix_results'][0]['gt_masks'] = self.get_gt_masks(
+ results['mix_results'][0])
+
+ if self.selected:
+ selected_results = self._select_object(results['mix_results'][0])
+ else:
+ selected_results = results['mix_results'][0]
+ return self._copy_paste(results, selected_results)
+
+ def _select_object(self, results):
+ """Select some objects from the source results."""
+ bboxes = results['gt_bboxes']
+ labels = results['gt_labels']
+ masks = results['gt_masks']
+ max_num_pasted = min(bboxes.shape[0] + 1, self.max_num_pasted)
+ num_pasted = np.random.randint(0, max_num_pasted)
+ selected_inds = np.random.choice(
+ bboxes.shape[0], size=num_pasted, replace=False)
+
+ selected_bboxes = bboxes[selected_inds]
+ selected_labels = labels[selected_inds]
+ selected_masks = masks[selected_inds]
+
+ results['gt_bboxes'] = selected_bboxes
+ results['gt_labels'] = selected_labels
+ results['gt_masks'] = selected_masks
+ return results
+
+ def _copy_paste(self, dst_results, src_results):
+ """CopyPaste transform function.
+
+ Args:
+ dst_results (dict): Result dict of the destination image.
+ src_results (dict): Result dict of the source image.
+ Returns:
+ dict: Updated result dict.
+ """
+ dst_img = dst_results['img']
+ dst_bboxes = dst_results['gt_bboxes']
+ dst_labels = dst_results['gt_labels']
+ dst_masks = dst_results['gt_masks']
+
+ src_img = src_results['img']
+ src_bboxes = src_results['gt_bboxes']
+ src_labels = src_results['gt_labels']
+ src_masks = src_results['gt_masks']
+
+ if len(src_bboxes) == 0:
+ if self.paste_by_box:
+ dst_results.pop('gt_masks')
+ return dst_results
+
+ # update masks and generate bboxes from updated masks
+ composed_mask = np.where(np.any(src_masks.masks, axis=0), 1, 0)
+ updated_dst_masks = self.get_updated_masks(dst_masks, composed_mask)
+ updated_dst_bboxes = updated_dst_masks.get_bboxes()
+ assert len(updated_dst_bboxes) == len(updated_dst_masks)
+
+ # filter totally occluded objects
+ bboxes_inds = np.all(
+ np.abs(
+ (updated_dst_bboxes - dst_bboxes)) <= self.bbox_occluded_thr,
+ axis=-1)
+ masks_inds = updated_dst_masks.masks.sum(
+ axis=(1, 2)) > self.mask_occluded_thr
+ valid_inds = bboxes_inds | masks_inds
+
+ # Paste source objects to destination image directly
+ img = dst_img * (1 - composed_mask[..., np.newaxis]
+ ) + src_img * composed_mask[..., np.newaxis]
+ bboxes = np.concatenate([updated_dst_bboxes[valid_inds], src_bboxes])
+ labels = np.concatenate([dst_labels[valid_inds], src_labels])
+ masks = np.concatenate(
+ [updated_dst_masks.masks[valid_inds], src_masks.masks])
+
+ dst_results['img'] = img
+ dst_results['gt_bboxes'] = bboxes
+ dst_results['gt_labels'] = labels
+ if self.paste_by_box:
+ dst_results.pop('gt_masks')
+ else:
+ dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1],
+ masks.shape[2])
+
+ return dst_results
+
+ def get_updated_masks(self, masks, composed_mask):
+ assert masks.masks.shape[-2:] == composed_mask.shape[-2:], \
+ 'Cannot compare two arrays of different size'
+ masks.masks = np.where(composed_mask, 0, masks.masks)
+ return masks
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'max_num_pasted={self.max_num_pasted}, '
+ repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, '
+ repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, '
+ repr_str += f'selected={self.selected}, '
+ return repr_str
diff --git a/mmdet/datasets/samplers/__init__.py b/mmdet/datasets/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4c7ea135af652712e5a9f14a2002c516c44a16b
--- /dev/null
+++ b/mmdet/datasets/samplers/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .class_aware_sampler import ClassAwareSampler
+from .distributed_sampler import DistributedSampler
+from .group_sampler import DistributedGroupSampler, GroupSampler
+from .infinite_sampler import InfiniteBatchSampler, InfiniteGroupBatchSampler
+
+__all__ = [
+ 'DistributedSampler', 'DistributedGroupSampler', 'GroupSampler',
+ 'InfiniteGroupBatchSampler', 'InfiniteBatchSampler', 'ClassAwareSampler'
+]
diff --git a/mmdet/datasets/samplers/class_aware_sampler.py b/mmdet/datasets/samplers/class_aware_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c52708eb8b98d85b3fac3ee55c7519be60681896
--- /dev/null
+++ b/mmdet/datasets/samplers/class_aware_sampler.py
@@ -0,0 +1,176 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+from mmcv.runner import get_dist_info
+from torch.utils.data import Sampler
+
+from mmdet.core.utils import sync_random_seed
+
+
+class ClassAwareSampler(Sampler):
+ r"""Sampler that restricts data loading to the label of the dataset.
+
+ A class-aware sampling strategy to effectively tackle the
+ non-uniform class distribution. The length of the training data is
+ consistent with source data. Simple improvements based on `Relay
+ Backpropagation for Effective Learning of Deep Convolutional
+ Neural Networks `_
+
+ The implementation logic is referred to
+ https://github.com/Sense-X/TSD/blob/master/mmdet/datasets/samplers/distributed_classaware_sampler.py
+
+ Args:
+ dataset: Dataset used for sampling.
+ samples_per_gpu (int): When model is :obj:`DistributedDataParallel`,
+ it is the number of training samples on each GPU.
+ When model is :obj:`DataParallel`, it is
+ `num_gpus * samples_per_gpu`.
+ Default : 1.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ seed (int, optional): random seed used to shuffle the sampler if
+ ``shuffle=True``. This number should be identical across all
+ processes in the distributed group. Default: 0.
+ num_sample_class (int): The number of samples taken from each
+ per-label list. Default: 1
+ """
+
+ def __init__(self,
+ dataset,
+ samples_per_gpu=1,
+ num_replicas=None,
+ rank=None,
+ seed=0,
+ num_sample_class=1):
+ _rank, _num_replicas = get_dist_info()
+ if num_replicas is None:
+ num_replicas = _num_replicas
+ if rank is None:
+ rank = _rank
+
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.samples_per_gpu = samples_per_gpu
+ self.rank = rank
+ self.epoch = 0
+ # Must be the same across all workers. If None, will use a
+ # random seed shared among workers
+ # (require synchronization among all workers)
+ self.seed = sync_random_seed(seed)
+
+ # The number of samples taken from each per-label list
+ assert num_sample_class > 0 and isinstance(num_sample_class, int)
+ self.num_sample_class = num_sample_class
+ # Get per-label image list from dataset
+ assert hasattr(dataset, 'get_cat2imgs'), \
+ 'dataset must have `get_cat2imgs` function'
+ self.cat_dict = dataset.get_cat2imgs()
+
+ self.num_samples = int(
+ math.ceil(
+ len(self.dataset) * 1.0 / self.num_replicas /
+ self.samples_per_gpu)) * self.samples_per_gpu
+ self.total_size = self.num_samples * self.num_replicas
+
+ # get number of images containing each category
+ self.num_cat_imgs = [len(x) for x in self.cat_dict.values()]
+ # filter labels without images
+ self.valid_cat_inds = [
+ i for i, length in enumerate(self.num_cat_imgs) if length != 0
+ ]
+ self.num_classes = len(self.valid_cat_inds)
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch + self.seed)
+
+ # initialize label list
+ label_iter_list = RandomCycleIter(self.valid_cat_inds, generator=g)
+ # initialize each per-label image list
+ data_iter_dict = dict()
+ for i in self.valid_cat_inds:
+ data_iter_dict[i] = RandomCycleIter(self.cat_dict[i], generator=g)
+
+ def gen_cat_img_inds(cls_list, data_dict, num_sample_cls):
+ """Traverse the categories and extract `num_sample_cls` image
+ indexes of the corresponding categories one by one."""
+ id_indices = []
+ for _ in range(len(cls_list)):
+ cls_idx = next(cls_list)
+ for _ in range(num_sample_cls):
+ id = next(data_dict[cls_idx])
+ id_indices.append(id)
+ return id_indices
+
+ # deterministically shuffle based on epoch
+ num_bins = int(
+ math.ceil(self.total_size * 1.0 / self.num_classes /
+ self.num_sample_class))
+ indices = []
+ for i in range(num_bins):
+ indices += gen_cat_img_inds(label_iter_list, data_iter_dict,
+ self.num_sample_class)
+
+ # fix extra samples to make it evenly divisible
+ if len(indices) >= self.total_size:
+ indices = indices[:self.total_size]
+ else:
+ indices += indices[:(self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset:offset + self.num_samples]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+class RandomCycleIter:
+ """Shuffle the list and do it again after the list have traversed.
+
+ The implementation logic is referred to
+ https://github.com/wutong16/DistributionBalancedLoss/blob/master/mllt/datasets/loader/sampler.py
+
+ Example:
+ >>> label_list = [0, 1, 2, 4, 5]
+ >>> g = torch.Generator()
+ >>> g.manual_seed(0)
+ >>> label_iter_list = RandomCycleIter(label_list, generator=g)
+ >>> index = next(label_iter_list)
+ Args:
+ data (list or ndarray): The data that needs to be shuffled.
+ generator: An torch.Generator object, which is used in setting the seed
+ for generating random numbers.
+ """ # noqa: W605
+
+ def __init__(self, data, generator=None):
+ self.data = data
+ self.length = len(data)
+ self.index = torch.randperm(self.length, generator=generator).numpy()
+ self.i = 0
+ self.generator = generator
+
+ def __iter__(self):
+ return self
+
+ def __len__(self):
+ return len(self.data)
+
+ def __next__(self):
+ if self.i == self.length:
+ self.index = torch.randperm(
+ self.length, generator=self.generator).numpy()
+ self.i = 0
+ idx = self.data[self.index[self.i]]
+ self.i += 1
+ return idx
diff --git a/mmdet/datasets/samplers/distributed_sampler.py b/mmdet/datasets/samplers/distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bc8b7c3602cee288e4ab8d661819c0a2490d4ee
--- /dev/null
+++ b/mmdet/datasets/samplers/distributed_sampler.py
@@ -0,0 +1,54 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+from torch.utils.data import DistributedSampler as _DistributedSampler
+
+from mmdet.core.utils import sync_random_seed
+from mmdet.utils import get_device
+
+
+class DistributedSampler(_DistributedSampler):
+
+ def __init__(self,
+ dataset,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ seed=0):
+ super().__init__(
+ dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
+
+ # In distributed sampling, different ranks should sample
+ # non-overlapped data in the dataset. Therefore, this function
+ # is used to make sure that each rank shuffles the data indices
+ # in the same order based on the same seed. Then different ranks
+ # could use different indices to select non-overlapped data from the
+ # same data list.
+ device = get_device()
+ self.seed = sync_random_seed(seed, device)
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ if self.shuffle:
+ g = torch.Generator()
+ # When :attr:`shuffle=True`, this ensures all replicas
+ # use a different random ordering for each epoch.
+ # Otherwise, the next iteration of this sampler will
+ # yield the same ordering.
+ g.manual_seed(self.epoch + self.seed)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ # in case that indices is shorter than half of total_size
+ indices = (indices *
+ math.ceil(self.total_size / len(indices)))[:self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
diff --git a/mmdet/datasets/samplers/group_sampler.py b/mmdet/datasets/samplers/group_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..783d2b21cca753f12a7a617f049f84a2b6541dd9
--- /dev/null
+++ b/mmdet/datasets/samplers/group_sampler.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+from mmcv.runner import get_dist_info
+from torch.utils.data import Sampler
+
+
+class GroupSampler(Sampler):
+
+ def __init__(self, dataset, samples_per_gpu=1):
+ assert hasattr(dataset, 'flag')
+ self.dataset = dataset
+ self.samples_per_gpu = samples_per_gpu
+ self.flag = dataset.flag.astype(np.int64)
+ self.group_sizes = np.bincount(self.flag)
+ self.num_samples = 0
+ for i, size in enumerate(self.group_sizes):
+ self.num_samples += int(np.ceil(
+ size / self.samples_per_gpu)) * self.samples_per_gpu
+
+ def __iter__(self):
+ indices = []
+ for i, size in enumerate(self.group_sizes):
+ if size == 0:
+ continue
+ indice = np.where(self.flag == i)[0]
+ assert len(indice) == size
+ np.random.shuffle(indice)
+ num_extra = int(np.ceil(size / self.samples_per_gpu)
+ ) * self.samples_per_gpu - len(indice)
+ indice = np.concatenate(
+ [indice, np.random.choice(indice, num_extra)])
+ indices.append(indice)
+ indices = np.concatenate(indices)
+ indices = [
+ indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
+ for i in np.random.permutation(
+ range(len(indices) // self.samples_per_gpu))
+ ]
+ indices = np.concatenate(indices)
+ indices = indices.astype(np.int64).tolist()
+ assert len(indices) == self.num_samples
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+
+class DistributedGroupSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+
+ .. note::
+ Dataset is assumed to be of constant size.
+
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ seed (int, optional): random seed used to shuffle the sampler if
+ ``shuffle=True``. This number should be identical across all
+ processes in the distributed group. Default: 0.
+ """
+
+ def __init__(self,
+ dataset,
+ samples_per_gpu=1,
+ num_replicas=None,
+ rank=None,
+ seed=0):
+ _rank, _num_replicas = get_dist_info()
+ if num_replicas is None:
+ num_replicas = _num_replicas
+ if rank is None:
+ rank = _rank
+ self.dataset = dataset
+ self.samples_per_gpu = samples_per_gpu
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.seed = seed if seed is not None else 0
+
+ assert hasattr(self.dataset, 'flag')
+ self.flag = self.dataset.flag
+ self.group_sizes = np.bincount(self.flag)
+
+ self.num_samples = 0
+ for i, j in enumerate(self.group_sizes):
+ self.num_samples += int(
+ math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
+ self.num_replicas)) * self.samples_per_gpu
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch + self.seed)
+
+ indices = []
+ for i, size in enumerate(self.group_sizes):
+ if size > 0:
+ indice = np.where(self.flag == i)[0]
+ assert len(indice) == size
+ # add .numpy() to avoid bug when selecting indice in parrots.
+ # TODO: check whether torch.randperm() can be replaced by
+ # numpy.random.permutation().
+ indice = indice[list(
+ torch.randperm(int(size), generator=g).numpy())].tolist()
+ extra = int(
+ math.ceil(
+ size * 1.0 / self.samples_per_gpu / self.num_replicas)
+ ) * self.samples_per_gpu * self.num_replicas - len(indice)
+ # pad indice
+ tmp = indice.copy()
+ for _ in range(extra // size):
+ indice.extend(tmp)
+ indice.extend(tmp[:extra % size])
+ indices.extend(indice)
+
+ assert len(indices) == self.total_size
+
+ indices = [
+ indices[j] for i in list(
+ torch.randperm(
+ len(indices) // self.samples_per_gpu, generator=g))
+ for j in range(i * self.samples_per_gpu, (i + 1) *
+ self.samples_per_gpu)
+ ]
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset:offset + self.num_samples]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/mmdet/datasets/samplers/infinite_sampler.py b/mmdet/datasets/samplers/infinite_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..d42487e6ac0c3e63cd8c4a0bb5ead9644b09a0ea
--- /dev/null
+++ b/mmdet/datasets/samplers/infinite_sampler.py
@@ -0,0 +1,186 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
+
+import numpy as np
+import torch
+from mmcv.runner import get_dist_info
+from torch.utils.data.sampler import Sampler
+
+from mmdet.core.utils import sync_random_seed
+
+
+class InfiniteGroupBatchSampler(Sampler):
+ """Similar to `BatchSampler` warping a `GroupSampler. It is designed for
+ iteration-based runners like `IterBasedRunner` and yields a mini-batch
+ indices each time, all indices in a batch should be in the same group.
+
+ The implementation logic is referred to
+ https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py
+
+ Args:
+ dataset (object): The dataset.
+ batch_size (int): When model is :obj:`DistributedDataParallel`,
+ it is the number of training samples on each GPU.
+ When model is :obj:`DataParallel`, it is
+ `num_gpus * samples_per_gpu`.
+ Default : 1.
+ world_size (int, optional): Number of processes participating in
+ distributed training. Default: None.
+ rank (int, optional): Rank of current process. Default: None.
+ seed (int): Random seed. Default: 0.
+ shuffle (bool): Whether shuffle the indices of a dummy `epoch`, it
+ should be noted that `shuffle` can not guarantee that you can
+ generate sequential indices because it need to ensure
+ that all indices in a batch is in a group. Default: True.
+ """ # noqa: W605
+
+ def __init__(self,
+ dataset,
+ batch_size=1,
+ world_size=None,
+ rank=None,
+ seed=0,
+ shuffle=True):
+ _rank, _world_size = get_dist_info()
+ if world_size is None:
+ world_size = _world_size
+ if rank is None:
+ rank = _rank
+ self.rank = rank
+ self.world_size = world_size
+ self.dataset = dataset
+ self.batch_size = batch_size
+ # In distributed sampling, different ranks should sample
+ # non-overlapped data in the dataset. Therefore, this function
+ # is used to make sure that each rank shuffles the data indices
+ # in the same order based on the same seed. Then different ranks
+ # could use different indices to select non-overlapped data from the
+ # same data list.
+ self.seed = sync_random_seed(seed)
+ self.shuffle = shuffle
+
+ assert hasattr(self.dataset, 'flag')
+ self.flag = self.dataset.flag
+ self.group_sizes = np.bincount(self.flag)
+ # buffer used to save indices of each group
+ self.buffer_per_group = {k: [] for k in range(len(self.group_sizes))}
+
+ self.size = len(dataset)
+ self.indices = self._indices_of_rank()
+
+ def _infinite_indices(self):
+ """Infinitely yield a sequence of indices."""
+ g = torch.Generator()
+ g.manual_seed(self.seed)
+ while True:
+ if self.shuffle:
+ yield from torch.randperm(self.size, generator=g).tolist()
+
+ else:
+ yield from torch.arange(self.size).tolist()
+
+ def _indices_of_rank(self):
+ """Slice the infinite indices by rank."""
+ yield from itertools.islice(self._infinite_indices(), self.rank, None,
+ self.world_size)
+
+ def __iter__(self):
+ # once batch size is reached, yield the indices
+ for idx in self.indices:
+ flag = self.flag[idx]
+ group_buffer = self.buffer_per_group[flag]
+ group_buffer.append(idx)
+ if len(group_buffer) == self.batch_size:
+ yield group_buffer[:]
+ del group_buffer[:]
+
+ def __len__(self):
+ """Length of base dataset."""
+ return self.size
+
+ def set_epoch(self, epoch):
+ """Not supported in `IterationBased` runner."""
+ raise NotImplementedError
+
+
+class InfiniteBatchSampler(Sampler):
+ """Similar to `BatchSampler` warping a `DistributedSampler. It is designed
+ iteration-based runners like `IterBasedRunner` and yields a mini-batch
+ indices each time.
+
+ The implementation logic is referred to
+ https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py
+
+ Args:
+ dataset (object): The dataset.
+ batch_size (int): When model is :obj:`DistributedDataParallel`,
+ it is the number of training samples on each GPU,
+ When model is :obj:`DataParallel`, it is
+ `num_gpus * samples_per_gpu`.
+ Default : 1.
+ world_size (int, optional): Number of processes participating in
+ distributed training. Default: None.
+ rank (int, optional): Rank of current process. Default: None.
+ seed (int): Random seed. Default: 0.
+ shuffle (bool): Whether shuffle the dataset or not. Default: True.
+ """ # noqa: W605
+
+ def __init__(self,
+ dataset,
+ batch_size=1,
+ world_size=None,
+ rank=None,
+ seed=0,
+ shuffle=True):
+ _rank, _world_size = get_dist_info()
+ if world_size is None:
+ world_size = _world_size
+ if rank is None:
+ rank = _rank
+ self.rank = rank
+ self.world_size = world_size
+ self.dataset = dataset
+ self.batch_size = batch_size
+ # In distributed sampling, different ranks should sample
+ # non-overlapped data in the dataset. Therefore, this function
+ # is used to make sure that each rank shuffles the data indices
+ # in the same order based on the same seed. Then different ranks
+ # could use different indices to select non-overlapped data from the
+ # same data list.
+ self.seed = sync_random_seed(seed)
+ self.shuffle = shuffle
+ self.size = len(dataset)
+ self.indices = self._indices_of_rank()
+
+ def _infinite_indices(self):
+ """Infinitely yield a sequence of indices."""
+ g = torch.Generator()
+ g.manual_seed(self.seed)
+ while True:
+ if self.shuffle:
+ yield from torch.randperm(self.size, generator=g).tolist()
+
+ else:
+ yield from torch.arange(self.size).tolist()
+
+ def _indices_of_rank(self):
+ """Slice the infinite indices by rank."""
+ yield from itertools.islice(self._infinite_indices(), self.rank, None,
+ self.world_size)
+
+ def __iter__(self):
+ # once batch size is reached, yield the indices
+ batch_buffer = []
+ for idx in self.indices:
+ batch_buffer.append(idx)
+ if len(batch_buffer) == self.batch_size:
+ yield batch_buffer
+ batch_buffer = []
+
+ def __len__(self):
+ """Length of base dataset."""
+ return self.size
+
+ def set_epoch(self, epoch):
+ """Not supported in `IterationBased` runner."""
+ raise NotImplementedError
diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..26e922d2ba8edb5dd0a1242a96f32ad56505393f
--- /dev/null
+++ b/mmdet/datasets/utils.py
@@ -0,0 +1,166 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+from mmcv.cnn import VGG
+from mmcv.runner.hooks import HOOKS, Hook
+
+from mmdet.datasets.builder import PIPELINES
+from mmdet.datasets.pipelines import (LoadAnnotations, LoadImageFromFile,
+ LoadPanopticAnnotations)
+from mmdet.models.dense_heads import GARPNHead, RPNHead
+from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
+
+
+def replace_ImageToTensor(pipelines):
+ """Replace the ImageToTensor transform in a data pipeline to
+ DefaultFormatBundle, which is normally useful in batch inference.
+
+ Args:
+ pipelines (list[dict]): Data pipeline configs.
+
+ Returns:
+ list: The new pipeline list with all ImageToTensor replaced by
+ DefaultFormatBundle.
+
+ Examples:
+ >>> pipelines = [
+ ... dict(type='LoadImageFromFile'),
+ ... dict(
+ ... type='MultiScaleFlipAug',
+ ... img_scale=(1333, 800),
+ ... flip=False,
+ ... transforms=[
+ ... dict(type='Resize', keep_ratio=True),
+ ... dict(type='RandomFlip'),
+ ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
+ ... dict(type='Pad', size_divisor=32),
+ ... dict(type='ImageToTensor', keys=['img']),
+ ... dict(type='Collect', keys=['img']),
+ ... ])
+ ... ]
+ >>> expected_pipelines = [
+ ... dict(type='LoadImageFromFile'),
+ ... dict(
+ ... type='MultiScaleFlipAug',
+ ... img_scale=(1333, 800),
+ ... flip=False,
+ ... transforms=[
+ ... dict(type='Resize', keep_ratio=True),
+ ... dict(type='RandomFlip'),
+ ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
+ ... dict(type='Pad', size_divisor=32),
+ ... dict(type='DefaultFormatBundle'),
+ ... dict(type='Collect', keys=['img']),
+ ... ])
+ ... ]
+ >>> assert expected_pipelines == replace_ImageToTensor(pipelines)
+ """
+ pipelines = copy.deepcopy(pipelines)
+ for i, pipeline in enumerate(pipelines):
+ if pipeline['type'] == 'MultiScaleFlipAug':
+ assert 'transforms' in pipeline
+ pipeline['transforms'] = replace_ImageToTensor(
+ pipeline['transforms'])
+ elif pipeline['type'] == 'ImageToTensor':
+ warnings.warn(
+ '"ImageToTensor" pipeline is replaced by '
+ '"DefaultFormatBundle" for batch inference. It is '
+ 'recommended to manually replace it in the test '
+ 'data pipeline in your config file.', UserWarning)
+ pipelines[i] = {'type': 'DefaultFormatBundle'}
+ return pipelines
+
+
+def get_loading_pipeline(pipeline):
+ """Only keep loading image and annotations related configuration.
+
+ Args:
+ pipeline (list[dict]): Data pipeline configs.
+
+ Returns:
+ list[dict]: The new pipeline list with only keep
+ loading image and annotations related configuration.
+
+ Examples:
+ >>> pipelines = [
+ ... dict(type='LoadImageFromFile'),
+ ... dict(type='LoadAnnotations', with_bbox=True),
+ ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ ... dict(type='RandomFlip', flip_ratio=0.5),
+ ... dict(type='Normalize', **img_norm_cfg),
+ ... dict(type='Pad', size_divisor=32),
+ ... dict(type='DefaultFormatBundle'),
+ ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
+ ... ]
+ >>> expected_pipelines = [
+ ... dict(type='LoadImageFromFile'),
+ ... dict(type='LoadAnnotations', with_bbox=True)
+ ... ]
+ >>> assert expected_pipelines ==\
+ ... get_loading_pipeline(pipelines)
+ """
+ loading_pipeline_cfg = []
+ for cfg in pipeline:
+ obj_cls = PIPELINES.get(cfg['type'])
+ # TODO:use more elegant way to distinguish loading modules
+ if obj_cls is not None and obj_cls in (LoadImageFromFile,
+ LoadAnnotations,
+ LoadPanopticAnnotations):
+ loading_pipeline_cfg.append(cfg)
+ assert len(loading_pipeline_cfg) == 2, \
+ 'The data pipeline in your config file must include ' \
+ 'loading image and annotations related pipeline.'
+ return loading_pipeline_cfg
+
+
+@HOOKS.register_module()
+class NumClassCheckHook(Hook):
+
+ def _check_head(self, runner):
+ """Check whether the `num_classes` in head matches the length of
+ `CLASSES` in `dataset`.
+
+ Args:
+ runner (obj:`EpochBasedRunner`): Epoch based Runner.
+ """
+ model = runner.model
+ dataset = runner.data_loader.dataset
+ if dataset.CLASSES is None:
+ runner.logger.warning(
+ f'Please set `CLASSES` '
+ f'in the {dataset.__class__.__name__} and'
+ f'check if it is consistent with the `num_classes` '
+ f'of head')
+ else:
+ assert type(dataset.CLASSES) is not str, \
+ (f'`CLASSES` in {dataset.__class__.__name__}'
+ f'should be a tuple of str.'
+ f'Add comma if number of classes is 1 as '
+ f'CLASSES = ({dataset.CLASSES},)')
+ for name, module in model.named_modules():
+ if hasattr(module, 'num_classes') and not isinstance(
+ module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)):
+ assert module.num_classes == len(dataset.CLASSES), \
+ (f'The `num_classes` ({module.num_classes}) in '
+ f'{module.__class__.__name__} of '
+ f'{model.__class__.__name__} does not matches '
+ f'the length of `CLASSES` '
+ f'{len(dataset.CLASSES)}) in '
+ f'{dataset.__class__.__name__}')
+
+ def before_train_epoch(self, runner):
+ """Check whether the training dataset is compatible with head.
+
+ Args:
+ runner (obj:`EpochBasedRunner`): Epoch based Runner.
+ """
+ self._check_head(runner)
+
+ def before_val_epoch(self, runner):
+ """Check whether the dataset in val epoch is compatible with head.
+
+ Args:
+ runner (obj:`EpochBasedRunner`): Epoch based Runner.
+ """
+ self._check_head(runner)
diff --git a/mmdet/datasets/voc.py b/mmdet/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a3ea7aac75c7ef3ee1576ec05f251fd47412b72
--- /dev/null
+++ b/mmdet/datasets/voc.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+from mmcv.utils import print_log
+
+from mmdet.core import eval_map, eval_recalls
+from .builder import DATASETS
+from .xml_style import XMLDataset
+
+
+@DATASETS.register_module()
+class VOCDataset(XMLDataset):
+
+ CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
+ 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
+ 'tvmonitor')
+
+ PALETTE = [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192),
+ (197, 226, 255), (0, 60, 100), (0, 0, 142), (255, 77, 255),
+ (153, 69, 1), (120, 166, 157), (0, 182, 199), (0, 226, 252),
+ (182, 182, 255), (0, 0, 230), (220, 20, 60), (163, 255, 0),
+ (0, 82, 0), (3, 95, 161), (0, 80, 100), (183, 130, 88)]
+
+ def __init__(self, **kwargs):
+ super(VOCDataset, self).__init__(**kwargs)
+ if 'VOC2007' in self.img_prefix:
+ self.year = 2007
+ elif 'VOC2012' in self.img_prefix:
+ self.year = 2012
+ else:
+ raise ValueError('Cannot infer dataset year from img_prefix')
+
+ def evaluate(self,
+ results,
+ metric='mAP',
+ logger=None,
+ proposal_nums=(100, 300, 1000),
+ iou_thr=0.5,
+ scale_ranges=None):
+ """Evaluate in VOC protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'mAP', 'recall'.
+ logger (logging.Logger | str, optional): Logger used for printing
+ related information during evaluation. Default: None.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thr (float | list[float]): IoU threshold. Default: 0.5.
+ scale_ranges (list[tuple], optional): Scale ranges for evaluating
+ mAP. If not specified, all bounding boxes would be included in
+ evaluation. Default: None.
+
+ Returns:
+ dict[str, float]: AP/recall metrics.
+ """
+
+ if not isinstance(metric, str):
+ assert len(metric) == 1
+ metric = metric[0]
+ allowed_metrics = ['mAP', 'recall']
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ annotations = [self.get_ann_info(i) for i in range(len(self))]
+ eval_results = OrderedDict()
+ iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
+ if metric == 'mAP':
+ assert isinstance(iou_thrs, list)
+ if self.year == 2007:
+ ds_name = 'voc07'
+ else:
+ ds_name = self.CLASSES
+ mean_aps = []
+ for iou_thr in iou_thrs:
+ print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
+ # Follow the official implementation,
+ # http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar
+ # we should use the legacy coordinate system in mmdet 1.x,
+ # which means w, h should be computed as 'x2 - x1 + 1` and
+ # `y2 - y1 + 1`
+ mean_ap, _ = eval_map(
+ results,
+ annotations,
+ scale_ranges=None,
+ iou_thr=iou_thr,
+ dataset=ds_name,
+ logger=logger,
+ use_legacy_coordinate=True)
+ mean_aps.append(mean_ap)
+ eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
+ eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
+ eval_results.move_to_end('mAP', last=False)
+ elif metric == 'recall':
+ gt_bboxes = [ann['bboxes'] for ann in annotations]
+ recalls = eval_recalls(
+ gt_bboxes,
+ results,
+ proposal_nums,
+ iou_thrs,
+ logger=logger,
+ use_legacy_coordinate=True)
+ for i, num in enumerate(proposal_nums):
+ for j, iou_thr in enumerate(iou_thrs):
+ eval_results[f'recall@{num}@{iou_thr}'] = recalls[i, j]
+ if recalls.shape[1] > 1:
+ ar = recalls.mean(axis=1)
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ return eval_results
diff --git a/mmdet/datasets/wider_face.py b/mmdet/datasets/wider_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..85a5fdc549659f9cf72e4511de28cc0ccb4a9f4c
--- /dev/null
+++ b/mmdet/datasets/wider_face.py
@@ -0,0 +1,54 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import xml.etree.ElementTree as ET
+
+import mmcv
+
+from .builder import DATASETS
+from .xml_style import XMLDataset
+
+
+@DATASETS.register_module()
+class WIDERFaceDataset(XMLDataset):
+ """Reader for the WIDER Face dataset in PASCAL VOC format.
+
+ Conversion scripts can be found in
+ https://github.com/sovrasov/wider-face-pascal-voc-annotations
+ """
+ CLASSES = ('face', )
+
+ PALETTE = [(0, 255, 0)]
+
+ def __init__(self, **kwargs):
+ super(WIDERFaceDataset, self).__init__(**kwargs)
+
+ def load_annotations(self, ann_file):
+ """Load annotation from WIDERFace XML style annotation file.
+
+ Args:
+ ann_file (str): Path of XML file.
+
+ Returns:
+ list[dict]: Annotation info from XML file.
+ """
+
+ data_infos = []
+ img_ids = mmcv.list_from_file(ann_file)
+ for img_id in img_ids:
+ filename = f'{img_id}.jpg'
+ xml_path = osp.join(self.img_prefix, 'Annotations',
+ f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ size = root.find('size')
+ width = int(size.find('width').text)
+ height = int(size.find('height').text)
+ folder = root.find('folder').text
+ data_infos.append(
+ dict(
+ id=img_id,
+ filename=osp.join(folder, filename),
+ width=width,
+ height=height))
+
+ return data_infos
diff --git a/mmdet/datasets/xml_style.py b/mmdet/datasets/xml_style.py
new file mode 100644
index 0000000000000000000000000000000000000000..039d5d7d08fc9874b7378444c0ff63b5d8dd2ade
--- /dev/null
+++ b/mmdet/datasets/xml_style.py
@@ -0,0 +1,178 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import xml.etree.ElementTree as ET
+
+import mmcv
+import numpy as np
+from PIL import Image
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class XMLDataset(CustomDataset):
+ """XML dataset for detection.
+
+ Args:
+ min_size (int | float, optional): The minimum size of bounding
+ boxes in the images. If the size of a bounding box is less than
+ ``min_size``, it would be add to ignored field.
+ img_subdir (str): Subdir where images are stored. Default: JPEGImages.
+ ann_subdir (str): Subdir where annotations are. Default: Annotations.
+ """
+
+ def __init__(self,
+ min_size=None,
+ img_subdir='JPEGImages',
+ ann_subdir='Annotations',
+ **kwargs):
+ assert self.CLASSES or kwargs.get(
+ 'classes', None), 'CLASSES in `XMLDataset` can not be None.'
+ self.img_subdir = img_subdir
+ self.ann_subdir = ann_subdir
+ super(XMLDataset, self).__init__(**kwargs)
+ self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)}
+ self.min_size = min_size
+
+ def load_annotations(self, ann_file):
+ """Load annotation from XML style ann_file.
+
+ Args:
+ ann_file (str): Path of XML file.
+
+ Returns:
+ list[dict]: Annotation info from XML file.
+ """
+
+ data_infos = []
+ img_ids = mmcv.list_from_file(ann_file)
+ for img_id in img_ids:
+ filename = osp.join(self.img_subdir, f'{img_id}.jpg')
+ xml_path = osp.join(self.img_prefix, self.ann_subdir,
+ f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ size = root.find('size')
+ if size is not None:
+ width = int(size.find('width').text)
+ height = int(size.find('height').text)
+ else:
+ img_path = osp.join(self.img_prefix, filename)
+ img = Image.open(img_path)
+ width, height = img.size
+ data_infos.append(
+ dict(id=img_id, filename=filename, width=width, height=height))
+
+ return data_infos
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without annotation."""
+ valid_inds = []
+ for i, img_info in enumerate(self.data_infos):
+ if min(img_info['width'], img_info['height']) < min_size:
+ continue
+ if self.filter_empty_gt:
+ img_id = img_info['id']
+ xml_path = osp.join(self.img_prefix, self.ann_subdir,
+ f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ for obj in root.findall('object'):
+ name = obj.find('name').text
+ if name in self.CLASSES:
+ valid_inds.append(i)
+ break
+ else:
+ valid_inds.append(i)
+ return valid_inds
+
+ def get_ann_info(self, idx):
+ """Get annotation from XML file by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ bboxes = []
+ labels = []
+ bboxes_ignore = []
+ labels_ignore = []
+ for obj in root.findall('object'):
+ name = obj.find('name').text
+ if name not in self.CLASSES:
+ continue
+ label = self.cat2label[name]
+ difficult = obj.find('difficult')
+ difficult = 0 if difficult is None else int(difficult.text)
+ bnd_box = obj.find('bndbox')
+ # TODO: check whether it is necessary to use int
+ # Coordinates may be float type
+ bbox = [
+ int(float(bnd_box.find('xmin').text)),
+ int(float(bnd_box.find('ymin').text)),
+ int(float(bnd_box.find('xmax').text)),
+ int(float(bnd_box.find('ymax').text))
+ ]
+ ignore = False
+ if self.min_size:
+ assert not self.test_mode
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ if w < self.min_size or h < self.min_size:
+ ignore = True
+ if difficult or ignore:
+ bboxes_ignore.append(bbox)
+ labels_ignore.append(label)
+ else:
+ bboxes.append(bbox)
+ labels.append(label)
+ if not bboxes:
+ bboxes = np.zeros((0, 4))
+ labels = np.zeros((0, ))
+ else:
+ bboxes = np.array(bboxes, ndmin=2) - 1
+ labels = np.array(labels)
+ if not bboxes_ignore:
+ bboxes_ignore = np.zeros((0, 4))
+ labels_ignore = np.zeros((0, ))
+ else:
+ bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
+ labels_ignore = np.array(labels_ignore)
+ ann = dict(
+ bboxes=bboxes.astype(np.float32),
+ labels=labels.astype(np.int64),
+ bboxes_ignore=bboxes_ignore.astype(np.float32),
+ labels_ignore=labels_ignore.astype(np.int64))
+ return ann
+
+ def get_cat_ids(self, idx):
+ """Get category ids in XML file by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ cat_ids = []
+ img_id = self.data_infos[idx]['id']
+ xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ for obj in root.findall('object'):
+ name = obj.find('name').text
+ if name not in self.CLASSES:
+ continue
+ label = self.cat2label[name]
+ cat_ids.append(label)
+
+ return cat_ids
diff --git a/mmdet/models/__init__.py b/mmdet/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..12efb013d26d6b7ee27226a0f205d7e009e4b5f3
--- /dev/null
+++ b/mmdet/models/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .backbones import * # noqa: F401,F403
+from .builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
+ ROI_EXTRACTORS, SHARED_HEADS, build_backbone,
+ build_detector, build_head, build_loss, build_neck,
+ build_roi_extractor, build_shared_head)
+from .dense_heads import * # noqa: F401,F403
+from .detectors import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .plugins import * # noqa: F401,F403
+from .roi_heads import * # noqa: F401,F403
+from .seg_heads import * # noqa: F401,F403
+
+__all__ = [
+ 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
+ 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
+ 'build_shared_head', 'build_head', 'build_loss', 'build_detector'
+]
diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..91b50d254a8866c7376286470c47e4de936d07ad
--- /dev/null
+++ b/mmdet/models/backbones/__init__.py
@@ -0,0 +1,26 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .csp_darknet import CSPDarknet
+from .darknet import Darknet
+from .detectors_resnet import DetectoRS_ResNet
+from .detectors_resnext import DetectoRS_ResNeXt
+from .efficientnet import EfficientNet
+from .hourglass import HourglassNet
+from .hrnet import HRNet
+from .mobilenet_v2 import MobileNetV2
+from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2
+from .regnet import RegNet
+from .res2net import Res2Net
+from .resnest import ResNeSt
+from .resnet import ResNet, ResNetV1d
+from .resnext import ResNeXt
+from .ssd_vgg import SSDVGG
+from .swin import SwinTransformer
+from .trident_resnet import TridentResNet
+
+__all__ = [
+ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet',
+ 'MobileNetV2', 'Res2Net', 'HourglassNet', 'DetectoRS_ResNet',
+ 'DetectoRS_ResNeXt', 'Darknet', 'ResNeSt', 'TridentResNet', 'CSPDarknet',
+ 'SwinTransformer', 'PyramidVisionTransformer',
+ 'PyramidVisionTransformerV2', 'EfficientNet'
+]
diff --git a/mmdet/models/backbones/csp_darknet.py b/mmdet/models/backbones/csp_darknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bbf3968a818ad9c1d27d82e3ef17e9c2f8072bc
--- /dev/null
+++ b/mmdet/models/backbones/csp_darknet.py
@@ -0,0 +1,284 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+from mmcv.runner import BaseModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import CSPLayer
+
+
+class Focus(nn.Module):
+ """Focus width and height information into channel space.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ kernel_size (int): The kernel size of the convolution. Default: 1
+ stride (int): The stride of the convolution. Default: 1
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', momentum=0.03, eps=0.001).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='Swish').
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish')):
+ super().__init__()
+ self.conv = ConvModule(
+ in_channels * 4,
+ out_channels,
+ kernel_size,
+ stride,
+ padding=(kernel_size - 1) // 2,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, x):
+ # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
+ patch_top_left = x[..., ::2, ::2]
+ patch_top_right = x[..., ::2, 1::2]
+ patch_bot_left = x[..., 1::2, ::2]
+ patch_bot_right = x[..., 1::2, 1::2]
+ x = torch.cat(
+ (
+ patch_top_left,
+ patch_bot_left,
+ patch_top_right,
+ patch_bot_right,
+ ),
+ dim=1,
+ )
+ return self.conv(x)
+
+
+class SPPBottleneck(BaseModule):
+ """Spatial pyramid pooling layer used in YOLOv3-SPP.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ kernel_sizes (tuple[int]): Sequential of kernel sizes of pooling
+ layers. Default: (5, 9, 13).
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='Swish').
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_sizes=(5, 9, 13),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish'),
+ init_cfg=None):
+ super().__init__(init_cfg)
+ mid_channels = in_channels // 2
+ self.conv1 = ConvModule(
+ in_channels,
+ mid_channels,
+ 1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.poolings = nn.ModuleList([
+ nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
+ for ks in kernel_sizes
+ ])
+ conv2_channels = mid_channels * (len(kernel_sizes) + 1)
+ self.conv2 = ConvModule(
+ conv2_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1)
+ x = self.conv2(x)
+ return x
+
+
+@BACKBONES.register_module()
+class CSPDarknet(BaseModule):
+ """CSP-Darknet backbone used in YOLOv5 and YOLOX.
+
+ Args:
+ arch (str): Architecture of CSP-Darknet, from {P5, P6}.
+ Default: P5.
+ deepen_factor (float): Depth multiplier, multiply number of
+ blocks in CSP layer by this amount. Default: 1.0.
+ widen_factor (float): Width multiplier, multiply number of
+ channels in each layer by this amount. Default: 1.0.
+ out_indices (Sequence[int]): Output from which stages.
+ Default: (2, 3, 4).
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
+ mode). -1 means not freezing any parameters. Default: -1.
+ use_depthwise (bool): Whether to use depthwise separable convolution.
+ Default: False.
+ arch_ovewrite(list): Overwrite default arch settings. Default: None.
+ spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
+ layers. Default: (5, 9, 13).
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ Example:
+ >>> from mmdet.models import CSPDarknet
+ >>> import torch
+ >>> self = CSPDarknet(depth=53)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 416, 416)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ ...
+ (1, 256, 52, 52)
+ (1, 512, 26, 26)
+ (1, 1024, 13, 13)
+ """
+ # From left to right:
+ # in_channels, out_channels, num_blocks, add_identity, use_spp
+ arch_settings = {
+ 'P5': [[64, 128, 3, True, False], [128, 256, 9, True, False],
+ [256, 512, 9, True, False], [512, 1024, 3, False, True]],
+ 'P6': [[64, 128, 3, True, False], [128, 256, 9, True, False],
+ [256, 512, 9, True, False], [512, 768, 3, True, False],
+ [768, 1024, 3, False, True]]
+ }
+
+ def __init__(self,
+ arch='P5',
+ deepen_factor=1.0,
+ widen_factor=1.0,
+ out_indices=(2, 3, 4),
+ frozen_stages=-1,
+ use_depthwise=False,
+ arch_ovewrite=None,
+ spp_kernal_sizes=(5, 9, 13),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish'),
+ norm_eval=False,
+ init_cfg=dict(
+ type='Kaiming',
+ layer='Conv2d',
+ a=math.sqrt(5),
+ distribution='uniform',
+ mode='fan_in',
+ nonlinearity='leaky_relu')):
+ super().__init__(init_cfg)
+ arch_setting = self.arch_settings[arch]
+ if arch_ovewrite:
+ arch_setting = arch_ovewrite
+ assert set(out_indices).issubset(
+ i for i in range(len(arch_setting) + 1))
+ if frozen_stages not in range(-1, len(arch_setting) + 1):
+ raise ValueError('frozen_stages must be in range(-1, '
+ 'len(arch_setting) + 1). But received '
+ f'{frozen_stages}')
+
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.use_depthwise = use_depthwise
+ self.norm_eval = norm_eval
+ conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
+
+ self.stem = Focus(
+ 3,
+ int(arch_setting[0][0] * widen_factor),
+ kernel_size=3,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.layers = ['stem']
+
+ for i, (in_channels, out_channels, num_blocks, add_identity,
+ use_spp) in enumerate(arch_setting):
+ in_channels = int(in_channels * widen_factor)
+ out_channels = int(out_channels * widen_factor)
+ num_blocks = max(round(num_blocks * deepen_factor), 1)
+ stage = []
+ conv_layer = conv(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ stage.append(conv_layer)
+ if use_spp:
+ spp = SPPBottleneck(
+ out_channels,
+ out_channels,
+ kernel_sizes=spp_kernal_sizes,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ stage.append(spp)
+ csp_layer = CSPLayer(
+ out_channels,
+ out_channels,
+ num_blocks=num_blocks,
+ add_identity=add_identity,
+ use_depthwise=use_depthwise,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ stage.append(csp_layer)
+ self.add_module(f'stage{i + 1}', nn.Sequential(*stage))
+ self.layers.append(f'stage{i + 1}')
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for i in range(self.frozen_stages + 1):
+ m = getattr(self, self.layers[i])
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(CSPDarknet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def forward(self, x):
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
diff --git a/mmdet/models/backbones/darknet.py b/mmdet/models/backbones/darknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..adfb1159b507d9fdc5bc6af20fe64411a8b55f92
--- /dev/null
+++ b/mmdet/models/backbones/darknet.py
@@ -0,0 +1,213 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Copyright (c) 2019 Western Digital Corporation or its affiliates.
+
+import warnings
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+
+
+class ResBlock(BaseModule):
+ """The basic residual block used in Darknet. Each ResBlock consists of two
+ ConvModules and the input is added to the final output. Each ConvModule is
+ composed of Conv, BN, and LeakyReLU. In YoloV3 paper, the first convLayer
+ has half of the number of the filters as much as the second convLayer. The
+ first convLayer has filter size of 1x1 and the second one has the filter
+ size of 3x3.
+
+ Args:
+ in_channels (int): The input channels. Must be even.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
+ init_cfg=None):
+ super(ResBlock, self).__init__(init_cfg)
+ assert in_channels % 2 == 0 # ensure the in_channels is even
+ half_in_channels = in_channels // 2
+
+ # shortcut
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+
+ self.conv1 = ConvModule(in_channels, half_in_channels, 1, **cfg)
+ self.conv2 = ConvModule(
+ half_in_channels, in_channels, 3, padding=1, **cfg)
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.conv2(out)
+ out = out + residual
+
+ return out
+
+
+@BACKBONES.register_module()
+class Darknet(BaseModule):
+ """Darknet backbone.
+
+ Args:
+ depth (int): Depth of Darknet. Currently only support 53.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+
+ Example:
+ >>> from mmdet.models import Darknet
+ >>> import torch
+ >>> self = Darknet(depth=53)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 416, 416)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ ...
+ (1, 256, 52, 52)
+ (1, 512, 26, 26)
+ (1, 1024, 13, 13)
+ """
+
+ # Dict(depth: (layers, channels))
+ arch_settings = {
+ 53: ((1, 2, 8, 8, 4), ((32, 64), (64, 128), (128, 256), (256, 512),
+ (512, 1024)))
+ }
+
+ def __init__(self,
+ depth=53,
+ out_indices=(3, 4, 5),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
+ norm_eval=True,
+ pretrained=None,
+ init_cfg=None):
+ super(Darknet, self).__init__(init_cfg)
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for darknet')
+
+ self.depth = depth
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.layers, self.channels = self.arch_settings[depth]
+
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+
+ self.conv1 = ConvModule(3, 32, 3, padding=1, **cfg)
+
+ self.cr_blocks = ['conv1']
+ for i, n_layers in enumerate(self.layers):
+ layer_name = f'conv_res_block{i + 1}'
+ in_c, out_c = self.channels[i]
+ self.add_module(
+ layer_name,
+ self.make_conv_res_block(in_c, out_c, n_layers, **cfg))
+ self.cr_blocks.append(layer_name)
+
+ self.norm_eval = norm_eval
+
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ outs = []
+ for i, layer_name in enumerate(self.cr_blocks):
+ cr_block = getattr(self, layer_name)
+ x = cr_block(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for i in range(self.frozen_stages):
+ m = getattr(self, self.cr_blocks[i])
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(Darknet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ @staticmethod
+ def make_conv_res_block(in_channels,
+ out_channels,
+ res_repeat,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU',
+ negative_slope=0.1)):
+ """In Darknet backbone, ConvLayer is usually followed by ResBlock. This
+ function will make that. The Conv layers always have 3x3 filters with
+ stride=2. The number of the filters in Conv layer is the same as the
+ out channels of the ResBlock.
+
+ Args:
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ res_repeat (int): The number of ResBlocks.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ """
+
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+
+ model = nn.Sequential()
+ model.add_module(
+ 'conv',
+ ConvModule(
+ in_channels, out_channels, 3, stride=2, padding=1, **cfg))
+ for idx in range(res_repeat):
+ model.add_module('res{}'.format(idx),
+ ResBlock(out_channels, **cfg))
+ return model
diff --git a/mmdet/models/backbones/detectors_resnet.py b/mmdet/models/backbones/detectors_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3c0d40b4284c1c2d5006df28620d230d93646cd
--- /dev/null
+++ b/mmdet/models/backbones/detectors_resnet.py
@@ -0,0 +1,353 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
+ kaiming_init)
+from mmcv.runner import Sequential, load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmdet.utils import get_root_logger
+from ..builder import BACKBONES
+from .resnet import BasicBlock
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottleneck(_Bottleneck):
+ r"""Bottleneck for the ResNet backbone in `DetectoRS
+ `_.
+
+ This bottleneck allows the users to specify whether to use
+ SAC (Switchable Atrous Convolution) and RFP (Recursive Feature Pyramid).
+
+ Args:
+ inplanes (int): The number of input channels.
+ planes (int): The number of output channels before expansion.
+ rfp_inplanes (int, optional): The number of channels from RFP.
+ Default: None. If specified, an additional conv layer will be
+ added for ``rfp_feat``. Otherwise, the structure is the same as
+ base class.
+ sac (dict, optional): Dictionary to construct SAC. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ rfp_inplanes=None,
+ sac=None,
+ init_cfg=None,
+ **kwargs):
+ super(Bottleneck, self).__init__(
+ inplanes, planes, init_cfg=init_cfg, **kwargs)
+
+ assert sac is None or isinstance(sac, dict)
+ self.sac = sac
+ self.with_sac = sac is not None
+ if self.with_sac:
+ self.conv2 = build_conv_layer(
+ self.sac,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ bias=False)
+
+ self.rfp_inplanes = rfp_inplanes
+ if self.rfp_inplanes:
+ self.rfp_conv = build_conv_layer(
+ None,
+ self.rfp_inplanes,
+ planes * self.expansion,
+ 1,
+ stride=1,
+ bias=True)
+ if init_cfg is None:
+ self.init_cfg = dict(
+ type='Constant', val=0, override=dict(name='rfp_conv'))
+
+ def rfp_forward(self, x, rfp_feat):
+ """The forward function that also takes the RFP features as input."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ if self.rfp_inplanes:
+ rfp_feat = self.rfp_conv(rfp_feat)
+ out = out + rfp_feat
+
+ out = self.relu(out)
+
+ return out
+
+
+class ResLayer(Sequential):
+ """ResLayer to build ResNet style backbone for RPF in detectoRS.
+
+ The difference between this module and base class is that we pass
+ ``rfp_inplanes`` to the first block.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ downsample_first (bool): Downsample at the first block or last block.
+ False for Hourglass, True for ResNet. Default: True
+ rfp_inplanes (int, optional): The number of channels from RFP.
+ Default: None. If specified, an additional conv layer will be
+ added for ``rfp_feat``. Otherwise, the structure is the same as
+ base class.
+ """
+
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ downsample_first=True,
+ rfp_inplanes=None,
+ **kwargs):
+ self.block = block
+ assert downsample_first, f'downsample_first={downsample_first} is ' \
+ 'not supported in DetectoRS'
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down and stride != 1:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ rfp_inplanes=rfp_inplanes,
+ **kwargs))
+ inplanes = planes * block.expansion
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+
+ super(ResLayer, self).__init__(*layers)
+
+
+@BACKBONES.register_module()
+class DetectoRS_ResNet(ResNet):
+ """ResNet backbone for DetectoRS.
+
+ Args:
+ sac (dict, optional): Dictionary to construct SAC (Switchable Atrous
+ Convolution). Default: None.
+ stage_with_sac (list): Which stage to use sac. Default: (False, False,
+ False, False).
+ rfp_inplanes (int, optional): The number of channels from RFP.
+ Default: None. If specified, an additional conv layer will be
+ added for ``rfp_feat``. Otherwise, the structure is the same as
+ base class.
+ output_img (bool): If ``True``, the input image will be inserted into
+ the starting position of output. Default: False.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ sac=None,
+ stage_with_sac=(False, False, False, False),
+ rfp_inplanes=None,
+ output_img=False,
+ pretrained=None,
+ init_cfg=None,
+ **kwargs):
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+ self.pretrained = pretrained
+ if init_cfg is not None:
+ assert isinstance(init_cfg, dict), \
+ f'init_cfg must be a dict, but got {type(init_cfg)}'
+ if 'type' in init_cfg:
+ assert init_cfg.get('type') == 'Pretrained', \
+ 'Only can initialize module by loading a pretrained model'
+ else:
+ raise KeyError('`init_cfg` must contain the key "type"')
+ self.pretrained = init_cfg.get('checkpoint')
+ self.sac = sac
+ self.stage_with_sac = stage_with_sac
+ self.rfp_inplanes = rfp_inplanes
+ self.output_img = output_img
+ super(DetectoRS_ResNet, self).__init__(**kwargs)
+
+ self.inplanes = self.stem_channels
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ sac = self.sac if self.stage_with_sac[i] else None
+ if self.plugins is not None:
+ stage_plugins = self.make_stage_plugins(self.plugins, i)
+ else:
+ stage_plugins = None
+ planes = self.base_channels * 2**i
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=self.with_cp,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=dcn,
+ sac=sac,
+ rfp_inplanes=rfp_inplanes if i > 0 else None,
+ plugins=stage_plugins)
+ self.inplanes = planes * self.block.expansion
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ # In order to be properly initialized by RFP
+ def init_weights(self):
+ # Calling this method will cause parameter initialization exception
+ # super(DetectoRS_ResNet, self).init_weights()
+
+ if isinstance(self.pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, self.pretrained, strict=False, logger=logger)
+ elif self.pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.dcn is not None:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and hasattr(
+ m.conv2, 'conv_offset'):
+ constant_init(m.conv2.conv_offset, 0)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer`` for DetectoRS."""
+ return ResLayer(**kwargs)
+
+ def forward(self, x):
+ """Forward function."""
+ outs = list(super(DetectoRS_ResNet, self).forward(x))
+ if self.output_img:
+ outs.insert(0, x)
+ return tuple(outs)
+
+ def rfp_forward(self, x, rfp_feats):
+ """Forward function for RFP."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ rfp_feat = rfp_feats[i] if i > 0 else None
+ for layer in res_layer:
+ x = layer.rfp_forward(x, rfp_feat)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
diff --git a/mmdet/models/backbones/detectors_resnext.py b/mmdet/models/backbones/detectors_resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8b20a0266a9d7e37ff1d39b3a160abef565c85
--- /dev/null
+++ b/mmdet/models/backbones/detectors_resnext.py
@@ -0,0 +1,123 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from .detectors_resnet import Bottleneck as _Bottleneck
+from .detectors_resnet import DetectoRS_ResNet
+
+
+class Bottleneck(_Bottleneck):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ **kwargs):
+ """Bottleneck block for ResNeXt.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, width, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if self.with_sac:
+ self.conv2 = build_conv_layer(
+ self.sac,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ elif not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+
+@BACKBONES.register_module()
+class DetectoRS_ResNeXt(DetectoRS_ResNet):
+ """ResNeXt backbone for DetectoRS.
+
+ Args:
+ groups (int): The number of groups in ResNeXt.
+ base_width (int): The base width of ResNeXt.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, groups=1, base_width=4, **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ super(DetectoRS_ResNeXt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return super().make_res_layer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/mmdet/models/backbones/efficientnet.py b/mmdet/models/backbones/efficientnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ee359567d91d0e42aa09dd2ad3be4ba006176c0
--- /dev/null
+++ b/mmdet/models/backbones/efficientnet.py
@@ -0,0 +1,417 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn.bricks import ConvModule, DropPath
+from mmcv.runner import BaseModule, Sequential
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidual, SELayer, make_divisible
+
+
+class EdgeResidual(BaseModule):
+ """Edge Residual Block.
+
+ Args:
+ in_channels (int): The input channels of this module.
+ out_channels (int): The output channels of this module.
+ mid_channels (int): The input channels of the second convolution.
+ kernel_size (int): The kernel size of the first convolution.
+ Defaults to 3.
+ stride (int): The stride of the first convolution. Defaults to 1.
+ se_cfg (dict, optional): Config dict for se layer. Defaults to None,
+ which means no se layer.
+ with_residual (bool): Use residual connection. Defaults to True.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Defaults to None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): Config dict for activation layer.
+ Defaults to ``dict(type='ReLU')``.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ init_cfg (dict | list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ se_cfg=None,
+ with_residual=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ drop_path_rate=0.,
+ with_cp=False,
+ init_cfg=None,
+ **kwargs):
+ super(EdgeResidual, self).__init__(init_cfg=init_cfg)
+ assert stride in [1, 2]
+ self.with_cp = with_cp
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+ self.with_se = se_cfg is not None
+ self.with_residual = (
+ stride == 1 and in_channels == out_channels and with_residual)
+
+ if self.with_se:
+ assert isinstance(se_cfg, dict)
+
+ self.conv1 = ConvModule(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ if self.with_se:
+ self.se = SELayer(**se_cfg)
+
+ self.conv2 = ConvModule(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=stride,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x
+ out = self.conv1(out)
+
+ if self.with_se:
+ out = self.se(out)
+
+ out = self.conv2(out)
+
+ if self.with_residual:
+ return x + self.drop_path(out)
+ else:
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+def model_scaling(layer_setting, arch_setting):
+ """Scaling operation to the layer's parameters according to the
+ arch_setting."""
+ # scale width
+ new_layer_setting = copy.deepcopy(layer_setting)
+ for layer_cfg in new_layer_setting:
+ for block_cfg in layer_cfg:
+ block_cfg[1] = make_divisible(block_cfg[1] * arch_setting[0], 8)
+
+ # scale depth
+ split_layer_setting = [new_layer_setting[0]]
+ for layer_cfg in new_layer_setting[1:-1]:
+ tmp_index = [0]
+ for i in range(len(layer_cfg) - 1):
+ if layer_cfg[i + 1][1] != layer_cfg[i][1]:
+ tmp_index.append(i + 1)
+ tmp_index.append(len(layer_cfg))
+ for i in range(len(tmp_index) - 1):
+ split_layer_setting.append(layer_cfg[tmp_index[i]:tmp_index[i +
+ 1]])
+ split_layer_setting.append(new_layer_setting[-1])
+
+ num_of_layers = [len(layer_cfg) for layer_cfg in split_layer_setting[1:-1]]
+ new_layers = [
+ int(math.ceil(arch_setting[1] * num)) for num in num_of_layers
+ ]
+
+ merge_layer_setting = [split_layer_setting[0]]
+ for i, layer_cfg in enumerate(split_layer_setting[1:-1]):
+ if new_layers[i] <= num_of_layers[i]:
+ tmp_layer_cfg = layer_cfg[:new_layers[i]]
+ else:
+ tmp_layer_cfg = copy.deepcopy(layer_cfg) + [layer_cfg[-1]] * (
+ new_layers[i] - num_of_layers[i])
+ if tmp_layer_cfg[0][3] == 1 and i != 0:
+ merge_layer_setting[-1] += tmp_layer_cfg.copy()
+ else:
+ merge_layer_setting.append(tmp_layer_cfg.copy())
+ merge_layer_setting.append(split_layer_setting[-1])
+
+ return merge_layer_setting
+
+
+@BACKBONES.register_module()
+class EfficientNet(BaseModule):
+ """EfficientNet backbone.
+
+ Args:
+ arch (str): Architecture of efficientnet. Defaults to b0.
+ out_indices (Sequence[int]): Output from which stages.
+ Defaults to (6, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Defaults to 0, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer.
+ Defaults to None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Defaults to dict(type='Swish').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Defaults to False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Defaults to False.
+ """
+
+ # Parameters to build layers.
+ # 'b' represents the architecture of normal EfficientNet family includes
+ # 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8'.
+ # 'e' represents the architecture of EfficientNet-EdgeTPU including 'es',
+ # 'em', 'el'.
+ # 6 parameters are needed to construct a layer, From left to right:
+ # - kernel_size: The kernel size of the block
+ # - out_channel: The number of out_channels of the block
+ # - se_ratio: The sequeeze ratio of SELayer.
+ # - stride: The stride of the block
+ # - expand_ratio: The expand_ratio of the mid_channels
+ # - block_type: -1: Not a block, 0: InvertedResidual, 1: EdgeResidual
+ layer_settings = {
+ 'b': [[[3, 32, 0, 2, 0, -1]],
+ [[3, 16, 4, 1, 1, 0]],
+ [[3, 24, 4, 2, 6, 0],
+ [3, 24, 4, 1, 6, 0]],
+ [[5, 40, 4, 2, 6, 0],
+ [5, 40, 4, 1, 6, 0]],
+ [[3, 80, 4, 2, 6, 0],
+ [3, 80, 4, 1, 6, 0],
+ [3, 80, 4, 1, 6, 0],
+ [5, 112, 4, 1, 6, 0],
+ [5, 112, 4, 1, 6, 0],
+ [5, 112, 4, 1, 6, 0]],
+ [[5, 192, 4, 2, 6, 0],
+ [5, 192, 4, 1, 6, 0],
+ [5, 192, 4, 1, 6, 0],
+ [5, 192, 4, 1, 6, 0],
+ [3, 320, 4, 1, 6, 0]],
+ [[1, 1280, 0, 1, 0, -1]]
+ ],
+ 'e': [[[3, 32, 0, 2, 0, -1]],
+ [[3, 24, 0, 1, 3, 1]],
+ [[3, 32, 0, 2, 8, 1],
+ [3, 32, 0, 1, 8, 1]],
+ [[3, 48, 0, 2, 8, 1],
+ [3, 48, 0, 1, 8, 1],
+ [3, 48, 0, 1, 8, 1],
+ [3, 48, 0, 1, 8, 1]],
+ [[5, 96, 0, 2, 8, 0],
+ [5, 96, 0, 1, 8, 0],
+ [5, 96, 0, 1, 8, 0],
+ [5, 96, 0, 1, 8, 0],
+ [5, 96, 0, 1, 8, 0],
+ [5, 144, 0, 1, 8, 0],
+ [5, 144, 0, 1, 8, 0],
+ [5, 144, 0, 1, 8, 0],
+ [5, 144, 0, 1, 8, 0]],
+ [[5, 192, 0, 2, 8, 0],
+ [5, 192, 0, 1, 8, 0]],
+ [[1, 1280, 0, 1, 0, -1]]
+ ]
+ } # yapf: disable
+
+ # Parameters to build different kinds of architecture.
+ # From left to right: scaling factor for width, scaling factor for depth,
+ # resolution.
+ arch_settings = {
+ 'b0': (1.0, 1.0, 224),
+ 'b1': (1.0, 1.1, 240),
+ 'b2': (1.1, 1.2, 260),
+ 'b3': (1.2, 1.4, 300),
+ 'b4': (1.4, 1.8, 380),
+ 'b5': (1.6, 2.2, 456),
+ 'b6': (1.8, 2.6, 528),
+ 'b7': (2.0, 3.1, 600),
+ 'b8': (2.2, 3.6, 672),
+ 'es': (1.0, 1.0, 224),
+ 'em': (1.0, 1.1, 240),
+ 'el': (1.2, 1.4, 300)
+ }
+
+ def __init__(self,
+ arch='b0',
+ drop_path_rate=0.,
+ out_indices=(6, ),
+ frozen_stages=0,
+ conv_cfg=dict(type='Conv2dAdaptivePadding'),
+ norm_cfg=dict(type='BN', eps=1e-3),
+ act_cfg=dict(type='Swish'),
+ norm_eval=False,
+ with_cp=False,
+ init_cfg=[
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ layer=['_BatchNorm', 'GroupNorm'],
+ val=1)
+ ]):
+ super(EfficientNet, self).__init__(init_cfg)
+ assert arch in self.arch_settings, \
+ f'"{arch}" is not one of the arch_settings ' \
+ f'({", ".join(self.arch_settings.keys())})'
+ self.arch_setting = self.arch_settings[arch]
+ self.layer_setting = self.layer_settings[arch[:1]]
+ for index in out_indices:
+ if index not in range(0, len(self.layer_setting)):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, {len(self.layer_setting)}). '
+ f'But received {index}')
+
+ if frozen_stages not in range(len(self.layer_setting) + 1):
+ raise ValueError('frozen_stages must be in range(0, '
+ f'{len(self.layer_setting) + 1}). '
+ f'But received {frozen_stages}')
+ self.drop_path_rate = drop_path_rate
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.layer_setting = model_scaling(self.layer_setting,
+ self.arch_setting)
+ block_cfg_0 = self.layer_setting[0][0]
+ block_cfg_last = self.layer_setting[-1][0]
+ self.in_channels = make_divisible(block_cfg_0[1], 8)
+ self.out_channels = block_cfg_last[1]
+ self.layers = nn.ModuleList()
+ self.layers.append(
+ ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=block_cfg_0[0],
+ stride=block_cfg_0[3],
+ padding=block_cfg_0[0] // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.make_layer()
+ # Avoid building unused layers in mmdetection.
+ if len(self.layers) < max(self.out_indices) + 1:
+ self.layers.append(
+ ConvModule(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ kernel_size=block_cfg_last[0],
+ stride=block_cfg_last[3],
+ padding=block_cfg_last[0] // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ def make_layer(self):
+ # Without the first and the final conv block.
+ layer_setting = self.layer_setting[1:-1]
+
+ total_num_blocks = sum([len(x) for x in layer_setting])
+ block_idx = 0
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, self.drop_path_rate, total_num_blocks)
+ ] # stochastic depth decay rule
+
+ for i, layer_cfg in enumerate(layer_setting):
+ # Avoid building unused layers in mmdetection.
+ if i > max(self.out_indices) - 1:
+ break
+ layer = []
+ for i, block_cfg in enumerate(layer_cfg):
+ (kernel_size, out_channels, se_ratio, stride, expand_ratio,
+ block_type) = block_cfg
+
+ mid_channels = int(self.in_channels * expand_ratio)
+ out_channels = make_divisible(out_channels, 8)
+ if se_ratio <= 0:
+ se_cfg = None
+ else:
+ # In mmdetection, the `divisor` is deleted to align
+ # the logic of SELayer with mmcls.
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=expand_ratio * se_ratio,
+ act_cfg=(self.act_cfg, dict(type='Sigmoid')))
+ if block_type == 1: # edge tpu
+ if i > 0 and expand_ratio == 3:
+ with_residual = False
+ expand_ratio = 4
+ else:
+ with_residual = True
+ mid_channels = int(self.in_channels * expand_ratio)
+ if se_cfg is not None:
+ # In mmdetection, the `divisor` is deleted to align
+ # the logic of SELayer with mmcls.
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=se_ratio * expand_ratio,
+ act_cfg=(self.act_cfg, dict(type='Sigmoid')))
+ block = partial(EdgeResidual, with_residual=with_residual)
+ else:
+ block = InvertedResidual
+ layer.append(
+ block(
+ in_channels=self.in_channels,
+ out_channels=out_channels,
+ mid_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ se_cfg=se_cfg,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ drop_path_rate=dpr[block_idx],
+ with_cp=self.with_cp,
+ # In mmdetection, `with_expand_conv` is set to align
+ # the logic of InvertedResidual with mmcls.
+ with_expand_conv=(mid_channels != self.in_channels)))
+ self.in_channels = out_channels
+ block_idx += 1
+ self.layers.append(Sequential(*layer))
+
+ def forward(self, x):
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(EfficientNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
diff --git a/mmdet/models/backbones/hourglass.py b/mmdet/models/backbones/hourglass.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0dfb434f8508d34d37a831230a8f794f0c354b4
--- /dev/null
+++ b/mmdet/models/backbones/hourglass.py
@@ -0,0 +1,222 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import BasicBlock
+
+
+class HourglassModule(BaseModule):
+ """Hourglass Module for HourglassNet backbone.
+
+ Generate module recursively and use BasicBlock as the base unit.
+
+ Args:
+ depth (int): Depth of current HourglassModule.
+ stage_channels (list[int]): Feature channels of sub-modules in current
+ and follow-up HourglassModule.
+ stage_blocks (list[int]): Number of sub-modules stacked in current and
+ follow-up HourglassModule.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ upsample_cfg (dict, optional): Config dict for interpolate layer.
+ Default: `dict(mode='nearest')`
+ """
+
+ def __init__(self,
+ depth,
+ stage_channels,
+ stage_blocks,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ init_cfg=None,
+ upsample_cfg=dict(mode='nearest')):
+ super(HourglassModule, self).__init__(init_cfg)
+
+ self.depth = depth
+
+ cur_block = stage_blocks[0]
+ next_block = stage_blocks[1]
+
+ cur_channel = stage_channels[0]
+ next_channel = stage_channels[1]
+
+ self.up1 = ResLayer(
+ BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg)
+
+ self.low1 = ResLayer(
+ BasicBlock,
+ cur_channel,
+ next_channel,
+ cur_block,
+ stride=2,
+ norm_cfg=norm_cfg)
+
+ if self.depth > 1:
+ self.low2 = HourglassModule(depth - 1, stage_channels[1:],
+ stage_blocks[1:])
+ else:
+ self.low2 = ResLayer(
+ BasicBlock,
+ next_channel,
+ next_channel,
+ next_block,
+ norm_cfg=norm_cfg)
+
+ self.low3 = ResLayer(
+ BasicBlock,
+ next_channel,
+ cur_channel,
+ cur_block,
+ norm_cfg=norm_cfg,
+ downsample_first=False)
+
+ self.up2 = F.interpolate
+ self.upsample_cfg = upsample_cfg
+
+ def forward(self, x):
+ """Forward function."""
+ up1 = self.up1(x)
+ low1 = self.low1(x)
+ low2 = self.low2(low1)
+ low3 = self.low3(low2)
+ # Fixing `scale factor` (e.g. 2) is common for upsampling, but
+ # in some cases the spatial size is mismatched and error will arise.
+ if 'scale_factor' in self.upsample_cfg:
+ up2 = self.up2(low3, **self.upsample_cfg)
+ else:
+ shape = up1.shape[2:]
+ up2 = self.up2(low3, size=shape, **self.upsample_cfg)
+ return up1 + up2
+
+
+@BACKBONES.register_module()
+class HourglassNet(BaseModule):
+ """HourglassNet backbone.
+
+ Stacked Hourglass Networks for Human Pose Estimation.
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ downsample_times (int): Downsample times in a HourglassModule.
+ num_stacks (int): Number of HourglassModule modules stacked,
+ 1 for Hourglass-52, 2 for Hourglass-104.
+ stage_channels (list[int]): Feature channel of each sub-module in a
+ HourglassModule.
+ stage_blocks (list[int]): Number of sub-modules stacked in a
+ HourglassModule.
+ feat_channel (int): Feature channel of conv after a HourglassModule.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+
+ Example:
+ >>> from mmdet.models import HourglassNet
+ >>> import torch
+ >>> self = HourglassNet()
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 511, 511)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_output in level_outputs:
+ ... print(tuple(level_output.shape))
+ (1, 256, 128, 128)
+ (1, 256, 128, 128)
+ """
+
+ def __init__(self,
+ downsample_times=5,
+ num_stacks=2,
+ stage_channels=(256, 256, 384, 384, 384, 512),
+ stage_blocks=(2, 2, 2, 2, 2, 4),
+ feat_channel=256,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ pretrained=None,
+ init_cfg=None):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ super(HourglassNet, self).__init__(init_cfg)
+
+ self.num_stacks = num_stacks
+ assert self.num_stacks >= 1
+ assert len(stage_channels) == len(stage_blocks)
+ assert len(stage_channels) > downsample_times
+
+ cur_channel = stage_channels[0]
+
+ self.stem = nn.Sequential(
+ ConvModule(
+ 3, cur_channel // 2, 7, padding=3, stride=2,
+ norm_cfg=norm_cfg),
+ ResLayer(
+ BasicBlock,
+ cur_channel // 2,
+ cur_channel,
+ 1,
+ stride=2,
+ norm_cfg=norm_cfg))
+
+ self.hourglass_modules = nn.ModuleList([
+ HourglassModule(downsample_times, stage_channels, stage_blocks)
+ for _ in range(num_stacks)
+ ])
+
+ self.inters = ResLayer(
+ BasicBlock,
+ cur_channel,
+ cur_channel,
+ num_stacks - 1,
+ norm_cfg=norm_cfg)
+
+ self.conv1x1s = nn.ModuleList([
+ ConvModule(
+ cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
+ for _ in range(num_stacks - 1)
+ ])
+
+ self.out_convs = nn.ModuleList([
+ ConvModule(
+ cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
+ for _ in range(num_stacks)
+ ])
+
+ self.remap_convs = nn.ModuleList([
+ ConvModule(
+ feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
+ for _ in range(num_stacks - 1)
+ ])
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def init_weights(self):
+ """Init module weights."""
+ # Training Centripetal Model needs to reset parameters for Conv2d
+ super(HourglassNet, self).init_weights()
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ m.reset_parameters()
+
+ def forward(self, x):
+ """Forward function."""
+ inter_feat = self.stem(x)
+ out_feats = []
+
+ for ind in range(self.num_stacks):
+ single_hourglass = self.hourglass_modules[ind]
+ out_conv = self.out_convs[ind]
+
+ hourglass_feat = single_hourglass(inter_feat)
+ out_feat = out_conv(hourglass_feat)
+ out_feats.append(out_feat)
+
+ if ind < self.num_stacks - 1:
+ inter_feat = self.conv1x1s[ind](
+ inter_feat) + self.remap_convs[ind](
+ out_feat)
+ inter_feat = self.inters[ind](self.relu(inter_feat))
+
+ return out_feats
diff --git a/mmdet/models/backbones/hrnet.py b/mmdet/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..06c210a6d422ccc0f55c96dc6c29be052af5494f
--- /dev/null
+++ b/mmdet/models/backbones/hrnet.py
@@ -0,0 +1,589 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmcv.runner import BaseModule, ModuleList, Sequential
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from .resnet import BasicBlock, Bottleneck
+
+
+class HRModule(BaseModule):
+ """High-Resolution Module for HRNet.
+
+ In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+ is in this module.
+ """
+
+ def __init__(self,
+ num_branches,
+ blocks,
+ num_blocks,
+ in_channels,
+ num_channels,
+ multiscale_output=True,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ block_init_cfg=None,
+ init_cfg=None):
+ super(HRModule, self).__init__(init_cfg)
+ self.block_init_cfg = block_init_cfg
+ self._check_branches(num_branches, num_blocks, in_channels,
+ num_channels)
+
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+
+ self.multiscale_output = multiscale_output
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.with_cp = with_cp
+ self.branches = self._make_branches(num_branches, blocks, num_blocks,
+ num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=False)
+
+ def _check_branches(self, num_branches, num_blocks, in_channels,
+ num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_BLOCKS({len(num_blocks)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_CHANNELS({len(num_channels)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_INCHANNELS({len(in_channels)})'
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self,
+ branch_index,
+ block,
+ num_blocks,
+ num_channels,
+ stride=1):
+ downsample = None
+ if stride != 1 or \
+ self.in_channels[branch_index] != \
+ num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ self.in_channels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, num_channels[branch_index] *
+ block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ init_cfg=self.block_init_cfg))
+ self.in_channels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ init_cfg=self.block_init_cfg))
+
+ return Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ nn.Upsample(
+ scale_factor=2**(j - i), mode='nearest')))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv_downsamples = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=False)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = 0
+ for j in range(self.num_branches):
+ if i == j:
+ y += x[j]
+ else:
+ y += self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+ return x_fuse
+
+
+@BACKBONES.register_module()
+class HRNet(BaseModule):
+ """HRNet backbone.
+
+ `High-Resolution Representations for Labeling Pixels and Regions
+ arXiv: `_.
+
+ Args:
+ extra (dict): Detailed configuration for each stage of HRNet.
+ There must be 4 stages, the configuration for each stage must have
+ 5 keys:
+
+ - num_modules(int): The number of HRModule in this stage.
+ - num_branches(int): The number of branches in the HRModule.
+ - block(str): The type of convolution block.
+ - num_blocks(tuple): The number of blocks in each branch.
+ The length must be equal to num_branches.
+ - num_channels(tuple): The number of channels in each branch.
+ The length must be equal to num_branches.
+ in_channels (int): Number of input image channels. Default: 3.
+ conv_cfg (dict): Dictionary to construct and config conv layer.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: True.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity. Default: False.
+ multiscale_output (bool): Whether to output multi-level features
+ produced by multiple branches. If False, only the first level
+ feature will be output. Default: True.
+ pretrained (str, optional): Model pretrained path. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+
+ Example:
+ >>> from mmdet.models import HRNet
+ >>> import torch
+ >>> extra = dict(
+ >>> stage1=dict(
+ >>> num_modules=1,
+ >>> num_branches=1,
+ >>> block='BOTTLENECK',
+ >>> num_blocks=(4, ),
+ >>> num_channels=(64, )),
+ >>> stage2=dict(
+ >>> num_modules=1,
+ >>> num_branches=2,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4),
+ >>> num_channels=(32, 64)),
+ >>> stage3=dict(
+ >>> num_modules=4,
+ >>> num_branches=3,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4),
+ >>> num_channels=(32, 64, 128)),
+ >>> stage4=dict(
+ >>> num_modules=3,
+ >>> num_branches=4,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4, 4),
+ >>> num_channels=(32, 64, 128, 256)))
+ >>> self = HRNet(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 32, 8, 8)
+ (1, 64, 4, 4)
+ (1, 128, 2, 2)
+ (1, 256, 1, 1)
+ """
+
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+
+ def __init__(self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ norm_eval=True,
+ with_cp=False,
+ zero_init_residual=False,
+ multiscale_output=True,
+ pretrained=None,
+ init_cfg=None):
+ super(HRNet, self).__init__(init_cfg)
+
+ self.pretrained = pretrained
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ # Assert configurations of 4 stages are in extra
+ assert 'stage1' in extra and 'stage2' in extra \
+ and 'stage3' in extra and 'stage4' in extra
+ # Assert whether the length of `num_blocks` and `num_channels` are
+ # equal to `num_branches`
+ for i in range(4):
+ cfg = extra[f'stage{i + 1}']
+ assert len(cfg['num_blocks']) == cfg['num_branches'] and \
+ len(cfg['num_channels']) == cfg['num_branches']
+
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.zero_init_residual = zero_init_residual
+
+ # stem net
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ 64,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # stage 1
+ self.stage1_cfg = self.extra['stage1']
+ num_channels = self.stage1_cfg['num_channels'][0]
+ block_type = self.stage1_cfg['block']
+ num_blocks = self.stage1_cfg['num_blocks'][0]
+
+ block = self.blocks_dict[block_type]
+ stage1_out_channels = num_channels * block.expansion
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+
+ # stage 2
+ self.stage2_cfg = self.extra['stage2']
+ num_channels = self.stage2_cfg['num_channels']
+ block_type = self.stage2_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition1 = self._make_transition_layer([stage1_out_channels],
+ num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ # stage 3
+ self.stage3_cfg = self.extra['stage3']
+ num_channels = self.stage3_cfg['num_channels']
+ block_type = self.stage3_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition2 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ # stage 4
+ self.stage4_cfg = self.extra['stage4']
+ num_channels = self.stage4_cfg['num_channels']
+ block_type = self.stage4_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition3 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels, multiscale_output=multiscale_output)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+
+ layers = []
+ block_init_cfg = None
+ if self.pretrained is None and not hasattr(
+ self, 'init_cfg') and self.zero_init_residual:
+ if block is BasicBlock:
+ block_init_cfg = dict(
+ type='Constant', val=0, override=dict(name='norm2'))
+ elif block is Bottleneck:
+ block_init_cfg = dict(
+ type='Constant', val=0, override=dict(name='norm3'))
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ init_cfg=block_init_cfg,
+ ))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ init_cfg=block_init_cfg))
+
+ return Sequential(*layers)
+
+ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+
+ hr_modules = []
+ block_init_cfg = None
+ if self.pretrained is None and not hasattr(
+ self, 'init_cfg') and self.zero_init_residual:
+ if block is BasicBlock:
+ block_init_cfg = dict(
+ type='Constant', val=0, override=dict(name='norm2'))
+ elif block is Bottleneck:
+ block_init_cfg = dict(
+ type='Constant', val=0, override=dict(name='norm3'))
+
+ for i in range(num_modules):
+ # multi_scale_output is only used for the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ hr_modules.append(
+ HRModule(
+ num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ reset_multiscale_output,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg,
+ block_init_cfg=block_init_cfg))
+
+ return Sequential(*hr_modules), in_channels
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['num_branches']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['num_branches']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['num_branches']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ return y_list
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(HRNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmdet/models/backbones/mobilenet_v2.py b/mmdet/models/backbones/mobilenet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c6fcfaaa4c550b3568343f6b9baf1512d41b4db
--- /dev/null
+++ b/mmdet/models/backbones/mobilenet_v2.py
@@ -0,0 +1,197 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidual, make_divisible
+
+
+@BACKBONES.register_module()
+class MobileNetV2(BaseModule):
+ """MobileNetV2 backbone.
+
+ Args:
+ widen_factor (float): Width multiplier, multiply number of
+ channels in each layer by this amount. Default: 1.0.
+ out_indices (Sequence[int], optional): Output from which stages.
+ Default: (1, 2, 4, 7).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ # Parameters to build layers. 4 parameters are needed to construct a
+ # layer, from left to right: expand_ratio, channel, num_blocks, stride.
+ arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],
+ [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],
+ [6, 320, 1, 1]]
+
+ def __init__(self,
+ widen_factor=1.,
+ out_indices=(1, 2, 4, 7),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ norm_eval=False,
+ with_cp=False,
+ pretrained=None,
+ init_cfg=None):
+ super(MobileNetV2, self).__init__(init_cfg)
+
+ self.pretrained = pretrained
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ self.widen_factor = widen_factor
+ self.out_indices = out_indices
+ if not set(out_indices).issubset(set(range(0, 8))):
+ raise ValueError('out_indices must be a subset of range'
+ f'(0, 8). But received {out_indices}')
+
+ if frozen_stages not in range(-1, 8):
+ raise ValueError('frozen_stages must be in range(-1, 8). '
+ f'But received {frozen_stages}')
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.in_channels = make_divisible(32 * widen_factor, 8)
+
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.layers = []
+
+ for i, layer_cfg in enumerate(self.arch_settings):
+ expand_ratio, channel, num_blocks, stride = layer_cfg
+ out_channels = make_divisible(channel * widen_factor, 8)
+ inverted_res_layer = self.make_layer(
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ stride=stride,
+ expand_ratio=expand_ratio)
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, inverted_res_layer)
+ self.layers.append(layer_name)
+
+ if widen_factor > 1.0:
+ self.out_channel = int(1280 * widen_factor)
+ else:
+ self.out_channel = 1280
+
+ layer = ConvModule(
+ in_channels=self.in_channels,
+ out_channels=self.out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.add_module('conv2', layer)
+ self.layers.append('conv2')
+
+ def make_layer(self, out_channels, num_blocks, stride, expand_ratio):
+ """Stack InvertedResidual blocks to build a layer for MobileNetV2.
+
+ Args:
+ out_channels (int): out_channels of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ expand_ratio (int): Expand the number of channels of the
+ hidden layer in InvertedResidual by this ratio. Default: 6.
+ """
+ layers = []
+ for i in range(num_blocks):
+ if i >= 1:
+ stride = 1
+ layers.append(
+ InvertedResidual(
+ self.in_channels,
+ out_channels,
+ mid_channels=int(round(self.in_channels * expand_ratio)),
+ stride=stride,
+ with_expand_conv=expand_ratio != 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ frozen."""
+ super(MobileNetV2, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmdet/models/backbones/pvt.py b/mmdet/models/backbones/pvt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7d5d5344a7968b95a088f3c7822840016a52db
--- /dev/null
+++ b/mmdet/models/backbones/pvt.py
@@ -0,0 +1,591 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
+ constant_init, normal_init, trunc_normal_init)
+from mmcv.cnn.bricks.drop import build_dropout
+from mmcv.cnn.bricks.transformer import MultiheadAttention
+from mmcv.cnn.utils.weight_init import trunc_normal_
+from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint,
+ load_state_dict)
+from torch.nn.modules.utils import _pair as to_2tuple
+
+from ...utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw, pvt_convert
+
+
+class MixFFN(BaseModule):
+ """An implementation of MixFFN of PVT.
+
+ The differences between MixFFN & FFN:
+ 1. Use 1X1 Conv to replace Linear layer.
+ 2. Introduce 3X3 Depth-wise Conv to encode positional information.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ act_cfg (dict, optional): The activation config for FFNs.
+ Default: dict(type='GELU').
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ Default: None.
+ use_conv (bool): If True, add 3x3 DWConv between two Linear layers.
+ Defaults: False.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ feedforward_channels,
+ act_cfg=dict(type='GELU'),
+ ffn_drop=0.,
+ dropout_layer=None,
+ use_conv=False,
+ init_cfg=None):
+ super(MixFFN, self).__init__(init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.act_cfg = act_cfg
+ activate = build_activation_layer(act_cfg)
+
+ in_channels = embed_dims
+ fc1 = Conv2d(
+ in_channels=in_channels,
+ out_channels=feedforward_channels,
+ kernel_size=1,
+ stride=1,
+ bias=True)
+ if use_conv:
+ # 3x3 depth wise conv to provide positional encode information
+ dw_conv = Conv2d(
+ in_channels=feedforward_channels,
+ out_channels=feedforward_channels,
+ kernel_size=3,
+ stride=1,
+ padding=(3 - 1) // 2,
+ bias=True,
+ groups=feedforward_channels)
+ fc2 = Conv2d(
+ in_channels=feedforward_channels,
+ out_channels=in_channels,
+ kernel_size=1,
+ stride=1,
+ bias=True)
+ drop = nn.Dropout(ffn_drop)
+ layers = [fc1, activate, drop, fc2, drop]
+ if use_conv:
+ layers.insert(1, dw_conv)
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else torch.nn.Identity()
+
+ def forward(self, x, hw_shape, identity=None):
+ out = nlc_to_nchw(x, hw_shape)
+ out = self.layers(out)
+ out = nchw_to_nlc(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
+
+
+class SpatialReductionAttention(MultiheadAttention):
+ """An implementation of Spatial Reduction Attention of PVT.
+
+ This module is modified from MultiheadAttention which is a module from
+ mmcv.cnn.bricks.transformer.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut. Default: None.
+ batch_first (bool): Key, Query and Value are shape of
+ (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default: False.
+ qkv_bias (bool): enable bias for qkv if True. Default: True.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
+ Attention of PVT. Default: 1.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ attn_drop=0.,
+ proj_drop=0.,
+ dropout_layer=None,
+ batch_first=True,
+ qkv_bias=True,
+ norm_cfg=dict(type='LN'),
+ sr_ratio=1,
+ init_cfg=None):
+ super().__init__(
+ embed_dims,
+ num_heads,
+ attn_drop,
+ proj_drop,
+ batch_first=batch_first,
+ dropout_layer=dropout_layer,
+ bias=qkv_bias,
+ init_cfg=init_cfg)
+
+ self.sr_ratio = sr_ratio
+ if sr_ratio > 1:
+ self.sr = Conv2d(
+ in_channels=embed_dims,
+ out_channels=embed_dims,
+ kernel_size=sr_ratio,
+ stride=sr_ratio)
+ # The ret[0] of build_norm_layer is norm name.
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
+
+ # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
+ from mmdet import digit_version, mmcv_version
+ if mmcv_version < digit_version('1.3.17'):
+ warnings.warn('The legacy version of forward function in'
+ 'SpatialReductionAttention is deprecated in'
+ 'mmcv>=1.3.17 and will no longer support in the'
+ 'future. Please upgrade your mmcv.')
+ self.forward = self.legacy_forward
+
+ def forward(self, x, hw_shape, identity=None):
+
+ x_q = x
+ if self.sr_ratio > 1:
+ x_kv = nlc_to_nchw(x, hw_shape)
+ x_kv = self.sr(x_kv)
+ x_kv = nchw_to_nlc(x_kv)
+ x_kv = self.norm(x_kv)
+ else:
+ x_kv = x
+
+ if identity is None:
+ identity = x_q
+
+ # Because the dataflow('key', 'query', 'value') of
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
+ # embed_dims), We should adjust the shape of dataflow from
+ # batch_first (batch, num_query, embed_dims) to num_query_first
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
+ # from num_query_first to batch_first.
+ if self.batch_first:
+ x_q = x_q.transpose(0, 1)
+ x_kv = x_kv.transpose(0, 1)
+
+ out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
+
+ if self.batch_first:
+ out = out.transpose(0, 1)
+
+ return identity + self.dropout_layer(self.proj_drop(out))
+
+ def legacy_forward(self, x, hw_shape, identity=None):
+ """multi head attention forward in mmcv version < 1.3.17."""
+ x_q = x
+ if self.sr_ratio > 1:
+ x_kv = nlc_to_nchw(x, hw_shape)
+ x_kv = self.sr(x_kv)
+ x_kv = nchw_to_nlc(x_kv)
+ x_kv = self.norm(x_kv)
+ else:
+ x_kv = x
+
+ if identity is None:
+ identity = x_q
+
+ out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
+
+ return identity + self.dropout_layer(self.proj_drop(out))
+
+
+class PVTEncoderLayer(BaseModule):
+ """Implements one encoder layer in PVT.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ drop_rate (float): Probability of an element to be zeroed.
+ after the feed forward layer. Default: 0.0.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default: 0.0.
+ drop_path_rate (float): stochastic depth rate. Default: 0.0.
+ qkv_bias (bool): enable bias for qkv if True.
+ Default: True.
+ act_cfg (dict): The activation config for FFNs.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
+ Attention of PVT. Default: 1.
+ use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
+ Default: False.
+ init_cfg (dict, optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ qkv_bias=True,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ sr_ratio=1,
+ use_conv_ffn=False,
+ init_cfg=None):
+ super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg)
+
+ # The ret[0] of build_norm_layer is norm name.
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+
+ self.attn = SpatialReductionAttention(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ attn_drop=attn_drop_rate,
+ proj_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ qkv_bias=qkv_bias,
+ norm_cfg=norm_cfg,
+ sr_ratio=sr_ratio)
+
+ # The ret[0] of build_norm_layer is norm name.
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+
+ self.ffn = MixFFN(
+ embed_dims=embed_dims,
+ feedforward_channels=feedforward_channels,
+ ffn_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ use_conv=use_conv_ffn,
+ act_cfg=act_cfg)
+
+ def forward(self, x, hw_shape):
+ x = self.attn(self.norm1(x), hw_shape, identity=x)
+ x = self.ffn(self.norm2(x), hw_shape, identity=x)
+
+ return x
+
+
+class AbsolutePositionEmbedding(BaseModule):
+ """An implementation of the absolute position embedding in PVT.
+
+ Args:
+ pos_shape (int): The shape of the absolute position embedding.
+ pos_dim (int): The dimension of the absolute position embedding.
+ drop_rate (float): Probability of an element to be zeroed.
+ Default: 0.0.
+ """
+
+ def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(pos_shape, int):
+ pos_shape = to_2tuple(pos_shape)
+ elif isinstance(pos_shape, tuple):
+ if len(pos_shape) == 1:
+ pos_shape = to_2tuple(pos_shape[0])
+ assert len(pos_shape) == 2, \
+ f'The size of image should have length 1 or 2, ' \
+ f'but got {len(pos_shape)}'
+ self.pos_shape = pos_shape
+ self.pos_dim = pos_dim
+
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim))
+ self.drop = nn.Dropout(p=drop_rate)
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+
+ def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'):
+ """Resize pos_embed weights.
+
+ Resize pos_embed using bilinear interpolate method.
+
+ Args:
+ pos_embed (torch.Tensor): Position embedding weights.
+ input_shape (tuple): Tuple for (downsampled input image height,
+ downsampled input image width).
+ mode (str): Algorithm used for upsampling:
+ ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
+ ``'trilinear'``. Default: ``'bilinear'``.
+
+ Return:
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C].
+ """
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
+ pos_h, pos_w = self.pos_shape
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
+ pos_embed_weight = pos_embed_weight.reshape(
+ 1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous()
+ pos_embed_weight = F.interpolate(
+ pos_embed_weight, size=input_shape, mode=mode)
+ pos_embed_weight = torch.flatten(pos_embed_weight,
+ 2).transpose(1, 2).contiguous()
+ pos_embed = pos_embed_weight
+
+ return pos_embed
+
+ def forward(self, x, hw_shape, mode='bilinear'):
+ pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode)
+ return self.drop(x + pos_embed)
+
+
+@BACKBONES.register_module()
+class PyramidVisionTransformer(BaseModule):
+ """Pyramid Vision Transformer (PVT)
+
+ Implementation of `Pyramid Vision Transformer: A Versatile Backbone for
+ Dense Prediction without Convolutions
+ `_.
+
+ Args:
+ pretrain_img_size (int | tuple[int]): The size of input image when
+ pretrain. Defaults: 224.
+ in_channels (int): Number of input channels. Default: 3.
+ embed_dims (int): Embedding dimension. Default: 64.
+ num_stags (int): The num of stages. Default: 4.
+ num_layers (Sequence[int]): The layer number of each transformer encode
+ layer. Default: [3, 4, 6, 3].
+ num_heads (Sequence[int]): The attention heads of each transformer
+ encode layer. Default: [1, 2, 5, 8].
+ patch_sizes (Sequence[int]): The patch_size of each patch embedding.
+ Default: [4, 2, 2, 2].
+ strides (Sequence[int]): The stride of each patch embedding.
+ Default: [4, 2, 2, 2].
+ paddings (Sequence[int]): The padding of each patch embedding.
+ Default: [0, 0, 0, 0].
+ sr_ratios (Sequence[int]): The spatial reduction rate of each
+ transformer encode layer. Default: [8, 4, 2, 1].
+ out_indices (Sequence[int] | int): Output from which stages.
+ Default: (0, 1, 2, 3).
+ mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
+ embedding dim of each transformer encode layer.
+ Default: [8, 8, 4, 4].
+ qkv_bias (bool): Enable bias for qkv if True. Default: True.
+ drop_rate (float): Probability of an element to be zeroed.
+ Default 0.0.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default 0.0.
+ drop_path_rate (float): stochastic depth rate. Default 0.1.
+ use_abs_pos_embed (bool): If True, add absolute position embedding to
+ the patch embedding. Defaults: True.
+ use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
+ Default: False.
+ act_cfg (dict): The activation config for FFNs.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ pretrained (str, optional): model pretrained path. Default: None.
+ convert_weights (bool): The flag indicates whether the
+ pre-trained model is from the original repo. We may need
+ to convert some keys to make it compatible.
+ Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ pretrain_img_size=224,
+ in_channels=3,
+ embed_dims=64,
+ num_stages=4,
+ num_layers=[3, 4, 6, 3],
+ num_heads=[1, 2, 5, 8],
+ patch_sizes=[4, 2, 2, 2],
+ strides=[4, 2, 2, 2],
+ paddings=[0, 0, 0, 0],
+ sr_ratios=[8, 4, 2, 1],
+ out_indices=(0, 1, 2, 3),
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ use_abs_pos_embed=True,
+ norm_after_stage=False,
+ use_conv_ffn=False,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN', eps=1e-6),
+ pretrained=None,
+ convert_weights=True,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ self.convert_weights = convert_weights
+ if isinstance(pretrain_img_size, int):
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ elif isinstance(pretrain_img_size, tuple):
+ if len(pretrain_img_size) == 1:
+ pretrain_img_size = to_2tuple(pretrain_img_size[0])
+ assert len(pretrain_img_size) == 2, \
+ f'The size of image should have length 1 or 2, ' \
+ f'but got {len(pretrain_img_size)}'
+
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be setting at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ self.init_cfg = init_cfg
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ self.embed_dims = embed_dims
+
+ self.num_stages = num_stages
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.patch_sizes = patch_sizes
+ self.strides = strides
+ self.sr_ratios = sr_ratios
+ assert num_stages == len(num_layers) == len(num_heads) \
+ == len(patch_sizes) == len(strides) == len(sr_ratios)
+
+ self.out_indices = out_indices
+ assert max(out_indices) < self.num_stages
+ self.pretrained = pretrained
+
+ # transformer encoder
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, drop_path_rate, sum(num_layers))
+ ] # stochastic num_layer decay rule
+
+ cur = 0
+ self.layers = ModuleList()
+ for i, num_layer in enumerate(num_layers):
+ embed_dims_i = embed_dims * num_heads[i]
+ patch_embed = PatchEmbed(
+ in_channels=in_channels,
+ embed_dims=embed_dims_i,
+ kernel_size=patch_sizes[i],
+ stride=strides[i],
+ padding=paddings[i],
+ bias=True,
+ norm_cfg=norm_cfg)
+
+ layers = ModuleList()
+ if use_abs_pos_embed:
+ pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1])
+ pos_embed = AbsolutePositionEmbedding(
+ pos_shape=pos_shape,
+ pos_dim=embed_dims_i,
+ drop_rate=drop_rate)
+ layers.append(pos_embed)
+ layers.extend([
+ PVTEncoderLayer(
+ embed_dims=embed_dims_i,
+ num_heads=num_heads[i],
+ feedforward_channels=mlp_ratios[i] * embed_dims_i,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[cur + idx],
+ qkv_bias=qkv_bias,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ sr_ratio=sr_ratios[i],
+ use_conv_ffn=use_conv_ffn) for idx in range(num_layer)
+ ])
+ in_channels = embed_dims_i
+ # The ret[0] of build_norm_layer is norm name.
+ if norm_after_stage:
+ norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
+ else:
+ norm = nn.Identity()
+ self.layers.append(ModuleList([patch_embed, layers, norm]))
+ cur += num_layer
+
+ def init_weights(self):
+ logger = get_root_logger()
+ if self.init_cfg is None:
+ logger.warn(f'No pre-trained weights for '
+ f'{self.__class__.__name__}, '
+ f'training start from scratch')
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.LayerNorm):
+ constant_init(m, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[
+ 1] * m.out_channels
+ fan_out //= m.groups
+ normal_init(m, 0, math.sqrt(2.0 / fan_out))
+ elif isinstance(m, AbsolutePositionEmbedding):
+ m.init_weights()
+ else:
+ assert 'checkpoint' in self.init_cfg, f'Only support ' \
+ f'specify `Pretrained` in ' \
+ f'`init_cfg` in ' \
+ f'{self.__class__.__name__} '
+ checkpoint = _load_checkpoint(
+ self.init_cfg.checkpoint, logger=logger, map_location='cpu')
+ logger.warn(f'Load pre-trained model for '
+ f'{self.__class__.__name__} from original repo')
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+ if self.convert_weights:
+ # Because pvt backbones are not supported by mmcls,
+ # so we need to convert pre-trained weights to match this
+ # implementation.
+ state_dict = pvt_convert(state_dict)
+ load_state_dict(self, state_dict, strict=False, logger=logger)
+
+ def forward(self, x):
+ outs = []
+
+ for i, layer in enumerate(self.layers):
+ x, hw_shape = layer[0](x)
+
+ for block in layer[1]:
+ x = block(x, hw_shape)
+ x = layer[2](x)
+ x = nlc_to_nchw(x, hw_shape)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return outs
+
+
+@BACKBONES.register_module()
+class PyramidVisionTransformerV2(PyramidVisionTransformer):
+ """Implementation of `PVTv2: Improved Baselines with Pyramid Vision
+ Transformer `_."""
+
+ def __init__(self, **kwargs):
+ super(PyramidVisionTransformerV2, self).__init__(
+ patch_sizes=[7, 3, 3, 3],
+ paddings=[3, 1, 1, 1],
+ use_abs_pos_embed=False,
+ norm_after_stage=True,
+ use_conv_ffn=True,
+ **kwargs)
diff --git a/mmdet/models/backbones/regnet.py b/mmdet/models/backbones/regnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..63adc3c1deb3b48193c243eb4ec5178a0b62103b
--- /dev/null
+++ b/mmdet/models/backbones/regnet.py
@@ -0,0 +1,356 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import numpy as np
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from .resnet import ResNet
+from .resnext import Bottleneck
+
+
+@BACKBONES.register_module()
+class RegNet(ResNet):
+ """RegNet backbone.
+
+ More details can be found in `paper `_ .
+
+ Args:
+ arch (dict): The parameter of RegNets.
+
+ - w0 (int): initial width
+ - wa (float): slope of width
+ - wm (float): quantization parameter to quantize the width
+ - depth (int): depth of the backbone
+ - group_w (int): width of group
+ - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ base_channels (int): Base channels after stem layer.
+ in_channels (int): Number of input image channels. Default: 3.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+
+ Example:
+ >>> from mmdet.models import RegNet
+ >>> import torch
+ >>> self = RegNet(
+ arch=dict(
+ w0=88,
+ wa=26.31,
+ wm=2.25,
+ group_w=48,
+ depth=25,
+ bot_mul=1.0))
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 96, 8, 8)
+ (1, 192, 4, 4)
+ (1, 432, 2, 2)
+ (1, 1008, 1, 1)
+ """
+ arch_settings = {
+ 'regnetx_400mf':
+ dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
+ 'regnetx_800mf':
+ dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
+ 'regnetx_1.6gf':
+ dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
+ 'regnetx_3.2gf':
+ dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
+ 'regnetx_4.0gf':
+ dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
+ 'regnetx_6.4gf':
+ dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
+ 'regnetx_8.0gf':
+ dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
+ 'regnetx_12gf':
+ dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
+ }
+
+ def __init__(self,
+ arch,
+ in_channels=3,
+ stem_channels=32,
+ base_channels=32,
+ strides=(2, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ with_cp=False,
+ zero_init_residual=True,
+ pretrained=None,
+ init_cfg=None):
+ super(ResNet, self).__init__(init_cfg)
+
+ # Generate RegNet parameters first
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'"arch": "{arch}" is not one of the' \
+ ' arch_settings'
+ arch = self.arch_settings[arch]
+ elif not isinstance(arch, dict):
+ raise ValueError('Expect "arch" to be either a string '
+ f'or a dict, got {type(arch)}')
+
+ widths, num_stages = self.generate_regnet(
+ arch['w0'],
+ arch['wa'],
+ arch['wm'],
+ arch['depth'],
+ )
+ # Convert to per stage format
+ stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
+ # Generate group widths and bot muls
+ group_widths = [arch['group_w'] for _ in range(num_stages)]
+ self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
+ # Adjust the compatibility of stage_widths and group_widths
+ stage_widths, group_widths = self.adjust_width_group(
+ stage_widths, self.bottleneck_ratio, group_widths)
+
+ # Group params by stage
+ self.stage_widths = stage_widths
+ self.group_widths = group_widths
+ self.depth = sum(stage_blocks)
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.zero_init_residual = zero_init_residual
+ self.block = Bottleneck
+ expansion_bak = self.block.expansion
+ self.block.expansion = 1
+ self.stage_blocks = stage_blocks[:num_stages]
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ block_init_cfg = None
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]
+ if self.zero_init_residual:
+ block_init_cfg = dict(
+ type='Constant', val=0, override=dict(name='norm3'))
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ self.inplanes = stem_channels
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ group_width = self.group_widths[i]
+ width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
+ stage_groups = width // group_width
+
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if self.plugins is not None:
+ stage_plugins = self.make_stage_plugins(self.plugins, i)
+ else:
+ stage_plugins = None
+
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=self.stage_widths[i],
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=self.with_cp,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ groups=stage_groups,
+ base_width=group_width,
+ base_channels=self.stage_widths[i],
+ init_cfg=block_init_cfg)
+ self.inplanes = self.stage_widths[i]
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = stage_widths[-1]
+ self.block.expansion = expansion_bak
+
+ def _make_stem_layer(self, in_channels, base_channels):
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ base_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, base_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def generate_regnet(self,
+ initial_width,
+ width_slope,
+ width_parameter,
+ depth,
+ divisor=8):
+ """Generates per block width from RegNet parameters.
+
+ Args:
+ initial_width ([int]): Initial width of the backbone
+ width_slope ([float]): Slope of the quantized linear function
+ width_parameter ([int]): Parameter used to quantize the width.
+ depth ([int]): Depth of the backbone.
+ divisor (int, optional): The divisor of channels. Defaults to 8.
+
+ Returns:
+ list, int: return a list of widths of each stage and the number \
+ of stages
+ """
+ assert width_slope >= 0
+ assert initial_width > 0
+ assert width_parameter > 1
+ assert initial_width % divisor == 0
+ widths_cont = np.arange(depth) * width_slope + initial_width
+ ks = np.round(
+ np.log(widths_cont / initial_width) / np.log(width_parameter))
+ widths = initial_width * np.power(width_parameter, ks)
+ widths = np.round(np.divide(widths, divisor)) * divisor
+ num_stages = len(np.unique(widths))
+ widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
+ return widths, num_stages
+
+ @staticmethod
+ def quantize_float(number, divisor):
+ """Converts a float to closest non-zero int divisible by divisor.
+
+ Args:
+ number (int): Original number to be quantized.
+ divisor (int): Divisor used to quantize the number.
+
+ Returns:
+ int: quantized number that is divisible by devisor.
+ """
+ return int(round(number / divisor) * divisor)
+
+ def adjust_width_group(self, widths, bottleneck_ratio, groups):
+ """Adjusts the compatibility of widths and groups.
+
+ Args:
+ widths (list[int]): Width of each stage.
+ bottleneck_ratio (float): Bottleneck ratio.
+ groups (int): number of groups in each stage
+
+ Returns:
+ tuple(list): The adjusted widths and groups of each stage.
+ """
+ bottleneck_width = [
+ int(w * b) for w, b in zip(widths, bottleneck_ratio)
+ ]
+ groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
+ bottleneck_width = [
+ self.quantize_float(w_bot, g)
+ for w_bot, g in zip(bottleneck_width, groups)
+ ]
+ widths = [
+ int(w_bot / b)
+ for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
+ ]
+ return widths, groups
+
+ def get_stages_from_blocks(self, widths):
+ """Gets widths/stage_blocks of network at each stage.
+
+ Args:
+ widths (list[int]): Width in each stage.
+
+ Returns:
+ tuple(list): width and depth of each stage
+ """
+ width_diff = [
+ width != width_prev
+ for width, width_prev in zip(widths + [0], [0] + widths)
+ ]
+ stage_widths = [
+ width for width, diff in zip(widths, width_diff[:-1]) if diff
+ ]
+ stage_blocks = np.diff([
+ depth for depth, diff in zip(range(len(width_diff)), width_diff)
+ if diff
+ ]).tolist()
+ return stage_widths, stage_blocks
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
diff --git a/mmdet/models/backbones/res2net.py b/mmdet/models/backbones/res2net.py
new file mode 100644
index 0000000000000000000000000000000000000000..96afb2fb2892f6e3973d48509071671bc8a5b7e0
--- /dev/null
+++ b/mmdet/models/backbones/res2net.py
@@ -0,0 +1,327 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmcv.runner import Sequential
+
+from ..builder import BACKBONES
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottle2neck(_Bottleneck):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ scales=4,
+ base_width=26,
+ base_channels=64,
+ stage_type='normal',
+ **kwargs):
+ """Bottle2neck block for Res2Net.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottle2neck, self).__init__(inplanes, planes, **kwargs)
+ assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.'
+ width = int(math.floor(self.planes * (base_width / base_channels)))
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width * scales, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width * scales,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+
+ if stage_type == 'stage' and self.conv2_stride != 1:
+ self.pool = nn.AvgPool2d(
+ kernel_size=3, stride=self.conv2_stride, padding=1)
+ convs = []
+ bns = []
+
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ for i in range(scales - 1):
+ convs.append(
+ build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ bias=False))
+ bns.append(
+ build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1])
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ for i in range(scales - 1):
+ convs.append(
+ build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ bias=False))
+ bns.append(
+ build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1])
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width * scales,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.stage_type = stage_type
+ self.scales = scales
+ self.width = width
+ delattr(self, 'conv2')
+ delattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ spx = torch.split(out, self.width, 1)
+ sp = self.convs[0](spx[0].contiguous())
+ sp = self.relu(self.bns[0](sp))
+ out = sp
+ for i in range(1, self.scales - 1):
+ if self.stage_type == 'stage':
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = self.convs[i](sp.contiguous())
+ sp = self.relu(self.bns[i](sp))
+ out = torch.cat((out, sp), 1)
+
+ if self.stage_type == 'normal' or self.conv2_stride == 1:
+ out = torch.cat((out, spx[self.scales - 1]), 1)
+ elif self.stage_type == 'stage':
+ out = torch.cat((out, self.pool(spx[self.scales - 1])), 1)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Res2Layer(Sequential):
+ """Res2Layer to build Res2Net style backbone.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottle2neck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ scales (int): Scales used in Res2Net. Default: 4
+ base_width (int): Basic width of each scale. Default: 26
+ """
+
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ avg_down=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ scales=4,
+ base_width=26,
+ **kwargs):
+ self.block = block
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False),
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=1,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1],
+ )
+
+ layers = []
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ scales=scales,
+ base_width=base_width,
+ stage_type='stage',
+ **kwargs))
+ inplanes = planes * block.expansion
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ scales=scales,
+ base_width=base_width,
+ **kwargs))
+ super(Res2Layer, self).__init__(*layers)
+
+
+@BACKBONES.register_module()
+class Res2Net(ResNet):
+ """Res2Net backbone.
+
+ Args:
+ scales (int): Scales used in Res2Net. Default: 4
+ base_width (int): Basic width of each scale. Default: 26
+ depth (int): Depth of res2net, from {50, 101, 152}.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Res2net stages. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottle2neck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+
+ - cfg (dict, required): Cfg dict to build plugin.
+ - position (str, required): Position inside block to insert
+ plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+
+ Example:
+ >>> from mmdet.models import Res2Net
+ >>> import torch
+ >>> self = Res2Net(depth=50, scales=4, base_width=26)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 8, 8)
+ (1, 512, 4, 4)
+ (1, 1024, 2, 2)
+ (1, 2048, 1, 1)
+ """
+
+ arch_settings = {
+ 50: (Bottle2neck, (3, 4, 6, 3)),
+ 101: (Bottle2neck, (3, 4, 23, 3)),
+ 152: (Bottle2neck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ scales=4,
+ base_width=26,
+ style='pytorch',
+ deep_stem=True,
+ avg_down=True,
+ pretrained=None,
+ init_cfg=None,
+ **kwargs):
+ self.scales = scales
+ self.base_width = base_width
+ super(Res2Net, self).__init__(
+ style='pytorch',
+ deep_stem=True,
+ avg_down=True,
+ pretrained=pretrained,
+ init_cfg=init_cfg,
+ **kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return Res2Layer(
+ scales=self.scales,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/mmdet/models/backbones/resnest.py b/mmdet/models/backbones/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..69629b96dfd44e4cbe53701fb14fb83fda4b6440
--- /dev/null
+++ b/mmdet/models/backbones/resnest.py
@@ -0,0 +1,322 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmcv.runner import BaseModule
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNetV1d
+
+
+class RSoftmax(nn.Module):
+ """Radix Softmax module in ``SplitAttentionConv2d``.
+
+ Args:
+ radix (int): Radix of input.
+ groups (int): Groups of input.
+ """
+
+ def __init__(self, radix, groups):
+ super().__init__()
+ self.radix = radix
+ self.groups = groups
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SplitAttentionConv2d(BaseModule):
+ """Split-Attention Conv2d in ResNeSt.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ channels (int): Number of intermediate channels.
+ kernel_size (int | tuple[int]): Size of the convolution kernel.
+ stride (int | tuple[int]): Stride of the convolution.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ groups (int): Number of blocked connections from input channels to
+ output channels.
+ groups (int): Same as nn.Conv2d.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels. Default: 4.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ dcn (dict): Config dict for DCN. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ radix=2,
+ reduction_factor=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ init_cfg=None):
+ super(SplitAttentionConv2d, self).__init__(init_cfg)
+ inter_channels = max(in_channels * radix // reduction_factor, 32)
+ self.radix = radix
+ self.groups = groups
+ self.channels = channels
+ self.with_dcn = dcn is not None
+ self.dcn = dcn
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if self.with_dcn and not fallback_on_stride:
+ assert conv_cfg is None, 'conv_cfg must be None for DCN'
+ conv_cfg = dcn
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ channels * radix,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups * radix,
+ bias=False)
+ # To be consistent with original implementation, starting from 0
+ self.norm0_name, norm0 = build_norm_layer(
+ norm_cfg, channels * radix, postfix=0)
+ self.add_module(self.norm0_name, norm0)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc1 = build_conv_layer(
+ None, channels, inter_channels, 1, groups=self.groups)
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, inter_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.fc2 = build_conv_layer(
+ None, inter_channels, channels * radix, 1, groups=self.groups)
+ self.rsoftmax = RSoftmax(radix, groups)
+
+ @property
+ def norm0(self):
+ """nn.Module: the normalization layer named "norm0" """
+ return getattr(self, self.norm0_name)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm0(x)
+ x = self.relu(x)
+
+ batch, rchannel = x.shape[:2]
+ batch = x.size(0)
+ if self.radix > 1:
+ splits = x.view(batch, self.radix, -1, *x.shape[2:])
+ gap = splits.sum(dim=1)
+ else:
+ gap = x
+ gap = F.adaptive_avg_pool2d(gap, 1)
+ gap = self.fc1(gap)
+
+ gap = self.norm1(gap)
+ gap = self.relu(gap)
+
+ atten = self.fc2(gap)
+ atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+ if self.radix > 1:
+ attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
+ out = torch.sum(attens * splits, dim=1)
+ else:
+ out = atten * x
+ return out.contiguous()
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeSt.
+
+ Args:
+ inplane (int): Input planes of this block.
+ planes (int): Middle planes of this block.
+ groups (int): Groups of conv2.
+ base_width (int): Base of width in terms of base channels. Default: 4.
+ base_channels (int): Base of channels for calculating width.
+ Default: 64.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Key word arguments for base class.
+ """
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ """Bottleneck block for ResNeSt."""
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.with_modulated_dcn = False
+ self.conv2 = SplitAttentionConv2d(
+ width,
+ width,
+ kernel_size=3,
+ stride=1 if self.avg_down_stride else self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ radix=radix,
+ reduction_factor=reduction_factor,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=self.dcn)
+ delattr(self, self.norm2_name)
+
+ if self.avg_down_stride:
+ self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+
+ if self.avg_down_stride:
+ out = self.avd_layer(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNeSt(ResNetV1d):
+ """ResNeSt backbone.
+
+ Args:
+ groups (int): Number of groups of Bottleneck. Default: 1
+ base_width (int): Base width of Bottleneck. Default: 4
+ radix (int): Radix of SplitAttentionConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Keyword arguments for ResNet.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ 200: (Bottleneck, (3, 24, 36, 3))
+ }
+
+ def __init__(self,
+ groups=1,
+ base_width=4,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ self.radix = radix
+ self.reduction_factor = reduction_factor
+ self.avg_down_stride = avg_down_stride
+ super(ResNeSt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ radix=self.radix,
+ reduction_factor=self.reduction_factor,
+ avg_down_stride=self.avg_down_stride,
+ **kwargs)
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eaaae67c9dfab9458ce60d7ca1d7cbfe651a664
--- /dev/null
+++ b/mmdet/models/backbones/resnet.py
@@ -0,0 +1,672 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
+from mmcv.runner import BaseModule
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+
+
+class BasicBlock(BaseModule):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None,
+ init_cfg=None):
+ super(BasicBlock, self).__init__(init_cfg)
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg, planes, planes, 3, padding=1, bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(BaseModule):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None,
+ init_cfg=None):
+ """Bottleneck block for ResNet.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__(init_cfg)
+ assert style in ['pytorch', 'caffe']
+ assert dcn is None or isinstance(dcn, dict)
+ assert plugins is None or isinstance(plugins, list)
+ if plugins is not None:
+ allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
+ assert all(p['position'] in allowed_position for p in plugins)
+
+ self.inplanes = inplanes
+ self.planes = planes
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.dcn = dcn
+ self.with_dcn = dcn is not None
+ self.plugins = plugins
+ self.with_plugins = plugins is not None
+
+ if self.with_plugins:
+ # collect plugins for conv1/conv2/conv3
+ self.after_conv1_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv1'
+ ]
+ self.after_conv2_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv2'
+ ]
+ self.after_conv3_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv3'
+ ]
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ dcn,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ if self.with_plugins:
+ self.after_conv1_plugin_names = self.make_block_plugins(
+ planes, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(
+ planes, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(
+ planes * self.expansion, self.after_conv3_plugins)
+
+ def make_block_plugins(self, in_channels, plugins):
+ """make plugins for block.
+
+ Args:
+ in_channels (int): Input channels of plugin.
+ plugins (list[dict]): List of plugins cfg to build.
+
+ Returns:
+ list[str]: List of the names of plugin.
+ """
+ assert isinstance(plugins, list)
+ plugin_names = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ name, layer = build_plugin_layer(
+ plugin,
+ in_channels=in_channels,
+ postfix=plugin.pop('postfix', ''))
+ assert not hasattr(self, name), f'duplicate plugin {name}'
+ self.add_module(name, layer)
+ plugin_names.append(name)
+ return plugin_names
+
+ def forward_plugin(self, x, plugin_names):
+ out = x
+ for name in plugin_names:
+ out = getattr(self, name)(out)
+ return out
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: normalization layer after the third convolution layer"""
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNet(BaseModule):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ stem_channels (int | None): Number of stem channels. If not specified,
+ it will be the same as `base_channels`. Default: None.
+ base_channels (int): Number of base channels of res layer. Default: 64.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Resnet stages. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+
+ - cfg (dict, required): Cfg dict to build plugin.
+ - position (str, required): Position inside block to insert
+ plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+
+ Example:
+ >>> from mmdet.models import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ stem_channels=None,
+ base_channels=64,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ with_cp=False,
+ zero_init_residual=True,
+ pretrained=None,
+ init_cfg=None):
+ super(ResNet, self).__init__(init_cfg)
+ self.zero_init_residual = zero_init_residual
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+
+ block_init_cfg = None
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]
+ block = self.arch_settings[depth][0]
+ if self.zero_init_residual:
+ if block is BasicBlock:
+ block_init_cfg = dict(
+ type='Constant',
+ val=0,
+ override=dict(name='norm2'))
+ elif block is Bottleneck:
+ block_init_cfg = dict(
+ type='Constant',
+ val=0,
+ override=dict(name='norm3'))
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ self.depth = depth
+ if stem_channels is None:
+ stem_channels = base_channels
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.inplanes = stem_channels
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if plugins is not None:
+ stage_plugins = self.make_stage_plugins(plugins, i)
+ else:
+ stage_plugins = None
+ planes = base_channels * 2**i
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ init_cfg=block_init_cfg)
+ self.inplanes = planes * self.block.expansion
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = self.block.expansion * base_channels * 2**(
+ len(self.stage_blocks) - 1)
+
+ def make_stage_plugins(self, plugins, stage_idx):
+ """Make plugins for ResNet ``stage_idx`` th stage.
+
+ Currently we support to insert ``context_block``,
+ ``empirical_attention_block``, ``nonlocal_block`` into the backbone
+ like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
+ Bottleneck.
+
+ An example of plugins format could be:
+
+ Examples:
+ >>> plugins=[
+ ... dict(cfg=dict(type='xxx', arg1='xxx'),
+ ... stages=(False, True, True, True),
+ ... position='after_conv2'),
+ ... dict(cfg=dict(type='yyy'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='1'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='2'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3')
+ ... ]
+ >>> self = ResNet(depth=18)
+ >>> stage_plugins = self.make_stage_plugins(plugins, 0)
+ >>> assert len(stage_plugins) == 3
+
+ Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
+
+ .. code-block:: none
+
+ conv1-> conv2->conv3->yyy->zzz1->zzz2
+
+ Suppose 'stage_idx=1', the structure of blocks in the stage would be:
+
+ .. code-block:: none
+
+ conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
+
+ If stages is missing, the plugin would be applied to all stages.
+
+ Args:
+ plugins (list[dict]): List of plugins cfg to build. The postfix is
+ required if multiple same type plugins are inserted.
+ stage_idx (int): Index of stage to build
+
+ Returns:
+ list[dict]: Plugins for current stage
+ """
+ stage_plugins = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ stages = plugin.pop('stages', None)
+ assert stages is None or len(stages) == self.num_stages
+ # whether to insert plugin into current stage
+ if stages is None or stages[stage_idx]:
+ stage_plugins.append(plugin)
+
+ return stage_plugins
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels):
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels)[1],
+ nn.ReLU(inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(ResNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+
+@BACKBONES.register_module()
+class ResNetV1d(ResNet):
+ r"""ResNetV1d variant described in `Bag of Tricks
+ `_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1d, self).__init__(
+ deep_stem=True, avg_down=True, **kwargs)
diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..8675d7c1149a321cbbba45fa93ea3cc3b79d0bd1
--- /dev/null
+++ b/mmdet/models/backbones/resnext.py
@@ -0,0 +1,154 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottleneck(_Bottleneck):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ **kwargs):
+ """Bottleneck block for ResNeXt.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, width, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ if self.with_plugins:
+ self._del_block_plugins(self.after_conv1_plugin_names +
+ self.after_conv2_plugin_names +
+ self.after_conv3_plugin_names)
+ self.after_conv1_plugin_names = self.make_block_plugins(
+ width, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(
+ width, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(
+ self.planes * self.expansion, self.after_conv3_plugins)
+
+ def _del_block_plugins(self, plugin_names):
+ """delete plugins for block if exist.
+
+ Args:
+ plugin_names (list[str]): List of plugins name to delete.
+ """
+ assert isinstance(plugin_names, list)
+ for plugin_name in plugin_names:
+ del self._modules[plugin_name]
+
+
+@BACKBONES.register_module()
+class ResNeXt(ResNet):
+ """ResNeXt backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Resnet stages. Default: 4.
+ groups (int): Group of resnext.
+ base_width (int): Base width of resnext.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, groups=1, base_width=4, **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ super(ResNeXt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``"""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..c15aeac00d004418a2a2c46e53add41b95a44815
--- /dev/null
+++ b/mmdet/models/backbones/ssd_vgg.py
@@ -0,0 +1,128 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+from mmcv.cnn import VGG
+from mmcv.runner import BaseModule
+
+from ..builder import BACKBONES
+from ..necks import ssd_neck
+
+
+@BACKBONES.register_module()
+class SSDVGG(VGG, BaseModule):
+ """VGG Backbone network for single-shot-detection.
+
+ Args:
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
+ with_last_pool (bool): Whether to add a pooling layer at the last
+ of the model
+ ceil_mode (bool): When True, will use `ceil` instead of `floor`
+ to compute the output shape.
+ out_indices (Sequence[int]): Output from which stages.
+ out_feature_indices (Sequence[int]): Output from which feature map.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ input_size (int, optional): Deprecated argumment.
+ Width and height of input, from {300, 512}.
+ l2_norm_scale (float, optional) : Deprecated argumment.
+ L2 normalization layer init scale.
+
+ Example:
+ >>> self = SSDVGG(input_size=300, depth=11)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 300, 300)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 1024, 19, 19)
+ (1, 512, 10, 10)
+ (1, 256, 5, 5)
+ (1, 256, 3, 3)
+ (1, 256, 1, 1)
+ """
+ extra_setting = {
+ 300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
+ 512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128),
+ }
+
+ def __init__(self,
+ depth,
+ with_last_pool=False,
+ ceil_mode=True,
+ out_indices=(3, 4),
+ out_feature_indices=(22, 34),
+ pretrained=None,
+ init_cfg=None,
+ input_size=None,
+ l2_norm_scale=None):
+ # TODO: in_channels for mmcv.VGG
+ super(SSDVGG, self).__init__(
+ depth,
+ with_last_pool=with_last_pool,
+ ceil_mode=ceil_mode,
+ out_indices=out_indices)
+
+ self.features.add_module(
+ str(len(self.features)),
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
+ self.features.add_module(
+ str(len(self.features)),
+ nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6))
+ self.features.add_module(
+ str(len(self.features)), nn.ReLU(inplace=True))
+ self.features.add_module(
+ str(len(self.features)), nn.Conv2d(1024, 1024, kernel_size=1))
+ self.features.add_module(
+ str(len(self.features)), nn.ReLU(inplace=True))
+ self.out_feature_indices = out_feature_indices
+
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+
+ if init_cfg is not None:
+ self.init_cfg = init_cfg
+ elif isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ self.init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(type='Constant', val=1, layer='BatchNorm2d'),
+ dict(type='Normal', std=0.01, layer='Linear'),
+ ]
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ if input_size is not None:
+ warnings.warn('DeprecationWarning: input_size is deprecated')
+ if l2_norm_scale is not None:
+ warnings.warn('DeprecationWarning: l2_norm_scale in VGG is '
+ 'deprecated, it has been moved to SSDNeck.')
+
+ def init_weights(self, pretrained=None):
+ super(VGG, self).init_weights()
+
+ def forward(self, x):
+ """Forward function."""
+ outs = []
+ for i, layer in enumerate(self.features):
+ x = layer(x)
+ if i in self.out_feature_indices:
+ outs.append(x)
+
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+
+class L2Norm(ssd_neck.L2Norm):
+
+ def __init__(self, **kwargs):
+ super(L2Norm, self).__init__(**kwargs)
+ warnings.warn('DeprecationWarning: L2Norm in ssd_vgg.py '
+ 'is deprecated, please use L2Norm in '
+ 'mmdet/models/necks/ssd_neck.py instead')
diff --git a/mmdet/models/backbones/swin.py b/mmdet/models/backbones/swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8eccfca195f5d76865d10d7220546eb297ecc99
--- /dev/null
+++ b/mmdet/models/backbones/swin.py
@@ -0,0 +1,772 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from collections import OrderedDict
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
+from mmcv.cnn.bricks.transformer import FFN, build_dropout
+from mmcv.cnn.utils.weight_init import trunc_normal_
+from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
+from mmcv.utils import to_2tuple
+
+from ...utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils.ckpt_convert import swin_converter
+from ..utils.transformer import PatchEmbed, PatchMerging
+
+
+class WindowMSA(BaseModule):
+ """Window based multi-head self-attention (W-MSA) module with relative
+ position bias.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): The height and width of the window.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Default: 0.0
+ proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
+ init_cfg (dict | None, optional): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ window_size,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ proj_drop_rate=0.,
+ init_cfg=None):
+
+ super().__init__()
+ self.embed_dims = embed_dims
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_embed_dims = embed_dims // num_heads
+ self.scale = qk_scale or head_embed_dims**-0.5
+ self.init_cfg = init_cfg
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
+ num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # About 2x faster than original impl
+ Wh, Ww = self.window_size
+ rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
+ rel_position_index = rel_index_coords + rel_index_coords.T
+ rel_position_index = rel_position_index.flip(1).contiguous()
+ self.register_buffer('relative_position_index', rel_position_index)
+
+ self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop_rate)
+ self.proj = nn.Linear(embed_dims, embed_dims)
+ self.proj_drop = nn.Dropout(proj_drop_rate)
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def init_weights(self):
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+
+ x (tensor): input features with shape of (num_windows*B, N, C)
+ mask (tensor | None, Optional): mask with shape of (num_windows,
+ Wh*Ww, Wh*Ww), value should be between (-inf, 0].
+ """
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ # make torchscript happy (cannot use tensor as tuple)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B // nW, nW, self.num_heads, N,
+ N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ @staticmethod
+ def double_step_seq(step1, len1, step2, len2):
+ seq1 = torch.arange(0, step1 * len1, step1)
+ seq2 = torch.arange(0, step2 * len2, step2)
+ return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
+
+
+class ShiftWindowMSA(BaseModule):
+ """Shifted Window Multihead Self-Attention Module.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): The height and width of the window.
+ shift_size (int, optional): The shift step of each window towards
+ right-bottom. If zero, act as regular window-msa. Defaults to 0.
+ qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
+ Default: True
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Defaults: None.
+ attn_drop_rate (float, optional): Dropout ratio of attention weight.
+ Defaults: 0.
+ proj_drop_rate (float, optional): Dropout ratio of output.
+ Defaults: 0.
+ dropout_layer (dict, optional): The dropout_layer used before output.
+ Defaults: dict(type='DropPath', drop_prob=0.).
+ init_cfg (dict, optional): The extra config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ window_size,
+ shift_size=0,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0,
+ proj_drop_rate=0,
+ dropout_layer=dict(type='DropPath', drop_prob=0.),
+ init_cfg=None):
+ super().__init__(init_cfg)
+
+ self.window_size = window_size
+ self.shift_size = shift_size
+ assert 0 <= self.shift_size < self.window_size
+
+ self.w_msa = WindowMSA(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ window_size=to_2tuple(window_size),
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=proj_drop_rate,
+ init_cfg=None)
+
+ self.drop = build_dropout(dropout_layer)
+
+ def forward(self, query, hw_shape):
+ B, L, C = query.shape
+ H, W = hw_shape
+ assert L == H * W, 'input feature has wrong size'
+ query = query.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
+ H_pad, W_pad = query.shape[1], query.shape[2]
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_query = torch.roll(
+ query,
+ shifts=(-self.shift_size, -self.shift_size),
+ dims=(1, 2))
+
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ # nW, window_size, window_size, 1
+ mask_windows = self.window_partition(img_mask)
+ mask_windows = mask_windows.view(
+ -1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0,
+ float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0))
+ else:
+ shifted_query = query
+ attn_mask = None
+
+ # nW*B, window_size, window_size, C
+ query_windows = self.window_partition(shifted_query)
+ # nW*B, window_size*window_size, C
+ query_windows = query_windows.view(-1, self.window_size**2, C)
+
+ # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
+ attn_windows = self.w_msa(query_windows, mask=attn_mask)
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size,
+ self.window_size, C)
+
+ # B H' W' C
+ shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x,
+ shifts=(self.shift_size, self.shift_size),
+ dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ x = self.drop(x)
+ return x
+
+ def window_reverse(self, windows, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ window_size = self.window_size
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+ def window_partition(self, x):
+ """
+ Args:
+ x: (B, H, W, C)
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ window_size = self.window_size
+ x = x.view(B, H // window_size, window_size, W // window_size,
+ window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ windows = windows.view(-1, window_size, window_size, C)
+ return windows
+
+
+class SwinBlock(BaseModule):
+ """"
+ Args:
+ embed_dims (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ window_size (int, optional): The local window scale. Default: 7.
+ shift (bool, optional): whether to shift window or not. Default False.
+ qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop_rate (float, optional): Dropout rate. Default: 0.
+ attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
+ drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
+ act_cfg (dict, optional): The config dict of activation function.
+ Default: dict(type='GELU').
+ norm_cfg (dict, optional): The config dict of normalization.
+ Default: dict(type='LN').
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ init_cfg (dict | list | None, optional): The init config.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ window_size=7,
+ shift=False,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ with_cp=False,
+ init_cfg=None):
+
+ super(SwinBlock, self).__init__()
+
+ self.init_cfg = init_cfg
+ self.with_cp = with_cp
+
+ self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.attn = ShiftWindowMSA(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=window_size // 2 if shift else 0,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ init_cfg=None)
+
+ self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
+ self.ffn = FFN(
+ embed_dims=embed_dims,
+ feedforward_channels=feedforward_channels,
+ num_fcs=2,
+ ffn_drop=drop_rate,
+ dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
+ act_cfg=act_cfg,
+ add_identity=True,
+ init_cfg=None)
+
+ def forward(self, x, hw_shape):
+
+ def _inner_forward(x):
+ identity = x
+ x = self.norm1(x)
+ x = self.attn(x, hw_shape)
+
+ x = x + identity
+
+ identity = x
+ x = self.norm2(x)
+ x = self.ffn(x, identity=identity)
+
+ return x
+
+ if self.with_cp and x.requires_grad:
+ x = cp.checkpoint(_inner_forward, x)
+ else:
+ x = _inner_forward(x)
+
+ return x
+
+
+class SwinBlockSequence(BaseModule):
+ """Implements one stage in Swin Transformer.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ depth (int): The number of blocks in this stage.
+ window_size (int, optional): The local window scale. Default: 7.
+ qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop_rate (float, optional): Dropout rate. Default: 0.
+ attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
+ drop_path_rate (float | list[float], optional): Stochastic depth
+ rate. Default: 0.
+ downsample (BaseModule | None, optional): The downsample operation
+ module. Default: None.
+ act_cfg (dict, optional): The config dict of activation function.
+ Default: dict(type='GELU').
+ norm_cfg (dict, optional): The config dict of normalization.
+ Default: dict(type='LN').
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ init_cfg (dict | list | None, optional): The init config.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ depth,
+ window_size=7,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ downsample=None,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ with_cp=False,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ if isinstance(drop_path_rate, list):
+ drop_path_rates = drop_path_rate
+ assert len(drop_path_rates) == depth
+ else:
+ drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
+
+ self.blocks = ModuleList()
+ for i in range(depth):
+ block = SwinBlock(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ feedforward_channels=feedforward_channels,
+ window_size=window_size,
+ shift=False if i % 2 == 0 else True,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rates[i],
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ with_cp=with_cp,
+ init_cfg=None)
+ self.blocks.append(block)
+
+ self.downsample = downsample
+
+ def forward(self, x, hw_shape):
+ for block in self.blocks:
+ x = block(x, hw_shape)
+
+ if self.downsample:
+ x_down, down_hw_shape = self.downsample(x, hw_shape)
+ return x_down, down_hw_shape, x, hw_shape
+ else:
+ return x, hw_shape, x, hw_shape
+
+
+@BACKBONES.register_module()
+class SwinTransformer(BaseModule):
+ """ Swin Transformer
+ A PyTorch implement of : `Swin Transformer:
+ Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/abs/2103.14030
+
+ Inspiration from
+ https://github.com/microsoft/Swin-Transformer
+
+ Args:
+ pretrain_img_size (int | tuple[int]): The size of input image when
+ pretrain. Defaults: 224.
+ in_channels (int): The num of input channels.
+ Defaults: 3.
+ embed_dims (int): The feature dimension. Default: 96.
+ patch_size (int | tuple[int]): Patch size. Default: 4.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (int | float): Ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ Default: (2, 2, 6, 2).
+ num_heads (tuple[int]): Parallel attention heads of each Swin
+ Transformer stage. Default: (3, 6, 12, 24).
+ strides (tuple[int]): The patch merging or patch embedding stride of
+ each Swin Transformer stage. (In swin, we set kernel size equal to
+ stride.) Default: (4, 2, 2, 2).
+ out_indices (tuple[int]): Output from which stages.
+ Default: (0, 1, 2, 3).
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key,
+ value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ patch_norm (bool): If add a norm layer for patch embed and patch
+ merging. Default: True.
+ drop_rate (float): Dropout rate. Defaults: 0.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
+ use_abs_pos_embed (bool): If True, add absolute position embedding to
+ the patch embedding. Defaults: False.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer at
+ output of backone. Defaults: dict(type='LN').
+ with_cp (bool, optional): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ pretrained (str, optional): model pretrained path. Default: None.
+ convert_weights (bool): The flag indicates whether the
+ pre-trained model is from the original repo. We may need
+ to convert some keys to make it compatible.
+ Default: False.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ Default: -1 (-1 means not freezing any parameters).
+ init_cfg (dict, optional): The Config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ pretrain_img_size=224,
+ in_channels=3,
+ embed_dims=96,
+ patch_size=4,
+ window_size=7,
+ mlp_ratio=4,
+ depths=(2, 2, 6, 2),
+ num_heads=(3, 6, 12, 24),
+ strides=(4, 2, 2, 2),
+ out_indices=(0, 1, 2, 3),
+ qkv_bias=True,
+ qk_scale=None,
+ patch_norm=True,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ use_abs_pos_embed=False,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ with_cp=False,
+ pretrained=None,
+ convert_weights=False,
+ frozen_stages=-1,
+ init_cfg=None):
+ self.convert_weights = convert_weights
+ self.frozen_stages = frozen_stages
+ if isinstance(pretrain_img_size, int):
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ elif isinstance(pretrain_img_size, tuple):
+ if len(pretrain_img_size) == 1:
+ pretrain_img_size = to_2tuple(pretrain_img_size[0])
+ assert len(pretrain_img_size) == 2, \
+ f'The size of image should have length 1 or 2, ' \
+ f'but got {len(pretrain_img_size)}'
+
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ self.init_cfg = init_cfg
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ super(SwinTransformer, self).__init__(init_cfg=init_cfg)
+
+ num_layers = len(depths)
+ self.out_indices = out_indices
+ self.use_abs_pos_embed = use_abs_pos_embed
+
+ assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
+
+ self.patch_embed = PatchEmbed(
+ in_channels=in_channels,
+ embed_dims=embed_dims,
+ conv_type='Conv2d',
+ kernel_size=patch_size,
+ stride=strides[0],
+ norm_cfg=norm_cfg if patch_norm else None,
+ init_cfg=None)
+
+ if self.use_abs_pos_embed:
+ patch_row = pretrain_img_size[0] // patch_size
+ patch_col = pretrain_img_size[1] // patch_size
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros((1, embed_dims, patch_row, patch_col)))
+
+ self.drop_after_pos = nn.Dropout(p=drop_rate)
+
+ # set stochastic depth decay rule
+ total_depth = sum(depths)
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
+ ]
+
+ self.stages = ModuleList()
+ in_channels = embed_dims
+ for i in range(num_layers):
+ if i < num_layers - 1:
+ downsample = PatchMerging(
+ in_channels=in_channels,
+ out_channels=2 * in_channels,
+ stride=strides[i + 1],
+ norm_cfg=norm_cfg if patch_norm else None,
+ init_cfg=None)
+ else:
+ downsample = None
+
+ stage = SwinBlockSequence(
+ embed_dims=in_channels,
+ num_heads=num_heads[i],
+ feedforward_channels=int(mlp_ratio * in_channels),
+ depth=depths[i],
+ window_size=window_size,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
+ downsample=downsample,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ with_cp=with_cp,
+ init_cfg=None)
+ self.stages.append(stage)
+ if downsample:
+ in_channels = downsample.out_channels
+
+ self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
+ # Add a norm layer for each output
+ for i in out_indices:
+ layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
+ layer_name = f'norm{i}'
+ self.add_module(layer_name, layer)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+ if self.use_abs_pos_embed:
+ self.absolute_pos_embed.requires_grad = False
+ self.drop_after_pos.eval()
+
+ for i in range(1, self.frozen_stages + 1):
+
+ if (i - 1) in self.out_indices:
+ norm_layer = getattr(self, f'norm{i-1}')
+ norm_layer.eval()
+ for param in norm_layer.parameters():
+ param.requires_grad = False
+
+ m = self.stages[i - 1]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ logger = get_root_logger()
+ if self.init_cfg is None:
+ logger.warn(f'No pre-trained weights for '
+ f'{self.__class__.__name__}, '
+ f'training start from scratch')
+ if self.use_abs_pos_embed:
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.LayerNorm):
+ constant_init(m, 1.0)
+ else:
+ assert 'checkpoint' in self.init_cfg, f'Only support ' \
+ f'specify `Pretrained` in ' \
+ f'`init_cfg` in ' \
+ f'{self.__class__.__name__} '
+ ckpt = _load_checkpoint(
+ self.init_cfg.checkpoint, logger=logger, map_location='cpu')
+ if 'state_dict' in ckpt:
+ _state_dict = ckpt['state_dict']
+ elif 'model' in ckpt:
+ _state_dict = ckpt['model']
+ else:
+ _state_dict = ckpt
+ if self.convert_weights:
+ # supported loading weight from original repo,
+ _state_dict = swin_converter(_state_dict)
+
+ state_dict = OrderedDict()
+ for k, v in _state_dict.items():
+ if k.startswith('backbone.'):
+ state_dict[k[9:]] = v
+
+ # strip prefix of state_dict
+ if list(state_dict.keys())[0].startswith('module.'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ # reshape absolute position embedding
+ if state_dict.get('absolute_pos_embed') is not None:
+ absolute_pos_embed = state_dict['absolute_pos_embed']
+ N1, L, C1 = absolute_pos_embed.size()
+ N2, C2, H, W = self.absolute_pos_embed.size()
+ if N1 != N2 or C1 != C2 or L != H * W:
+ logger.warning('Error in loading absolute_pos_embed, pass')
+ else:
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
+ N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
+
+ # interpolate position bias table if needed
+ relative_position_bias_table_keys = [
+ k for k in state_dict.keys()
+ if 'relative_position_bias_table' in k
+ ]
+ for table_key in relative_position_bias_table_keys:
+ table_pretrained = state_dict[table_key]
+ table_current = self.state_dict()[table_key]
+ L1, nH1 = table_pretrained.size()
+ L2, nH2 = table_current.size()
+ if nH1 != nH2:
+ logger.warning(f'Error in loading {table_key}, pass')
+ elif L1 != L2:
+ S1 = int(L1**0.5)
+ S2 = int(L2**0.5)
+ table_pretrained_resized = F.interpolate(
+ table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
+ size=(S2, S2),
+ mode='bicubic')
+ state_dict[table_key] = table_pretrained_resized.view(
+ nH2, L2).permute(1, 0).contiguous()
+
+ # load state_dict
+ self.load_state_dict(state_dict, False)
+
+ def forward(self, x):
+ x, hw_shape = self.patch_embed(x)
+
+ if self.use_abs_pos_embed:
+ h, w = self.absolute_pos_embed.shape[1:3]
+ if hw_shape[0] != h or hw_shape[1] != w:
+ absolute_pos_embed = F.interpolate(
+ self.absolute_pos_embed,
+ size=hw_shape,
+ mode='bicubic',
+ align_corners=False).flatten(2).transpose(1, 2)
+ else:
+ absolute_pos_embed = self.absolute_pos_embed.flatten(
+ 2).transpose(1, 2)
+ x = x + absolute_pos_embed
+ x = self.drop_after_pos(x)
+
+ outs = []
+ for i, stage in enumerate(self.stages):
+ x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ out = norm_layer(out)
+ out = out.view(-1, *out_hw_shape,
+ self.num_features[i]).permute(0, 3, 1,
+ 2).contiguous()
+ outs.append(out)
+
+ return outs
diff --git a/mmdet/models/backbones/trident_resnet.py b/mmdet/models/backbones/trident_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..013ba64b59d81e5be3a3f00b65c6a76915247c9d
--- /dev/null
+++ b/mmdet/models/backbones/trident_resnet.py
@@ -0,0 +1,298 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmcv.runner import BaseModule
+from torch.nn.modules.utils import _pair
+
+from mmdet.models.backbones.resnet import Bottleneck, ResNet
+from mmdet.models.builder import BACKBONES
+
+
+class TridentConv(BaseModule):
+ """Trident Convolution Module.
+
+ Args:
+ in_channels (int): Number of channels in input.
+ out_channels (int): Number of channels in output.
+ kernel_size (int): Size of convolution kernel.
+ stride (int, optional): Convolution stride. Default: 1.
+ trident_dilations (tuple[int, int, int], optional): Dilations of
+ different trident branch. Default: (1, 2, 3).
+ test_branch_idx (int, optional): In inference, all 3 branches will
+ be used if `test_branch_idx==-1`, otherwise only branch with
+ index `test_branch_idx` will be used. Default: 1.
+ bias (bool, optional): Whether to use bias in convolution or not.
+ Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ trident_dilations=(1, 2, 3),
+ test_branch_idx=1,
+ bias=False,
+ init_cfg=None):
+ super(TridentConv, self).__init__(init_cfg)
+ self.num_branch = len(trident_dilations)
+ self.with_bias = bias
+ self.test_branch_idx = test_branch_idx
+ self.stride = _pair(stride)
+ self.kernel_size = _pair(kernel_size)
+ self.paddings = _pair(trident_dilations)
+ self.dilations = trident_dilations
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.bias = bias
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.bias = None
+
+ def extra_repr(self):
+ tmpstr = f'in_channels={self.in_channels}'
+ tmpstr += f', out_channels={self.out_channels}'
+ tmpstr += f', kernel_size={self.kernel_size}'
+ tmpstr += f', num_branch={self.num_branch}'
+ tmpstr += f', test_branch_idx={self.test_branch_idx}'
+ tmpstr += f', stride={self.stride}'
+ tmpstr += f', paddings={self.paddings}'
+ tmpstr += f', dilations={self.dilations}'
+ tmpstr += f', bias={self.bias}'
+ return tmpstr
+
+ def forward(self, inputs):
+ if self.training or self.test_branch_idx == -1:
+ outputs = [
+ F.conv2d(input, self.weight, self.bias, self.stride, padding,
+ dilation) for input, dilation, padding in zip(
+ inputs, self.dilations, self.paddings)
+ ]
+ else:
+ assert len(inputs) == 1
+ outputs = [
+ F.conv2d(inputs[0], self.weight, self.bias, self.stride,
+ self.paddings[self.test_branch_idx],
+ self.dilations[self.test_branch_idx])
+ ]
+
+ return outputs
+
+
+# Since TridentNet is defined over ResNet50 and ResNet101, here we
+# only support TridentBottleneckBlock.
+class TridentBottleneck(Bottleneck):
+ """BottleBlock for TridentResNet.
+
+ Args:
+ trident_dilations (tuple[int, int, int]): Dilations of different
+ trident branch.
+ test_branch_idx (int): In inference, all 3 branches will be used
+ if `test_branch_idx==-1`, otherwise only branch with index
+ `test_branch_idx` will be used.
+ concat_output (bool): Whether to concat the output list to a Tensor.
+ `True` only in the last Block.
+ """
+
+ def __init__(self, trident_dilations, test_branch_idx, concat_output,
+ **kwargs):
+
+ super(TridentBottleneck, self).__init__(**kwargs)
+ self.trident_dilations = trident_dilations
+ self.num_branch = len(trident_dilations)
+ self.concat_output = concat_output
+ self.test_branch_idx = test_branch_idx
+ self.conv2 = TridentConv(
+ self.planes,
+ self.planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ bias=False,
+ trident_dilations=self.trident_dilations,
+ test_branch_idx=test_branch_idx,
+ init_cfg=dict(
+ type='Kaiming',
+ distribution='uniform',
+ mode='fan_in',
+ override=dict(name='conv2')))
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ num_branch = (
+ self.num_branch
+ if self.training or self.test_branch_idx == -1 else 1)
+ identity = x
+ if not isinstance(x, list):
+ x = (x, ) * num_branch
+ identity = x
+ if self.downsample is not None:
+ identity = [self.downsample(b) for b in x]
+
+ out = [self.conv1(b) for b in x]
+ out = [self.norm1(b) for b in out]
+ out = [self.relu(b) for b in out]
+
+ if self.with_plugins:
+ for k in range(len(out)):
+ out[k] = self.forward_plugin(out[k],
+ self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = [self.norm2(b) for b in out]
+ out = [self.relu(b) for b in out]
+ if self.with_plugins:
+ for k in range(len(out)):
+ out[k] = self.forward_plugin(out[k],
+ self.after_conv2_plugin_names)
+
+ out = [self.conv3(b) for b in out]
+ out = [self.norm3(b) for b in out]
+
+ if self.with_plugins:
+ for k in range(len(out)):
+ out[k] = self.forward_plugin(out[k],
+ self.after_conv3_plugin_names)
+
+ out = [
+ out_b + identity_b for out_b, identity_b in zip(out, identity)
+ ]
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = [self.relu(b) for b in out]
+ if self.concat_output:
+ out = torch.cat(out, dim=0)
+ return out
+
+
+def make_trident_res_layer(block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ trident_dilations=(1, 2, 3),
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None,
+ test_branch_idx=-1):
+ """Build Trident Res Layers."""
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ for i in range(num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride if i == 0 else 1,
+ trident_dilations=trident_dilations,
+ downsample=downsample if i == 0 else None,
+ style=style,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=plugins,
+ test_branch_idx=test_branch_idx,
+ concat_output=True if i == num_blocks - 1 else False))
+ inplanes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+
+@BACKBONES.register_module()
+class TridentResNet(ResNet):
+ """The stem layer, stage 1 and stage 2 in Trident ResNet are identical to
+ ResNet, while in stage 3, Trident BottleBlock is utilized to replace the
+ normal BottleBlock to yield trident output. Different branch shares the
+ convolution weight but uses different dilations to achieve multi-scale
+ output.
+
+ / stage3(b0) \
+ x - stem - stage1 - stage2 - stage3(b1) - output
+ \ stage3(b2) /
+
+ Args:
+ depth (int): Depth of resnet, from {50, 101, 152}.
+ num_branch (int): Number of branches in TridentNet.
+ test_branch_idx (int): In inference, all 3 branches will be used
+ if `test_branch_idx==-1`, otherwise only branch with index
+ `test_branch_idx` will be used.
+ trident_dilations (tuple[int]): Dilations of different trident branch.
+ len(trident_dilations) should be equal to num_branch.
+ """ # noqa
+
+ def __init__(self, depth, num_branch, test_branch_idx, trident_dilations,
+ **kwargs):
+
+ assert num_branch == len(trident_dilations)
+ assert depth in (50, 101, 152)
+ super(TridentResNet, self).__init__(depth, **kwargs)
+ assert self.num_stages == 3
+ self.test_branch_idx = test_branch_idx
+ self.num_branch = num_branch
+
+ last_stage_idx = self.num_stages - 1
+ stride = self.strides[last_stage_idx]
+ dilation = trident_dilations
+ dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None
+ if self.plugins is not None:
+ stage_plugins = self.make_stage_plugins(self.plugins,
+ last_stage_idx)
+ else:
+ stage_plugins = None
+ planes = self.base_channels * 2**last_stage_idx
+ res_layer = make_trident_res_layer(
+ TridentBottleneck,
+ inplanes=(self.block.expansion * self.base_channels *
+ 2**(last_stage_idx - 1)),
+ planes=planes,
+ num_blocks=self.stage_blocks[last_stage_idx],
+ stride=stride,
+ trident_dilations=dilation,
+ style=self.style,
+ with_cp=self.with_cp,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ test_branch_idx=self.test_branch_idx)
+
+ layer_name = f'layer{last_stage_idx + 1}'
+
+ self.__setattr__(layer_name, res_layer)
+ self.res_layers.pop(last_stage_idx)
+ self.res_layers.insert(last_stage_idx, layer_name)
+
+ self._freeze_stages()
diff --git a/mmdet/models/builder.py b/mmdet/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace6209f71f96676b87a6c046a4fc77bed100062
--- /dev/null
+++ b/mmdet/models/builder.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+from mmcv.cnn import MODELS as MMCV_MODELS
+from mmcv.utils import Registry
+
+MODELS = Registry('models', parent=MMCV_MODELS)
+
+BACKBONES = MODELS
+NECKS = MODELS
+ROI_EXTRACTORS = MODELS
+SHARED_HEADS = MODELS
+HEADS = MODELS
+LOSSES = MODELS
+DETECTORS = MODELS
+
+
+def build_backbone(cfg):
+ """Build backbone."""
+ return BACKBONES.build(cfg)
+
+
+def build_neck(cfg):
+ """Build neck."""
+ return NECKS.build(cfg)
+
+
+def build_roi_extractor(cfg):
+ """Build roi extractor."""
+ return ROI_EXTRACTORS.build(cfg)
+
+
+def build_shared_head(cfg):
+ """Build shared head."""
+ return SHARED_HEADS.build(cfg)
+
+
+def build_head(cfg):
+ """Build head."""
+ return HEADS.build(cfg)
+
+
+def build_loss(cfg):
+ """Build loss."""
+ return LOSSES.build(cfg)
+
+
+def build_detector(cfg, train_cfg=None, test_cfg=None):
+ """Build detector."""
+ if train_cfg is not None or test_cfg is not None:
+ warnings.warn(
+ 'train_cfg and test_cfg is deprecated, '
+ 'please specify them in model', UserWarning)
+ assert cfg.get('train_cfg') is None or train_cfg is None, \
+ 'train_cfg specified in both outer field and model field '
+ assert cfg.get('test_cfg') is None or test_cfg is None, \
+ 'test_cfg specified in both outer field and model field '
+ return DETECTORS.build(
+ cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c60ae14796ef0bf95e6da6da6da452b6aed7870
--- /dev/null
+++ b/mmdet/models/dense_heads/__init__.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .anchor_free_head import AnchorFreeHead
+from .anchor_head import AnchorHead
+from .ascend_anchor_head import AscendAnchorHead
+from .ascend_retina_head import AscendRetinaHead
+from .ascend_ssd_head import AscendSSDHead
+from .atss_head import ATSSHead
+from .autoassign_head import AutoAssignHead
+from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead
+from .centernet_head import CenterNetHead
+from .centripetal_head import CentripetalHead
+from .corner_head import CornerHead
+from .ddod_head import DDODHead
+from .deformable_detr_head import DeformableDETRHead
+from .detr_head import DETRHead
+from .embedding_rpn_head import EmbeddingRPNHead
+from .fcos_head import FCOSHead
+from .fovea_head import FoveaHead
+from .free_anchor_retina_head import FreeAnchorRetinaHead
+from .fsaf_head import FSAFHead
+from .ga_retina_head import GARetinaHead
+from .ga_rpn_head import GARPNHead
+from .gfl_head import GFLHead
+from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
+from .lad_head import LADHead
+from .ld_head import LDHead
+from .mask2former_head import Mask2FormerHead
+from .maskformer_head import MaskFormerHead
+from .nasfcos_head import NASFCOSHead
+from .paa_head import PAAHead
+from .pisa_retinanet_head import PISARetinaHead
+from .pisa_ssd_head import PISASSDHead
+from .reppoints_head import RepPointsHead
+from .retina_head import RetinaHead
+from .retina_sepbn_head import RetinaSepBNHead
+from .rpn_head import RPNHead
+from .sabl_retina_head import SABLRetinaHead
+from .solo_head import DecoupledSOLOHead, DecoupledSOLOLightHead, SOLOHead
+from .solov2_head import SOLOV2Head
+from .ssd_head import SSDHead
+from .tood_head import TOODHead
+from .vfnet_head import VFNetHead
+from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead
+from .yolo_head import YOLOV3Head
+from .yolof_head import YOLOFHead
+from .yolox_head import YOLOXHead
+
+__all__ = [
+ 'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption',
+ 'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead',
+ 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
+ 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
+ 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead',
+ 'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
+ 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead',
+ 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'AutoAssignHead',
+ 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead',
+ 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
+ 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead', 'MaskFormerHead',
+ 'Mask2FormerHead', 'SOLOV2Head', 'DDODHead', 'AscendAnchorHead',
+ 'AscendRetinaHead', 'AscendSSDHead'
+]
diff --git a/mmdet/models/dense_heads/anchor_free_head.py b/mmdet/models/dense_heads/anchor_free_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0460b945ca43b663553ab081d100edb76d8496a
--- /dev/null
+++ b/mmdet/models/dense_heads/anchor_free_head.py
@@ -0,0 +1,350 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import force_fp32
+
+from mmdet.core import build_bbox_coder, multi_apply
+from mmdet.core.anchor.point_generator import MlvlPointGenerator
+from ..builder import HEADS, build_loss
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+
+
+@HEADS.register_module()
+class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
+ """Anchor-free head (FCOS, Fovea, RepPoints, etc.).
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels. Used in child classes.
+ stacked_convs (int): Number of stacking convs of the head.
+ strides (tuple): Downsample factor of each feature map.
+ dcn_on_last_conv (bool): If true, use dcn in the last layer of
+ towers. Default: False.
+ conv_bias (bool | str): If specified as `auto`, it will be decided by
+ the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
+ None, otherwise False. Default: "auto".
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ bbox_coder (dict): Config of bbox coder. Defaults
+ 'DistancePointBBoxCoder'.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ train_cfg (dict): Training config of anchor head.
+ test_cfg (dict): Testing config of anchor head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """ # noqa: W605
+
+ _version = 1
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ stacked_convs=4,
+ strides=(4, 8, 16, 32, 64),
+ dcn_on_last_conv=False,
+ conv_bias='auto',
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox=dict(type='IoULoss', loss_weight=1.0),
+ bbox_coder=dict(type='DistancePointBBoxCoder'),
+ conv_cfg=None,
+ norm_cfg=None,
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='conv_cls',
+ std=0.01,
+ bias_prob=0.01))):
+ super(AnchorFreeHead, self).__init__(init_cfg)
+ self.num_classes = num_classes
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = num_classes
+ else:
+ self.cls_out_channels = num_classes + 1
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.stacked_convs = stacked_convs
+ self.strides = strides
+ self.dcn_on_last_conv = dcn_on_last_conv
+ assert conv_bias == 'auto' or isinstance(conv_bias, bool)
+ self.conv_bias = conv_bias
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+
+ self.prior_generator = MlvlPointGenerator(strides)
+
+ # In order to keep a more general interface and be consistent with
+ # anchor_head. We can think of point like one anchor
+ self.num_base_priors = self.prior_generator.num_base_priors[0]
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.fp16_enabled = False
+
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self._init_cls_convs()
+ self._init_reg_convs()
+ self._init_predictor()
+
+ def _init_cls_convs(self):
+ """Initialize classification conv layers of the head."""
+ self.cls_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ if self.dcn_on_last_conv and i == self.stacked_convs - 1:
+ conv_cfg = dict(type='DCNv2')
+ else:
+ conv_cfg = self.conv_cfg
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+
+ def _init_reg_convs(self):
+ """Initialize bbox regression conv layers of the head."""
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ if self.dcn_on_last_conv and i == self.stacked_convs - 1:
+ conv_cfg = dict(type='DCNv2')
+ else:
+ conv_cfg = self.conv_cfg
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+
+ def _init_predictor(self):
+ """Initialize predictor layers of the head."""
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """Hack some keys of the model state dict so that can load checkpoints
+ of previous version."""
+ version = local_metadata.get('version', None)
+ if version is None:
+ # the key is different in early versions
+ # for example, 'fcos_cls' become 'conv_cls' now
+ bbox_head_keys = [
+ k for k in state_dict.keys() if k.startswith(prefix)
+ ]
+ ori_predictor_keys = []
+ new_predictor_keys = []
+ # e.g. 'fcos_cls' or 'fcos_reg'
+ for key in bbox_head_keys:
+ ori_predictor_keys.append(key)
+ key = key.split('.')
+ conv_name = None
+ if key[1].endswith('cls'):
+ conv_name = 'conv_cls'
+ elif key[1].endswith('reg'):
+ conv_name = 'conv_reg'
+ elif key[1].endswith('centerness'):
+ conv_name = 'conv_centerness'
+ else:
+ assert NotImplementedError
+ if conv_name is not None:
+ key[1] = conv_name
+ new_predictor_keys.append('.'.join(key))
+ else:
+ ori_predictor_keys.pop(-1)
+ for i in range(len(new_predictor_keys)):
+ state_dict[new_predictor_keys[i]] = state_dict.pop(
+ ori_predictor_keys[i])
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually contain classification scores and bbox predictions.
+ cls_scores (list[Tensor]): Box scores for each scale level,
+ each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ """
+ return multi_apply(self.forward_single, feats)[:2]
+
+ def forward_single(self, x):
+ """Forward features of a single scale level.
+
+ Args:
+ x (Tensor): FPN feature maps of the specified stride.
+
+ Returns:
+ tuple: Scores for each class, bbox predictions, features
+ after classification and regression conv layers, some
+ models needs these features like FCOS.
+ """
+ cls_feat = x
+ reg_feat = x
+
+ for cls_layer in self.cls_convs:
+ cls_feat = cls_layer(cls_feat)
+ cls_score = self.conv_cls(cls_feat)
+
+ for reg_layer in self.reg_convs:
+ reg_feat = reg_layer(reg_feat)
+ bbox_pred = self.conv_reg(reg_feat)
+ return cls_score, bbox_pred, cls_feat, reg_feat
+
+ @abstractmethod
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level,
+ each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ """
+
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_targets(self, points, gt_bboxes_list, gt_labels_list):
+ """Compute regression, classification and centerness targets for points
+ in multiple images.
+
+ Args:
+ points (list[Tensor]): Points of each fpn level, each has shape
+ (num_points, 2).
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels_list (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+ """
+ raise NotImplementedError
+
+ def _get_points_single(self,
+ featmap_size,
+ stride,
+ dtype,
+ device,
+ flatten=False):
+ """Get points of a single scale level.
+
+ This function will be deprecated soon.
+ """
+
+ warnings.warn(
+ '`_get_points_single` in `AnchorFreeHead` will be '
+ 'deprecated soon, we support a multi level point generator now'
+ 'you can get points of a single level feature map '
+ 'with `self.prior_generator.single_level_grid_priors` ')
+
+ h, w = featmap_size
+ # First create Range with the default dtype, than convert to
+ # target `dtype` for onnx exporting.
+ x_range = torch.arange(w, device=device).to(dtype)
+ y_range = torch.arange(h, device=device).to(dtype)
+ y, x = torch.meshgrid(y_range, x_range)
+ if flatten:
+ y = y.flatten()
+ x = x.flatten()
+ return y, x
+
+ def get_points(self, featmap_sizes, dtype, device, flatten=False):
+ """Get points according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ dtype (torch.dtype): Type of points.
+ device (torch.device): Device of points.
+
+ Returns:
+ tuple: points of each image.
+ """
+ warnings.warn(
+ '`get_points` in `AnchorFreeHead` will be '
+ 'deprecated soon, we support a multi level point generator now'
+ 'you can get points of all levels '
+ 'with `self.prior_generator.grid_priors` ')
+
+ mlvl_points = []
+ for i in range(len(featmap_sizes)):
+ mlvl_points.append(
+ self._get_points_single(featmap_sizes[i], self.strides[i],
+ dtype, device, flatten))
+ return mlvl_points
+
+ def aug_test(self, feats, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ feats (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains features for all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[ndarray]: bbox results of each class
+ """
+ return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1bfab62de230feaccc83b935573b87d1d8061df
--- /dev/null
+++ b/mmdet/models/dense_heads/anchor_head.py
@@ -0,0 +1,542 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, build_assigner, build_bbox_coder,
+ build_prior_generator, build_sampler, images_to_levels,
+ multi_apply, unmap)
+from ..builder import HEADS, build_loss
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+
+
+@HEADS.register_module()
+class AnchorHead(BaseDenseHead, BBoxTestMixin):
+ """Anchor-based head (RPN, RetinaNet, SSD, etc.).
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels. Used in child classes.
+ anchor_generator (dict): Config dict for anchor generator
+ bbox_coder (dict): Config of bounding box coder.
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ train_cfg (dict): Training config of anchor head.
+ test_cfg (dict): Testing config of anchor head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8, 16, 32],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ clip_border=True,
+ target_means=(.0, .0, .0, .0),
+ target_stds=(1.0, 1.0, 1.0, 1.0)),
+ reg_decoded_bbox=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)):
+ super(AnchorHead, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.feat_channels = feat_channels
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = num_classes
+ else:
+ self.cls_out_channels = num_classes + 1
+
+ if self.cls_out_channels <= 0:
+ raise ValueError(f'num_classes={num_classes} is too small')
+ self.reg_decoded_bbox = reg_decoded_bbox
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ if hasattr(self.train_cfg,
+ 'sampler') and self.train_cfg.sampler.type.split(
+ '.')[-1] != 'PseudoSampler':
+ self.sampling = True
+ sampler_cfg = self.train_cfg.sampler
+ # avoid BC-breaking
+ if loss_cls['type'] in [
+ 'FocalLoss', 'GHMC', 'QualityFocalLoss'
+ ]:
+ warnings.warn(
+ 'DeprecationWarning: Determining whether to sampling'
+ 'by loss type is deprecated, please delete sampler in'
+ 'your config when using `FocalLoss`, `GHMC`, '
+ '`QualityFocalLoss` or other FocalLoss variant.')
+ self.sampling = False
+ sampler_cfg = dict(type='PseudoSampler')
+ else:
+ self.sampling = False
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.fp16_enabled = False
+
+ self.prior_generator = build_prior_generator(anchor_generator)
+
+ # Usually the numbers of anchors for each level are the same
+ # except SSD detectors. So it is an int in the most dense
+ # heads but a list of int in SSDHead
+ self.num_base_priors = self.prior_generator.num_base_priors[0]
+ self._init_layers()
+
+ @property
+ def num_anchors(self):
+ warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
+ 'for consistency or also use '
+ '`num_base_priors` instead')
+ return self.prior_generator.num_base_priors[0]
+
+ @property
+ def anchor_generator(self):
+ warnings.warn('DeprecationWarning: anchor_generator is deprecated, '
+ 'please use "prior_generator" instead')
+ return self.prior_generator
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.conv_cls = nn.Conv2d(self.in_channels,
+ self.num_base_priors * self.cls_out_channels,
+ 1)
+ self.conv_reg = nn.Conv2d(self.in_channels, self.num_base_priors * 4,
+ 1)
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls scores for a single scale level \
+ the channels number is num_base_priors * num_classes.
+ bbox_pred (Tensor): Box energies / deltas for a single scale \
+ level, the channels number is num_base_priors * 4.
+ """
+ cls_score = self.conv_cls(x)
+ bbox_pred = self.conv_reg(x)
+ return cls_score, bbox_pred
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: A tuple of classification scores and bbox prediction.
+
+ - cls_scores (list[Tensor]): Classification scores for all \
+ scale levels, each is a 4D-tensor, the channels number \
+ is num_base_priors * num_classes.
+ - bbox_preds (list[Tensor]): Box energies / deltas for all \
+ scale levels, each is a 4D-tensor, the channels number \
+ is num_base_priors * 4.
+ """
+ return multi_apply(self.forward_single, feats)
+
+ def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
+ """Get anchors according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+ device (torch.device | str): Device for returned tensors
+
+ Returns:
+ tuple:
+ anchor_list (list[Tensor]): Anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ """
+ num_imgs = len(img_metas)
+
+ # since feature map sizes of all images are the same, we only compute
+ # anchors for one time
+ multi_level_anchors = self.prior_generator.grid_priors(
+ featmap_sizes, device=device)
+ anchor_list = [multi_level_anchors for _ in range(num_imgs)]
+
+ # for each image, we compute valid flags of multi level anchors
+ valid_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = self.prior_generator.valid_flags(
+ featmap_sizes, img_meta['pad_shape'], device)
+ valid_flag_list.append(multi_level_flags)
+
+ return anchor_list, valid_flag_list
+
+ def _get_targets_single(self,
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ Args:
+ flat_anchors (Tensor): Multi-level anchors of the image, which are
+ concatenated into a single tensor of shape (num_anchors ,4)
+ valid_flags (Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_anchors,).
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ img_meta (dict): Meta info of the image.
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple:
+ labels_list (list[Tensor]): Labels of each level
+ label_weights_list (list[Tensor]): Label weights of each level
+ bbox_targets_list (list[Tensor]): BBox targets of each level
+ bbox_weights_list (list[Tensor]): BBox weights of each level
+ num_total_pos (int): Number of positive samples in all images
+ num_total_neg (int): Number of negative samples in all images
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+
+ assign_result = self.assigner.assign(
+ anchors, gt_bboxes, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags,
+ fill=self.num_classes) # fill bg label
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds, sampling_result)
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True,
+ return_sampling_results=False):
+ """Compute regression and classification targets for anchors in
+ multiple images.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, 4).
+ valid_flag_list (list[list[Tensor]]): Multi level valid flags of
+ each image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each
+ level.
+ - bbox_targets_list (list[Tensor]): BBox targets of each level.
+ - bbox_weights_list (list[Tensor]): BBox weights of each level.
+ - num_total_pos (int): Number of positive samples in all
+ images.
+ - num_total_neg (int): Number of negative samples in all
+ images.
+
+ additional_returns: This function enables user-defined returns from
+ `self._get_targets_single`. These returns are currently refined
+ to properties at each feature map (i.e. having HxW dimension).
+ The results will be concatenated after the end
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors to a single tensor
+ concat_anchor_list = []
+ concat_valid_flag_list = []
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ results = multi_apply(
+ self._get_targets_single,
+ concat_anchor_list,
+ concat_valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
+ pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
+ rest_results = list(results[7:]) # user-added return values
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors)
+ res = (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg)
+ if return_sampling_results:
+ res = res + (sampling_results_list, )
+ for i, r in enumerate(rest_results): # user-added return values
+ rest_results[i] = images_to_levels(r, num_level_anchors)
+
+ return res + tuple(rest_results)
+
+ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
+ bbox_targets, bbox_weights, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ cls_score (Tensor): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W).
+ bbox_pred (Tensor): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W).
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor
+ weight shape (N, num_total_anchors, 4).
+ bbox_weights (Tensor): BBox regression loss weights of each anchor
+ with shape (N, num_total_anchors, 4).
+ num_total_samples (int): If sampling, num total samples equal to
+ the number of total anchors; Otherwise, it is the number of
+ positive anchors.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+ # regression loss
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ bbox_weights = bbox_weights.reshape(-1, 4)
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ anchors = anchors.reshape(-1, 4)
+ bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
+ loss_bbox = self.loss_bbox(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+ return loss_cls, loss_bbox
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(anchor_list)):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+ return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
+
+ def aug_test(self, feats, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ feats (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains features for all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is ``bboxes`` with shape (n, 5), where
+ 5 represent (tl_x, tl_y, br_x, br_y, score).
+ The shape of the second tensor in the tuple is ``labels``
+ with shape (n,), The length of list should always be 1.
+ """
+ return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
diff --git a/mmdet/models/dense_heads/ascend_anchor_head.py b/mmdet/models/dense_heads/ascend_anchor_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d100ba9218f123c90feb54e5ec7a43b060356a1
--- /dev/null
+++ b/mmdet/models/dense_heads/ascend_anchor_head.py
@@ -0,0 +1,389 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ...core.bbox.assigners import AscendMaxIoUAssigner
+from ...core.bbox.samplers import PseudoSampler
+from ...utils import (batch_images_to_levels, get_max_num_gt_division_factor,
+ masked_fill)
+from ..builder import HEADS
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module()
+class AscendAnchorHead(AnchorHead):
+ """Ascend Anchor-based head (RetinaNet, SSD, etc.).
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels. Used in child classes.
+ anchor_generator (dict): Config dict for anchor generator
+ bbox_coder (dict): Config of bounding box coder.
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ train_cfg (dict): Training config of anchor head.
+ test_cfg (dict): Testing config of anchor head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8, 16, 32],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ clip_border=True,
+ target_means=(.0, .0, .0, .0),
+ target_stds=(1.0, 1.0, 1.0, 1.0)),
+ reg_decoded_bbox=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)):
+ super(AscendAnchorHead, self).__init__(
+ num_classes=num_classes,
+ in_channels=in_channels,
+ feat_channels=feat_channels,
+ anchor_generator=anchor_generator,
+ bbox_coder=bbox_coder,
+ reg_decoded_bbox=reg_decoded_bbox,
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ init_cfg=init_cfg)
+
+ def get_batch_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, device,
+ max_gt_labels):
+ """Get ground truth bboxes of all image.
+
+ Args:
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ num_images (int): The num of images.
+ gt_nums(list[int]): The ground truth bboxes num of each image.
+ device (torch.device | str): Device for returned tensors
+ max_gt_labels(int): The max ground truth bboxes num of all image.
+ Returns:
+ batch_gt_bboxes: (Tensor): Ground truth bboxes of all image.
+ """
+ # a static ground truth boxes.
+ # Save static gt. Related to Ascend. Helps improve performance
+ if not hasattr(self, 'batch_gt_bboxes'):
+ self.batch_gt_bboxes = {}
+ # a min anchor filled the excess anchor
+ if not hasattr(self, 'min_anchor'):
+ self.min_anchor = (-1354, -1344)
+ if gt_bboxes_list is None:
+ batch_gt_bboxes = None
+ else:
+ if self.batch_gt_bboxes.get(max_gt_labels) is None:
+ batch_gt_bboxes = torch.zeros((num_images, max_gt_labels, 4),
+ dtype=gt_bboxes_list[0].dtype,
+ device=device)
+ batch_gt_bboxes[:, :, :2] = self.min_anchor[0]
+ batch_gt_bboxes[:, :, 2:] = self.min_anchor[1]
+ self.batch_gt_bboxes[max_gt_labels] = batch_gt_bboxes.clone()
+ else:
+ batch_gt_bboxes = self.batch_gt_bboxes.get(
+ max_gt_labels).clone()
+ for index_imgs, gt_bboxes in enumerate(gt_bboxes_list):
+ batch_gt_bboxes[index_imgs, :gt_nums[index_imgs]] = gt_bboxes
+ return batch_gt_bboxes
+
+ def get_batch_gt_bboxes_ignore(self, gt_bboxes_ignore_list, num_images,
+ gt_nums, device):
+ """Ground truth bboxes to be ignored of all image.
+
+ Args:
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ num_images (int): The num of images.
+ gt_nums(list[int]): The ground truth bboxes num of each image.
+ device (torch.device | str): Device for returned tensors
+ Returns:
+ batch_gt_bboxes_ignore: (Tensor): Ground truth bboxes to be
+ ignored of all image.
+ """
+ # TODO: support gt_bboxes_ignore_list
+ if gt_bboxes_ignore_list is None:
+ batch_gt_bboxes_ignore = None
+ else:
+ raise RuntimeError('gt_bboxes_ignore not support yet')
+ return batch_gt_bboxes_ignore
+
+ def get_batch_gt_labels(self, gt_labels_list, num_images, gt_nums, device,
+ max_gt_labels):
+ """Ground truth bboxes to be ignored of all image.
+
+ Args:
+ gt_labels_list (list[Tensor]): Ground truth labels.
+ num_images (int): The num of images.
+ gt_nums(list[int]): The ground truth bboxes num of each image.
+ device (torch.device | str): Device for returned tensors
+ Returns:
+ batch_gt_labels: (Tensor): Ground truth labels of all image.
+ """
+ if gt_labels_list is None:
+ batch_gt_labels = None
+ else:
+ batch_gt_labels = torch.zeros((num_images, max_gt_labels),
+ dtype=gt_labels_list[0].dtype,
+ device=device)
+ for index_imgs, gt_labels in enumerate(gt_labels_list):
+ batch_gt_labels[index_imgs, :gt_nums[index_imgs]] = gt_labels
+
+ return batch_gt_labels
+
+ def _get_targets_concat(self,
+ batch_anchors,
+ batch_valid_flags,
+ batch_gt_bboxes,
+ batch_gt_bboxes_ignore,
+ batch_gt_labels,
+ img_metas,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in all
+ images.
+
+ Args:
+ batch_anchors (Tensor): anchors of all image, which are
+ concatenated into a single tensor of
+ shape (num_imgs, num_anchors ,4).
+ batch_valid_flags (Tensor): valid flags of all image,
+ which are concatenated into a single tensor of
+ shape (num_imgs, num_anchors,).
+ batch_gt_bboxes (Tensor): Ground truth bboxes of all image,
+ shape (num_imgs, max_gt_nums, 4).
+ batch_gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_imgs, num_ignored_gts, 4).
+ batch_gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_imgs, max_gt_nums,).
+ img_metas (list[dict]): Meta info of each image.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple:
+ batch_labels (Tensor): Labels of all level
+ batch_label_weights (Tensor): Label weights of all level
+ batch_bbox_targets (Tensor): BBox targets of all level
+ batch_bbox_weights (Tensor): BBox weights of all level
+ batch_pos_mask (Tensor): Positive samples mask in all images
+ batch_neg_mask (Tensor): Negative samples mask in all images
+ sampling_result (Sampling): The result of sampling,
+ default: None.
+ """
+ num_imgs, num_anchors, _ = batch_anchors.size()
+ # assign gt and sample batch_anchors
+ assign_result = self.assigner.assign(
+ batch_anchors,
+ batch_gt_bboxes,
+ batch_gt_bboxes_ignore,
+ None if self.sampling else batch_gt_labels,
+ batch_bboxes_ignore_mask=batch_valid_flags)
+ # TODO: support sampling_result
+ sampling_result = None
+ batch_pos_mask = assign_result.batch_pos_mask
+ batch_neg_mask = assign_result.batch_neg_mask
+ batch_anchor_gt_indes = assign_result.batch_anchor_gt_indes
+ batch_anchor_gt_labels = assign_result.batch_anchor_gt_labels
+
+ batch_anchor_gt_bboxes = torch.zeros(
+ batch_anchors.size(),
+ dtype=batch_anchors.dtype,
+ device=batch_anchors.device)
+ for index_imgs in range(num_imgs):
+ batch_anchor_gt_bboxes[index_imgs] = torch.index_select(
+ batch_gt_bboxes[index_imgs], 0,
+ batch_anchor_gt_indes[index_imgs])
+
+ batch_bbox_targets = torch.zeros_like(batch_anchors)
+ batch_bbox_weights = torch.zeros_like(batch_anchors)
+ batch_labels = batch_anchors.new_full((num_imgs, num_anchors),
+ self.num_classes,
+ dtype=torch.int)
+ batch_label_weights = batch_anchors.new_zeros((num_imgs, num_anchors),
+ dtype=torch.float)
+
+ if not self.reg_decoded_bbox:
+ batch_pos_bbox_targets = self.bbox_coder.encode(
+ batch_anchors, batch_anchor_gt_bboxes)
+ else:
+ batch_pos_bbox_targets = batch_anchor_gt_bboxes
+
+ batch_bbox_targets = masked_fill(batch_bbox_targets,
+ batch_pos_mask.unsqueeze(2),
+ batch_pos_bbox_targets)
+ batch_bbox_weights = masked_fill(batch_bbox_weights,
+ batch_pos_mask.unsqueeze(2), 1.0)
+ if batch_gt_labels is None:
+ batch_labels = masked_fill(batch_labels, batch_pos_mask, 0.0)
+ else:
+ batch_labels = masked_fill(batch_labels, batch_pos_mask,
+ batch_anchor_gt_labels)
+ if self.train_cfg.pos_weight <= 0:
+ batch_label_weights = masked_fill(batch_label_weights,
+ batch_pos_mask, 1.0)
+ else:
+ batch_label_weights = masked_fill(batch_label_weights,
+ batch_pos_mask,
+ self.train_cfg.pos_weight)
+ batch_label_weights = masked_fill(batch_label_weights, batch_neg_mask,
+ 1.0)
+ return (batch_labels, batch_label_weights, batch_bbox_targets,
+ batch_bbox_weights, batch_pos_mask, batch_neg_mask,
+ sampling_result)
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True,
+ return_sampling_results=False,
+ return_level=True):
+ """Compute regression and classification targets for anchors in
+ multiple images.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, 4).
+ valid_flag_list (list[list[Tensor]]): Multi level valid flags of
+ each image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+ return_sampling_results (bool): Whether to return the result of
+ sample.
+ return_level (bool): Whether to map outputs back to the levels
+ of feature map sizes.
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each
+ level.
+ - bbox_targets_list (list[Tensor]): BBox targets of each level.
+ - bbox_weights_list (list[Tensor]): BBox weights of each level.
+ - num_total_pos (int): Number of positive samples in all
+ images.
+ - num_total_neg (int): Number of negative samples in all
+ images.
+
+ additional_returns: This function enables user-defined returns from
+ `self._get_targets_single`. These returns are currently refined
+ to properties at each feature map (i.e. having HxW dimension).
+ The results will be concatenated after the end
+ """
+ assert gt_bboxes_ignore_list is None
+ assert unmap_outputs is True
+ assert return_sampling_results is False
+ assert self.train_cfg.allowed_border < 0
+ assert isinstance(self.assigner, AscendMaxIoUAssigner)
+ assert isinstance(self.sampler, PseudoSampler)
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ device = anchor_list[0][0].device
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+
+ batch_anchor_list = []
+ batch_valid_flag_list = []
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ batch_anchor_list.append(torch.cat(anchor_list[i]))
+ batch_valid_flag_list.append(torch.cat(valid_flag_list[i]))
+ batch_anchors = torch.cat(
+ [torch.unsqueeze(anchor, 0) for anchor in batch_anchor_list], 0)
+ batch_valid_flags = torch.cat([
+ torch.unsqueeze(batch_valid_flag, 0)
+ for batch_valid_flag in batch_valid_flag_list
+ ], 0)
+
+ gt_nums = [len(gt_bbox) for gt_bbox in gt_bboxes_list]
+ max_gt_nums = get_max_num_gt_division_factor(gt_nums)
+ batch_gt_bboxes = self.get_batch_gt_bboxes(gt_bboxes_list, num_imgs,
+ gt_nums, device,
+ max_gt_nums)
+ batch_gt_bboxes_ignore = self.get_batch_gt_bboxes_ignore(
+ gt_bboxes_ignore_list, num_imgs, gt_nums, device)
+ batch_gt_labels = self.get_batch_gt_labels(gt_labels_list, num_imgs,
+ gt_nums, device,
+ max_gt_nums)
+
+ results = self._get_targets_concat(
+ batch_anchors,
+ batch_valid_flags,
+ batch_gt_bboxes,
+ batch_gt_bboxes_ignore,
+ batch_gt_labels,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+
+ (batch_labels, batch_label_weights, batch_bbox_targets,
+ batch_bbox_weights, batch_pos_mask, batch_neg_mask,
+ sampling_result) = results[:7]
+ rest_results = list(results[7:]) # user-added return values
+
+ # sampled anchors of all images
+ min_num = torch.ones((num_imgs, ),
+ dtype=torch.long,
+ device=batch_pos_mask.device)
+ num_total_pos = torch.sum(
+ torch.max(torch.sum(batch_pos_mask, dim=1), min_num))
+ num_total_neg = torch.sum(
+ torch.max(torch.sum(batch_neg_mask, dim=1), min_num))
+ if return_level is True:
+ labels_list = batch_images_to_levels(batch_labels,
+ num_level_anchors)
+ label_weights_list = batch_images_to_levels(
+ batch_label_weights, num_level_anchors)
+ bbox_targets_list = batch_images_to_levels(batch_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = batch_images_to_levels(batch_bbox_weights,
+ num_level_anchors)
+ res = (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg)
+ if return_sampling_results:
+ res = res + (sampling_result, )
+ for i, r in enumerate(rest_results): # user-added return values
+ rest_results[i] = batch_images_to_levels(r, num_level_anchors)
+
+ return res + tuple(rest_results)
+ else:
+ res = (batch_labels, batch_label_weights, batch_bbox_targets,
+ batch_bbox_weights, batch_pos_mask, batch_neg_mask,
+ sampling_result, num_total_pos, num_total_neg,
+ batch_anchors)
+ return res
diff --git a/mmdet/models/dense_heads/ascend_retina_head.py b/mmdet/models/dense_heads/ascend_retina_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..159fe75c1cafa8fcbcda5affe9b442fa9018bdef
--- /dev/null
+++ b/mmdet/models/dense_heads/ascend_retina_head.py
@@ -0,0 +1,115 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import HEADS
+from .ascend_anchor_head import AscendAnchorHead
+from .retina_head import RetinaHead
+
+
+@HEADS.register_module()
+class AscendRetinaHead(RetinaHead, AscendAnchorHead):
+ r"""An anchor-based head used in `RetinaNet
+ `_.
+
+ The head contains two subnetworks. The first classifies anchor boxes and
+ the second regresses deltas for the anchors.
+
+ Example:
+ >>> import torch
+ >>> self = RetinaHead(11, 7)
+ >>> x = torch.rand(1, 7, 32, 32)
+ >>> cls_score, bbox_pred = self.forward_single(x)
+ >>> # Each anchor predicts a score for each class except background
+ >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
+ >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
+ >>> assert cls_per_anchor == (self.num_classes)
+ >>> assert box_per_anchor == 4
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=4,
+ scales_per_octave=3,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64, 128]),
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='retina_cls',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ super(AscendRetinaHead, self).__init__(
+ num_classes=num_classes,
+ in_channels=in_channels,
+ stacked_convs=stacked_convs,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ anchor_generator=anchor_generator,
+ init_cfg=init_cfg,
+ **kwargs)
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True,
+ return_sampling_results=False,
+ return_level=True):
+ """Compute regression and classification targets for anchors in
+ multiple images.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, 4).
+ valid_flag_list (list[list[Tensor]]): Multi level valid flags of
+ each image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+ return_sampling_results (bool): Whether to return the result of
+ sample.
+ return_level (bool): Whether to map outputs back to the levels
+ of feature map sizes.
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each
+ level.
+ - bbox_targets_list (list[Tensor]): BBox targets of each level.
+ - bbox_weights_list (list[Tensor]): BBox weights of each level.
+ - num_total_pos (int): Number of positive samples in all
+ images.
+ - num_total_neg (int): Number of negative samples in all
+ images.
+
+ additional_returns: This function enables user-defined returns from
+ `self._get_targets_single`. These returns are currently refined
+ to properties at each feature map (i.e. having HxW dimension).
+ The results will be concatenated after the end
+ """
+ return AscendAnchorHead.get_targets(
+ self, anchor_list, valid_flag_list, gt_bboxes_list, img_metas,
+ gt_bboxes_ignore_list, gt_labels_list, label_channels,
+ unmap_outputs, return_sampling_results, return_level)
diff --git a/mmdet/models/dense_heads/ascend_ssd_head.py b/mmdet/models/dense_heads/ascend_ssd_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e326b48bc18a6f446437333b38d8ef558234be7
--- /dev/null
+++ b/mmdet/models/dense_heads/ascend_ssd_head.py
@@ -0,0 +1,328 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn.functional as F
+from mmcv.runner import force_fp32
+
+from ..builder import HEADS
+from ..losses import smooth_l1_loss
+from .ascend_anchor_head import AscendAnchorHead
+from .ssd_head import SSDHead
+
+
+@HEADS.register_module()
+class AscendSSDHead(SSDHead, AscendAnchorHead):
+ """Ascend SSD head used in https://arxiv.org/abs/1512.02325.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): Number of conv layers in cls and reg tower.
+ Default: 0.
+ feat_channels (int): Number of hidden channels when stacked_convs
+ > 0. Default: 256.
+ use_depthwise (bool): Whether to use DepthwiseSeparableConv.
+ Default: False.
+ conv_cfg (dict): Dictionary to construct and config conv layer.
+ Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: None.
+ act_cfg (dict): Dictionary to construct and config activation layer.
+ Default: None.
+ anchor_generator (dict): Config dict for anchor generator
+ bbox_coder (dict): Config of bounding box coder.
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ train_cfg (dict): Training config of anchor head.
+ test_cfg (dict): Testing config of anchor head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes=80,
+ in_channels=(512, 1024, 512, 256, 256, 256),
+ stacked_convs=0,
+ feat_channels=256,
+ use_depthwise=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ anchor_generator=dict(
+ type='SSDAnchorGenerator',
+ scale_major=False,
+ input_size=300,
+ strides=[8, 16, 32, 64, 100, 300],
+ ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
+ basesize_ratio_range=(0.1, 0.9)),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ clip_border=True,
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0],
+ ),
+ reg_decoded_bbox=False,
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=dict(
+ type='Xavier',
+ layer='Conv2d',
+ distribution='uniform',
+ bias=0)):
+ super(AscendSSDHead, self).__init__(
+ num_classes=num_classes,
+ in_channels=in_channels,
+ stacked_convs=stacked_convs,
+ feat_channels=feat_channels,
+ use_depthwise=use_depthwise,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ anchor_generator=anchor_generator,
+ bbox_coder=bbox_coder,
+ reg_decoded_bbox=reg_decoded_bbox,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ init_cfg=init_cfg)
+ assert self.reg_decoded_bbox is False, \
+ 'reg_decoded_bbox only support False now.'
+
+ def get_static_anchors(self, featmap_sizes, img_metas, device='cuda'):
+ """Get static anchors according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+ device (torch.device | str): Device for returned tensors
+
+ Returns:
+ tuple:
+ anchor_list (list[Tensor]): Anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ """
+ if not hasattr(self, 'static_anchors') or \
+ not hasattr(self, 'static_valid_flags'):
+ static_anchors, static_valid_flags = self.get_anchors(
+ featmap_sizes, img_metas, device)
+ self.static_anchors = static_anchors
+ self.static_valid_flags = static_valid_flags
+ return self.static_anchors, self.static_valid_flags
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True,
+ return_sampling_results=False,
+ return_level=True):
+ """Compute regression and classification targets for anchors in
+ multiple images.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, 4).
+ valid_flag_list (list[list[Tensor]]): Multi level valid flags of
+ each image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+ return_sampling_results (bool): Whether to return the result of
+ sample.
+ return_level (bool): Whether to map outputs back to the levels
+ of feature map sizes.
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each
+ level.
+ - bbox_targets_list (list[Tensor]): BBox targets of each level.
+ - bbox_weights_list (list[Tensor]): BBox weights of each level.
+ - num_total_pos (int): Number of positive samples in all
+ images.
+ - num_total_neg (int): Number of negative samples in all
+ images.
+
+ additional_returns: This function enables user-defined returns from
+ `self._get_targets_single`. These returns are currently refined
+ to properties at each feature map (i.e. having HxW dimension).
+ The results will be concatenated after the end
+ """
+ return AscendAnchorHead.get_targets(
+ self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ label_channels,
+ unmap_outputs,
+ return_sampling_results,
+ return_level,
+ )
+
+ def batch_loss(self, batch_cls_score, batch_bbox_pred, batch_anchor,
+ batch_labels, batch_label_weights, batch_bbox_targets,
+ batch_bbox_weights, batch_pos_mask, batch_neg_mask,
+ num_total_samples):
+ """Compute loss of all images.
+
+ Args:
+ batch_cls_score (Tensor): Box scores for all image
+ Has shape (num_imgs, num_total_anchors, num_classes).
+ batch_bbox_pred (Tensor): Box energies / deltas for all image
+ level with shape (num_imgs, num_total_anchors, 4).
+ batch_anchor (Tensor): Box reference for all image with shape
+ (num_imgs, num_total_anchors, 4).
+ batch_labels (Tensor): Labels of all anchors with shape
+ (num_imgs, num_total_anchors,).
+ batch_label_weights (Tensor): Label weights of all anchor with
+ shape (num_imgs, num_total_anchors,)
+ batch_bbox_targets (Tensor): BBox regression targets of all anchor
+ weight shape (num_imgs, num_total_anchors, 4).
+ batch_bbox_weights (Tensor): BBox regression loss weights of
+ all anchor with shape (num_imgs, num_total_anchors, 4).
+ batch_pos_mask (Tensor): Positive samples mask in all images.
+ batch_neg_mask (Tensor): negative samples mask in all images.
+ num_total_samples (int): If sampling, num total samples equal to
+ the number of total anchors; Otherwise, it is the number of
+ positive anchors.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_images, num_anchors, _ = batch_anchor.size()
+
+ batch_loss_cls_all = F.cross_entropy(
+ batch_cls_score.view((-1, self.cls_out_channels)),
+ batch_labels.view(-1),
+ reduction='none').view(
+ batch_label_weights.size()) * batch_label_weights
+ # # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ batch_num_pos_samples = torch.sum(batch_pos_mask, dim=1)
+ batch_num_neg_samples = \
+ self.train_cfg.neg_pos_ratio * batch_num_pos_samples
+
+ batch_num_neg_samples_max = torch.sum(batch_neg_mask, dim=1)
+ batch_num_neg_samples = torch.min(batch_num_neg_samples,
+ batch_num_neg_samples_max)
+
+ batch_topk_loss_cls_neg, _ = torch.topk(
+ batch_loss_cls_all * batch_neg_mask, k=num_anchors, dim=1)
+ batch_loss_cls_pos = torch.sum(
+ batch_loss_cls_all * batch_pos_mask, dim=1)
+
+ anchor_index = torch.arange(
+ end=num_anchors, dtype=torch.float,
+ device=batch_anchor.device).view((1, -1))
+ topk_loss_neg_mask = (anchor_index < batch_num_neg_samples.view(
+ -1, 1)).float()
+
+ batch_loss_cls_neg = torch.sum(
+ batch_topk_loss_cls_neg * topk_loss_neg_mask, dim=1)
+ loss_cls = \
+ (batch_loss_cls_pos + batch_loss_cls_neg) / num_total_samples
+
+ if self.reg_decoded_bbox:
+ # TODO: support self.reg_decoded_bbox is True
+ raise RuntimeError
+
+ loss_bbox_all = smooth_l1_loss(
+ batch_bbox_pred,
+ batch_bbox_targets,
+ batch_bbox_weights,
+ reduction='none',
+ beta=self.train_cfg.smoothl1_beta,
+ avg_factor=num_total_samples)
+ eps = torch.finfo(torch.float32).eps
+
+ sum_dim = (i for i in range(1, len(loss_bbox_all.size())))
+ loss_bbox = loss_bbox_all.sum(tuple(sum_dim)) / (
+ num_total_samples + eps)
+ return loss_cls[None], loss_bbox
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=1,
+ unmap_outputs=True,
+ return_level=False)
+ if cls_reg_targets is None:
+ return None
+
+ (batch_labels, batch_label_weights, batch_bbox_targets,
+ batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result,
+ num_total_pos, num_total_neg, batch_anchors) = cls_reg_targets
+
+ num_imgs = len(img_metas)
+ batch_cls_score = torch.cat([
+ s.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels)
+ for s in cls_scores
+ ], 1)
+
+ batch_bbox_pred = torch.cat([
+ b.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for b in bbox_preds
+ ], -2)
+
+ batch_losses_cls, batch_losses_bbox = self.batch_loss(
+ batch_cls_score, batch_bbox_pred, batch_anchors, batch_labels,
+ batch_label_weights, batch_bbox_targets, batch_bbox_weights,
+ batch_pos_mask, batch_neg_mask, num_total_pos)
+ losses_cls = [
+ batch_losses_cls[:, index_imgs] for index_imgs in range(num_imgs)
+ ]
+ losses_bbox = [losses_bbox for losses_bbox in batch_losses_bbox]
+ return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
diff --git a/mmdet/models/dense_heads/atss_head.py b/mmdet/models/dense_heads/atss_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8f401caa1a83cf6f6b62a642fb1d42c379a4e11
--- /dev/null
+++ b/mmdet/models/dense_heads/atss_head.py
@@ -0,0 +1,501 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, Scale
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, build_assigner, build_sampler,
+ images_to_levels, multi_apply, reduce_mean, unmap)
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module()
+class ATSSHead(AnchorHead):
+ """Bridging the Gap Between Anchor-based and Anchor-free Detection via
+ Adaptive Training Sample Selection.
+
+ ATSS head structure is similar with FCOS, however ATSS use anchor boxes
+ and assign label by Adaptive Training Sample Selection instead max-iou.
+
+ https://arxiv.org/abs/1912.02424
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ pred_kernel_size=3,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ reg_decoded_bbox=True,
+ loss_centerness=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='atss_cls',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ self.pred_kernel_size = pred_kernel_size
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ super(ATSSHead, self).__init__(
+ num_classes,
+ in_channels,
+ reg_decoded_bbox=reg_decoded_bbox,
+ init_cfg=init_cfg,
+ **kwargs)
+
+ self.sampling = False
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # SSD sampling=False so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.loss_centerness = build_loss(loss_centerness)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ pred_pad_size = self.pred_kernel_size // 2
+ self.atss_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ self.pred_kernel_size,
+ padding=pred_pad_size)
+ self.atss_reg = nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * 4,
+ self.pred_kernel_size,
+ padding=pred_pad_size)
+ self.atss_centerness = nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * 1,
+ self.pred_kernel_size,
+ padding=pred_pad_size)
+ self.scales = nn.ModuleList(
+ [Scale(1.0) for _ in self.prior_generator.strides])
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * 4.
+ """
+ return multi_apply(self.forward_single, feats, self.scales)
+
+ def forward_single(self, x, scale):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls scores for a single scale level
+ the channels number is num_anchors * num_classes.
+ bbox_pred (Tensor): Box energies / deltas for a single scale
+ level, the channels number is num_anchors * 4.
+ centerness (Tensor): Centerness for a single scale level, the
+ channel number is (N, num_anchors * 1, H, W).
+ """
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.atss_cls(cls_feat)
+ # we just follow atss, not apply exp in bbox_pred
+ bbox_pred = scale(self.atss_reg(reg_feat)).float()
+ centerness = self.atss_centerness(reg_feat)
+ return cls_score, bbox_pred, centerness
+
+ def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
+ label_weights, bbox_targets, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ cls_score (Tensor): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W).
+ bbox_pred (Tensor): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W).
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor
+ weight shape (N, num_total_anchors, 4).
+ num_total_samples (int): Number os positive samples that is
+ reduced over all GPUs.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ anchors = anchors.reshape(-1, 4)
+ cls_score = cls_score.permute(0, 2, 3, 1).reshape(
+ -1, self.cls_out_channels).contiguous()
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+
+ # classification loss
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((labels >= 0)
+ & (labels < bg_class_ind)).nonzero().squeeze(1)
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = bbox_targets[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_anchors = anchors[pos_inds]
+ pos_centerness = centerness[pos_inds]
+
+ centerness_targets = self.centerness_target(
+ pos_anchors, pos_bbox_targets)
+ pos_decode_bbox_pred = self.bbox_coder.decode(
+ pos_anchors, pos_bbox_pred)
+
+ # regression loss
+ loss_bbox = self.loss_bbox(
+ pos_decode_bbox_pred,
+ pos_bbox_targets,
+ weight=centerness_targets,
+ avg_factor=1.0)
+
+ # centerness loss
+ loss_centerness = self.loss_centerness(
+ pos_centerness,
+ centerness_targets,
+ avg_factor=num_total_samples)
+
+ else:
+ loss_bbox = bbox_pred.sum() * 0
+ loss_centerness = centerness.sum() * 0
+ centerness_targets = bbox_targets.new_tensor(0.)
+
+ return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ centernesses (list[Tensor]): Centerness for each scale
+ level with shape (N, num_anchors * 1, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ losses_cls, losses_bbox, loss_centerness,\
+ bbox_avg_factor = multi_apply(
+ self.loss_single,
+ anchor_list,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ num_total_samples=num_total_samples)
+
+ bbox_avg_factor = sum(bbox_avg_factor)
+ bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item()
+ losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox=losses_bbox,
+ loss_centerness=loss_centerness)
+
+ def centerness_target(self, anchors, gts):
+ # only calculate pos centerness targets, otherwise there may be nan
+ anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
+ anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
+ l_ = anchors_cx - gts[:, 0]
+ t_ = anchors_cy - gts[:, 1]
+ r_ = gts[:, 2] - anchors_cx
+ b_ = gts[:, 3] - anchors_cy
+
+ left_right = torch.stack([l_, r_], dim=1)
+ top_bottom = torch.stack([t_, b_], dim=1)
+ centerness = torch.sqrt(
+ (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
+ (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
+ assert not torch.isnan(centerness).any()
+ return centerness
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Get targets for ATSS head.
+
+ This method is almost the same as `AnchorHead.get_targets()`. Besides
+ returning the targets as the parent method does, it also returns the
+ anchors as the first element of the returned tuple.
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ num_level_anchors_list = [num_level_anchors] * num_imgs
+
+ # concat all level anchors and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ anchor_list[i] = torch.cat(anchor_list[i])
+ valid_flag_list[i] = torch.cat(valid_flag_list[i])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ anchors_list = images_to_levels(all_anchors, num_level_anchors)
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors)
+ return (anchors_list, labels_list, label_weights_list,
+ bbox_targets_list, bbox_weights_list, num_total_pos,
+ num_total_neg)
+
+ def _get_target_single(self,
+ flat_anchors,
+ valid_flags,
+ num_level_anchors,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression, classification targets for anchors in a single
+ image.
+
+ Args:
+ flat_anchors (Tensor): Multi-level anchors of the image, which are
+ concatenated into a single tensor of shape (num_anchors ,4)
+ valid_flags (Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_anchors,).
+ num_level_anchors Tensor): Number of anchors of each scale level.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ img_meta (dict): Meta info of the image.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: N is the number of total anchors in the image.
+ labels (Tensor): Labels of all anchors in the image with shape
+ (N,).
+ label_weights (Tensor): Label weights of all anchor in the
+ image with shape (N,).
+ bbox_targets (Tensor): BBox targets of all anchors in the
+ image with shape (N, 4).
+ bbox_weights (Tensor): BBox weights of all anchors in the
+ image with shape (N, 4)
+ pos_inds (Tensor): Indices of positive anchor with shape
+ (num_pos,).
+ neg_inds (Tensor): Indices of negative anchor with shape
+ (num_neg,).
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+
+ num_level_anchors_inside = self.get_num_level_anchors_inside(
+ num_level_anchors, inside_flags)
+ assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
+ gt_bboxes, gt_bboxes_ignore,
+ gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if self.reg_decoded_bbox:
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ else:
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ anchors = unmap(anchors, num_total_anchors, inside_flags)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags, fill=self.num_classes)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (anchors, labels, label_weights, bbox_targets, bbox_weights,
+ pos_inds, neg_inds)
+
+ def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
+ split_inside_flags = torch.split(inside_flags, num_level_anchors)
+ num_level_anchors_inside = [
+ int(flags.sum()) for flags in split_inside_flags
+ ]
+ return num_level_anchors_inside
diff --git a/mmdet/models/dense_heads/autoassign_head.py b/mmdet/models/dense_heads/autoassign_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..446da244b9e78a4e64d8633477600ad6d732e327
--- /dev/null
+++ b/mmdet/models/dense_heads/autoassign_head.py
@@ -0,0 +1,527 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import bias_init_with_prob, normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import multi_apply
+from mmdet.core.anchor.point_generator import MlvlPointGenerator
+from mmdet.core.bbox import bbox_overlaps
+from mmdet.models import HEADS
+from mmdet.models.dense_heads.atss_head import reduce_mean
+from mmdet.models.dense_heads.fcos_head import FCOSHead
+from mmdet.models.dense_heads.paa_head import levels_to_images
+
+EPS = 1e-12
+
+
+class CenterPrior(nn.Module):
+ """Center Weighting module to adjust the category-specific prior
+ distributions.
+
+ Args:
+ force_topk (bool): When no point falls into gt_bbox, forcibly
+ select the k points closest to the center to calculate
+ the center prior. Defaults to False.
+ topk (int): The number of points used to calculate the
+ center prior when no point falls in gt_bbox. Only work when
+ force_topk if True. Defaults to 9.
+ num_classes (int): The class number of dataset. Defaults to 80.
+ strides (tuple[int]): The stride of each input feature map. Defaults
+ to (8, 16, 32, 64, 128).
+ """
+
+ def __init__(self,
+ force_topk=False,
+ topk=9,
+ num_classes=80,
+ strides=(8, 16, 32, 64, 128)):
+ super(CenterPrior, self).__init__()
+ self.mean = nn.Parameter(torch.zeros(num_classes, 2))
+ self.sigma = nn.Parameter(torch.ones(num_classes, 2))
+ self.strides = strides
+ self.force_topk = force_topk
+ self.topk = topk
+
+ def forward(self, anchor_points_list, gt_bboxes, labels,
+ inside_gt_bbox_mask):
+ """Get the center prior of each point on the feature map for each
+ instance.
+
+ Args:
+ anchor_points_list (list[Tensor]): list of coordinate
+ of points on feature map. Each with shape
+ (num_points, 2).
+ gt_bboxes (Tensor): The gt_bboxes with shape of
+ (num_gt, 4).
+ labels (Tensor): The gt_labels with shape of (num_gt).
+ inside_gt_bbox_mask (Tensor): Tensor of bool type,
+ with shape of (num_points, num_gt), each
+ value is used to mark whether this point falls
+ within a certain gt.
+
+ Returns:
+ tuple(Tensor):
+
+ - center_prior_weights(Tensor): Float tensor with shape \
+ of (num_points, num_gt). Each value represents \
+ the center weighting coefficient.
+ - inside_gt_bbox_mask (Tensor): Tensor of bool type, \
+ with shape of (num_points, num_gt), each \
+ value is used to mark whether this point falls \
+ within a certain gt or is the topk nearest points for \
+ a specific gt_bbox.
+ """
+ inside_gt_bbox_mask = inside_gt_bbox_mask.clone()
+ num_gts = len(labels)
+ num_points = sum([len(item) for item in anchor_points_list])
+ if num_gts == 0:
+ return gt_bboxes.new_zeros(num_points,
+ num_gts), inside_gt_bbox_mask
+ center_prior_list = []
+ for slvl_points, stride in zip(anchor_points_list, self.strides):
+ # slvl_points: points from single level in FPN, has shape (h*w, 2)
+ # single_level_points has shape (h*w, num_gt, 2)
+ single_level_points = slvl_points[:, None, :].expand(
+ (slvl_points.size(0), len(gt_bboxes), 2))
+ gt_center_x = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2)
+ gt_center_y = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2)
+ gt_center = torch.stack((gt_center_x, gt_center_y), dim=1)
+ gt_center = gt_center[None]
+ # instance_center has shape (1, num_gt, 2)
+ instance_center = self.mean[labels][None]
+ # instance_sigma has shape (1, num_gt, 2)
+ instance_sigma = self.sigma[labels][None]
+ # distance has shape (num_points, num_gt, 2)
+ distance = (((single_level_points - gt_center) / float(stride) -
+ instance_center)**2)
+ center_prior = torch.exp(-distance /
+ (2 * instance_sigma**2)).prod(dim=-1)
+ center_prior_list.append(center_prior)
+ center_prior_weights = torch.cat(center_prior_list, dim=0)
+
+ if self.force_topk:
+ gt_inds_no_points_inside = torch.nonzero(
+ inside_gt_bbox_mask.sum(0) == 0).reshape(-1)
+ if gt_inds_no_points_inside.numel():
+ topk_center_index = \
+ center_prior_weights[:, gt_inds_no_points_inside].topk(
+ self.topk,
+ dim=0)[1]
+ temp_mask = inside_gt_bbox_mask[:, gt_inds_no_points_inside]
+ inside_gt_bbox_mask[:, gt_inds_no_points_inside] = \
+ torch.scatter(temp_mask,
+ dim=0,
+ index=topk_center_index,
+ src=torch.ones_like(
+ topk_center_index,
+ dtype=torch.bool))
+
+ center_prior_weights[~inside_gt_bbox_mask] = 0
+ return center_prior_weights, inside_gt_bbox_mask
+
+
+@HEADS.register_module()
+class AutoAssignHead(FCOSHead):
+ """AutoAssignHead head used in AutoAssign.
+
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ force_topk (bool): Used in center prior initialization to
+ handle extremely small gt. Default is False.
+ topk (int): The number of points used to calculate the
+ center prior when no point falls in gt_bbox. Only work when
+ force_topk if True. Defaults to 9.
+ pos_loss_weight (float): The loss weight of positive loss
+ and with default value 0.25.
+ neg_loss_weight (float): The loss weight of negative loss
+ and with default value 0.75.
+ center_loss_weight (float): The loss weight of center prior
+ loss and with default value 0.75.
+ """
+
+ def __init__(self,
+ *args,
+ force_topk=False,
+ topk=9,
+ pos_loss_weight=0.25,
+ neg_loss_weight=0.75,
+ center_loss_weight=0.75,
+ **kwargs):
+ super().__init__(*args, conv_bias=True, **kwargs)
+ self.center_prior = CenterPrior(
+ force_topk=force_topk,
+ topk=topk,
+ num_classes=self.num_classes,
+ strides=self.strides)
+ self.pos_loss_weight = pos_loss_weight
+ self.neg_loss_weight = neg_loss_weight
+ self.center_loss_weight = center_loss_weight
+ self.prior_generator = MlvlPointGenerator(self.strides, offset=0)
+
+ def init_weights(self):
+ """Initialize weights of the head.
+
+ In particular, we have special initialization for classified conv's and
+ regression conv's bias
+ """
+
+ super(AutoAssignHead, self).init_weights()
+ bias_cls = bias_init_with_prob(0.02)
+ normal_init(self.conv_cls, std=0.01, bias=bias_cls)
+ normal_init(self.conv_reg, std=0.01, bias=4.0)
+
+ def forward_single(self, x, scale, stride):
+ """Forward features of a single scale level.
+
+ Args:
+ x (Tensor): FPN feature maps of the specified stride.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+ stride (int): The corresponding stride for feature maps, only
+ used to normalize the bbox prediction when self.norm_on_bbox
+ is True.
+
+ Returns:
+ tuple: scores for each class, bbox predictions and centerness \
+ predictions of input feature maps.
+ """
+ cls_score, bbox_pred, cls_feat, reg_feat = super(
+ FCOSHead, self).forward_single(x)
+ centerness = self.conv_centerness(reg_feat)
+ # scale the bbox_pred of different level
+ # float to avoid overflow when enabling FP16
+ bbox_pred = scale(bbox_pred).float()
+ # bbox_pred needed for gradient computation has been modified
+ # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
+ # F.relu(bbox_pred) with bbox_pred.clamp(min=0)
+ bbox_pred = bbox_pred.clamp(min=0)
+ bbox_pred *= stride
+ return cls_score, bbox_pred, centerness
+
+ def get_pos_loss_single(self, cls_score, objectness, reg_loss, gt_labels,
+ center_prior_weights):
+ """Calculate the positive loss of all points in gt_bboxes.
+
+ Args:
+ cls_score (Tensor): All category scores for each point on
+ the feature map. The shape is (num_points, num_class).
+ objectness (Tensor): Foreground probability of all points,
+ has shape (num_points, 1).
+ reg_loss (Tensor): The regression loss of each gt_bbox and each
+ prediction box, has shape of (num_points, num_gt).
+ gt_labels (Tensor): The zeros based gt_labels of all gt
+ with shape of (num_gt,).
+ center_prior_weights (Tensor): Float tensor with shape
+ of (num_points, num_gt). Each value represents
+ the center weighting coefficient.
+
+ Returns:
+ tuple[Tensor]:
+
+ - pos_loss (Tensor): The positive loss of all points
+ in the gt_bboxes.
+ """
+ # p_loc: localization confidence
+ p_loc = torch.exp(-reg_loss)
+ # p_cls: classification confidence
+ p_cls = (cls_score * objectness)[:, gt_labels]
+ # p_pos: joint confidence indicator
+ p_pos = p_cls * p_loc
+
+ # 3 is a hyper-parameter to control the contributions of high and
+ # low confidence locations towards positive losses.
+ confidence_weight = torch.exp(p_pos * 3)
+ p_pos_weight = (confidence_weight * center_prior_weights) / (
+ (confidence_weight * center_prior_weights).sum(
+ 0, keepdim=True)).clamp(min=EPS)
+ reweighted_p_pos = (p_pos * p_pos_weight).sum(0)
+ pos_loss = F.binary_cross_entropy(
+ reweighted_p_pos,
+ torch.ones_like(reweighted_p_pos),
+ reduction='none')
+ pos_loss = pos_loss.sum() * self.pos_loss_weight
+ return pos_loss,
+
+ def get_neg_loss_single(self, cls_score, objectness, gt_labels, ious,
+ inside_gt_bbox_mask):
+ """Calculate the negative loss of all points in feature map.
+
+ Args:
+ cls_score (Tensor): All category scores for each point on
+ the feature map. The shape is (num_points, num_class).
+ objectness (Tensor): Foreground probability of all points
+ and is shape of (num_points, 1).
+ gt_labels (Tensor): The zeros based label of all gt with shape of
+ (num_gt).
+ ious (Tensor): Float tensor with shape of (num_points, num_gt).
+ Each value represent the iou of pred_bbox and gt_bboxes.
+ inside_gt_bbox_mask (Tensor): Tensor of bool type,
+ with shape of (num_points, num_gt), each
+ value is used to mark whether this point falls
+ within a certain gt.
+
+ Returns:
+ tuple[Tensor]:
+
+ - neg_loss (Tensor): The negative loss of all points
+ in the feature map.
+ """
+ num_gts = len(gt_labels)
+ joint_conf = (cls_score * objectness)
+ p_neg_weight = torch.ones_like(joint_conf)
+ if num_gts > 0:
+ # the order of dinmension would affect the value of
+ # p_neg_weight, we strictly follow the original
+ # implementation.
+ inside_gt_bbox_mask = inside_gt_bbox_mask.permute(1, 0)
+ ious = ious.permute(1, 0)
+
+ foreground_idxs = torch.nonzero(inside_gt_bbox_mask, as_tuple=True)
+ temp_weight = (1 / (1 - ious[foreground_idxs]).clamp_(EPS))
+
+ def normalize(x):
+ return (x - x.min() + EPS) / (x.max() - x.min() + EPS)
+
+ for instance_idx in range(num_gts):
+ idxs = foreground_idxs[0] == instance_idx
+ if idxs.any():
+ temp_weight[idxs] = normalize(temp_weight[idxs])
+
+ p_neg_weight[foreground_idxs[1],
+ gt_labels[foreground_idxs[0]]] = 1 - temp_weight
+
+ logits = (joint_conf * p_neg_weight)
+ neg_loss = (
+ logits**2 * F.binary_cross_entropy(
+ logits, torch.zeros_like(logits), reduction='none'))
+ neg_loss = neg_loss.sum() * self.neg_loss_weight
+ return neg_loss,
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'objectnesses'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ objectnesses,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level,
+ each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ objectnesses (list[Tensor]): objectness for each scale level, each
+ is a 4D-tensor, the channel number is num_points * 1.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ assert len(cls_scores) == len(bbox_preds) == len(objectnesses)
+ all_num_gt = sum([len(item) for item in gt_bboxes])
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ all_level_points = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=bbox_preds[0].dtype,
+ device=bbox_preds[0].device)
+ inside_gt_bbox_mask_list, bbox_targets_list = self.get_targets(
+ all_level_points, gt_bboxes)
+
+ center_prior_weight_list = []
+ temp_inside_gt_bbox_mask_list = []
+ for gt_bboxe, gt_label, inside_gt_bbox_mask in zip(
+ gt_bboxes, gt_labels, inside_gt_bbox_mask_list):
+ center_prior_weight, inside_gt_bbox_mask = \
+ self.center_prior(all_level_points, gt_bboxe, gt_label,
+ inside_gt_bbox_mask)
+ center_prior_weight_list.append(center_prior_weight)
+ temp_inside_gt_bbox_mask_list.append(inside_gt_bbox_mask)
+ inside_gt_bbox_mask_list = temp_inside_gt_bbox_mask_list
+ mlvl_points = torch.cat(all_level_points, dim=0)
+ bbox_preds = levels_to_images(bbox_preds)
+ cls_scores = levels_to_images(cls_scores)
+ objectnesses = levels_to_images(objectnesses)
+
+ reg_loss_list = []
+ ious_list = []
+ num_points = len(mlvl_points)
+
+ for bbox_pred, encoded_targets, inside_gt_bbox_mask in zip(
+ bbox_preds, bbox_targets_list, inside_gt_bbox_mask_list):
+ temp_num_gt = encoded_targets.size(1)
+ expand_mlvl_points = mlvl_points[:, None, :].expand(
+ num_points, temp_num_gt, 2).reshape(-1, 2)
+ encoded_targets = encoded_targets.reshape(-1, 4)
+ expand_bbox_pred = bbox_pred[:, None, :].expand(
+ num_points, temp_num_gt, 4).reshape(-1, 4)
+ decoded_bbox_preds = self.bbox_coder.decode(
+ expand_mlvl_points, expand_bbox_pred)
+ decoded_target_preds = self.bbox_coder.decode(
+ expand_mlvl_points, encoded_targets)
+ with torch.no_grad():
+ ious = bbox_overlaps(
+ decoded_bbox_preds, decoded_target_preds, is_aligned=True)
+ ious = ious.reshape(num_points, temp_num_gt)
+ if temp_num_gt:
+ ious = ious.max(
+ dim=-1, keepdim=True).values.repeat(1, temp_num_gt)
+ else:
+ ious = ious.new_zeros(num_points, temp_num_gt)
+ ious[~inside_gt_bbox_mask] = 0
+ ious_list.append(ious)
+ loss_bbox = self.loss_bbox(
+ decoded_bbox_preds,
+ decoded_target_preds,
+ weight=None,
+ reduction_override='none')
+ reg_loss_list.append(loss_bbox.reshape(num_points, temp_num_gt))
+
+ cls_scores = [item.sigmoid() for item in cls_scores]
+ objectnesses = [item.sigmoid() for item in objectnesses]
+ pos_loss_list, = multi_apply(self.get_pos_loss_single, cls_scores,
+ objectnesses, reg_loss_list, gt_labels,
+ center_prior_weight_list)
+ pos_avg_factor = reduce_mean(
+ bbox_pred.new_tensor(all_num_gt)).clamp_(min=1)
+ pos_loss = sum(pos_loss_list) / pos_avg_factor
+
+ neg_loss_list, = multi_apply(self.get_neg_loss_single, cls_scores,
+ objectnesses, gt_labels, ious_list,
+ inside_gt_bbox_mask_list)
+ neg_avg_factor = sum(item.data.sum()
+ for item in center_prior_weight_list)
+ neg_avg_factor = reduce_mean(neg_avg_factor).clamp_(min=1)
+ neg_loss = sum(neg_loss_list) / neg_avg_factor
+
+ center_loss = []
+ for i in range(len(img_metas)):
+
+ if inside_gt_bbox_mask_list[i].any():
+ center_loss.append(
+ len(gt_bboxes[i]) /
+ center_prior_weight_list[i].sum().clamp_(min=EPS))
+ # when width or height of gt_bbox is smaller than stride of p3
+ else:
+ center_loss.append(center_prior_weight_list[i].sum() * 0)
+
+ center_loss = torch.stack(center_loss).mean() * self.center_loss_weight
+
+ # avoid dead lock in DDP
+ if all_num_gt == 0:
+ pos_loss = bbox_preds[0].sum() * 0
+ dummy_center_prior_loss = self.center_prior.mean.sum(
+ ) * 0 + self.center_prior.sigma.sum() * 0
+ center_loss = objectnesses[0].sum() * 0 + dummy_center_prior_loss
+
+ loss = dict(
+ loss_pos=pos_loss, loss_neg=neg_loss, loss_center=center_loss)
+
+ return loss
+
+ def get_targets(self, points, gt_bboxes_list):
+ """Compute regression targets and each point inside or outside gt_bbox
+ in multiple images.
+
+ Args:
+ points (list[Tensor]): Points of all fpn level, each has shape
+ (num_points, 2).
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+
+ Returns:
+ tuple(list[Tensor]):
+
+ - inside_gt_bbox_mask_list (list[Tensor]): Each
+ Tensor is with bool type and shape of
+ (num_points, num_gt), each value
+ is used to mark whether this point falls
+ within a certain gt.
+ - concat_lvl_bbox_targets (list[Tensor]): BBox
+ targets of each level. Each tensor has shape
+ (num_points, num_gt, 4).
+ """
+
+ concat_points = torch.cat(points, dim=0)
+ # the number of points per img, per lvl
+ inside_gt_bbox_mask_list, bbox_targets_list = multi_apply(
+ self._get_target_single, gt_bboxes_list, points=concat_points)
+ return inside_gt_bbox_mask_list, bbox_targets_list
+
+ def _get_target_single(self, gt_bboxes, points):
+ """Compute regression targets and each point inside or outside gt_bbox
+ for a single image.
+
+ Args:
+ gt_bboxes (Tensor): gt_bbox of single image, has shape
+ (num_gt, 4).
+ points (Tensor): Points of all fpn level, has shape
+ (num_points, 2).
+
+ Returns:
+ tuple[Tensor]: Containing the following Tensors:
+
+ - inside_gt_bbox_mask (Tensor): Bool tensor with shape
+ (num_points, num_gt), each value is used to mark
+ whether this point falls within a certain gt.
+ - bbox_targets (Tensor): BBox targets of each points with
+ each gt_bboxes, has shape (num_points, num_gt, 4).
+ """
+ num_points = points.size(0)
+ num_gts = gt_bboxes.size(0)
+ gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
+ xs, ys = points[:, 0], points[:, 1]
+ xs = xs[:, None]
+ ys = ys[:, None]
+ left = xs - gt_bboxes[..., 0]
+ right = gt_bboxes[..., 2] - xs
+ top = ys - gt_bboxes[..., 1]
+ bottom = gt_bboxes[..., 3] - ys
+ bbox_targets = torch.stack((left, top, right, bottom), -1)
+ if num_gts:
+ inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
+ else:
+ inside_gt_bbox_mask = bbox_targets.new_zeros((num_points, num_gts),
+ dtype=torch.bool)
+
+ return inside_gt_bbox_mask, bbox_targets
+
+ def _get_points_single(self,
+ featmap_size,
+ stride,
+ dtype,
+ device,
+ flatten=False):
+ """Almost the same as the implementation in fcos, we remove half stride
+ offset to align with the original implementation.
+
+ This function will be deprecated soon.
+ """
+ warnings.warn(
+ '`_get_points_single` in `AutoAssignHead` will be '
+ 'deprecated soon, we support a multi level point generator now'
+ 'you can get points of a single level feature map '
+ 'with `self.prior_generator.single_level_grid_priors` ')
+ y, x = super(FCOSHead,
+ self)._get_points_single(featmap_size, stride, dtype,
+ device)
+ points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
+ dim=-1)
+ return points
diff --git a/mmdet/models/dense_heads/base_dense_head.py b/mmdet/models/dense_heads/base_dense_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c7abb7b9b83f034afe06482a659b39ac1d63139
--- /dev/null
+++ b/mmdet/models/dense_heads/base_dense_head.py
@@ -0,0 +1,526 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+import torch
+from mmcv.cnn.utils.weight_init import constant_init
+from mmcv.ops import batched_nms
+from mmcv.runner import BaseModule, force_fp32
+
+from mmdet.core.utils import filter_scores_and_topk, select_single_mlvl
+
+
+class BaseDenseHead(BaseModule, metaclass=ABCMeta):
+ """Base class for DenseHeads."""
+
+ def __init__(self, init_cfg=None):
+ super(BaseDenseHead, self).__init__(init_cfg)
+
+ def init_weights(self):
+ super(BaseDenseHead, self).init_weights()
+ # avoid init_cfg overwrite the initialization of `conv_offset`
+ for m in self.modules():
+ # DeformConv2dPack, ModulatedDeformConv2dPack
+ if hasattr(m, 'conv_offset'):
+ constant_init(m.conv_offset, 0)
+
+ @abstractmethod
+ def loss(self, **kwargs):
+ """Compute losses of the head."""
+ pass
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ score_factors=None,
+ img_metas=None,
+ cfg=None,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform network outputs of a batch into bbox results.
+
+ Note: When score_factors is not None, the cls_scores are
+ usually multiplied by it then obtain the real score used in NMS,
+ such as CenterNess in FCOS, IoU branch in ATSS.
+
+ Args:
+ cls_scores (list[Tensor]): Classification scores for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * 4, H, W).
+ score_factors (list[Tensor], Optional): Score factor for
+ all scale level, each is a 4D-tensor, has shape
+ (batch_size, num_priors * 1, H, W). Default None.
+ img_metas (list[dict], Optional): Image meta info. Default None.
+ cfg (mmcv.Config, Optional): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default None.
+ rescale (bool): If True, return boxes in original image space.
+ Default False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default True.
+
+ Returns:
+ list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where the first 4 columns
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the
+ 5-th column is a score between 0 and 1. The second item is a
+ (n,) tensor where each item is the predicted class label of
+ the corresponding box.
+ """
+ assert len(cls_scores) == len(bbox_preds)
+
+ if score_factors is None:
+ # e.g. Retina, FreeAnchor, Foveabox, etc.
+ with_score_factors = False
+ else:
+ # e.g. FCOS, PAA, ATSS, AutoAssign, etc.
+ with_score_factors = True
+ assert len(cls_scores) == len(score_factors)
+
+ num_levels = len(cls_scores)
+
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_priors = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=cls_scores[0].dtype,
+ device=cls_scores[0].device)
+
+ result_list = []
+
+ for img_id in range(len(img_metas)):
+ img_meta = img_metas[img_id]
+ cls_score_list = select_single_mlvl(cls_scores, img_id)
+ bbox_pred_list = select_single_mlvl(bbox_preds, img_id)
+ if with_score_factors:
+ score_factor_list = select_single_mlvl(score_factors, img_id)
+ else:
+ score_factor_list = [None for _ in range(num_levels)]
+
+ results = self._get_bboxes_single(cls_score_list, bbox_pred_list,
+ score_factor_list, mlvl_priors,
+ img_meta, cfg, rescale, with_nms,
+ **kwargs)
+ result_list.append(results)
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ score_factor_list,
+ mlvl_priors,
+ img_meta,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform outputs of a single image into bbox predictions.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_priors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has shape
+ (num_priors * 4, H, W).
+ score_factor_list (list[Tensor]): Score factor from all scale
+ levels of a single image, each item has shape
+ (num_priors * 1, H, W).
+ mlvl_priors (list[Tensor]): Each element in the list is
+ the priors of a single level in feature pyramid. In all
+ anchor-based methods, it has shape (num_priors, 4). In
+ all anchor-free methods, it has shape (num_priors, 2)
+ when `with_stride=True`, otherwise it still has shape
+ (num_priors, 4).
+ img_meta (dict): Image meta info.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+ if score_factor_list[0] is None:
+ # e.g. Retina, FreeAnchor, etc.
+ with_score_factors = False
+ else:
+ # e.g. FCOS, PAA, ATSS, etc.
+ with_score_factors = True
+
+ cfg = self.test_cfg if cfg is None else cfg
+ img_shape = img_meta['img_shape']
+ nms_pre = cfg.get('nms_pre', -1)
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_labels = []
+ if with_score_factors:
+ mlvl_score_factors = []
+ else:
+ mlvl_score_factors = None
+ for level_idx, (cls_score, bbox_pred, score_factor, priors) in \
+ enumerate(zip(cls_score_list, bbox_pred_list,
+ score_factor_list, mlvl_priors)):
+
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ if with_score_factors:
+ score_factor = score_factor.permute(1, 2,
+ 0).reshape(-1).sigmoid()
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ scores = cls_score.softmax(-1)[:, :-1]
+
+ # After https://github.com/open-mmlab/mmdetection/pull/6268/,
+ # this operation keeps fewer bboxes under the same `nms_pre`.
+ # There is no difference in performance for most models. If you
+ # find a slight drop in performance, you can set a larger
+ # `nms_pre` than before.
+ results = filter_scores_and_topk(
+ scores, cfg.score_thr, nms_pre,
+ dict(bbox_pred=bbox_pred, priors=priors))
+ scores, labels, keep_idxs, filtered_results = results
+
+ bbox_pred = filtered_results['bbox_pred']
+ priors = filtered_results['priors']
+
+ if with_score_factors:
+ score_factor = score_factor[keep_idxs]
+
+ bboxes = self.bbox_coder.decode(
+ priors, bbox_pred, max_shape=img_shape)
+
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_labels.append(labels)
+ if with_score_factors:
+ mlvl_score_factors.append(score_factor)
+
+ return self._bbox_post_process(mlvl_scores, mlvl_labels, mlvl_bboxes,
+ img_meta['scale_factor'], cfg, rescale,
+ with_nms, mlvl_score_factors, **kwargs)
+
+ def _bbox_post_process(self,
+ mlvl_scores,
+ mlvl_labels,
+ mlvl_bboxes,
+ scale_factor,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ mlvl_score_factors=None,
+ **kwargs):
+ """bbox post-processing method.
+
+ The boxes would be rescaled to the original image scale and do
+ the nms operation. Usually `with_nms` is False is used for aug test.
+
+ Args:
+ mlvl_scores (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_bboxes, ).
+ mlvl_labels (list[Tensor]): Box class labels from all scale
+ levels of a single image, each item has shape
+ (num_bboxes, ).
+ mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale
+ levels of a single image, each item has shape (num_bboxes, 4).
+ scale_factor (ndarray, optional): Scale factor of the image arange
+ as (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ mlvl_score_factors (list[Tensor], optional): Score factor from
+ all scale levels of a single image, each item has shape
+ (num_bboxes, ). Default: None.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+ assert len(mlvl_scores) == len(mlvl_bboxes) == len(mlvl_labels)
+
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ mlvl_labels = torch.cat(mlvl_labels)
+
+ if mlvl_score_factors is not None:
+ # TODO: Add sqrt operation in order to be consistent with
+ # the paper.
+ mlvl_score_factors = torch.cat(mlvl_score_factors)
+ mlvl_scores = mlvl_scores * mlvl_score_factors
+
+ if with_nms:
+ if mlvl_bboxes.numel() == 0:
+ det_bboxes = torch.cat([mlvl_bboxes, mlvl_scores[:, None]], -1)
+ return det_bboxes, mlvl_labels
+
+ det_bboxes, keep_idxs = batched_nms(mlvl_bboxes, mlvl_scores,
+ mlvl_labels, cfg.nms)
+ det_bboxes = det_bboxes[:cfg.max_per_img]
+ det_labels = mlvl_labels[keep_idxs][:cfg.max_per_img]
+ return det_bboxes, det_labels
+ else:
+ return mlvl_bboxes, mlvl_scores, mlvl_labels
+
+ def forward_train(self,
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ proposal_cfg=None,
+ **kwargs):
+ """
+ Args:
+ x (list[Tensor]): Features from FPN.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ proposal_cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used
+
+ Returns:
+ tuple:
+ losses: (dict[str, Tensor]): A dictionary of loss components.
+ proposal_list (list[Tensor]): Proposals of each image.
+ """
+ outs = self(x)
+ if gt_labels is None:
+ loss_inputs = outs + (gt_bboxes, img_metas)
+ else:
+ loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
+ losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+ if proposal_cfg is None:
+ return losses
+ else:
+ proposal_list = self.get_bboxes(
+ *outs, img_metas=img_metas, cfg=proposal_cfg)
+ return losses, proposal_list
+
+ def simple_test(self, feats, img_metas, rescale=False):
+ """Test function without test-time augmentation.
+
+ Args:
+ feats (tuple[torch.Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is ``bboxes`` with shape (n, 5),
+ where 5 represent (tl_x, tl_y, br_x, br_y, score).
+ The shape of the second tensor in the tuple is ``labels``
+ with shape (n, ).
+ """
+ return self.simple_test_bboxes(feats, img_metas, rescale=rescale)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def onnx_export(self,
+ cls_scores,
+ bbox_preds,
+ score_factors=None,
+ img_metas=None,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ with shape (N, num_points * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_points * 4, H, W).
+ score_factors (list[Tensor]): score_factors for each s
+ cale level with shape (N, num_points * 1, H, W).
+ Default: None.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc. Default: None.
+ with_nms (bool): Whether apply nms to the bboxes. Default: True.
+
+ Returns:
+ tuple[Tensor, Tensor] | list[tuple]: When `with_nms` is True,
+ it is tuple[Tensor, Tensor], first tensor bboxes with shape
+ [N, num_det, 5], 5 arrange as (x1, y1, x2, y2, score)
+ and second element is class labels of shape [N, num_det].
+ When `with_nms` is False, first tensor is bboxes with
+ shape [N, num_det, 4], second tensor is raw score has
+ shape [N, num_det, num_classes].
+ """
+ assert len(cls_scores) == len(bbox_preds)
+
+ num_levels = len(cls_scores)
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ mlvl_priors = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=bbox_preds[0].dtype,
+ device=bbox_preds[0].device)
+
+ mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
+ mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
+
+ assert len(
+ img_metas
+ ) == 1, 'Only support one input image while in exporting to ONNX'
+ img_shape = img_metas[0]['img_shape_for_onnx']
+
+ cfg = self.test_cfg
+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
+ device = cls_scores[0].device
+ batch_size = cls_scores[0].shape[0]
+ # convert to tensor to keep tracing
+ nms_pre_tensor = torch.tensor(
+ cfg.get('nms_pre', -1), device=device, dtype=torch.long)
+
+ # e.g. Retina, FreeAnchor, etc.
+ if score_factors is None:
+ with_score_factors = False
+ mlvl_score_factor = [None for _ in range(num_levels)]
+ else:
+ # e.g. FCOS, PAA, ATSS, etc.
+ with_score_factors = True
+ mlvl_score_factor = [
+ score_factors[i].detach() for i in range(num_levels)
+ ]
+ mlvl_score_factors = []
+
+ mlvl_batch_bboxes = []
+ mlvl_scores = []
+
+ for cls_score, bbox_pred, score_factors, priors in zip(
+ mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor,
+ mlvl_priors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+
+ scores = cls_score.permute(0, 2, 3,
+ 1).reshape(batch_size, -1,
+ self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = scores.sigmoid()
+ nms_pre_score = scores
+ else:
+ scores = scores.softmax(-1)
+ nms_pre_score = scores
+
+ if with_score_factors:
+ score_factors = score_factors.permute(0, 2, 3, 1).reshape(
+ batch_size, -1).sigmoid()
+ bbox_pred = bbox_pred.permute(0, 2, 3,
+ 1).reshape(batch_size, -1, 4)
+ priors = priors.expand(batch_size, -1, priors.size(-1))
+ # Get top-k predictions
+ from mmdet.core.export import get_k_for_topk
+ nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
+ if nms_pre > 0:
+
+ if with_score_factors:
+ nms_pre_score = (nms_pre_score * score_factors[..., None])
+ else:
+ nms_pre_score = nms_pre_score
+
+ # Get maximum scores for foreground classes.
+ if self.use_sigmoid_cls:
+ max_scores, _ = nms_pre_score.max(-1)
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ max_scores, _ = nms_pre_score[..., :-1].max(-1)
+ _, topk_inds = max_scores.topk(nms_pre)
+
+ batch_inds = torch.arange(
+ batch_size, device=bbox_pred.device).view(
+ -1, 1).expand_as(topk_inds).long()
+ # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
+ transformed_inds = bbox_pred.shape[1] * batch_inds + topk_inds
+ priors = priors.reshape(
+ -1, priors.size(-1))[transformed_inds, :].reshape(
+ batch_size, -1, priors.size(-1))
+ bbox_pred = bbox_pred.reshape(-1,
+ 4)[transformed_inds, :].reshape(
+ batch_size, -1, 4)
+ scores = scores.reshape(
+ -1, self.cls_out_channels)[transformed_inds, :].reshape(
+ batch_size, -1, self.cls_out_channels)
+ if with_score_factors:
+ score_factors = score_factors.reshape(
+ -1, 1)[transformed_inds].reshape(batch_size, -1)
+
+ bboxes = self.bbox_coder.decode(
+ priors, bbox_pred, max_shape=img_shape)
+
+ mlvl_batch_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ if with_score_factors:
+ mlvl_score_factors.append(score_factors)
+
+ batch_bboxes = torch.cat(mlvl_batch_bboxes, dim=1)
+ batch_scores = torch.cat(mlvl_scores, dim=1)
+ if with_score_factors:
+ batch_score_factors = torch.cat(mlvl_score_factors, dim=1)
+
+ # Replace multiclass_nms with ONNX::NonMaxSuppression in deployment
+
+ from mmdet.core.export import add_dummy_nms_for_onnx
+
+ if not self.use_sigmoid_cls:
+ batch_scores = batch_scores[..., :self.num_classes]
+
+ if with_score_factors:
+ batch_scores = batch_scores * (batch_score_factors.unsqueeze(2))
+
+ if with_nms:
+ max_output_boxes_per_class = cfg.nms.get(
+ 'max_output_boxes_per_class', 200)
+ iou_threshold = cfg.nms.get('iou_threshold', 0.5)
+ score_threshold = cfg.score_thr
+ nms_pre = cfg.get('deploy_nms_pre', -1)
+ return add_dummy_nms_for_onnx(batch_bboxes, batch_scores,
+ max_output_boxes_per_class,
+ iou_threshold, score_threshold,
+ nms_pre, cfg.max_per_img)
+ else:
+ return batch_bboxes, batch_scores
diff --git a/mmdet/models/dense_heads/base_mask_head.py b/mmdet/models/dense_heads/base_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eb94fb287e223888c0181f1debae0d84b306bf2
--- /dev/null
+++ b/mmdet/models/dense_heads/base_mask_head.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+from mmcv.runner import BaseModule
+
+
+class BaseMaskHead(BaseModule, metaclass=ABCMeta):
+ """Base class for mask heads used in One-Stage Instance Segmentation."""
+
+ def __init__(self, init_cfg):
+ super(BaseMaskHead, self).__init__(init_cfg)
+
+ @abstractmethod
+ def loss(self, **kwargs):
+ pass
+
+ @abstractmethod
+ def get_results(self, **kwargs):
+ """Get precessed :obj:`InstanceData` of multiple images."""
+ pass
+
+ def forward_train(self,
+ x,
+ gt_labels,
+ gt_masks,
+ img_metas,
+ gt_bboxes=None,
+ gt_bboxes_ignore=None,
+ positive_infos=None,
+ **kwargs):
+ """
+ Args:
+ x (list[Tensor] | tuple[Tensor]): Features from FPN.
+ Each has a shape (B, C, H, W).
+ gt_labels (list[Tensor]): Ground truth labels of all images.
+ each has a shape (num_gts,).
+ gt_masks (list[Tensor]) : Masks for each bbox, has a shape
+ (num_gts, h , w).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
+ each item has a shape (num_gts, 4).
+ gt_bboxes_ignore (list[Tensor], None): Ground truth bboxes to be
+ ignored, each item has a shape (num_ignored_gts, 4).
+ positive_infos (list[:obj:`InstanceData`], optional): Information
+ of positive samples. Used when the label assignment is
+ done outside the MaskHead, e.g., in BboxHead in
+ YOLACT or CondInst, etc. When the label assignment is done in
+ MaskHead, it would be None, like SOLO. All values
+ in it should have shape (num_positive_samples, *).
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ if positive_infos is None:
+ outs = self(x)
+ else:
+ outs = self(x, positive_infos)
+
+ assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \
+ 'even if only one item is returned'
+ loss = self.loss(
+ *outs,
+ gt_labels=gt_labels,
+ gt_masks=gt_masks,
+ img_metas=img_metas,
+ gt_bboxes=gt_bboxes,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ positive_infos=positive_infos,
+ **kwargs)
+ return loss
+
+ def simple_test(self,
+ feats,
+ img_metas,
+ rescale=False,
+ instances_list=None,
+ **kwargs):
+ """Test function without test-time augmentation.
+
+ Args:
+ feats (tuple[torch.Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+ instances_list (list[obj:`InstanceData`], optional): Detection
+ results of each image after the post process. Only exist
+ if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc.
+
+ Returns:
+ list[obj:`InstanceData`]: Instance segmentation \
+ results of each image after the post process. \
+ Each item usually contains following keys. \
+
+ - scores (Tensor): Classification scores, has a shape
+ (num_instance,)
+ - labels (Tensor): Has a shape (num_instances,).
+ - masks (Tensor): Processed mask results, has a
+ shape (num_instances, h, w).
+ """
+ if instances_list is None:
+ outs = self(feats)
+ else:
+ outs = self(feats, instances_list=instances_list)
+ mask_inputs = outs + (img_metas, )
+ results_list = self.get_results(
+ *mask_inputs,
+ rescale=rescale,
+ instances_list=instances_list,
+ **kwargs)
+ return results_list
+
+ def onnx_export(self, img, img_metas):
+ raise NotImplementedError(f'{self.__class__.__name__} does '
+ f'not support ONNX EXPORT')
diff --git a/mmdet/models/dense_heads/cascade_rpn_head.py b/mmdet/models/dense_heads/cascade_rpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..69347e00c436430b57413a81cb5cb49bb52f1841
--- /dev/null
+++ b/mmdet/models/dense_heads/cascade_rpn_head.py
@@ -0,0 +1,801 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from __future__ import division
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv import ConfigDict
+from mmcv.ops import DeformConv2d, batched_nms
+from mmcv.runner import BaseModule, ModuleList
+
+from mmdet.core import (RegionAssigner, build_assigner, build_sampler,
+ images_to_levels, multi_apply)
+from mmdet.core.utils import select_single_mlvl
+from ..builder import HEADS, build_head
+from .base_dense_head import BaseDenseHead
+from .rpn_head import RPNHead
+
+
+class AdaptiveConv(BaseModule):
+ """AdaptiveConv used to adapt the sampling location with the anchors.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the conv kernel. Default: 3
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 1
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 3
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If set True, adds a learnable bias to the
+ output. Default: False.
+ type (str, optional): Type of adaptive conv, can be either 'offset'
+ (arbitrary anchors) or 'dilation' (uniform anchor).
+ Default: 'dilation'.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ dilation=3,
+ groups=1,
+ bias=False,
+ type='dilation',
+ init_cfg=dict(
+ type='Normal', std=0.01, override=dict(name='conv'))):
+ super(AdaptiveConv, self).__init__(init_cfg)
+ assert type in ['offset', 'dilation']
+ self.adapt_type = type
+
+ assert kernel_size == 3, 'Adaptive conv only supports kernels 3'
+ if self.adapt_type == 'offset':
+ assert stride == 1 and padding == 1 and groups == 1, \
+ 'Adaptive conv offset mode only supports padding: {1}, ' \
+ f'stride: {1}, groups: {1}'
+ self.conv = DeformConv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding=padding,
+ stride=stride,
+ groups=groups,
+ bias=bias)
+ else:
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding=dilation,
+ dilation=dilation)
+
+ def forward(self, x, offset):
+ """Forward function."""
+ if self.adapt_type == 'offset':
+ N, _, H, W = x.shape
+ assert offset is not None
+ assert H * W == offset.shape[1]
+ # reshape [N, NA, 18] to (N, 18, H, W)
+ offset = offset.permute(0, 2, 1).reshape(N, -1, H, W)
+ offset = offset.contiguous()
+ x = self.conv(x, offset)
+ else:
+ assert offset is None
+ x = self.conv(x)
+ return x
+
+
+@HEADS.register_module()
+class StageCascadeRPNHead(RPNHead):
+ """Stage of CascadeRPNHead.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ anchor_generator (dict): anchor generator config.
+ adapt_cfg (dict): adaptation config.
+ bridged_feature (bool, optional): whether update rpn feature.
+ Default: False.
+ with_cls (bool, optional): whether use classification branch.
+ Default: True.
+ sampling (bool, optional): whether use sampling. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8],
+ ratios=[1.0],
+ strides=[4, 8, 16, 32, 64]),
+ adapt_cfg=dict(type='dilation', dilation=3),
+ bridged_feature=False,
+ with_cls=True,
+ sampling=True,
+ init_cfg=None,
+ **kwargs):
+ self.with_cls = with_cls
+ self.anchor_strides = anchor_generator['strides']
+ self.anchor_scales = anchor_generator['scales']
+ self.bridged_feature = bridged_feature
+ self.adapt_cfg = adapt_cfg
+ super(StageCascadeRPNHead, self).__init__(
+ in_channels,
+ anchor_generator=anchor_generator,
+ init_cfg=init_cfg,
+ **kwargs)
+
+ # override sampling and sampler
+ self.sampling = sampling
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ if init_cfg is None:
+ self.init_cfg = dict(
+ type='Normal', std=0.01, override=[dict(name='rpn_reg')])
+ if self.with_cls:
+ self.init_cfg['override'].append(dict(name='rpn_cls'))
+
+ def _init_layers(self):
+ """Init layers of a CascadeRPN stage."""
+ self.rpn_conv = AdaptiveConv(self.in_channels, self.feat_channels,
+ **self.adapt_cfg)
+ if self.with_cls:
+ self.rpn_cls = nn.Conv2d(self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ 1)
+ self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward_single(self, x, offset):
+ """Forward function of single scale."""
+ bridged_x = x
+ x = self.relu(self.rpn_conv(x, offset))
+ if self.bridged_feature:
+ bridged_x = x # update feature
+ cls_score = self.rpn_cls(x) if self.with_cls else None
+ bbox_pred = self.rpn_reg(x)
+ return bridged_x, cls_score, bbox_pred
+
+ def forward(self, feats, offset_list=None):
+ """Forward function."""
+ if offset_list is None:
+ offset_list = [None for _ in range(len(feats))]
+ return multi_apply(self.forward_single, feats, offset_list)
+
+ def _region_targets_single(self,
+ anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ featmap_sizes,
+ label_channels=1):
+ """Get anchor targets based on region for single level."""
+ assign_result = self.assigner.assign(
+ anchors,
+ valid_flags,
+ gt_bboxes,
+ img_meta,
+ featmap_sizes,
+ self.anchor_scales[0],
+ self.anchor_strides,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ gt_labels=None,
+ allowed_border=self.train_cfg.allowed_border)
+ flat_anchors = torch.cat(anchors)
+ sampling_result = self.sampler.sample(assign_result, flat_anchors,
+ gt_bboxes)
+
+ num_anchors = flat_anchors.shape[0]
+ bbox_targets = torch.zeros_like(flat_anchors)
+ bbox_weights = torch.zeros_like(flat_anchors)
+ labels = flat_anchors.new_zeros(num_anchors, dtype=torch.long)
+ label_weights = flat_anchors.new_zeros(num_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ labels[pos_inds] = 1
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds)
+
+ def region_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ featmap_sizes,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """See :func:`StageCascadeRPNHead.get_targets`."""
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
+ pos_inds_list, neg_inds_list) = multi_apply(
+ self._region_targets_single,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ featmap_sizes=featmap_sizes,
+ label_channels=label_channels)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors)
+ return (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg)
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ featmap_sizes,
+ gt_bboxes_ignore=None,
+ label_channels=1):
+ """Compute regression and classification targets for anchors.
+
+ Args:
+ anchor_list (list[list]): Multi level anchors of each image.
+ valid_flag_list (list[list]): Multi level valid flags of each
+ image.
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ featmap_sizes (list[Tensor]): Feature mapsize each level
+ gt_bboxes_ignore (list[Tensor]): Ignore bboxes of each images
+ label_channels (int): Channel of label.
+
+ Returns:
+ cls_reg_targets (tuple)
+ """
+ if isinstance(self.assigner, RegionAssigner):
+ cls_reg_targets = self.region_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ featmap_sizes,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ label_channels=label_channels)
+ else:
+ cls_reg_targets = super(StageCascadeRPNHead, self).get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ label_channels=label_channels)
+ return cls_reg_targets
+
+ def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes):
+ """ Get offset for deformable conv based on anchor shape
+ NOTE: currently support deformable kernel_size=3 and dilation=1
+
+ Args:
+ anchor_list (list[list[tensor])): [NI, NLVL, NA, 4] list of
+ multi-level anchors
+ anchor_strides (list[int]): anchor stride of each level
+
+ Returns:
+ offset_list (list[tensor]): [NLVL, NA, 2, 18]: offset of DeformConv
+ kernel.
+ """
+
+ def _shape_offset(anchors, stride, ks=3, dilation=1):
+ # currently support kernel_size=3 and dilation=1
+ assert ks == 3 and dilation == 1
+ pad = (ks - 1) // 2
+ idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device)
+ yy, xx = torch.meshgrid(idx, idx) # return order matters
+ xx = xx.reshape(-1)
+ yy = yy.reshape(-1)
+ w = (anchors[:, 2] - anchors[:, 0]) / stride
+ h = (anchors[:, 3] - anchors[:, 1]) / stride
+ w = w / (ks - 1) - dilation
+ h = h / (ks - 1) - dilation
+ offset_x = w[:, None] * xx # (NA, ks**2)
+ offset_y = h[:, None] * yy # (NA, ks**2)
+ return offset_x, offset_y
+
+ def _ctr_offset(anchors, stride, featmap_size):
+ feat_h, feat_w = featmap_size
+ assert len(anchors) == feat_h * feat_w
+
+ x = (anchors[:, 0] + anchors[:, 2]) * 0.5
+ y = (anchors[:, 1] + anchors[:, 3]) * 0.5
+ # compute centers on feature map
+ x = x / stride
+ y = y / stride
+ # compute predefine centers
+ xx = torch.arange(0, feat_w, device=anchors.device)
+ yy = torch.arange(0, feat_h, device=anchors.device)
+ yy, xx = torch.meshgrid(yy, xx)
+ xx = xx.reshape(-1).type_as(x)
+ yy = yy.reshape(-1).type_as(y)
+
+ offset_x = x - xx # (NA, )
+ offset_y = y - yy # (NA, )
+ return offset_x, offset_y
+
+ num_imgs = len(anchor_list)
+ num_lvls = len(anchor_list[0])
+ dtype = anchor_list[0][0].dtype
+ device = anchor_list[0][0].device
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+
+ offset_list = []
+ for i in range(num_imgs):
+ mlvl_offset = []
+ for lvl in range(num_lvls):
+ c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl],
+ anchor_strides[lvl],
+ featmap_sizes[lvl])
+ s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl],
+ anchor_strides[lvl])
+
+ # offset = ctr_offset + shape_offset
+ offset_x = s_offset_x + c_offset_x[:, None]
+ offset_y = s_offset_y + c_offset_y[:, None]
+
+ # offset order (y0, x0, y1, x2, .., y8, x8, y9, x9)
+ offset = torch.stack([offset_y, offset_x], dim=-1)
+ offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2]
+ mlvl_offset.append(offset)
+ offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2]
+ offset_list = images_to_levels(offset_list, num_level_anchors)
+ return offset_list
+
+ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
+ bbox_targets, bbox_weights, num_total_samples):
+ """Loss function on single scale."""
+ # classification loss
+ if self.with_cls:
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+ # regression loss
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ bbox_weights = bbox_weights.reshape(-1, 4)
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ anchors = anchors.reshape(-1, 4)
+ bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
+ loss_reg = self.loss_bbox(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+ if self.with_cls:
+ return loss_cls, loss_reg
+ return None, loss_reg
+
+ def loss(self,
+ anchor_list,
+ valid_flag_list,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ anchor_list (list[list]): Multi level anchors of each image.
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ featmap_sizes,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ if self.sampling:
+ num_total_samples = num_total_pos + num_total_neg
+ else:
+ # 200 is hard-coded average factor,
+ # which follows guided anchoring.
+ num_total_samples = sum([label.numel()
+ for label in labels_list]) / 200.0
+
+ # change per image, per level anchor_list to per_level, per_image
+ mlvl_anchor_list = list(zip(*anchor_list))
+ # concat mlvl_anchor_list
+ mlvl_anchor_list = [
+ torch.cat(anchors, dim=0) for anchors in mlvl_anchor_list
+ ]
+
+ losses = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+ if self.with_cls:
+ return dict(loss_rpn_cls=losses[0], loss_rpn_reg=losses[1])
+ return dict(loss_rpn_reg=losses[1])
+
+ def get_bboxes(self,
+ anchor_list,
+ cls_scores,
+ bbox_preds,
+ img_metas,
+ cfg,
+ rescale=False):
+ """Get proposal predict.
+
+ Args:
+ anchor_list (list[list]): Multi level anchors of each image.
+ cls_scores (list[Tensor]): Classification scores for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * 4, H, W).
+ img_metas (list[dict], Optional): Image meta info. Default None.
+ cfg (mmcv.Config, Optional): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+
+ Returns:
+ Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the
+ 5-th column is a score between 0 and 1.
+ """
+ assert len(cls_scores) == len(bbox_preds)
+
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = select_single_mlvl(cls_scores, img_id)
+ bbox_pred_list = select_single_mlvl(bbox_preds, img_id)
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
+ anchor_list[img_id], img_shape,
+ scale_factor, cfg, rescale)
+ result_list.append(proposals)
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchors,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ """Transform outputs of a single image into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_anchors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has
+ shape (num_anchors * 4, H, W).
+ mlvl_anchors (list[Tensor]): Box reference from all scale
+ levels of a single image, each item has shape
+ (num_total_anchors, 4).
+ img_shape (tuple[int]): Shape of the input image,
+ (height, width, 3).
+ scale_factor (ndarray): Scale factor of the image arange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default False.
+
+ Returns:
+ Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the
+ 5-th column is a score between 0 and 1.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ cfg = copy.deepcopy(cfg)
+ # bboxes from different level should be independent during NMS,
+ # level_ids are used as labels for batched NMS to separate them
+ level_ids = []
+ mlvl_scores = []
+ mlvl_bbox_preds = []
+ mlvl_valid_anchors = []
+ nms_pre = cfg.get('nms_pre', -1)
+ for idx in range(len(cls_scores)):
+ rpn_cls_score = cls_scores[idx]
+ rpn_bbox_pred = bbox_preds[idx]
+ assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
+ rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
+ if self.use_sigmoid_cls:
+ rpn_cls_score = rpn_cls_score.reshape(-1)
+ scores = rpn_cls_score.sigmoid()
+ else:
+ rpn_cls_score = rpn_cls_score.reshape(-1, 2)
+ # We set FG labels to [0, num_class-1] and BG label to
+ # num_class in RPN head since mmdet v2.5, which is unified to
+ # be consistent with other head since mmdet v2.0. In mmdet v2.0
+ # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
+ scores = rpn_cls_score.softmax(dim=1)[:, 0]
+ rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ anchors = mlvl_anchors[idx]
+
+ if 0 < nms_pre < scores.shape[0]:
+ # sort is faster than topk
+ # _, topk_inds = scores.topk(cfg.nms_pre)
+ ranked_scores, rank_inds = scores.sort(descending=True)
+ topk_inds = rank_inds[:nms_pre]
+ scores = ranked_scores[:nms_pre]
+ rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
+ anchors = anchors[topk_inds, :]
+ mlvl_scores.append(scores)
+ mlvl_bbox_preds.append(rpn_bbox_pred)
+ mlvl_valid_anchors.append(anchors)
+ level_ids.append(
+ scores.new_full((scores.size(0), ), idx, dtype=torch.long))
+
+ scores = torch.cat(mlvl_scores)
+ anchors = torch.cat(mlvl_valid_anchors)
+ rpn_bbox_pred = torch.cat(mlvl_bbox_preds)
+ proposals = self.bbox_coder.decode(
+ anchors, rpn_bbox_pred, max_shape=img_shape)
+ ids = torch.cat(level_ids)
+
+ if cfg.min_bbox_size >= 0:
+ w = proposals[:, 2] - proposals[:, 0]
+ h = proposals[:, 3] - proposals[:, 1]
+ valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
+ if not valid_mask.all():
+ proposals = proposals[valid_mask]
+ scores = scores[valid_mask]
+ ids = ids[valid_mask]
+
+ # deprecate arguments warning
+ if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
+ warnings.warn(
+ 'In rpn_proposal or test_cfg, '
+ 'nms_thr has been moved to a dict named nms as '
+ 'iou_threshold, max_num has been renamed as max_per_img, '
+ 'name of original arguments and the way to specify '
+ 'iou_threshold of NMS will be deprecated.')
+ if 'nms' not in cfg:
+ cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
+ if 'max_num' in cfg:
+ if 'max_per_img' in cfg:
+ assert cfg.max_num == cfg.max_per_img, f'You ' \
+ f'set max_num and ' \
+ f'max_per_img at the same time, but get {cfg.max_num} ' \
+ f'and {cfg.max_per_img} respectively' \
+ 'Please delete max_num which will be deprecated.'
+ else:
+ cfg.max_per_img = cfg.max_num
+ if 'nms_thr' in cfg:
+ assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set' \
+ f' iou_threshold in nms and ' \
+ f'nms_thr at the same time, but get' \
+ f' {cfg.nms.iou_threshold} and {cfg.nms_thr}' \
+ f' respectively. Please delete the nms_thr ' \
+ f'which will be deprecated.'
+
+ if proposals.numel() > 0:
+ dets, _ = batched_nms(proposals, scores, ids, cfg.nms)
+ else:
+ return proposals.new_zeros(0, 5)
+
+ return dets[:cfg.max_per_img]
+
+ def refine_bboxes(self, anchor_list, bbox_preds, img_metas):
+ """Refine bboxes through stages."""
+ num_levels = len(bbox_preds)
+ new_anchor_list = []
+ for img_id in range(len(img_metas)):
+ mlvl_anchors = []
+ for i in range(num_levels):
+ bbox_pred = bbox_preds[i][img_id].detach()
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ img_shape = img_metas[img_id]['img_shape']
+ bboxes = self.bbox_coder.decode(anchor_list[img_id][i],
+ bbox_pred, img_shape)
+ mlvl_anchors.append(bboxes)
+ new_anchor_list.append(mlvl_anchors)
+ return new_anchor_list
+
+
+@HEADS.register_module()
+class CascadeRPNHead(BaseDenseHead):
+ """The CascadeRPNHead will predict more accurate region proposals, which is
+ required for two-stage detectors (such as Fast/Faster R-CNN). CascadeRPN
+ consists of a sequence of RPNStage to progressively improve the accuracy of
+ the detected proposals.
+
+ More details can be found in ``https://arxiv.org/abs/1909.06720``.
+
+ Args:
+ num_stages (int): number of CascadeRPN stages.
+ stages (list[dict]): list of configs to build the stages.
+ train_cfg (list[dict]): list of configs at training time each stage.
+ test_cfg (dict): config at testing time.
+ """
+
+ def __init__(self, num_stages, stages, train_cfg, test_cfg, init_cfg=None):
+ super(CascadeRPNHead, self).__init__(init_cfg)
+ assert num_stages == len(stages)
+ self.num_stages = num_stages
+ # Be careful! Pretrained weights cannot be loaded when use
+ # nn.ModuleList
+ self.stages = ModuleList()
+ for i in range(len(stages)):
+ train_cfg_i = train_cfg[i] if train_cfg is not None else None
+ stages[i].update(train_cfg=train_cfg_i)
+ stages[i].update(test_cfg=test_cfg)
+ self.stages.append(build_head(stages[i]))
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ def loss(self):
+ """loss() is implemented in StageCascadeRPNHead."""
+ pass
+
+ def get_bboxes(self):
+ """get_bboxes() is implemented in StageCascadeRPNHead."""
+ pass
+
+ def forward_train(self,
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ proposal_cfg=None):
+ """Forward train function."""
+ assert gt_labels is None, 'RPN does not require gt_labels'
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in x]
+ device = x[0].device
+ anchor_list, valid_flag_list = self.stages[0].get_anchors(
+ featmap_sizes, img_metas, device=device)
+
+ losses = dict()
+
+ for i in range(self.num_stages):
+ stage = self.stages[i]
+
+ if stage.adapt_cfg['type'] == 'offset':
+ offset_list = stage.anchor_offset(anchor_list,
+ stage.anchor_strides,
+ featmap_sizes)
+ else:
+ offset_list = None
+ x, cls_score, bbox_pred = stage(x, offset_list)
+ rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score,
+ bbox_pred, gt_bboxes, img_metas)
+ stage_loss = stage.loss(*rpn_loss_inputs)
+ for name, value in stage_loss.items():
+ losses['s{}.{}'.format(i, name)] = value
+
+ # refine boxes
+ if i < self.num_stages - 1:
+ anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
+ img_metas)
+ if proposal_cfg is None:
+ return losses
+ else:
+ proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
+ bbox_pred, img_metas,
+ self.test_cfg)
+ return losses, proposal_list
+
+ def simple_test_rpn(self, x, img_metas):
+ """Simple forward test function."""
+ featmap_sizes = [featmap.size()[-2:] for featmap in x]
+ device = x[0].device
+ anchor_list, _ = self.stages[0].get_anchors(
+ featmap_sizes, img_metas, device=device)
+
+ for i in range(self.num_stages):
+ stage = self.stages[i]
+ if stage.adapt_cfg['type'] == 'offset':
+ offset_list = stage.anchor_offset(anchor_list,
+ stage.anchor_strides,
+ featmap_sizes)
+ else:
+ offset_list = None
+ x, cls_score, bbox_pred = stage(x, offset_list)
+ if i < self.num_stages - 1:
+ anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
+ img_metas)
+
+ proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
+ bbox_pred, img_metas,
+ self.test_cfg)
+ return proposal_list
+
+ def aug_test_rpn(self, x, img_metas):
+ """Augmented forward test function."""
+ raise NotImplementedError(
+ 'CascadeRPNHead does not support test-time augmentation')
diff --git a/mmdet/models/dense_heads/centernet_head.py b/mmdet/models/dense_heads/centernet_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9d5d2f01fb1cc2494739262517082f6a52b7297
--- /dev/null
+++ b/mmdet/models/dense_heads/centernet_head.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import bias_init_with_prob, normal_init
+from mmcv.ops import batched_nms
+from mmcv.runner import force_fp32
+
+from mmdet.core import multi_apply
+from mmdet.models import HEADS, build_loss
+from mmdet.models.utils import gaussian_radius, gen_gaussian_target
+from ..utils.gaussian_target import (get_local_maximum, get_topk_from_heatmap,
+ transpose_and_gather_feat)
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+
+
+@HEADS.register_module()
+class CenterNetHead(BaseDenseHead, BBoxTestMixin):
+ """Objects as Points Head. CenterHead use center_point to indicate object's
+ position. Paper link
+
+ Args:
+ in_channel (int): Number of channel in the input feature map.
+ feat_channel (int): Number of channel in the intermediate feature map.
+ num_classes (int): Number of categories excluding the background
+ category.
+ loss_center_heatmap (dict | None): Config of center heatmap loss.
+ Default: GaussianFocalLoss.
+ loss_wh (dict | None): Config of wh loss. Default: L1Loss.
+ loss_offset (dict | None): Config of offset loss. Default: L1Loss.
+ train_cfg (dict | None): Training config. Useless in CenterNet,
+ but we keep this variable for SingleStageDetector. Default: None.
+ test_cfg (dict | None): Testing config of CenterNet. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channel,
+ feat_channel,
+ num_classes,
+ loss_center_heatmap=dict(
+ type='GaussianFocalLoss', loss_weight=1.0),
+ loss_wh=dict(type='L1Loss', loss_weight=0.1),
+ loss_offset=dict(type='L1Loss', loss_weight=1.0),
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=None):
+ super(CenterNetHead, self).__init__(init_cfg)
+ self.num_classes = num_classes
+ self.heatmap_head = self._build_head(in_channel, feat_channel,
+ num_classes)
+ self.wh_head = self._build_head(in_channel, feat_channel, 2)
+ self.offset_head = self._build_head(in_channel, feat_channel, 2)
+
+ self.loss_center_heatmap = build_loss(loss_center_heatmap)
+ self.loss_wh = build_loss(loss_wh)
+ self.loss_offset = build_loss(loss_offset)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.fp16_enabled = False
+
+ def _build_head(self, in_channel, feat_channel, out_channel):
+ """Build head for each branch."""
+ layer = nn.Sequential(
+ nn.Conv2d(in_channel, feat_channel, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(feat_channel, out_channel, kernel_size=1))
+ return layer
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ bias_init = bias_init_with_prob(0.1)
+ self.heatmap_head[-1].bias.data.fill_(bias_init)
+ for head in [self.wh_head, self.offset_head]:
+ for m in head.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.001)
+
+ def forward(self, feats):
+ """Forward features. Notice CenterNet head does not use FPN.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ center_heatmap_preds (List[Tensor]): center predict heatmaps for
+ all levels, the channels number is num_classes.
+ wh_preds (List[Tensor]): wh predicts for all levels, the channels
+ number is 2.
+ offset_preds (List[Tensor]): offset predicts for all levels, the
+ channels number is 2.
+ """
+ return multi_apply(self.forward_single, feats)
+
+ def forward_single(self, feat):
+ """Forward feature of a single level.
+
+ Args:
+ feat (Tensor): Feature of a single level.
+
+ Returns:
+ center_heatmap_pred (Tensor): center predict heatmaps, the
+ channels number is num_classes.
+ wh_pred (Tensor): wh predicts, the channels number is 2.
+ offset_pred (Tensor): offset predicts, the channels number is 2.
+ """
+ center_heatmap_pred = self.heatmap_head(feat).sigmoid()
+ wh_pred = self.wh_head(feat)
+ offset_pred = self.offset_head(feat)
+ return center_heatmap_pred, wh_pred, offset_pred
+
+ @force_fp32(apply_to=('center_heatmap_preds', 'wh_preds', 'offset_preds'))
+ def loss(self,
+ center_heatmap_preds,
+ wh_preds,
+ offset_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ center_heatmap_preds (list[Tensor]): center predict heatmaps for
+ all levels with shape (B, num_classes, H, W).
+ wh_preds (list[Tensor]): wh predicts for all levels with
+ shape (B, 2, H, W).
+ offset_preds (list[Tensor]): offset predicts for all levels
+ with shape (B, 2, H, W).
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ dict[str, Tensor]: which has components below:
+ - loss_center_heatmap (Tensor): loss of center heatmap.
+ - loss_wh (Tensor): loss of hw heatmap
+ - loss_offset (Tensor): loss of offset heatmap.
+ """
+ assert len(center_heatmap_preds) == len(wh_preds) == len(
+ offset_preds) == 1
+ center_heatmap_pred = center_heatmap_preds[0]
+ wh_pred = wh_preds[0]
+ offset_pred = offset_preds[0]
+
+ target_result, avg_factor = self.get_targets(gt_bboxes, gt_labels,
+ center_heatmap_pred.shape,
+ img_metas[0]['pad_shape'])
+
+ center_heatmap_target = target_result['center_heatmap_target']
+ wh_target = target_result['wh_target']
+ offset_target = target_result['offset_target']
+ wh_offset_target_weight = target_result['wh_offset_target_weight']
+
+ # Since the channel of wh_target and offset_target is 2, the avg_factor
+ # of loss_center_heatmap is always 1/2 of loss_wh and loss_offset.
+ loss_center_heatmap = self.loss_center_heatmap(
+ center_heatmap_pred, center_heatmap_target, avg_factor=avg_factor)
+ loss_wh = self.loss_wh(
+ wh_pred,
+ wh_target,
+ wh_offset_target_weight,
+ avg_factor=avg_factor * 2)
+ loss_offset = self.loss_offset(
+ offset_pred,
+ offset_target,
+ wh_offset_target_weight,
+ avg_factor=avg_factor * 2)
+ return dict(
+ loss_center_heatmap=loss_center_heatmap,
+ loss_wh=loss_wh,
+ loss_offset=loss_offset)
+
+ def get_targets(self, gt_bboxes, gt_labels, feat_shape, img_shape):
+ """Compute regression and classification targets in multiple images.
+
+ Args:
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box.
+ feat_shape (list[int]): feature map shape with value [B, _, H, W]
+ img_shape (list[int]): image shape in [h, w] format.
+
+ Returns:
+ tuple[dict,float]: The float value is mean avg_factor, the dict has
+ components below:
+ - center_heatmap_target (Tensor): targets of center heatmap, \
+ shape (B, num_classes, H, W).
+ - wh_target (Tensor): targets of wh predict, shape \
+ (B, 2, H, W).
+ - offset_target (Tensor): targets of offset predict, shape \
+ (B, 2, H, W).
+ - wh_offset_target_weight (Tensor): weights of wh and offset \
+ predict, shape (B, 2, H, W).
+ """
+ img_h, img_w = img_shape[:2]
+ bs, _, feat_h, feat_w = feat_shape
+
+ width_ratio = float(feat_w / img_w)
+ height_ratio = float(feat_h / img_h)
+
+ center_heatmap_target = gt_bboxes[-1].new_zeros(
+ [bs, self.num_classes, feat_h, feat_w])
+ wh_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w])
+ offset_target = gt_bboxes[-1].new_zeros([bs, 2, feat_h, feat_w])
+ wh_offset_target_weight = gt_bboxes[-1].new_zeros(
+ [bs, 2, feat_h, feat_w])
+
+ for batch_id in range(bs):
+ gt_bbox = gt_bboxes[batch_id]
+ gt_label = gt_labels[batch_id]
+ center_x = (gt_bbox[:, [0]] + gt_bbox[:, [2]]) * width_ratio / 2
+ center_y = (gt_bbox[:, [1]] + gt_bbox[:, [3]]) * height_ratio / 2
+ gt_centers = torch.cat((center_x, center_y), dim=1)
+
+ for j, ct in enumerate(gt_centers):
+ ctx_int, cty_int = ct.int()
+ ctx, cty = ct
+ scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio
+ scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio
+ radius = gaussian_radius([scale_box_h, scale_box_w],
+ min_overlap=0.3)
+ radius = max(0, int(radius))
+ ind = gt_label[j]
+ gen_gaussian_target(center_heatmap_target[batch_id, ind],
+ [ctx_int, cty_int], radius)
+
+ wh_target[batch_id, 0, cty_int, ctx_int] = scale_box_w
+ wh_target[batch_id, 1, cty_int, ctx_int] = scale_box_h
+
+ offset_target[batch_id, 0, cty_int, ctx_int] = ctx - ctx_int
+ offset_target[batch_id, 1, cty_int, ctx_int] = cty - cty_int
+
+ wh_offset_target_weight[batch_id, :, cty_int, ctx_int] = 1
+
+ avg_factor = max(1, center_heatmap_target.eq(1).sum())
+ target_result = dict(
+ center_heatmap_target=center_heatmap_target,
+ wh_target=wh_target,
+ offset_target=offset_target,
+ wh_offset_target_weight=wh_offset_target_weight)
+ return target_result, avg_factor
+
+ @force_fp32(apply_to=('center_heatmap_preds', 'wh_preds', 'offset_preds'))
+ def get_bboxes(self,
+ center_heatmap_preds,
+ wh_preds,
+ offset_preds,
+ img_metas,
+ rescale=True,
+ with_nms=False):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ center_heatmap_preds (list[Tensor]): Center predict heatmaps for
+ all levels with shape (B, num_classes, H, W).
+ wh_preds (list[Tensor]): WH predicts for all levels with
+ shape (B, 2, H, W).
+ offset_preds (list[Tensor]): Offset predicts for all levels
+ with shape (B, 2, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: True.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: False.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+ """
+ assert len(center_heatmap_preds) == len(wh_preds) == len(
+ offset_preds) == 1
+ result_list = []
+ for img_id in range(len(img_metas)):
+ result_list.append(
+ self._get_bboxes_single(
+ center_heatmap_preds[0][img_id:img_id + 1, ...],
+ wh_preds[0][img_id:img_id + 1, ...],
+ offset_preds[0][img_id:img_id + 1, ...],
+ img_metas[img_id],
+ rescale=rescale,
+ with_nms=with_nms))
+ return result_list
+
+ def _get_bboxes_single(self,
+ center_heatmap_pred,
+ wh_pred,
+ offset_pred,
+ img_meta,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs of a single image into bbox results.
+
+ Args:
+ center_heatmap_pred (Tensor): Center heatmap for current level with
+ shape (1, num_classes, H, W).
+ wh_pred (Tensor): WH heatmap for current level with shape
+ (1, num_classes, H, W).
+ offset_pred (Tensor): Offset for current level with shape
+ (1, corner_offset_channels, H, W).
+ img_meta (dict): Meta information of current image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ tuple[Tensor, Tensor]: The first item is an (n, 5) tensor, where
+ 5 represent (tl_x, tl_y, br_x, br_y, score) and the score
+ between 0 and 1. The shape of the second tensor in the tuple
+ is (n,), and each element represents the class label of the
+ corresponding box.
+ """
+ batch_det_bboxes, batch_labels = self.decode_heatmap(
+ center_heatmap_pred,
+ wh_pred,
+ offset_pred,
+ img_meta['batch_input_shape'],
+ k=self.test_cfg.topk,
+ kernel=self.test_cfg.local_maximum_kernel)
+
+ det_bboxes = batch_det_bboxes.view([-1, 5])
+ det_labels = batch_labels.view(-1)
+
+ batch_border = det_bboxes.new_tensor(img_meta['border'])[...,
+ [2, 0, 2, 0]]
+ det_bboxes[..., :4] -= batch_border
+
+ if rescale:
+ det_bboxes[..., :4] /= det_bboxes.new_tensor(
+ img_meta['scale_factor'])
+
+ if with_nms:
+ det_bboxes, det_labels = self._bboxes_nms(det_bboxes, det_labels,
+ self.test_cfg)
+ return det_bboxes, det_labels
+
+ def decode_heatmap(self,
+ center_heatmap_pred,
+ wh_pred,
+ offset_pred,
+ img_shape,
+ k=100,
+ kernel=3):
+ """Transform outputs into detections raw bbox prediction.
+
+ Args:
+ center_heatmap_pred (Tensor): center predict heatmap,
+ shape (B, num_classes, H, W).
+ wh_pred (Tensor): wh predict, shape (B, 2, H, W).
+ offset_pred (Tensor): offset predict, shape (B, 2, H, W).
+ img_shape (list[int]): image shape in [h, w] format.
+ k (int): Get top k center keypoints from heatmap. Default 100.
+ kernel (int): Max pooling kernel for extract local maximum pixels.
+ Default 3.
+
+ Returns:
+ tuple[torch.Tensor]: Decoded output of CenterNetHead, containing
+ the following Tensors:
+
+ - batch_bboxes (Tensor): Coords of each box with shape (B, k, 5)
+ - batch_topk_labels (Tensor): Categories of each box with \
+ shape (B, k)
+ """
+ height, width = center_heatmap_pred.shape[2:]
+ inp_h, inp_w = img_shape
+
+ center_heatmap_pred = get_local_maximum(
+ center_heatmap_pred, kernel=kernel)
+
+ *batch_dets, topk_ys, topk_xs = get_topk_from_heatmap(
+ center_heatmap_pred, k=k)
+ batch_scores, batch_index, batch_topk_labels = batch_dets
+
+ wh = transpose_and_gather_feat(wh_pred, batch_index)
+ offset = transpose_and_gather_feat(offset_pred, batch_index)
+ topk_xs = topk_xs + offset[..., 0]
+ topk_ys = topk_ys + offset[..., 1]
+ tl_x = (topk_xs - wh[..., 0] / 2) * (inp_w / width)
+ tl_y = (topk_ys - wh[..., 1] / 2) * (inp_h / height)
+ br_x = (topk_xs + wh[..., 0] / 2) * (inp_w / width)
+ br_y = (topk_ys + wh[..., 1] / 2) * (inp_h / height)
+
+ batch_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], dim=2)
+ batch_bboxes = torch.cat((batch_bboxes, batch_scores[..., None]),
+ dim=-1)
+ return batch_bboxes, batch_topk_labels
+
+ def _bboxes_nms(self, bboxes, labels, cfg):
+ if labels.numel() > 0:
+ max_num = cfg.max_per_img
+ bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:,
+ -1].contiguous(),
+ labels, cfg.nms)
+ if max_num > 0:
+ bboxes = bboxes[:max_num]
+ labels = labels[keep][:max_num]
+
+ return bboxes, labels
diff --git a/mmdet/models/dense_heads/centripetal_head.py b/mmdet/models/dense_heads/centripetal_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc721b7623236c0b95679c762725574687ee56f
--- /dev/null
+++ b/mmdet/models/dense_heads/centripetal_head.py
@@ -0,0 +1,430 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule, normal_init
+from mmcv.ops import DeformConv2d
+from mmcv.runner import force_fp32
+
+from mmdet.core import multi_apply
+from ..builder import HEADS, build_loss
+from .corner_head import CornerHead
+
+
+@HEADS.register_module()
+class CentripetalHead(CornerHead):
+ """Head of CentripetalNet: Pursuing High-quality Keypoint Pairs for Object
+ Detection.
+
+ CentripetalHead inherits from :class:`CornerHead`. It removes the
+ embedding branch and adds guiding shift and centripetal shift branches.
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ num_feat_levels (int): Levels of feature from the previous module. 2
+ for HourglassNet-104 and 1 for HourglassNet-52. HourglassNet-104
+ outputs the final feature and intermediate supervision feature and
+ HourglassNet-52 only outputs the final feature. Default: 2.
+ corner_emb_channels (int): Channel of embedding vector. Default: 1.
+ train_cfg (dict | None): Training config. Useless in CornerHead,
+ but we keep this variable for SingleStageDetector. Default: None.
+ test_cfg (dict | None): Testing config of CornerHead. Default: None.
+ loss_heatmap (dict | None): Config of corner heatmap loss. Default:
+ GaussianFocalLoss.
+ loss_embedding (dict | None): Config of corner embedding loss. Default:
+ AssociativeEmbeddingLoss.
+ loss_offset (dict | None): Config of corner offset loss. Default:
+ SmoothL1Loss.
+ loss_guiding_shift (dict): Config of guiding shift loss. Default:
+ SmoothL1Loss.
+ loss_centripetal_shift (dict): Config of centripetal shift loss.
+ Default: SmoothL1Loss.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ *args,
+ centripetal_shift_channels=2,
+ guiding_shift_channels=2,
+ feat_adaption_conv_kernel=3,
+ loss_guiding_shift=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=0.05),
+ loss_centripetal_shift=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=1),
+ init_cfg=None,
+ **kwargs):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ assert centripetal_shift_channels == 2, (
+ 'CentripetalHead only support centripetal_shift_channels == 2')
+ self.centripetal_shift_channels = centripetal_shift_channels
+ assert guiding_shift_channels == 2, (
+ 'CentripetalHead only support guiding_shift_channels == 2')
+ self.guiding_shift_channels = guiding_shift_channels
+ self.feat_adaption_conv_kernel = feat_adaption_conv_kernel
+ super(CentripetalHead, self).__init__(
+ *args, init_cfg=init_cfg, **kwargs)
+ self.loss_guiding_shift = build_loss(loss_guiding_shift)
+ self.loss_centripetal_shift = build_loss(loss_centripetal_shift)
+
+ def _init_centripetal_layers(self):
+ """Initialize centripetal layers.
+
+ Including feature adaption deform convs (feat_adaption), deform offset
+ prediction convs (dcn_off), guiding shift (guiding_shift) and
+ centripetal shift ( centripetal_shift). Each branch has two parts:
+ prefix `tl_` for top-left and `br_` for bottom-right.
+ """
+ self.tl_feat_adaption = nn.ModuleList()
+ self.br_feat_adaption = nn.ModuleList()
+ self.tl_dcn_offset = nn.ModuleList()
+ self.br_dcn_offset = nn.ModuleList()
+ self.tl_guiding_shift = nn.ModuleList()
+ self.br_guiding_shift = nn.ModuleList()
+ self.tl_centripetal_shift = nn.ModuleList()
+ self.br_centripetal_shift = nn.ModuleList()
+
+ for _ in range(self.num_feat_levels):
+ self.tl_feat_adaption.append(
+ DeformConv2d(self.in_channels, self.in_channels,
+ self.feat_adaption_conv_kernel, 1, 1))
+ self.br_feat_adaption.append(
+ DeformConv2d(self.in_channels, self.in_channels,
+ self.feat_adaption_conv_kernel, 1, 1))
+
+ self.tl_guiding_shift.append(
+ self._make_layers(
+ out_channels=self.guiding_shift_channels,
+ in_channels=self.in_channels))
+ self.br_guiding_shift.append(
+ self._make_layers(
+ out_channels=self.guiding_shift_channels,
+ in_channels=self.in_channels))
+
+ self.tl_dcn_offset.append(
+ ConvModule(
+ self.guiding_shift_channels,
+ self.feat_adaption_conv_kernel**2 *
+ self.guiding_shift_channels,
+ 1,
+ bias=False,
+ act_cfg=None))
+ self.br_dcn_offset.append(
+ ConvModule(
+ self.guiding_shift_channels,
+ self.feat_adaption_conv_kernel**2 *
+ self.guiding_shift_channels,
+ 1,
+ bias=False,
+ act_cfg=None))
+
+ self.tl_centripetal_shift.append(
+ self._make_layers(
+ out_channels=self.centripetal_shift_channels,
+ in_channels=self.in_channels))
+ self.br_centripetal_shift.append(
+ self._make_layers(
+ out_channels=self.centripetal_shift_channels,
+ in_channels=self.in_channels))
+
+ def _init_layers(self):
+ """Initialize layers for CentripetalHead.
+
+ Including two parts: CornerHead layers and CentripetalHead layers
+ """
+ super()._init_layers() # using _init_layers in CornerHead
+ self._init_centripetal_layers()
+
+ def init_weights(self):
+ super(CentripetalHead, self).init_weights()
+ for i in range(self.num_feat_levels):
+ normal_init(self.tl_feat_adaption[i], std=0.01)
+ normal_init(self.br_feat_adaption[i], std=0.01)
+ normal_init(self.tl_dcn_offset[i].conv, std=0.1)
+ normal_init(self.br_dcn_offset[i].conv, std=0.1)
+ _ = [x.conv.reset_parameters() for x in self.tl_guiding_shift[i]]
+ _ = [x.conv.reset_parameters() for x in self.br_guiding_shift[i]]
+ _ = [
+ x.conv.reset_parameters() for x in self.tl_centripetal_shift[i]
+ ]
+ _ = [
+ x.conv.reset_parameters() for x in self.br_centripetal_shift[i]
+ ]
+
+ def forward_single(self, x, lvl_ind):
+ """Forward feature of a single level.
+
+ Args:
+ x (Tensor): Feature of a single level.
+ lvl_ind (int): Level index of current feature.
+
+ Returns:
+ tuple[Tensor]: A tuple of CentripetalHead's output for current
+ feature level. Containing the following Tensors:
+
+ - tl_heat (Tensor): Predicted top-left corner heatmap.
+ - br_heat (Tensor): Predicted bottom-right corner heatmap.
+ - tl_off (Tensor): Predicted top-left offset heatmap.
+ - br_off (Tensor): Predicted bottom-right offset heatmap.
+ - tl_guiding_shift (Tensor): Predicted top-left guiding shift
+ heatmap.
+ - br_guiding_shift (Tensor): Predicted bottom-right guiding
+ shift heatmap.
+ - tl_centripetal_shift (Tensor): Predicted top-left centripetal
+ shift heatmap.
+ - br_centripetal_shift (Tensor): Predicted bottom-right
+ centripetal shift heatmap.
+ """
+ tl_heat, br_heat, _, _, tl_off, br_off, tl_pool, br_pool = super(
+ ).forward_single(
+ x, lvl_ind, return_pool=True)
+
+ tl_guiding_shift = self.tl_guiding_shift[lvl_ind](tl_pool)
+ br_guiding_shift = self.br_guiding_shift[lvl_ind](br_pool)
+
+ tl_dcn_offset = self.tl_dcn_offset[lvl_ind](tl_guiding_shift.detach())
+ br_dcn_offset = self.br_dcn_offset[lvl_ind](br_guiding_shift.detach())
+
+ tl_feat_adaption = self.tl_feat_adaption[lvl_ind](tl_pool,
+ tl_dcn_offset)
+ br_feat_adaption = self.br_feat_adaption[lvl_ind](br_pool,
+ br_dcn_offset)
+
+ tl_centripetal_shift = self.tl_centripetal_shift[lvl_ind](
+ tl_feat_adaption)
+ br_centripetal_shift = self.br_centripetal_shift[lvl_ind](
+ br_feat_adaption)
+
+ result_list = [
+ tl_heat, br_heat, tl_off, br_off, tl_guiding_shift,
+ br_guiding_shift, tl_centripetal_shift, br_centripetal_shift
+ ]
+ return result_list
+
+ @force_fp32()
+ def loss(self,
+ tl_heats,
+ br_heats,
+ tl_offs,
+ br_offs,
+ tl_guiding_shifts,
+ br_guiding_shifts,
+ tl_centripetal_shifts,
+ br_centripetal_shifts,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ tl_heats (list[Tensor]): Top-left corner heatmaps for each level
+ with shape (N, num_classes, H, W).
+ br_heats (list[Tensor]): Bottom-right corner heatmaps for each
+ level with shape (N, num_classes, H, W).
+ tl_offs (list[Tensor]): Top-left corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ br_offs (list[Tensor]): Bottom-right corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
+ level with shape (N, guiding_shift_channels, H, W).
+ br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
+ each level with shape (N, guiding_shift_channels, H, W).
+ tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
+ for each level with shape (N, centripetal_shift_channels, H,
+ W).
+ br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
+ shifts for each level with shape (N,
+ centripetal_shift_channels, H, W).
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [left, top, right, bottom] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components. Containing the
+ following losses:
+
+ - det_loss (list[Tensor]): Corner keypoint losses of all
+ feature levels.
+ - off_loss (list[Tensor]): Corner offset losses of all feature
+ levels.
+ - guiding_loss (list[Tensor]): Guiding shift losses of all
+ feature levels.
+ - centripetal_loss (list[Tensor]): Centripetal shift losses of
+ all feature levels.
+ """
+ targets = self.get_targets(
+ gt_bboxes,
+ gt_labels,
+ tl_heats[-1].shape,
+ img_metas[0]['pad_shape'],
+ with_corner_emb=self.with_corner_emb,
+ with_guiding_shift=True,
+ with_centripetal_shift=True)
+ mlvl_targets = [targets for _ in range(self.num_feat_levels)]
+ [det_losses, off_losses, guiding_losses, centripetal_losses
+ ] = multi_apply(self.loss_single, tl_heats, br_heats, tl_offs,
+ br_offs, tl_guiding_shifts, br_guiding_shifts,
+ tl_centripetal_shifts, br_centripetal_shifts,
+ mlvl_targets)
+ loss_dict = dict(
+ det_loss=det_losses,
+ off_loss=off_losses,
+ guiding_loss=guiding_losses,
+ centripetal_loss=centripetal_losses)
+ return loss_dict
+
+ def loss_single(self, tl_hmp, br_hmp, tl_off, br_off, tl_guiding_shift,
+ br_guiding_shift, tl_centripetal_shift,
+ br_centripetal_shift, targets):
+ """Compute losses for single level.
+
+ Args:
+ tl_hmp (Tensor): Top-left corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ br_hmp (Tensor): Bottom-right corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ tl_off (Tensor): Top-left corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ br_off (Tensor): Bottom-right corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ tl_guiding_shift (Tensor): Top-left guiding shift for current level
+ with shape (N, guiding_shift_channels, H, W).
+ br_guiding_shift (Tensor): Bottom-right guiding shift for current
+ level with shape (N, guiding_shift_channels, H, W).
+ tl_centripetal_shift (Tensor): Top-left centripetal shift for
+ current level with shape (N, centripetal_shift_channels, H, W).
+ br_centripetal_shift (Tensor): Bottom-right centripetal shift for
+ current level with shape (N, centripetal_shift_channels, H, W).
+ targets (dict): Corner target generated by `get_targets`.
+
+ Returns:
+ tuple[torch.Tensor]: Losses of the head's different branches
+ containing the following losses:
+
+ - det_loss (Tensor): Corner keypoint loss.
+ - off_loss (Tensor): Corner offset loss.
+ - guiding_loss (Tensor): Guiding shift loss.
+ - centripetal_loss (Tensor): Centripetal shift loss.
+ """
+ targets['corner_embedding'] = None
+
+ det_loss, _, _, off_loss = super().loss_single(tl_hmp, br_hmp, None,
+ None, tl_off, br_off,
+ targets)
+
+ gt_tl_guiding_shift = targets['topleft_guiding_shift']
+ gt_br_guiding_shift = targets['bottomright_guiding_shift']
+ gt_tl_centripetal_shift = targets['topleft_centripetal_shift']
+ gt_br_centripetal_shift = targets['bottomright_centripetal_shift']
+
+ gt_tl_heatmap = targets['topleft_heatmap']
+ gt_br_heatmap = targets['bottomright_heatmap']
+ # We only compute the offset loss at the real corner position.
+ # The value of real corner would be 1 in heatmap ground truth.
+ # The mask is computed in class agnostic mode and its shape is
+ # batch * 1 * width * height.
+ tl_mask = gt_tl_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
+ gt_tl_heatmap)
+ br_mask = gt_br_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
+ gt_br_heatmap)
+
+ # Guiding shift loss
+ tl_guiding_loss = self.loss_guiding_shift(
+ tl_guiding_shift,
+ gt_tl_guiding_shift,
+ tl_mask,
+ avg_factor=tl_mask.sum())
+ br_guiding_loss = self.loss_guiding_shift(
+ br_guiding_shift,
+ gt_br_guiding_shift,
+ br_mask,
+ avg_factor=br_mask.sum())
+ guiding_loss = (tl_guiding_loss + br_guiding_loss) / 2.0
+ # Centripetal shift loss
+ tl_centripetal_loss = self.loss_centripetal_shift(
+ tl_centripetal_shift,
+ gt_tl_centripetal_shift,
+ tl_mask,
+ avg_factor=tl_mask.sum())
+ br_centripetal_loss = self.loss_centripetal_shift(
+ br_centripetal_shift,
+ gt_br_centripetal_shift,
+ br_mask,
+ avg_factor=br_mask.sum())
+ centripetal_loss = (tl_centripetal_loss + br_centripetal_loss) / 2.0
+
+ return det_loss, off_loss, guiding_loss, centripetal_loss
+
+ @force_fp32()
+ def get_bboxes(self,
+ tl_heats,
+ br_heats,
+ tl_offs,
+ br_offs,
+ tl_guiding_shifts,
+ br_guiding_shifts,
+ tl_centripetal_shifts,
+ br_centripetal_shifts,
+ img_metas,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ tl_heats (list[Tensor]): Top-left corner heatmaps for each level
+ with shape (N, num_classes, H, W).
+ br_heats (list[Tensor]): Bottom-right corner heatmaps for each
+ level with shape (N, num_classes, H, W).
+ tl_offs (list[Tensor]): Top-left corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ br_offs (list[Tensor]): Bottom-right corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
+ level with shape (N, guiding_shift_channels, H, W). Useless in
+ this function, we keep this arg because it's the raw output
+ from CentripetalHead.
+ br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
+ each level with shape (N, guiding_shift_channels, H, W).
+ Useless in this function, we keep this arg because it's the
+ raw output from CentripetalHead.
+ tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
+ for each level with shape (N, centripetal_shift_channels, H,
+ W).
+ br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
+ shifts for each level with shape (N,
+ centripetal_shift_channels, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ """
+ assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(img_metas)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ result_list.append(
+ self._get_bboxes_single(
+ tl_heats[-1][img_id:img_id + 1, :],
+ br_heats[-1][img_id:img_id + 1, :],
+ tl_offs[-1][img_id:img_id + 1, :],
+ br_offs[-1][img_id:img_id + 1, :],
+ img_metas[img_id],
+ tl_emb=None,
+ br_emb=None,
+ tl_centripetal_shift=tl_centripetal_shifts[-1][
+ img_id:img_id + 1, :],
+ br_centripetal_shift=br_centripetal_shifts[-1][
+ img_id:img_id + 1, :],
+ rescale=rescale,
+ with_nms=with_nms))
+
+ return result_list
diff --git a/mmdet/models/dense_heads/corner_head.py b/mmdet/models/dense_heads/corner_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6a2866f94ab25922bc47db0ef0df530f93f6f79
--- /dev/null
+++ b/mmdet/models/dense_heads/corner_head.py
@@ -0,0 +1,1086 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from logging import warning
+from math import ceil, log
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, bias_init_with_prob
+from mmcv.ops import CornerPool, batched_nms
+from mmcv.runner import BaseModule, force_fp32
+
+from mmdet.core import multi_apply
+from ..builder import HEADS, build_loss
+from ..utils import gaussian_radius, gen_gaussian_target
+from ..utils.gaussian_target import (gather_feat, get_local_maximum,
+ get_topk_from_heatmap,
+ transpose_and_gather_feat)
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+
+
+class BiCornerPool(BaseModule):
+ """Bidirectional Corner Pooling Module (TopLeft, BottomRight, etc.)
+
+ Args:
+ in_channels (int): Input channels of module.
+ out_channels (int): Output channels of module.
+ feat_channels (int): Feature channels of module.
+ directions (list[str]): Directions of two CornerPools.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ directions,
+ feat_channels=128,
+ out_channels=128,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ init_cfg=None):
+ super(BiCornerPool, self).__init__(init_cfg)
+ self.direction1_conv = ConvModule(
+ in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg)
+ self.direction2_conv = ConvModule(
+ in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg)
+
+ self.aftpool_conv = ConvModule(
+ feat_channels,
+ out_channels,
+ 3,
+ padding=1,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ self.conv1 = ConvModule(
+ in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
+ self.conv2 = ConvModule(
+ in_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg)
+
+ self.direction1_pool = CornerPool(directions[0])
+ self.direction2_pool = CornerPool(directions[1])
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward features from the upstream network.
+
+ Args:
+ x (tensor): Input feature of BiCornerPool.
+
+ Returns:
+ conv2 (tensor): Output feature of BiCornerPool.
+ """
+ direction1_conv = self.direction1_conv(x)
+ direction2_conv = self.direction2_conv(x)
+ direction1_feat = self.direction1_pool(direction1_conv)
+ direction2_feat = self.direction2_pool(direction2_conv)
+ aftpool_conv = self.aftpool_conv(direction1_feat + direction2_feat)
+ conv1 = self.conv1(x)
+ relu = self.relu(aftpool_conv + conv1)
+ conv2 = self.conv2(relu)
+ return conv2
+
+
+@HEADS.register_module()
+class CornerHead(BaseDenseHead, BBoxTestMixin):
+ """Head of CornerNet: Detecting Objects as Paired Keypoints.
+
+ Code is modified from the `official github repo
+ `_ .
+
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ num_feat_levels (int): Levels of feature from the previous module. 2
+ for HourglassNet-104 and 1 for HourglassNet-52. Because
+ HourglassNet-104 outputs the final feature and intermediate
+ supervision feature and HourglassNet-52 only outputs the final
+ feature. Default: 2.
+ corner_emb_channels (int): Channel of embedding vector. Default: 1.
+ train_cfg (dict | None): Training config. Useless in CornerHead,
+ but we keep this variable for SingleStageDetector. Default: None.
+ test_cfg (dict | None): Testing config of CornerHead. Default: None.
+ loss_heatmap (dict | None): Config of corner heatmap loss. Default:
+ GaussianFocalLoss.
+ loss_embedding (dict | None): Config of corner embedding loss. Default:
+ AssociativeEmbeddingLoss.
+ loss_offset (dict | None): Config of corner offset loss. Default:
+ SmoothL1Loss.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ num_feat_levels=2,
+ corner_emb_channels=1,
+ train_cfg=None,
+ test_cfg=None,
+ loss_heatmap=dict(
+ type='GaussianFocalLoss',
+ alpha=2.0,
+ gamma=4.0,
+ loss_weight=1),
+ loss_embedding=dict(
+ type='AssociativeEmbeddingLoss',
+ pull_weight=0.25,
+ push_weight=0.25),
+ loss_offset=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=1),
+ init_cfg=None):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ super(CornerHead, self).__init__(init_cfg)
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.corner_emb_channels = corner_emb_channels
+ self.with_corner_emb = self.corner_emb_channels > 0
+ self.corner_offset_channels = 2
+ self.num_feat_levels = num_feat_levels
+ self.loss_heatmap = build_loss(
+ loss_heatmap) if loss_heatmap is not None else None
+ self.loss_embedding = build_loss(
+ loss_embedding) if loss_embedding is not None else None
+ self.loss_offset = build_loss(
+ loss_offset) if loss_offset is not None else None
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ self.fp16_enabled = False
+ self._init_layers()
+
+ def _make_layers(self, out_channels, in_channels=256, feat_channels=256):
+ """Initialize conv sequential for CornerHead."""
+ return nn.Sequential(
+ ConvModule(in_channels, feat_channels, 3, padding=1),
+ ConvModule(
+ feat_channels, out_channels, 1, norm_cfg=None, act_cfg=None))
+
+ def _init_corner_kpt_layers(self):
+ """Initialize corner keypoint layers.
+
+ Including corner heatmap branch and corner offset branch. Each branch
+ has two parts: prefix `tl_` for top-left and `br_` for bottom-right.
+ """
+ self.tl_pool, self.br_pool = nn.ModuleList(), nn.ModuleList()
+ self.tl_heat, self.br_heat = nn.ModuleList(), nn.ModuleList()
+ self.tl_off, self.br_off = nn.ModuleList(), nn.ModuleList()
+
+ for _ in range(self.num_feat_levels):
+ self.tl_pool.append(
+ BiCornerPool(
+ self.in_channels, ['top', 'left'],
+ out_channels=self.in_channels))
+ self.br_pool.append(
+ BiCornerPool(
+ self.in_channels, ['bottom', 'right'],
+ out_channels=self.in_channels))
+
+ self.tl_heat.append(
+ self._make_layers(
+ out_channels=self.num_classes,
+ in_channels=self.in_channels))
+ self.br_heat.append(
+ self._make_layers(
+ out_channels=self.num_classes,
+ in_channels=self.in_channels))
+
+ self.tl_off.append(
+ self._make_layers(
+ out_channels=self.corner_offset_channels,
+ in_channels=self.in_channels))
+ self.br_off.append(
+ self._make_layers(
+ out_channels=self.corner_offset_channels,
+ in_channels=self.in_channels))
+
+ def _init_corner_emb_layers(self):
+ """Initialize corner embedding layers.
+
+ Only include corner embedding branch with two parts: prefix `tl_` for
+ top-left and `br_` for bottom-right.
+ """
+ self.tl_emb, self.br_emb = nn.ModuleList(), nn.ModuleList()
+
+ for _ in range(self.num_feat_levels):
+ self.tl_emb.append(
+ self._make_layers(
+ out_channels=self.corner_emb_channels,
+ in_channels=self.in_channels))
+ self.br_emb.append(
+ self._make_layers(
+ out_channels=self.corner_emb_channels,
+ in_channels=self.in_channels))
+
+ def _init_layers(self):
+ """Initialize layers for CornerHead.
+
+ Including two parts: corner keypoint layers and corner embedding layers
+ """
+ self._init_corner_kpt_layers()
+ if self.with_corner_emb:
+ self._init_corner_emb_layers()
+
+ def init_weights(self):
+ super(CornerHead, self).init_weights()
+ bias_init = bias_init_with_prob(0.1)
+ for i in range(self.num_feat_levels):
+ # The initialization of parameters are different between
+ # nn.Conv2d and ConvModule. Our experiments show that
+ # using the original initialization of nn.Conv2d increases
+ # the final mAP by about 0.2%
+ self.tl_heat[i][-1].conv.reset_parameters()
+ self.tl_heat[i][-1].conv.bias.data.fill_(bias_init)
+ self.br_heat[i][-1].conv.reset_parameters()
+ self.br_heat[i][-1].conv.bias.data.fill_(bias_init)
+ self.tl_off[i][-1].conv.reset_parameters()
+ self.br_off[i][-1].conv.reset_parameters()
+ if self.with_corner_emb:
+ self.tl_emb[i][-1].conv.reset_parameters()
+ self.br_emb[i][-1].conv.reset_parameters()
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of corner heatmaps, offset heatmaps and
+ embedding heatmaps.
+ - tl_heats (list[Tensor]): Top-left corner heatmaps for all
+ levels, each is a 4D-tensor, the channels number is
+ num_classes.
+ - br_heats (list[Tensor]): Bottom-right corner heatmaps for all
+ levels, each is a 4D-tensor, the channels number is
+ num_classes.
+ - tl_embs (list[Tensor] | list[None]): Top-left embedding
+ heatmaps for all levels, each is a 4D-tensor or None.
+ If not None, the channels number is corner_emb_channels.
+ - br_embs (list[Tensor] | list[None]): Bottom-right embedding
+ heatmaps for all levels, each is a 4D-tensor or None.
+ If not None, the channels number is corner_emb_channels.
+ - tl_offs (list[Tensor]): Top-left offset heatmaps for all
+ levels, each is a 4D-tensor. The channels number is
+ corner_offset_channels.
+ - br_offs (list[Tensor]): Bottom-right offset heatmaps for all
+ levels, each is a 4D-tensor. The channels number is
+ corner_offset_channels.
+ """
+ lvl_ind = list(range(self.num_feat_levels))
+ return multi_apply(self.forward_single, feats, lvl_ind)
+
+ def forward_single(self, x, lvl_ind, return_pool=False):
+ """Forward feature of a single level.
+
+ Args:
+ x (Tensor): Feature of a single level.
+ lvl_ind (int): Level index of current feature.
+ return_pool (bool): Return corner pool feature or not.
+
+ Returns:
+ tuple[Tensor]: A tuple of CornerHead's output for current feature
+ level. Containing the following Tensors:
+
+ - tl_heat (Tensor): Predicted top-left corner heatmap.
+ - br_heat (Tensor): Predicted bottom-right corner heatmap.
+ - tl_emb (Tensor | None): Predicted top-left embedding heatmap.
+ None for `self.with_corner_emb == False`.
+ - br_emb (Tensor | None): Predicted bottom-right embedding
+ heatmap. None for `self.with_corner_emb == False`.
+ - tl_off (Tensor): Predicted top-left offset heatmap.
+ - br_off (Tensor): Predicted bottom-right offset heatmap.
+ - tl_pool (Tensor): Top-left corner pool feature. Not must
+ have.
+ - br_pool (Tensor): Bottom-right corner pool feature. Not must
+ have.
+ """
+ tl_pool = self.tl_pool[lvl_ind](x)
+ tl_heat = self.tl_heat[lvl_ind](tl_pool)
+ br_pool = self.br_pool[lvl_ind](x)
+ br_heat = self.br_heat[lvl_ind](br_pool)
+
+ tl_emb, br_emb = None, None
+ if self.with_corner_emb:
+ tl_emb = self.tl_emb[lvl_ind](tl_pool)
+ br_emb = self.br_emb[lvl_ind](br_pool)
+
+ tl_off = self.tl_off[lvl_ind](tl_pool)
+ br_off = self.br_off[lvl_ind](br_pool)
+
+ result_list = [tl_heat, br_heat, tl_emb, br_emb, tl_off, br_off]
+ if return_pool:
+ result_list.append(tl_pool)
+ result_list.append(br_pool)
+
+ return result_list
+
+ def get_targets(self,
+ gt_bboxes,
+ gt_labels,
+ feat_shape,
+ img_shape,
+ with_corner_emb=False,
+ with_guiding_shift=False,
+ with_centripetal_shift=False):
+ """Generate corner targets.
+
+ Including corner heatmap, corner offset.
+
+ Optional: corner embedding, corner guiding shift, centripetal shift.
+
+ For CornerNet, we generate corner heatmap, corner offset and corner
+ embedding from this function.
+
+ For CentripetalNet, we generate corner heatmap, corner offset, guiding
+ shift and centripetal shift from this function.
+
+ Args:
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image, each
+ has shape (num_gt, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each box, each has
+ shape (num_gt,).
+ feat_shape (list[int]): Shape of output feature,
+ [batch, channel, height, width].
+ img_shape (list[int]): Shape of input image,
+ [height, width, channel].
+ with_corner_emb (bool): Generate corner embedding target or not.
+ Default: False.
+ with_guiding_shift (bool): Generate guiding shift target or not.
+ Default: False.
+ with_centripetal_shift (bool): Generate centripetal shift target or
+ not. Default: False.
+
+ Returns:
+ dict: Ground truth of corner heatmap, corner offset, corner
+ embedding, guiding shift and centripetal shift. Containing the
+ following keys:
+
+ - topleft_heatmap (Tensor): Ground truth top-left corner
+ heatmap.
+ - bottomright_heatmap (Tensor): Ground truth bottom-right
+ corner heatmap.
+ - topleft_offset (Tensor): Ground truth top-left corner offset.
+ - bottomright_offset (Tensor): Ground truth bottom-right corner
+ offset.
+ - corner_embedding (list[list[list[int]]]): Ground truth corner
+ embedding. Not must have.
+ - topleft_guiding_shift (Tensor): Ground truth top-left corner
+ guiding shift. Not must have.
+ - bottomright_guiding_shift (Tensor): Ground truth bottom-right
+ corner guiding shift. Not must have.
+ - topleft_centripetal_shift (Tensor): Ground truth top-left
+ corner centripetal shift. Not must have.
+ - bottomright_centripetal_shift (Tensor): Ground truth
+ bottom-right corner centripetal shift. Not must have.
+ """
+ batch_size, _, height, width = feat_shape
+ img_h, img_w = img_shape[:2]
+
+ width_ratio = float(width / img_w)
+ height_ratio = float(height / img_h)
+
+ gt_tl_heatmap = gt_bboxes[-1].new_zeros(
+ [batch_size, self.num_classes, height, width])
+ gt_br_heatmap = gt_bboxes[-1].new_zeros(
+ [batch_size, self.num_classes, height, width])
+ gt_tl_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width])
+ gt_br_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width])
+
+ if with_corner_emb:
+ match = []
+
+ # Guiding shift is a kind of offset, from center to corner
+ if with_guiding_shift:
+ gt_tl_guiding_shift = gt_bboxes[-1].new_zeros(
+ [batch_size, 2, height, width])
+ gt_br_guiding_shift = gt_bboxes[-1].new_zeros(
+ [batch_size, 2, height, width])
+ # Centripetal shift is also a kind of offset, from center to corner
+ # and normalized by log.
+ if with_centripetal_shift:
+ gt_tl_centripetal_shift = gt_bboxes[-1].new_zeros(
+ [batch_size, 2, height, width])
+ gt_br_centripetal_shift = gt_bboxes[-1].new_zeros(
+ [batch_size, 2, height, width])
+
+ for batch_id in range(batch_size):
+ # Ground truth of corner embedding per image is a list of coord set
+ corner_match = []
+ for box_id in range(len(gt_labels[batch_id])):
+ left, top, right, bottom = gt_bboxes[batch_id][box_id]
+ center_x = (left + right) / 2.0
+ center_y = (top + bottom) / 2.0
+ label = gt_labels[batch_id][box_id]
+
+ # Use coords in the feature level to generate ground truth
+ scale_left = left * width_ratio
+ scale_right = right * width_ratio
+ scale_top = top * height_ratio
+ scale_bottom = bottom * height_ratio
+ scale_center_x = center_x * width_ratio
+ scale_center_y = center_y * height_ratio
+
+ # Int coords on feature map/ground truth tensor
+ left_idx = int(min(scale_left, width - 1))
+ right_idx = int(min(scale_right, width - 1))
+ top_idx = int(min(scale_top, height - 1))
+ bottom_idx = int(min(scale_bottom, height - 1))
+
+ # Generate gaussian heatmap
+ scale_box_width = ceil(scale_right - scale_left)
+ scale_box_height = ceil(scale_bottom - scale_top)
+ radius = gaussian_radius((scale_box_height, scale_box_width),
+ min_overlap=0.3)
+ radius = max(0, int(radius))
+ gt_tl_heatmap[batch_id, label] = gen_gaussian_target(
+ gt_tl_heatmap[batch_id, label], [left_idx, top_idx],
+ radius)
+ gt_br_heatmap[batch_id, label] = gen_gaussian_target(
+ gt_br_heatmap[batch_id, label], [right_idx, bottom_idx],
+ radius)
+
+ # Generate corner offset
+ left_offset = scale_left - left_idx
+ top_offset = scale_top - top_idx
+ right_offset = scale_right - right_idx
+ bottom_offset = scale_bottom - bottom_idx
+ gt_tl_offset[batch_id, 0, top_idx, left_idx] = left_offset
+ gt_tl_offset[batch_id, 1, top_idx, left_idx] = top_offset
+ gt_br_offset[batch_id, 0, bottom_idx, right_idx] = right_offset
+ gt_br_offset[batch_id, 1, bottom_idx,
+ right_idx] = bottom_offset
+
+ # Generate corner embedding
+ if with_corner_emb:
+ corner_match.append([[top_idx, left_idx],
+ [bottom_idx, right_idx]])
+ # Generate guiding shift
+ if with_guiding_shift:
+ gt_tl_guiding_shift[batch_id, 0, top_idx,
+ left_idx] = scale_center_x - left_idx
+ gt_tl_guiding_shift[batch_id, 1, top_idx,
+ left_idx] = scale_center_y - top_idx
+ gt_br_guiding_shift[batch_id, 0, bottom_idx,
+ right_idx] = right_idx - scale_center_x
+ gt_br_guiding_shift[
+ batch_id, 1, bottom_idx,
+ right_idx] = bottom_idx - scale_center_y
+ # Generate centripetal shift
+ if with_centripetal_shift:
+ gt_tl_centripetal_shift[batch_id, 0, top_idx,
+ left_idx] = log(scale_center_x -
+ scale_left)
+ gt_tl_centripetal_shift[batch_id, 1, top_idx,
+ left_idx] = log(scale_center_y -
+ scale_top)
+ gt_br_centripetal_shift[batch_id, 0, bottom_idx,
+ right_idx] = log(scale_right -
+ scale_center_x)
+ gt_br_centripetal_shift[batch_id, 1, bottom_idx,
+ right_idx] = log(scale_bottom -
+ scale_center_y)
+
+ if with_corner_emb:
+ match.append(corner_match)
+
+ target_result = dict(
+ topleft_heatmap=gt_tl_heatmap,
+ topleft_offset=gt_tl_offset,
+ bottomright_heatmap=gt_br_heatmap,
+ bottomright_offset=gt_br_offset)
+
+ if with_corner_emb:
+ target_result.update(corner_embedding=match)
+ if with_guiding_shift:
+ target_result.update(
+ topleft_guiding_shift=gt_tl_guiding_shift,
+ bottomright_guiding_shift=gt_br_guiding_shift)
+ if with_centripetal_shift:
+ target_result.update(
+ topleft_centripetal_shift=gt_tl_centripetal_shift,
+ bottomright_centripetal_shift=gt_br_centripetal_shift)
+
+ return target_result
+
+ @force_fp32()
+ def loss(self,
+ tl_heats,
+ br_heats,
+ tl_embs,
+ br_embs,
+ tl_offs,
+ br_offs,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ tl_heats (list[Tensor]): Top-left corner heatmaps for each level
+ with shape (N, num_classes, H, W).
+ br_heats (list[Tensor]): Bottom-right corner heatmaps for each
+ level with shape (N, num_classes, H, W).
+ tl_embs (list[Tensor]): Top-left corner embeddings for each level
+ with shape (N, corner_emb_channels, H, W).
+ br_embs (list[Tensor]): Bottom-right corner embeddings for each
+ level with shape (N, corner_emb_channels, H, W).
+ tl_offs (list[Tensor]): Top-left corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ br_offs (list[Tensor]): Bottom-right corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [left, top, right, bottom] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components. Containing the
+ following losses:
+
+ - det_loss (list[Tensor]): Corner keypoint losses of all
+ feature levels.
+ - pull_loss (list[Tensor]): Part one of AssociativeEmbedding
+ losses of all feature levels.
+ - push_loss (list[Tensor]): Part two of AssociativeEmbedding
+ losses of all feature levels.
+ - off_loss (list[Tensor]): Corner offset losses of all feature
+ levels.
+ """
+ targets = self.get_targets(
+ gt_bboxes,
+ gt_labels,
+ tl_heats[-1].shape,
+ img_metas[0]['pad_shape'],
+ with_corner_emb=self.with_corner_emb)
+ mlvl_targets = [targets for _ in range(self.num_feat_levels)]
+ det_losses, pull_losses, push_losses, off_losses = multi_apply(
+ self.loss_single, tl_heats, br_heats, tl_embs, br_embs, tl_offs,
+ br_offs, mlvl_targets)
+ loss_dict = dict(det_loss=det_losses, off_loss=off_losses)
+ if self.with_corner_emb:
+ loss_dict.update(pull_loss=pull_losses, push_loss=push_losses)
+ return loss_dict
+
+ def loss_single(self, tl_hmp, br_hmp, tl_emb, br_emb, tl_off, br_off,
+ targets):
+ """Compute losses for single level.
+
+ Args:
+ tl_hmp (Tensor): Top-left corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ br_hmp (Tensor): Bottom-right corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ tl_emb (Tensor): Top-left corner embedding for current level with
+ shape (N, corner_emb_channels, H, W).
+ br_emb (Tensor): Bottom-right corner embedding for current level
+ with shape (N, corner_emb_channels, H, W).
+ tl_off (Tensor): Top-left corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ br_off (Tensor): Bottom-right corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ targets (dict): Corner target generated by `get_targets`.
+
+ Returns:
+ tuple[torch.Tensor]: Losses of the head's different branches
+ containing the following losses:
+
+ - det_loss (Tensor): Corner keypoint loss.
+ - pull_loss (Tensor): Part one of AssociativeEmbedding loss.
+ - push_loss (Tensor): Part two of AssociativeEmbedding loss.
+ - off_loss (Tensor): Corner offset loss.
+ """
+ gt_tl_hmp = targets['topleft_heatmap']
+ gt_br_hmp = targets['bottomright_heatmap']
+ gt_tl_off = targets['topleft_offset']
+ gt_br_off = targets['bottomright_offset']
+ gt_embedding = targets['corner_embedding']
+
+ # Detection loss
+ tl_det_loss = self.loss_heatmap(
+ tl_hmp.sigmoid(),
+ gt_tl_hmp,
+ avg_factor=max(1,
+ gt_tl_hmp.eq(1).sum()))
+ br_det_loss = self.loss_heatmap(
+ br_hmp.sigmoid(),
+ gt_br_hmp,
+ avg_factor=max(1,
+ gt_br_hmp.eq(1).sum()))
+ det_loss = (tl_det_loss + br_det_loss) / 2.0
+
+ # AssociativeEmbedding loss
+ if self.with_corner_emb and self.loss_embedding is not None:
+ pull_loss, push_loss = self.loss_embedding(tl_emb, br_emb,
+ gt_embedding)
+ else:
+ pull_loss, push_loss = None, None
+
+ # Offset loss
+ # We only compute the offset loss at the real corner position.
+ # The value of real corner would be 1 in heatmap ground truth.
+ # The mask is computed in class agnostic mode and its shape is
+ # batch * 1 * width * height.
+ tl_off_mask = gt_tl_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
+ gt_tl_hmp)
+ br_off_mask = gt_br_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
+ gt_br_hmp)
+ tl_off_loss = self.loss_offset(
+ tl_off,
+ gt_tl_off,
+ tl_off_mask,
+ avg_factor=max(1, tl_off_mask.sum()))
+ br_off_loss = self.loss_offset(
+ br_off,
+ gt_br_off,
+ br_off_mask,
+ avg_factor=max(1, br_off_mask.sum()))
+
+ off_loss = (tl_off_loss + br_off_loss) / 2.0
+
+ return det_loss, pull_loss, push_loss, off_loss
+
+ @force_fp32()
+ def get_bboxes(self,
+ tl_heats,
+ br_heats,
+ tl_embs,
+ br_embs,
+ tl_offs,
+ br_offs,
+ img_metas,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ tl_heats (list[Tensor]): Top-left corner heatmaps for each level
+ with shape (N, num_classes, H, W).
+ br_heats (list[Tensor]): Bottom-right corner heatmaps for each
+ level with shape (N, num_classes, H, W).
+ tl_embs (list[Tensor]): Top-left corner embeddings for each level
+ with shape (N, corner_emb_channels, H, W).
+ br_embs (list[Tensor]): Bottom-right corner embeddings for each
+ level with shape (N, corner_emb_channels, H, W).
+ tl_offs (list[Tensor]): Top-left corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ br_offs (list[Tensor]): Bottom-right corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ """
+ assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(img_metas)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ result_list.append(
+ self._get_bboxes_single(
+ tl_heats[-1][img_id:img_id + 1, :],
+ br_heats[-1][img_id:img_id + 1, :],
+ tl_offs[-1][img_id:img_id + 1, :],
+ br_offs[-1][img_id:img_id + 1, :],
+ img_metas[img_id],
+ tl_emb=tl_embs[-1][img_id:img_id + 1, :],
+ br_emb=br_embs[-1][img_id:img_id + 1, :],
+ rescale=rescale,
+ with_nms=with_nms))
+
+ return result_list
+
+ def _get_bboxes_single(self,
+ tl_heat,
+ br_heat,
+ tl_off,
+ br_off,
+ img_meta,
+ tl_emb=None,
+ br_emb=None,
+ tl_centripetal_shift=None,
+ br_centripetal_shift=None,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a single batch item into bbox predictions.
+
+ Args:
+ tl_heat (Tensor): Top-left corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ br_heat (Tensor): Bottom-right corner heatmap for current level
+ with shape (N, num_classes, H, W).
+ tl_off (Tensor): Top-left corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ br_off (Tensor): Bottom-right corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ img_meta (dict): Meta information of current image, e.g.,
+ image size, scaling factor, etc.
+ tl_emb (Tensor): Top-left corner embedding for current level with
+ shape (N, corner_emb_channels, H, W).
+ br_emb (Tensor): Bottom-right corner embedding for current level
+ with shape (N, corner_emb_channels, H, W).
+ tl_centripetal_shift: Top-left corner's centripetal shift for
+ current level with shape (N, 2, H, W).
+ br_centripetal_shift: Bottom-right corner's centripetal shift for
+ current level with shape (N, 2, H, W).
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ """
+ if isinstance(img_meta, (list, tuple)):
+ img_meta = img_meta[0]
+
+ batch_bboxes, batch_scores, batch_clses = self.decode_heatmap(
+ tl_heat=tl_heat.sigmoid(),
+ br_heat=br_heat.sigmoid(),
+ tl_off=tl_off,
+ br_off=br_off,
+ tl_emb=tl_emb,
+ br_emb=br_emb,
+ tl_centripetal_shift=tl_centripetal_shift,
+ br_centripetal_shift=br_centripetal_shift,
+ img_meta=img_meta,
+ k=self.test_cfg.corner_topk,
+ kernel=self.test_cfg.local_maximum_kernel,
+ distance_threshold=self.test_cfg.distance_threshold)
+
+ if rescale:
+ batch_bboxes /= batch_bboxes.new_tensor(img_meta['scale_factor'])
+
+ bboxes = batch_bboxes.view([-1, 4])
+ scores = batch_scores.view(-1)
+ clses = batch_clses.view(-1)
+
+ detections = torch.cat([bboxes, scores.unsqueeze(-1)], -1)
+ keepinds = (detections[:, -1] > -0.1)
+ detections = detections[keepinds]
+ labels = clses[keepinds]
+
+ if with_nms:
+ detections, labels = self._bboxes_nms(detections, labels,
+ self.test_cfg)
+
+ return detections, labels
+
+ def _bboxes_nms(self, bboxes, labels, cfg):
+ if 'nms_cfg' in cfg:
+ warning.warn('nms_cfg in test_cfg will be deprecated. '
+ 'Please rename it as nms')
+ if 'nms' not in cfg:
+ cfg.nms = cfg.nms_cfg
+
+ if labels.numel() > 0:
+ max_num = cfg.max_per_img
+ bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:,
+ -1].contiguous(),
+ labels, cfg.nms)
+ if max_num > 0:
+ bboxes = bboxes[:max_num]
+ labels = labels[keep][:max_num]
+
+ return bboxes, labels
+
+ def decode_heatmap(self,
+ tl_heat,
+ br_heat,
+ tl_off,
+ br_off,
+ tl_emb=None,
+ br_emb=None,
+ tl_centripetal_shift=None,
+ br_centripetal_shift=None,
+ img_meta=None,
+ k=100,
+ kernel=3,
+ distance_threshold=0.5,
+ num_dets=1000):
+ """Transform outputs for a single batch item into raw bbox predictions.
+
+ Args:
+ tl_heat (Tensor): Top-left corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ br_heat (Tensor): Bottom-right corner heatmap for current level
+ with shape (N, num_classes, H, W).
+ tl_off (Tensor): Top-left corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ br_off (Tensor): Bottom-right corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ tl_emb (Tensor | None): Top-left corner embedding for current
+ level with shape (N, corner_emb_channels, H, W).
+ br_emb (Tensor | None): Bottom-right corner embedding for current
+ level with shape (N, corner_emb_channels, H, W).
+ tl_centripetal_shift (Tensor | None): Top-left centripetal shift
+ for current level with shape (N, 2, H, W).
+ br_centripetal_shift (Tensor | None): Bottom-right centripetal
+ shift for current level with shape (N, 2, H, W).
+ img_meta (dict): Meta information of current image, e.g.,
+ image size, scaling factor, etc.
+ k (int): Get top k corner keypoints from heatmap.
+ kernel (int): Max pooling kernel for extract local maximum pixels.
+ distance_threshold (float): Distance threshold. Top-left and
+ bottom-right corner keypoints with feature distance less than
+ the threshold will be regarded as keypoints from same object.
+ num_dets (int): Num of raw boxes before doing nms.
+
+ Returns:
+ tuple[torch.Tensor]: Decoded output of CornerHead, containing the
+ following Tensors:
+
+ - bboxes (Tensor): Coords of each box.
+ - scores (Tensor): Scores of each box.
+ - clses (Tensor): Categories of each box.
+ """
+ with_embedding = tl_emb is not None and br_emb is not None
+ with_centripetal_shift = (
+ tl_centripetal_shift is not None
+ and br_centripetal_shift is not None)
+ assert with_embedding + with_centripetal_shift == 1
+ batch, _, height, width = tl_heat.size()
+ if torch.onnx.is_in_onnx_export():
+ inp_h, inp_w = img_meta['pad_shape_for_onnx'][:2]
+ else:
+ inp_h, inp_w, _ = img_meta['pad_shape']
+
+ # perform nms on heatmaps
+ tl_heat = get_local_maximum(tl_heat, kernel=kernel)
+ br_heat = get_local_maximum(br_heat, kernel=kernel)
+
+ tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = get_topk_from_heatmap(
+ tl_heat, k=k)
+ br_scores, br_inds, br_clses, br_ys, br_xs = get_topk_from_heatmap(
+ br_heat, k=k)
+
+ # We use repeat instead of expand here because expand is a
+ # shallow-copy function. Thus it could cause unexpected testing result
+ # sometimes. Using expand will decrease about 10% mAP during testing
+ # compared to repeat.
+ tl_ys = tl_ys.view(batch, k, 1).repeat(1, 1, k)
+ tl_xs = tl_xs.view(batch, k, 1).repeat(1, 1, k)
+ br_ys = br_ys.view(batch, 1, k).repeat(1, k, 1)
+ br_xs = br_xs.view(batch, 1, k).repeat(1, k, 1)
+
+ tl_off = transpose_and_gather_feat(tl_off, tl_inds)
+ tl_off = tl_off.view(batch, k, 1, 2)
+ br_off = transpose_and_gather_feat(br_off, br_inds)
+ br_off = br_off.view(batch, 1, k, 2)
+
+ tl_xs = tl_xs + tl_off[..., 0]
+ tl_ys = tl_ys + tl_off[..., 1]
+ br_xs = br_xs + br_off[..., 0]
+ br_ys = br_ys + br_off[..., 1]
+
+ if with_centripetal_shift:
+ tl_centripetal_shift = transpose_and_gather_feat(
+ tl_centripetal_shift, tl_inds).view(batch, k, 1, 2).exp()
+ br_centripetal_shift = transpose_and_gather_feat(
+ br_centripetal_shift, br_inds).view(batch, 1, k, 2).exp()
+
+ tl_ctxs = tl_xs + tl_centripetal_shift[..., 0]
+ tl_ctys = tl_ys + tl_centripetal_shift[..., 1]
+ br_ctxs = br_xs - br_centripetal_shift[..., 0]
+ br_ctys = br_ys - br_centripetal_shift[..., 1]
+
+ # all possible boxes based on top k corners (ignoring class)
+ tl_xs *= (inp_w / width)
+ tl_ys *= (inp_h / height)
+ br_xs *= (inp_w / width)
+ br_ys *= (inp_h / height)
+
+ if with_centripetal_shift:
+ tl_ctxs *= (inp_w / width)
+ tl_ctys *= (inp_h / height)
+ br_ctxs *= (inp_w / width)
+ br_ctys *= (inp_h / height)
+
+ x_off, y_off = 0, 0 # no crop
+ if not torch.onnx.is_in_onnx_export():
+ # since `RandomCenterCropPad` is done on CPU with numpy and it's
+ # not dynamic traceable when exporting to ONNX, thus 'border'
+ # does not appears as key in 'img_meta'. As a tmp solution,
+ # we move this 'border' handle part to the postprocess after
+ # finished exporting to ONNX, which is handle in
+ # `mmdet/core/export/model_wrappers.py`. Though difference between
+ # pytorch and exported onnx model, it might be ignored since
+ # comparable performance is achieved between them (e.g. 40.4 vs
+ # 40.6 on COCO val2017, for CornerNet without test-time flip)
+ if 'border' in img_meta:
+ x_off = img_meta['border'][2]
+ y_off = img_meta['border'][0]
+
+ tl_xs -= x_off
+ tl_ys -= y_off
+ br_xs -= x_off
+ br_ys -= y_off
+
+ zeros = tl_xs.new_zeros(*tl_xs.size())
+ tl_xs = torch.where(tl_xs > 0.0, tl_xs, zeros)
+ tl_ys = torch.where(tl_ys > 0.0, tl_ys, zeros)
+ br_xs = torch.where(br_xs > 0.0, br_xs, zeros)
+ br_ys = torch.where(br_ys > 0.0, br_ys, zeros)
+
+ bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
+ area_bboxes = ((br_xs - tl_xs) * (br_ys - tl_ys)).abs()
+
+ if with_centripetal_shift:
+ tl_ctxs -= x_off
+ tl_ctys -= y_off
+ br_ctxs -= x_off
+ br_ctys -= y_off
+
+ tl_ctxs *= tl_ctxs.gt(0.0).type_as(tl_ctxs)
+ tl_ctys *= tl_ctys.gt(0.0).type_as(tl_ctys)
+ br_ctxs *= br_ctxs.gt(0.0).type_as(br_ctxs)
+ br_ctys *= br_ctys.gt(0.0).type_as(br_ctys)
+
+ ct_bboxes = torch.stack((tl_ctxs, tl_ctys, br_ctxs, br_ctys),
+ dim=3)
+ area_ct_bboxes = ((br_ctxs - tl_ctxs) * (br_ctys - tl_ctys)).abs()
+
+ rcentral = torch.zeros_like(ct_bboxes)
+ # magic nums from paper section 4.1
+ mu = torch.ones_like(area_bboxes) / 2.4
+ mu[area_bboxes > 3500] = 1 / 2.1 # large bbox have smaller mu
+
+ bboxes_center_x = (bboxes[..., 0] + bboxes[..., 2]) / 2
+ bboxes_center_y = (bboxes[..., 1] + bboxes[..., 3]) / 2
+ rcentral[..., 0] = bboxes_center_x - mu * (bboxes[..., 2] -
+ bboxes[..., 0]) / 2
+ rcentral[..., 1] = bboxes_center_y - mu * (bboxes[..., 3] -
+ bboxes[..., 1]) / 2
+ rcentral[..., 2] = bboxes_center_x + mu * (bboxes[..., 2] -
+ bboxes[..., 0]) / 2
+ rcentral[..., 3] = bboxes_center_y + mu * (bboxes[..., 3] -
+ bboxes[..., 1]) / 2
+ area_rcentral = ((rcentral[..., 2] - rcentral[..., 0]) *
+ (rcentral[..., 3] - rcentral[..., 1])).abs()
+ dists = area_ct_bboxes / area_rcentral
+
+ tl_ctx_inds = (ct_bboxes[..., 0] <= rcentral[..., 0]) | (
+ ct_bboxes[..., 0] >= rcentral[..., 2])
+ tl_cty_inds = (ct_bboxes[..., 1] <= rcentral[..., 1]) | (
+ ct_bboxes[..., 1] >= rcentral[..., 3])
+ br_ctx_inds = (ct_bboxes[..., 2] <= rcentral[..., 0]) | (
+ ct_bboxes[..., 2] >= rcentral[..., 2])
+ br_cty_inds = (ct_bboxes[..., 3] <= rcentral[..., 1]) | (
+ ct_bboxes[..., 3] >= rcentral[..., 3])
+
+ if with_embedding:
+ tl_emb = transpose_and_gather_feat(tl_emb, tl_inds)
+ tl_emb = tl_emb.view(batch, k, 1)
+ br_emb = transpose_and_gather_feat(br_emb, br_inds)
+ br_emb = br_emb.view(batch, 1, k)
+ dists = torch.abs(tl_emb - br_emb)
+
+ tl_scores = tl_scores.view(batch, k, 1).repeat(1, 1, k)
+ br_scores = br_scores.view(batch, 1, k).repeat(1, k, 1)
+
+ scores = (tl_scores + br_scores) / 2 # scores for all possible boxes
+
+ # tl and br should have same class
+ tl_clses = tl_clses.view(batch, k, 1).repeat(1, 1, k)
+ br_clses = br_clses.view(batch, 1, k).repeat(1, k, 1)
+ cls_inds = (tl_clses != br_clses)
+
+ # reject boxes based on distances
+ dist_inds = dists > distance_threshold
+
+ # reject boxes based on widths and heights
+ width_inds = (br_xs <= tl_xs)
+ height_inds = (br_ys <= tl_ys)
+
+ # No use `scores[cls_inds]`, instead we use `torch.where` here.
+ # Since only 1-D indices with type 'tensor(bool)' are supported
+ # when exporting to ONNX, any other bool indices with more dimensions
+ # (e.g. 2-D bool tensor) as input parameter in node is invalid
+ negative_scores = -1 * torch.ones_like(scores)
+ scores = torch.where(cls_inds, negative_scores, scores)
+ scores = torch.where(width_inds, negative_scores, scores)
+ scores = torch.where(height_inds, negative_scores, scores)
+ scores = torch.where(dist_inds, negative_scores, scores)
+
+ if with_centripetal_shift:
+ scores[tl_ctx_inds] = -1
+ scores[tl_cty_inds] = -1
+ scores[br_ctx_inds] = -1
+ scores[br_cty_inds] = -1
+
+ scores = scores.view(batch, -1)
+ scores, inds = torch.topk(scores, num_dets)
+ scores = scores.unsqueeze(2)
+
+ bboxes = bboxes.view(batch, -1, 4)
+ bboxes = gather_feat(bboxes, inds)
+
+ clses = tl_clses.contiguous().view(batch, -1, 1)
+ clses = gather_feat(clses, inds).float()
+
+ return bboxes, scores, clses
+
+ def onnx_export(self,
+ tl_heats,
+ br_heats,
+ tl_embs,
+ br_embs,
+ tl_offs,
+ br_offs,
+ img_metas,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ tl_heats (list[Tensor]): Top-left corner heatmaps for each level
+ with shape (N, num_classes, H, W).
+ br_heats (list[Tensor]): Bottom-right corner heatmaps for each
+ level with shape (N, num_classes, H, W).
+ tl_embs (list[Tensor]): Top-left corner embeddings for each level
+ with shape (N, corner_emb_channels, H, W).
+ br_embs (list[Tensor]): Bottom-right corner embeddings for each
+ level with shape (N, corner_emb_channels, H, W).
+ tl_offs (list[Tensor]): Top-left corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ br_offs (list[Tensor]): Bottom-right corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ tuple[Tensor, Tensor]: First tensor bboxes with shape
+ [N, num_det, 5], 5 arrange as (x1, y1, x2, y2, score)
+ and second element is class labels of shape [N, num_det].
+ """
+ assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(
+ img_metas) == 1
+ result_list = []
+ for img_id in range(len(img_metas)):
+ result_list.append(
+ self._get_bboxes_single(
+ tl_heats[-1][img_id:img_id + 1, :],
+ br_heats[-1][img_id:img_id + 1, :],
+ tl_offs[-1][img_id:img_id + 1, :],
+ br_offs[-1][img_id:img_id + 1, :],
+ img_metas[img_id],
+ tl_emb=tl_embs[-1][img_id:img_id + 1, :],
+ br_emb=br_embs[-1][img_id:img_id + 1, :],
+ rescale=rescale,
+ with_nms=with_nms))
+
+ detections, labels = result_list[0]
+ # batch_size 1 here, [1, num_det, 5], [1, num_det]
+ return detections.unsqueeze(0), labels.unsqueeze(0)
diff --git a/mmdet/models/dense_heads/ddod_head.py b/mmdet/models/dense_heads/ddod_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2ff223348753b1338cccfefefd370dba0f38672
--- /dev/null
+++ b/mmdet/models/dense_heads/ddod_head.py
@@ -0,0 +1,778 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, build_assigner, build_sampler,
+ images_to_levels, multi_apply, reduce_mean, unmap)
+from mmdet.core.bbox import bbox_overlaps
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+EPS = 1e-12
+
+
+@HEADS.register_module()
+class DDODHead(AnchorHead):
+ """DDOD head decomposes conjunctions lying in most current one-stage
+ detectors via label assignment disentanglement, spatial feature
+ disentanglement, and pyramid supervision disentanglement.
+
+ https://arxiv.org/abs/2107.02963
+
+ Args:
+ num_classes (int): Number of categories excluding the
+ background category.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): The number of stacked Conv. Default: 4.
+ conv_cfg (dict): Conv config of ddod head. Default: None.
+ use_dcn (bool): Use dcn, Same as ATSS when False. Default: True.
+ norm_cfg (dict): Normal config of ddod head. Default:
+ dict(type='GN', num_groups=32, requires_grad=True).
+ loss_iou (dict): Config of IoU loss. Default:
+ dict(type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0).
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ use_dcn=True,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ loss_iou=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.use_dcn = use_dcn
+ super(DDODHead, self).__init__(num_classes, in_channels, **kwargs)
+
+ self.sampling = False
+ if self.train_cfg:
+ self.cls_assigner = build_assigner(self.train_cfg.assigner)
+ self.reg_assigner = build_assigner(self.train_cfg.reg_assigner)
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.loss_iou = build_loss(loss_iou)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=dict(type='DCN', deform_groups=1)
+ if i == 0 and self.use_dcn else self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=dict(type='DCN', deform_groups=1)
+ if i == 0 and self.use_dcn else self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.atss_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.atss_reg = nn.Conv2d(
+ self.feat_channels, self.num_base_priors * 4, 3, padding=1)
+ self.atss_iou = nn.Conv2d(
+ self.feat_channels, self.num_base_priors * 1, 3, padding=1)
+ self.scales = nn.ModuleList(
+ [Scale(1.0) for _ in self.prior_generator.strides])
+
+ # we use the global list in loss
+ self.cls_num_pos_samples_per_level = [
+ 0. for _ in range(len(self.prior_generator.strides))
+ ]
+ self.reg_num_pos_samples_per_level = [
+ 0. for _ in range(len(self.prior_generator.strides))
+ ]
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.cls_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ normal_init(m.conv, std=0.01)
+ normal_init(self.atss_reg, std=0.01)
+ normal_init(self.atss_iou, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.atss_cls, std=0.01, bias=bias_cls)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ iou_preds (list[Tensor]): IoU scores for all scale levels,
+ each is a 4D-tensor, the channels number is
+ num_base_priors * 1.
+ """
+ return multi_apply(self.forward_single, feats, self.scales)
+
+ def forward_single(self, x, scale):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+
+ Returns:
+ tuple:
+ - cls_score (Tensor): Cls scores for a single scale level \
+ the channels number is num_base_priors * num_classes.
+ - bbox_pred (Tensor): Box energies / deltas for a single \
+ scale level, the channels number is num_base_priors * 4.
+ - iou_pred (Tensor): Iou for a single scale level, the \
+ channel number is (N, num_base_priors * 1, H, W).
+ """
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.atss_cls(cls_feat)
+ # we just follow atss, not apply exp in bbox_pred
+ bbox_pred = scale(self.atss_reg(reg_feat)).float()
+ iou_pred = self.atss_iou(reg_feat)
+ return cls_score, bbox_pred, iou_pred
+
+ def loss_cls_single(self, cls_score, labels, label_weights,
+ reweight_factor, num_total_samples):
+ """Compute cls loss of a single scale level.
+
+ Args:
+ cls_score (Tensor): Box scores for each scale level
+ Has shape (N, num_base_priors * num_classes, H, W).
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ reweight_factor (list[int]): Reweight factor for cls and reg
+ loss.
+ num_total_samples (int): Number of positive samples that is
+ reduced over all GPUs.
+
+ Returns:
+ tuple[Tensor]: A tuple of loss components.
+ """
+ cls_score = cls_score.permute(0, 2, 3, 1).reshape(
+ -1, self.cls_out_channels).contiguous()
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+ return reweight_factor * loss_cls,
+
+ def loss_reg_single(self, anchors, bbox_pred, iou_pred, labels,
+ label_weights, bbox_targets, bbox_weights,
+ reweight_factor, num_total_samples):
+ """Compute reg loss of a single scale level.
+
+ Args:
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ bbox_pred (Tensor): Box energies / deltas for each scale
+ level with shape (N, num_base_priors * 4, H, W).
+ iou_pred (Tensor): Iou for a single scale level, the
+ channel number is (N, num_base_priors * 1, H, W).
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor
+ weight shape (N, num_total_anchors, 4).
+ bbox_weights (Tensor): BBox weights of all anchors in the
+ image with shape (N, 4)
+ reweight_factor (list[int]): Reweight factor for cls and reg
+ loss.
+ num_total_samples (int): Number of positive samples that is
+ reduced over all GPUs.
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ anchors = anchors.reshape(-1, 4)
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ iou_pred = iou_pred.permute(0, 2, 3, 1).reshape(-1, )
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ bbox_weights = bbox_weights.reshape(-1, 4)
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+
+ iou_targets = label_weights.new_zeros(labels.shape)
+ iou_weights = label_weights.new_zeros(labels.shape)
+ iou_weights[(bbox_weights.sum(axis=1) > 0).nonzero(
+ as_tuple=False)] = 1.
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((labels >= 0)
+ &
+ (labels < bg_class_ind)).nonzero(as_tuple=False).squeeze(1)
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = bbox_targets[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_anchors = anchors[pos_inds]
+
+ pos_decode_bbox_pred = self.bbox_coder.decode(
+ pos_anchors, pos_bbox_pred)
+ pos_decode_bbox_targets = self.bbox_coder.decode(
+ pos_anchors, pos_bbox_targets)
+
+ # regression loss
+ loss_bbox = self.loss_bbox(
+ pos_decode_bbox_pred,
+ pos_decode_bbox_targets,
+ avg_factor=num_total_samples)
+
+ iou_targets[pos_inds] = bbox_overlaps(
+ pos_decode_bbox_pred.detach(),
+ pos_decode_bbox_targets,
+ is_aligned=True)
+ loss_iou = self.loss_iou(
+ iou_pred,
+ iou_targets,
+ iou_weights,
+ avg_factor=num_total_samples)
+ else:
+ loss_bbox = bbox_pred.sum() * 0
+ loss_iou = iou_pred.sum() * 0
+
+ return reweight_factor * loss_bbox, reweight_factor * loss_iou
+
+ def calc_reweight_factor(self, labels_list):
+ """Compute reweight_factor for regression and classification loss."""
+ # get pos samples for each level
+ bg_class_ind = self.num_classes
+ for ii, each_level_label in enumerate(labels_list):
+ pos_inds = ((each_level_label >= 0) &
+ (each_level_label < bg_class_ind)).nonzero(
+ as_tuple=False).squeeze(1)
+ self.cls_num_pos_samples_per_level[ii] += len(pos_inds)
+ # get reweight factor from 1 ~ 2 with bilinear interpolation
+ min_pos_samples = min(self.cls_num_pos_samples_per_level)
+ max_pos_samples = max(self.cls_num_pos_samples_per_level)
+ interval = 1. / (max_pos_samples - min_pos_samples + 1e-10)
+ reweight_factor_per_level = []
+ for pos_samples in self.cls_num_pos_samples_per_level:
+ factor = 2. - (pos_samples - min_pos_samples) * interval
+ reweight_factor_per_level.append(factor)
+ return reweight_factor_per_level
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ iou_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_base_priors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_base_priors * 4, H, W)
+ iou_preds (list[Tensor]): Score factor for all scale level,
+ each is a 4D-tensor, has shape (batch_size, 1, H, W).
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ # calculate common vars for cls and reg assigners at once
+ targets_com = self.process_predictions_and_anchors(
+ anchor_list, valid_flag_list, cls_scores, bbox_preds, img_metas,
+ gt_bboxes_ignore)
+ (anchor_list, valid_flag_list, num_level_anchors_list, cls_score_list,
+ bbox_pred_list, gt_bboxes_ignore_list) = targets_com
+
+ # classification branch assigner
+ cls_targets = self.get_cls_targets(
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ cls_score_list,
+ bbox_pred_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore_list,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_targets is None:
+ return None
+
+ (cls_anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ reweight_factor_per_level = self.calc_reweight_factor(labels_list)
+
+ cls_losses_cls, = multi_apply(
+ self.loss_cls_single,
+ cls_scores,
+ labels_list,
+ label_weights_list,
+ reweight_factor_per_level,
+ num_total_samples=num_total_samples)
+
+ # regression branch assigner
+ reg_targets = self.get_reg_targets(
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ cls_score_list,
+ bbox_pred_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore_list,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if reg_targets is None:
+ return None
+
+ (reg_anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = reg_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ reweight_factor_per_level = self.calc_reweight_factor(labels_list)
+
+ reg_losses_bbox, reg_losses_iou = multi_apply(
+ self.loss_reg_single,
+ reg_anchor_list,
+ bbox_preds,
+ iou_preds,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ reweight_factor_per_level,
+ num_total_samples=num_total_samples)
+
+ return dict(
+ loss_cls=cls_losses_cls,
+ loss_bbox=reg_losses_bbox,
+ loss_iou=reg_losses_iou)
+
+ def process_predictions_and_anchors(self, anchor_list, valid_flag_list,
+ cls_scores, bbox_preds, img_metas,
+ gt_bboxes_ignore_list):
+ """Compute common vars for regression and classification targets.
+
+ Args:
+ anchor_list (list[Tensor]): anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore_list (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Return:
+ tuple[Tensor]: A tuple of common loss vars.
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ num_level_anchors_list = [num_level_anchors] * num_imgs
+
+ anchor_list_ = []
+ valid_flag_list_ = []
+ # concat all level anchors and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ anchor_list_.append(torch.cat(anchor_list[i]))
+ valid_flag_list_.append(torch.cat(valid_flag_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+
+ num_levels = len(cls_scores)
+ cls_score_list = []
+ bbox_pred_list = []
+
+ mlvl_cls_score_list = [
+ cls_score.permute(0, 2, 3, 1).reshape(
+ num_imgs, -1, self.num_base_priors * self.cls_out_channels)
+ for cls_score in cls_scores
+ ]
+ mlvl_bbox_pred_list = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
+ self.num_base_priors * 4)
+ for bbox_pred in bbox_preds
+ ]
+
+ for i in range(num_imgs):
+ mlvl_cls_tensor_list = [
+ mlvl_cls_score_list[j][i] for j in range(num_levels)
+ ]
+ mlvl_bbox_tensor_list = [
+ mlvl_bbox_pred_list[j][i] for j in range(num_levels)
+ ]
+ cat_mlvl_cls_score = torch.cat(mlvl_cls_tensor_list, dim=0)
+ cat_mlvl_bbox_pred = torch.cat(mlvl_bbox_tensor_list, dim=0)
+ cls_score_list.append(cat_mlvl_cls_score)
+ bbox_pred_list.append(cat_mlvl_bbox_pred)
+ return (anchor_list_, valid_flag_list_, num_level_anchors_list,
+ cls_score_list, bbox_pred_list, gt_bboxes_ignore_list)
+
+ def get_cls_targets(self,
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ cls_score_list,
+ bbox_pred_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Get cls targets for DDOD head.
+
+ This method is almost the same as `AnchorHead.get_targets()`.
+ Besides returning the targets as the parent method does,
+ it also returns the anchors as the first element of the
+ returned tuple.
+
+ Args:
+ anchor_list (list[Tensor]): anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ num_level_anchors_list (list[Tensor]): Number of anchors of each
+ scale level of all image.
+ cls_score_list (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ bbox_pred_list (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore_list (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_labels_list (list[Tensor]): class indices corresponding to
+ each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Return:
+ tuple[Tensor]: A tuple of cls targets components.
+ """
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ anchor_list,
+ valid_flag_list,
+ cls_score_list,
+ bbox_pred_list,
+ num_level_anchors_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs,
+ is_cls_assigner=True)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0])
+ labels_list = images_to_levels(all_labels, num_level_anchors_list[0])
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors_list[0])
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors_list[0])
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors_list[0])
+ return (anchors_list, labels_list, label_weights_list,
+ bbox_targets_list, bbox_weights_list, num_total_pos,
+ num_total_neg)
+
+ def get_reg_targets(self,
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ cls_score_list,
+ bbox_pred_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Get reg targets for DDOD head.
+
+ This method is almost the same as `AnchorHead.get_targets()` when
+ is_cls_assigner is False. Besides returning the targets as the parent
+ method does, it also returns the anchors as the first element of the
+ returned tuple.
+
+ Args:
+ anchor_list (list[Tensor]): anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ num_level_anchors (int): Number of anchors of each scale level.
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_base_priors * 4.
+ gt_labels_list (list[Tensor]): class indices corresponding to
+ each box.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore_list (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Return:
+ tuple[Tensor]: A tuple of reg targets components.
+ """
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ anchor_list,
+ valid_flag_list,
+ cls_score_list,
+ bbox_pred_list,
+ num_level_anchors_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs,
+ is_cls_assigner=False)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ anchors_list = images_to_levels(all_anchors, num_level_anchors_list[0])
+ labels_list = images_to_levels(all_labels, num_level_anchors_list[0])
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors_list[0])
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors_list[0])
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors_list[0])
+ return (anchors_list, labels_list, label_weights_list,
+ bbox_targets_list, bbox_weights_list, num_total_pos,
+ num_total_neg)
+
+ def _get_target_single(self,
+ flat_anchors,
+ valid_flags,
+ cls_scores,
+ bbox_preds,
+ num_level_anchors,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True,
+ is_cls_assigner=True):
+ """Compute regression, classification targets for anchors in a single
+ image.
+
+ Args:
+ flat_anchors (Tensor): Multi-level anchors of the image,
+ which are concatenated into a single tensor of shape
+ (num_base_priors, 4).
+ valid_flags (Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_base_priors,).
+ cls_scores (Tensor): Classification scores for all scale
+ levels of the image.
+ bbox_preds (Tensor): Box energies / deltas for all scale
+ levels of the image.
+ num_level_anchors (list[int]): Number of anchors of each
+ scale level.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, ).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts, ).
+ img_meta (dict): Meta info of the image.
+ label_channels (int): Channel of label. Default: 1.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors. Default: True.
+ is_cls_assigner (bool): Classification or regression.
+ Default: True.
+
+ Returns:
+ tuple: N is the number of total anchors in the image.
+ - labels (Tensor): Labels of all anchors in the image with \
+ shape (N, ).
+ - label_weights (Tensor): Label weights of all anchor in the \
+ image with shape (N, ).
+ - bbox_targets (Tensor): BBox targets of all anchors in the \
+ image with shape (N, 4).
+ - bbox_weights (Tensor): BBox weights of all anchors in the \
+ image with shape (N, 4)
+ - pos_inds (Tensor): Indices of positive anchor with shape \
+ (num_pos, ).
+ - neg_inds (Tensor): Indices of negative anchor with shape \
+ (num_neg, ).
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+
+ num_level_anchors_inside = self.get_num_level_anchors_inside(
+ num_level_anchors, inside_flags)
+ bbox_preds_valid = bbox_preds[inside_flags, :]
+ cls_scores_valid = cls_scores[inside_flags, :]
+
+ assigner = self.cls_assigner if is_cls_assigner else self.reg_assigner
+
+ # decode prediction out of assigner
+ bbox_preds_valid = self.bbox_coder.decode(anchors, bbox_preds_valid)
+ assign_result = assigner.assign(anchors, num_level_anchors_inside,
+ gt_bboxes, gt_bboxes_ignore, gt_labels,
+ cls_scores_valid, bbox_preds_valid)
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if hasattr(self, 'bbox_coder'):
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ # used in VFNetHead
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ anchors = unmap(anchors, num_total_anchors, inside_flags)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags, fill=self.num_classes)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (anchors, labels, label_weights, bbox_targets, bbox_weights,
+ pos_inds, neg_inds)
+
+ def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
+ """Get the anchors of each scale level inside.
+
+ Args:
+ num_level_anchors (list[int]): Number of anchors of each
+ scale level.
+ inside_flags (Tensor): Multi level inside flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_base_priors,).
+
+ Returns:
+ list[int]: Number of anchors of each scale level inside.
+ """
+ split_inside_flags = torch.split(inside_flags, num_level_anchors)
+ num_level_anchors_inside = [
+ int(flags.sum()) for flags in split_inside_flags
+ ]
+ return num_level_anchors_inside
diff --git a/mmdet/models/dense_heads/deformable_detr_head.py b/mmdet/models/dense_heads/deformable_detr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..31290dbb51b2991514fe00effadce97d5df6ce01
--- /dev/null
+++ b/mmdet/models/dense_heads/deformable_detr_head.py
@@ -0,0 +1,318 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Linear, bias_init_with_prob, constant_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import multi_apply
+from mmdet.models.utils.transformer import inverse_sigmoid
+from ..builder import HEADS
+from .detr_head import DETRHead
+
+
+@HEADS.register_module()
+class DeformableDETRHead(DETRHead):
+ """Head of DeformDETR: Deformable DETR: Deformable Transformers for End-to-
+ End Object Detection.
+
+ Code is modified from the `official github repo
+ `_.
+
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ with_box_refine (bool): Whether to refine the reference points
+ in the decoder. Defaults to False.
+ as_two_stage (bool) : Whether to generate the proposal from
+ the outputs of encoder.
+ transformer (obj:`ConfigDict`): ConfigDict is used for building
+ the Encoder and Decoder.
+ """
+
+ def __init__(self,
+ *args,
+ with_box_refine=False,
+ as_two_stage=False,
+ transformer=None,
+ **kwargs):
+ self.with_box_refine = with_box_refine
+ self.as_two_stage = as_two_stage
+ if self.as_two_stage:
+ transformer['as_two_stage'] = self.as_two_stage
+
+ super(DeformableDETRHead, self).__init__(
+ *args, transformer=transformer, **kwargs)
+
+ def _init_layers(self):
+ """Initialize classification branch and regression branch of head."""
+
+ fc_cls = Linear(self.embed_dims, self.cls_out_channels)
+ reg_branch = []
+ for _ in range(self.num_reg_fcs):
+ reg_branch.append(Linear(self.embed_dims, self.embed_dims))
+ reg_branch.append(nn.ReLU())
+ reg_branch.append(Linear(self.embed_dims, 4))
+ reg_branch = nn.Sequential(*reg_branch)
+
+ def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+ # last reg_branch is used to generate proposal from
+ # encode feature map when as_two_stage is True.
+ num_pred = (self.transformer.decoder.num_layers + 1) if \
+ self.as_two_stage else self.transformer.decoder.num_layers
+
+ if self.with_box_refine:
+ self.cls_branches = _get_clones(fc_cls, num_pred)
+ self.reg_branches = _get_clones(reg_branch, num_pred)
+ else:
+
+ self.cls_branches = nn.ModuleList(
+ [fc_cls for _ in range(num_pred)])
+ self.reg_branches = nn.ModuleList(
+ [reg_branch for _ in range(num_pred)])
+
+ if not self.as_two_stage:
+ self.query_embedding = nn.Embedding(self.num_query,
+ self.embed_dims * 2)
+
+ def init_weights(self):
+ """Initialize weights of the DeformDETR head."""
+ self.transformer.init_weights()
+ if self.loss_cls.use_sigmoid:
+ bias_init = bias_init_with_prob(0.01)
+ for m in self.cls_branches:
+ nn.init.constant_(m.bias, bias_init)
+ for m in self.reg_branches:
+ constant_init(m[-1], 0, bias=0)
+ nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], -2.0)
+ if self.as_two_stage:
+ for m in self.reg_branches:
+ nn.init.constant_(m[-1].bias.data[2:], 0.0)
+
+ def forward(self, mlvl_feats, img_metas):
+ """Forward function.
+
+ Args:
+ mlvl_feats (tuple[Tensor]): Features from the upstream
+ network, each is a 4D-tensor with shape
+ (N, C, H, W).
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ all_cls_scores (Tensor): Outputs from the classification head, \
+ shape [nb_dec, bs, num_query, cls_out_channels]. Note \
+ cls_out_channels should includes background.
+ all_bbox_preds (Tensor): Sigmoid outputs from the regression \
+ head with normalized coordinate format (cx, cy, w, h). \
+ Shape [nb_dec, bs, num_query, 4].
+ enc_outputs_class (Tensor): The score of each point on encode \
+ feature map, has shape (N, h*w, num_class). Only when \
+ as_two_stage is True it would be returned, otherwise \
+ `None` would be returned.
+ enc_outputs_coord (Tensor): The proposal generate from the \
+ encode feature map, has shape (N, h*w, 4). Only when \
+ as_two_stage is True it would be returned, otherwise \
+ `None` would be returned.
+ """
+
+ batch_size = mlvl_feats[0].size(0)
+ input_img_h, input_img_w = img_metas[0]['batch_input_shape']
+ img_masks = mlvl_feats[0].new_ones(
+ (batch_size, input_img_h, input_img_w))
+ for img_id in range(batch_size):
+ img_h, img_w, _ = img_metas[img_id]['img_shape']
+ img_masks[img_id, :img_h, :img_w] = 0
+
+ mlvl_masks = []
+ mlvl_positional_encodings = []
+ for feat in mlvl_feats:
+ mlvl_masks.append(
+ F.interpolate(img_masks[None],
+ size=feat.shape[-2:]).to(torch.bool).squeeze(0))
+ mlvl_positional_encodings.append(
+ self.positional_encoding(mlvl_masks[-1]))
+
+ query_embeds = None
+ if not self.as_two_stage:
+ query_embeds = self.query_embedding.weight
+ hs, init_reference, inter_references, \
+ enc_outputs_class, enc_outputs_coord = self.transformer(
+ mlvl_feats,
+ mlvl_masks,
+ query_embeds,
+ mlvl_positional_encodings,
+ reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
+ cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
+ )
+ hs = hs.permute(0, 2, 1, 3)
+ outputs_classes = []
+ outputs_coords = []
+
+ for lvl in range(hs.shape[0]):
+ if lvl == 0:
+ reference = init_reference
+ else:
+ reference = inter_references[lvl - 1]
+ reference = inverse_sigmoid(reference)
+ outputs_class = self.cls_branches[lvl](hs[lvl])
+ tmp = self.reg_branches[lvl](hs[lvl])
+ if reference.shape[-1] == 4:
+ tmp += reference
+ else:
+ assert reference.shape[-1] == 2
+ tmp[..., :2] += reference
+ outputs_coord = tmp.sigmoid()
+ outputs_classes.append(outputs_class)
+ outputs_coords.append(outputs_coord)
+
+ outputs_classes = torch.stack(outputs_classes)
+ outputs_coords = torch.stack(outputs_coords)
+ if self.as_two_stage:
+ return outputs_classes, outputs_coords, \
+ enc_outputs_class, \
+ enc_outputs_coord.sigmoid()
+ else:
+ return outputs_classes, outputs_coords, \
+ None, None
+
+ @force_fp32(apply_to=('all_cls_scores', 'all_bbox_preds'))
+ def loss(self,
+ all_cls_scores,
+ all_bbox_preds,
+ enc_cls_scores,
+ enc_bbox_preds,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """"Loss function.
+
+ Args:
+ all_cls_scores (Tensor): Classification score of all
+ decoder layers, has shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds (Tensor): Sigmoid regression
+ outputs of all decode layers. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ enc_cls_scores (Tensor): Classification scores of
+ points on encode feature map , has shape
+ (N, h*w, num_classes). Only be passed when as_two_stage is
+ True, otherwise is None.
+ enc_bbox_preds (Tensor): Regression results of each points
+ on the encode feature map, has shape (N, h*w, 4). Only be
+ passed when as_two_stage is True, otherwise is None.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
+ which can be ignored for each image. Default None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert gt_bboxes_ignore is None, \
+ f'{self.__class__.__name__} only supports ' \
+ f'for gt_bboxes_ignore setting to None.'
+
+ num_dec_layers = len(all_cls_scores)
+ all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
+ all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+ all_gt_bboxes_ignore_list = [
+ gt_bboxes_ignore for _ in range(num_dec_layers)
+ ]
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
+
+ losses_cls, losses_bbox, losses_iou = multi_apply(
+ self.loss_single, all_cls_scores, all_bbox_preds,
+ all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
+ all_gt_bboxes_ignore_list)
+
+ loss_dict = dict()
+ # loss of proposal generated from encode feature map.
+ if enc_cls_scores is not None:
+ binary_labels_list = [
+ torch.zeros_like(gt_labels_list[i])
+ for i in range(len(img_metas))
+ ]
+ enc_loss_cls, enc_losses_bbox, enc_losses_iou = \
+ self.loss_single(enc_cls_scores, enc_bbox_preds,
+ gt_bboxes_list, binary_labels_list,
+ img_metas, gt_bboxes_ignore)
+ loss_dict['enc_loss_cls'] = enc_loss_cls
+ loss_dict['enc_loss_bbox'] = enc_losses_bbox
+ loss_dict['enc_loss_iou'] = enc_losses_iou
+
+ # loss from the last decoder layer
+ loss_dict['loss_cls'] = losses_cls[-1]
+ loss_dict['loss_bbox'] = losses_bbox[-1]
+ loss_dict['loss_iou'] = losses_iou[-1]
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
+ losses_bbox[:-1],
+ losses_iou[:-1]):
+ loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
+ loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
+ loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
+ num_dec_layer += 1
+ return loss_dict
+
+ @force_fp32(apply_to=('all_cls_scores', 'all_bbox_preds'))
+ def get_bboxes(self,
+ all_cls_scores,
+ all_bbox_preds,
+ enc_cls_scores,
+ enc_bbox_preds,
+ img_metas,
+ rescale=False):
+ """Transform network outputs for a batch into bbox predictions.
+
+ Args:
+ all_cls_scores (Tensor): Classification score of all
+ decoder layers, has shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds (Tensor): Sigmoid regression
+ outputs of all decode layers. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ enc_cls_scores (Tensor): Classification scores of
+ points on encode feature map , has shape
+ (N, h*w, num_classes). Only be passed when as_two_stage is
+ True, otherwise is None.
+ enc_bbox_preds (Tensor): Regression results of each points
+ on the encode feature map, has shape (N, h*w, 4). Only be
+ passed when as_two_stage is True, otherwise is None.
+ img_metas (list[dict]): Meta information of each image.
+ rescale (bool, optional): If True, return boxes in original
+ image space. Default False.
+
+ Returns:
+ list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
+ The first item is an (n, 5) tensor, where the first 4 columns \
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
+ 5-th column is a score between 0 and 1. The second item is a \
+ (n,) tensor where each item is the predicted class label of \
+ the corresponding box.
+ """
+ cls_scores = all_cls_scores[-1]
+ bbox_preds = all_bbox_preds[-1]
+
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score = cls_scores[img_id]
+ bbox_pred = bbox_preds[img_id]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score, bbox_pred,
+ img_shape, scale_factor,
+ rescale)
+ result_list.append(proposals)
+ return result_list
diff --git a/mmdet/models/dense_heads/dense_test_mixins.py b/mmdet/models/dense_heads/dense_test_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..3421548955d62652ea3d6e65dec71253d021615a
--- /dev/null
+++ b/mmdet/models/dense_heads/dense_test_mixins.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+from inspect import signature
+
+import torch
+from mmcv.ops import batched_nms
+
+from mmdet.core import bbox_mapping_back, merge_aug_proposals
+
+if sys.version_info >= (3, 7):
+ from mmdet.utils.contextmanagers import completed
+
+
+class BBoxTestMixin(object):
+ """Mixin class for testing det bboxes via DenseHead."""
+
+ def simple_test_bboxes(self, feats, img_metas, rescale=False):
+ """Test det bboxes without test-time augmentation, can be applied in
+ DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``,
+ etc.
+
+ Args:
+ feats (tuple[torch.Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is ``bboxes`` with shape (n, 5),
+ where 5 represent (tl_x, tl_y, br_x, br_y, score).
+ The shape of the second tensor in the tuple is ``labels``
+ with shape (n,)
+ """
+ outs = self.forward(feats)
+ results_list = self.get_bboxes(
+ *outs, img_metas=img_metas, rescale=rescale)
+ return results_list
+
+ def aug_test_bboxes(self, feats, img_metas, rescale=False):
+ """Test det bboxes with test time augmentation, can be applied in
+ DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``,
+ etc.
+
+ Args:
+ feats (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains features for all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is ``bboxes`` with shape (n, 5),
+ where 5 represent (tl_x, tl_y, br_x, br_y, score).
+ The shape of the second tensor in the tuple is ``labels``
+ with shape (n,). The length of list should always be 1.
+ """
+ # check with_nms argument
+ gb_sig = signature(self.get_bboxes)
+ gb_args = [p.name for p in gb_sig.parameters.values()]
+ gbs_sig = signature(self._get_bboxes_single)
+ gbs_args = [p.name for p in gbs_sig.parameters.values()]
+ assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
+ f'{self.__class__.__name__}' \
+ ' does not support test-time augmentation'
+
+ aug_bboxes = []
+ aug_scores = []
+ aug_labels = []
+ for x, img_meta in zip(feats, img_metas):
+ # only one image in the batch
+ outs = self.forward(x)
+ bbox_outputs = self.get_bboxes(
+ *outs,
+ img_metas=img_meta,
+ cfg=self.test_cfg,
+ rescale=False,
+ with_nms=False)[0]
+ aug_bboxes.append(bbox_outputs[0])
+ aug_scores.append(bbox_outputs[1])
+ if len(bbox_outputs) >= 3:
+ aug_labels.append(bbox_outputs[2])
+
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = self.merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas)
+ merged_labels = torch.cat(aug_labels, dim=0) if aug_labels else None
+
+ if merged_bboxes.numel() == 0:
+ det_bboxes = torch.cat([merged_bboxes, merged_scores[:, None]], -1)
+ return [
+ (det_bboxes, merged_labels),
+ ]
+
+ det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores,
+ merged_labels, self.test_cfg.nms)
+ det_bboxes = det_bboxes[:self.test_cfg.max_per_img]
+ det_labels = merged_labels[keep_idxs][:self.test_cfg.max_per_img]
+
+ if rescale:
+ _det_bboxes = det_bboxes
+ else:
+ _det_bboxes = det_bboxes.clone()
+ _det_bboxes[:, :4] *= det_bboxes.new_tensor(
+ img_metas[0][0]['scale_factor'])
+
+ return [
+ (_det_bboxes, det_labels),
+ ]
+
+ def simple_test_rpn(self, x, img_metas):
+ """Test without augmentation, only for ``RPNHead`` and its variants,
+ e.g., ``GARPNHead``, etc.
+
+ Args:
+ x (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ img_metas (list[dict]): Meta info of each image.
+
+ Returns:
+ list[Tensor]: Proposals of each image, each item has shape (n, 5),
+ where 5 represent (tl_x, tl_y, br_x, br_y, score).
+ """
+ rpn_outs = self(x)
+ proposal_list = self.get_bboxes(*rpn_outs, img_metas=img_metas)
+ return proposal_list
+
+ def aug_test_rpn(self, feats, img_metas):
+ """Test with augmentation for only for ``RPNHead`` and its variants,
+ e.g., ``GARPNHead``, etc.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ img_metas (list[dict]): Meta info of each image.
+
+ Returns:
+ list[Tensor]: Proposals of each image, each item has shape (n, 5),
+ where 5 represent (tl_x, tl_y, br_x, br_y, score).
+ """
+ samples_per_gpu = len(img_metas[0])
+ aug_proposals = [[] for _ in range(samples_per_gpu)]
+ for x, img_meta in zip(feats, img_metas):
+ proposal_list = self.simple_test_rpn(x, img_meta)
+ for i, proposals in enumerate(proposal_list):
+ aug_proposals[i].append(proposals)
+ # reorganize the order of 'img_metas' to match the dimensions
+ # of 'aug_proposals'
+ aug_img_metas = []
+ for i in range(samples_per_gpu):
+ aug_img_meta = []
+ for j in range(len(img_metas)):
+ aug_img_meta.append(img_metas[j][i])
+ aug_img_metas.append(aug_img_meta)
+ # after merging, proposals will be rescaled to the original image size
+ merged_proposals = [
+ merge_aug_proposals(proposals, aug_img_meta, self.test_cfg)
+ for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas)
+ ]
+ return merged_proposals
+
+ if sys.version_info >= (3, 7):
+
+ async def async_simple_test_rpn(self, x, img_metas):
+ sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025)
+ async with completed(
+ __name__, 'rpn_head_forward',
+ sleep_interval=sleep_interval):
+ rpn_outs = self(x)
+
+ proposal_list = self.get_bboxes(*rpn_outs, img_metas=img_metas)
+ return proposal_list
+
+ def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
+ """Merge augmented detection bboxes and scores.
+
+ Args:
+ aug_bboxes (list[Tensor]): shape (n, 4*#class)
+ aug_scores (list[Tensor] or None): shape (n, #class)
+ img_shapes (list[Tensor]): shape (3, ).
+
+ Returns:
+ tuple[Tensor]: ``bboxes`` with shape (n,4), where
+ 4 represent (tl_x, tl_y, br_x, br_y)
+ and ``scores`` with shape (n,).
+ """
+ recovered_bboxes = []
+ for bboxes, img_info in zip(aug_bboxes, img_metas):
+ img_shape = img_info[0]['img_shape']
+ scale_factor = img_info[0]['scale_factor']
+ flip = img_info[0]['flip']
+ flip_direction = img_info[0]['flip_direction']
+ bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
+ flip_direction)
+ recovered_bboxes.append(bboxes)
+ bboxes = torch.cat(recovered_bboxes, dim=0)
+ if aug_scores is None:
+ return bboxes
+ else:
+ scores = torch.cat(aug_scores, dim=0)
+ return bboxes, scores
diff --git a/mmdet/models/dense_heads/detr_head.py b/mmdet/models/dense_heads/detr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..de1913c9db19f8dae93e5e1a8e045673c6faa96e
--- /dev/null
+++ b/mmdet/models/dense_heads/detr_head.py
@@ -0,0 +1,844 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Conv2d, Linear, build_activation_layer
+from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding
+from mmcv.runner import force_fp32
+
+from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
+ build_assigner, build_sampler, multi_apply,
+ reduce_mean)
+from mmdet.models.utils import build_transformer
+from ..builder import HEADS, build_loss
+from .anchor_free_head import AnchorFreeHead
+
+
+@HEADS.register_module()
+class DETRHead(AnchorFreeHead):
+ """Implements the DETR transformer head.
+
+ See `paper: End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ num_classes (int): Number of categories excluding the background.
+ in_channels (int): Number of channels in the input feature map.
+ num_query (int): Number of query in Transformer.
+ num_reg_fcs (int, optional): Number of fully-connected layers used in
+ `FFN`, which is then used for the regression head. Default 2.
+ transformer (obj:`mmcv.ConfigDict`|dict): Config for transformer.
+ Default: None.
+ sync_cls_avg_factor (bool): Whether to sync the avg_factor of
+ all ranks. Default to False.
+ positional_encoding (obj:`mmcv.ConfigDict`|dict):
+ Config for position encoding.
+ loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the
+ classification loss. Default `CrossEntropyLoss`.
+ loss_bbox (obj:`mmcv.ConfigDict`|dict): Config of the
+ regression loss. Default `L1Loss`.
+ loss_iou (obj:`mmcv.ConfigDict`|dict): Config of the
+ regression iou loss. Default `GIoULoss`.
+ tran_cfg (obj:`mmcv.ConfigDict`|dict): Training config of
+ transformer head.
+ test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of
+ transformer head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ _version = 2
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ num_query=100,
+ num_reg_fcs=2,
+ transformer=None,
+ sync_cls_avg_factor=False,
+ positional_encoding=dict(
+ type='SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ bg_cls_weight=0.1,
+ use_sigmoid=False,
+ loss_weight=1.0,
+ class_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=5.0),
+ loss_iou=dict(type='GIoULoss', loss_weight=2.0),
+ train_cfg=dict(
+ assigner=dict(
+ type='HungarianAssigner',
+ cls_cost=dict(type='ClassificationCost', weight=1.),
+ reg_cost=dict(type='BBoxL1Cost', weight=5.0),
+ iou_cost=dict(
+ type='IoUCost', iou_mode='giou', weight=2.0))),
+ test_cfg=dict(max_per_img=100),
+ init_cfg=None,
+ **kwargs):
+ # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
+ # since it brings inconvenience when the initialization of
+ # `AnchorFreeHead` is called.
+ super(AnchorFreeHead, self).__init__(init_cfg)
+ self.bg_cls_weight = 0
+ self.sync_cls_avg_factor = sync_cls_avg_factor
+ class_weight = loss_cls.get('class_weight', None)
+ if class_weight is not None and (self.__class__ is DETRHead):
+ assert isinstance(class_weight, float), 'Expected ' \
+ 'class_weight to have type float. Found ' \
+ f'{type(class_weight)}.'
+ # NOTE following the official DETR rep0, bg_cls_weight means
+ # relative classification weight of the no-object class.
+ bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
+ assert isinstance(bg_cls_weight, float), 'Expected ' \
+ 'bg_cls_weight to have type float. Found ' \
+ f'{type(bg_cls_weight)}.'
+ class_weight = torch.ones(num_classes + 1) * class_weight
+ # set background class as the last indice
+ class_weight[num_classes] = bg_cls_weight
+ loss_cls.update({'class_weight': class_weight})
+ if 'bg_cls_weight' in loss_cls:
+ loss_cls.pop('bg_cls_weight')
+ self.bg_cls_weight = bg_cls_weight
+
+ if train_cfg:
+ assert 'assigner' in train_cfg, 'assigner should be provided '\
+ 'when train_cfg is set.'
+ assigner = train_cfg['assigner']
+ assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \
+ 'The classification weight for loss and matcher should be' \
+ 'exactly the same.'
+ assert loss_bbox['loss_weight'] == assigner['reg_cost'][
+ 'weight'], 'The regression L1 weight for loss and matcher ' \
+ 'should be exactly the same.'
+ assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'], \
+ 'The regression iou weight for loss and matcher should be' \
+ 'exactly the same.'
+ self.assigner = build_assigner(assigner)
+ # DETR sampling=False, so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.num_query = num_query
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.num_reg_fcs = num_reg_fcs
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.fp16_enabled = False
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.loss_iou = build_loss(loss_iou)
+
+ if self.loss_cls.use_sigmoid:
+ self.cls_out_channels = num_classes
+ else:
+ self.cls_out_channels = num_classes + 1
+ self.act_cfg = transformer.get('act_cfg',
+ dict(type='ReLU', inplace=True))
+ self.activate = build_activation_layer(self.act_cfg)
+ self.positional_encoding = build_positional_encoding(
+ positional_encoding)
+ self.transformer = build_transformer(transformer)
+ self.embed_dims = self.transformer.embed_dims
+ assert 'num_feats' in positional_encoding
+ num_feats = positional_encoding['num_feats']
+ assert num_feats * 2 == self.embed_dims, 'embed_dims should' \
+ f' be exactly 2 times of num_feats. Found {self.embed_dims}' \
+ f' and {num_feats}.'
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize layers of the transformer head."""
+ self.input_proj = Conv2d(
+ self.in_channels, self.embed_dims, kernel_size=1)
+ self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
+ self.reg_ffn = FFN(
+ self.embed_dims,
+ self.embed_dims,
+ self.num_reg_fcs,
+ self.act_cfg,
+ dropout=0.0,
+ add_residual=False)
+ self.fc_reg = Linear(self.embed_dims, 4)
+ self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
+
+ def init_weights(self):
+ """Initialize weights of the transformer head."""
+ # The initialization for transformer is important
+ self.transformer.init_weights()
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """load checkpoints."""
+ # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
+ # since `AnchorFreeHead._load_from_state_dict` should not be
+ # called here. Invoking the default `Module._load_from_state_dict`
+ # is enough.
+
+ # Names of some parameters in has been changed.
+ version = local_metadata.get('version', None)
+ if (version is None or version < 2) and self.__class__ is DETRHead:
+ convert_dict = {
+ '.self_attn.': '.attentions.0.',
+ '.ffn.': '.ffns.0.',
+ '.multihead_attn.': '.attentions.1.',
+ '.decoder.norm.': '.decoder.post_norm.'
+ }
+ state_dict_keys = list(state_dict.keys())
+ for k in state_dict_keys:
+ for ori_key, convert_key in convert_dict.items():
+ if ori_key in k:
+ convert_key = k.replace(ori_key, convert_key)
+ state_dict[convert_key] = state_dict[k]
+ del state_dict[k]
+
+ super(AnchorFreeHead,
+ self)._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys,
+ unexpected_keys, error_msgs)
+
+ def forward(self, feats, img_metas):
+ """Forward function.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
+
+ - all_cls_scores_list (list[Tensor]): Classification scores \
+ for each scale level. Each is a 4D-tensor with shape \
+ [nb_dec, bs, num_query, cls_out_channels]. Note \
+ `cls_out_channels` should includes background.
+ - all_bbox_preds_list (list[Tensor]): Sigmoid regression \
+ outputs for each scale level. Each is a 4D-tensor with \
+ normalized coordinate format (cx, cy, w, h) and shape \
+ [nb_dec, bs, num_query, 4].
+ """
+ num_levels = len(feats)
+ img_metas_list = [img_metas for _ in range(num_levels)]
+ return multi_apply(self.forward_single, feats, img_metas_list)
+
+ def forward_single(self, x, img_metas):
+ """"Forward function for a single feature level.
+
+ Args:
+ x (Tensor): Input feature from backbone's single stage, shape
+ [bs, c, h, w].
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ all_cls_scores (Tensor): Outputs from the classification head,
+ shape [nb_dec, bs, num_query, cls_out_channels]. Note
+ cls_out_channels should includes background.
+ all_bbox_preds (Tensor): Sigmoid outputs from the regression
+ head with normalized coordinate format (cx, cy, w, h).
+ Shape [nb_dec, bs, num_query, 4].
+ """
+ # construct binary masks which used for the transformer.
+ # NOTE following the official DETR repo, non-zero values representing
+ # ignored positions, while zero values means valid positions.
+ batch_size = x.size(0)
+ input_img_h, input_img_w = img_metas[0]['batch_input_shape']
+ masks = x.new_ones((batch_size, input_img_h, input_img_w))
+ for img_id in range(batch_size):
+ img_h, img_w, _ = img_metas[img_id]['img_shape']
+ masks[img_id, :img_h, :img_w] = 0
+
+ x = self.input_proj(x)
+ # interpolate masks to have the same spatial shape with x
+ masks = F.interpolate(
+ masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
+ # position encoding
+ pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
+ # outs_dec: [nb_dec, bs, num_query, embed_dim]
+ outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
+ pos_embed)
+
+ all_cls_scores = self.fc_cls(outs_dec)
+ all_bbox_preds = self.fc_reg(self.activate(
+ self.reg_ffn(outs_dec))).sigmoid()
+ return all_cls_scores, all_bbox_preds
+
+ @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
+ def loss(self,
+ all_cls_scores_list,
+ all_bbox_preds_list,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """"Loss function.
+
+ Only outputs from the last feature level are used for computing
+ losses by default.
+
+ Args:
+ all_cls_scores_list (list[Tensor]): Classification outputs
+ for each feature level. Each is a 4D-tensor with shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds_list (list[Tensor]): Sigmoid regression
+ outputs for each feature level. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
+ which can be ignored for each image. Default None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # NOTE defaultly only the outputs from the last feature scale is used.
+ all_cls_scores = all_cls_scores_list[-1]
+ all_bbox_preds = all_bbox_preds_list[-1]
+ assert gt_bboxes_ignore is None, \
+ 'Only supports for gt_bboxes_ignore setting to None.'
+
+ num_dec_layers = len(all_cls_scores)
+ all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
+ all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+ all_gt_bboxes_ignore_list = [
+ gt_bboxes_ignore for _ in range(num_dec_layers)
+ ]
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
+
+ losses_cls, losses_bbox, losses_iou = multi_apply(
+ self.loss_single, all_cls_scores, all_bbox_preds,
+ all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
+ all_gt_bboxes_ignore_list)
+
+ loss_dict = dict()
+ # loss from the last decoder layer
+ loss_dict['loss_cls'] = losses_cls[-1]
+ loss_dict['loss_bbox'] = losses_bbox[-1]
+ loss_dict['loss_iou'] = losses_iou[-1]
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
+ losses_bbox[:-1],
+ losses_iou[:-1]):
+ loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
+ loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
+ loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
+ num_dec_layer += 1
+ return loss_dict
+
+ def loss_single(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore_list=None):
+ """"Loss function for outputs from a single decoder layer of a single
+ feature level.
+
+ Args:
+ cls_scores (Tensor): Box score logits from a single decoder layer
+ for all images. Shape [bs, num_query, cls_out_channels].
+ bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
+ for all images, with normalized coordinate (cx, cy, w, h) and
+ shape [bs, num_query, 4].
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore_list (list[Tensor], optional): Bounding
+ boxes which can be ignored for each image. Default None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components for outputs from
+ a single decoder layer.
+ """
+ num_imgs = cls_scores.size(0)
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
+ cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
+ gt_bboxes_list, gt_labels_list,
+ img_metas, gt_bboxes_ignore_list)
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ labels = torch.cat(labels_list, 0)
+ label_weights = torch.cat(label_weights_list, 0)
+ bbox_targets = torch.cat(bbox_targets_list, 0)
+ bbox_weights = torch.cat(bbox_weights_list, 0)
+
+ # classification loss
+ cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
+ # construct weighted avg_factor to match with the official DETR repo
+ cls_avg_factor = num_total_pos * 1.0 + \
+ num_total_neg * self.bg_cls_weight
+ if self.sync_cls_avg_factor:
+ cls_avg_factor = reduce_mean(
+ cls_scores.new_tensor([cls_avg_factor]))
+ cls_avg_factor = max(cls_avg_factor, 1)
+
+ loss_cls = self.loss_cls(
+ cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
+
+ # Compute the average number of gt boxes across all gpus, for
+ # normalization purposes
+ num_total_pos = loss_cls.new_tensor([num_total_pos])
+ num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
+
+ # construct factors used for rescale bboxes
+ factors = []
+ for img_meta, bbox_pred in zip(img_metas, bbox_preds):
+ img_h, img_w, _ = img_meta['img_shape']
+ factor = bbox_pred.new_tensor([img_w, img_h, img_w,
+ img_h]).unsqueeze(0).repeat(
+ bbox_pred.size(0), 1)
+ factors.append(factor)
+ factors = torch.cat(factors, 0)
+
+ # DETR regress the relative position of boxes (cxcywh) in the image,
+ # thus the learning target is normalized by the image size. So here
+ # we need to re-scale them for calculating IoU loss
+ bbox_preds = bbox_preds.reshape(-1, 4)
+ bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
+ bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
+
+ # regression IoU loss, defaultly GIoU loss
+ loss_iou = self.loss_iou(
+ bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
+
+ # regression L1 loss
+ loss_bbox = self.loss_bbox(
+ bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
+ return loss_cls, loss_bbox, loss_iou
+
+ def get_targets(self,
+ cls_scores_list,
+ bbox_preds_list,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore_list=None):
+ """"Compute regression and classification targets for a batch image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_scores_list (list[Tensor]): Box score logits from a single
+ decoder layer for each image with shape [num_query,
+ cls_out_channels].
+ bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
+ decoder layer for each image, with normalized coordinate
+ (cx, cy, w, h) and shape [num_query, 4].
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore_list (list[Tensor], optional): Bounding
+ boxes which can be ignored for each image. Default None.
+
+ Returns:
+ tuple: a tuple containing the following targets.
+
+ - labels_list (list[Tensor]): Labels for all images.
+ - label_weights_list (list[Tensor]): Label weights for all \
+ images.
+ - bbox_targets_list (list[Tensor]): BBox targets for all \
+ images.
+ - bbox_weights_list (list[Tensor]): BBox weights for all \
+ images.
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+ """
+ assert gt_bboxes_ignore_list is None, \
+ 'Only supports for gt_bboxes_ignore setting to None.'
+ num_imgs = len(cls_scores_list)
+ gt_bboxes_ignore_list = [
+ gt_bboxes_ignore_list for _ in range(num_imgs)
+ ]
+
+ (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single, cls_scores_list, bbox_preds_list,
+ gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+ return (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg)
+
+ def _get_target_single(self,
+ cls_score,
+ bbox_pred,
+ gt_bboxes,
+ gt_labels,
+ img_meta,
+ gt_bboxes_ignore=None):
+ """"Compute regression and classification targets for one image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_score (Tensor): Box score logits from a single decoder layer
+ for one image. Shape [num_query, cls_out_channels].
+ bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
+ for one image, with normalized coordinate (cx, cy, w, h) and
+ shape [num_query, 4].
+ gt_bboxes (Tensor): Ground truth bboxes for one image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (Tensor): Ground truth class indices for one image
+ with shape (num_gts, ).
+ img_meta (dict): Meta information for one image.
+ gt_bboxes_ignore (Tensor, optional): Bounding boxes
+ which can be ignored. Default None.
+
+ Returns:
+ tuple[Tensor]: a tuple containing the following for one image.
+
+ - labels (Tensor): Labels of each image.
+ - label_weights (Tensor]): Label weights of each image.
+ - bbox_targets (Tensor): BBox targets of each image.
+ - bbox_weights (Tensor): BBox weights of each image.
+ - pos_inds (Tensor): Sampled positive indices for each image.
+ - neg_inds (Tensor): Sampled negative indices for each image.
+ """
+
+ num_bboxes = bbox_pred.size(0)
+ # assigner and sampler
+ assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes,
+ gt_labels, img_meta,
+ gt_bboxes_ignore)
+ sampling_result = self.sampler.sample(assign_result, bbox_pred,
+ gt_bboxes)
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ # label targets
+ labels = gt_bboxes.new_full((num_bboxes, ),
+ self.num_classes,
+ dtype=torch.long)
+ labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
+ label_weights = gt_bboxes.new_ones(num_bboxes)
+
+ # bbox targets
+ bbox_targets = torch.zeros_like(bbox_pred)
+ bbox_weights = torch.zeros_like(bbox_pred)
+ bbox_weights[pos_inds] = 1.0
+ img_h, img_w, _ = img_meta['img_shape']
+
+ # DETR regress the relative position of boxes (cxcywh) in the image.
+ # Thus the learning target should be normalized by the image size, also
+ # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
+ factor = bbox_pred.new_tensor([img_w, img_h, img_w,
+ img_h]).unsqueeze(0)
+ pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
+ pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
+ bbox_targets[pos_inds] = pos_gt_bboxes_targets
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds)
+
+ # over-write because img_metas are needed as inputs for bbox_head.
+ def forward_train(self,
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ proposal_cfg=None,
+ **kwargs):
+ """Forward function for training mode.
+
+ Args:
+ x (list[Tensor]): Features from backbone.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ proposal_cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert proposal_cfg is None, '"proposal_cfg" must be None'
+ outs = self(x, img_metas)
+ if gt_labels is None:
+ loss_inputs = outs + (gt_bboxes, img_metas)
+ else:
+ loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
+ losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+ return losses
+
+ @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
+ def get_bboxes(self,
+ all_cls_scores_list,
+ all_bbox_preds_list,
+ img_metas,
+ rescale=False):
+ """Transform network outputs for a batch into bbox predictions.
+
+ Args:
+ all_cls_scores_list (list[Tensor]): Classification outputs
+ for each feature level. Each is a 4D-tensor with shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds_list (list[Tensor]): Sigmoid regression
+ outputs for each feature level. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ img_metas (list[dict]): Meta information of each image.
+ rescale (bool, optional): If True, return boxes in original
+ image space. Default False.
+
+ Returns:
+ list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
+ The first item is an (n, 5) tensor, where the first 4 columns \
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
+ 5-th column is a score between 0 and 1. The second item is a \
+ (n,) tensor where each item is the predicted class label of \
+ the corresponding box.
+ """
+ # NOTE defaultly only using outputs from the last feature level,
+ # and only the outputs from the last decoder layer is used.
+ cls_scores = all_cls_scores_list[-1][-1]
+ bbox_preds = all_bbox_preds_list[-1][-1]
+
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score = cls_scores[img_id]
+ bbox_pred = bbox_preds[img_id]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score, bbox_pred,
+ img_shape, scale_factor,
+ rescale)
+ result_list.append(proposals)
+
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=False):
+ """Transform outputs from the last decoder layer into bbox predictions
+ for each image.
+
+ Args:
+ cls_score (Tensor): Box score logits from the last decoder layer
+ for each image. Shape [num_query, cls_out_channels].
+ bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
+ for each image, with coordinate format (cx, cy, w, h) and
+ shape [num_query, 4].
+ img_shape (tuple[int]): Shape of input image, (height, width, 3).
+ scale_factor (ndarray, optional): Scale factor of the image arange
+ as (w_scale, h_scale, w_scale, h_scale).
+ rescale (bool, optional): If True, return boxes in original image
+ space. Default False.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels.
+
+ - det_bboxes: Predicted bboxes with shape [num_query, 5], \
+ where the first 4 columns are bounding box positions \
+ (tl_x, tl_y, br_x, br_y) and the 5-th column are scores \
+ between 0 and 1.
+ - det_labels: Predicted labels of the corresponding box with \
+ shape [num_query].
+ """
+ assert len(cls_score) == len(bbox_pred)
+ max_per_img = self.test_cfg.get('max_per_img', self.num_query)
+ # exclude background
+ if self.loss_cls.use_sigmoid:
+ cls_score = cls_score.sigmoid()
+ scores, indexes = cls_score.view(-1).topk(max_per_img)
+ det_labels = indexes % self.num_classes
+ bbox_index = indexes // self.num_classes
+ bbox_pred = bbox_pred[bbox_index]
+ else:
+ scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1)
+ scores, bbox_index = scores.topk(max_per_img)
+ bbox_pred = bbox_pred[bbox_index]
+ det_labels = det_labels[bbox_index]
+
+ det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
+ det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
+ det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
+ det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
+ det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
+ if rescale:
+ det_bboxes /= det_bboxes.new_tensor(scale_factor)
+ det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1)
+
+ return det_bboxes, det_labels
+
+ def simple_test_bboxes(self, feats, img_metas, rescale=False):
+ """Test det bboxes without test-time augmentation.
+
+ Args:
+ feats (tuple[torch.Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is ``bboxes`` with shape (n, 5),
+ where 5 represent (tl_x, tl_y, br_x, br_y, score).
+ The shape of the second tensor in the tuple is ``labels``
+ with shape (n,)
+ """
+ # forward of this head requires img_metas
+ outs = self.forward(feats, img_metas)
+ results_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
+ return results_list
+
+ def forward_onnx(self, feats, img_metas):
+ """Forward function for exporting to ONNX.
+
+ Over-write `forward` because: `masks` is directly created with
+ zero (valid position tag) and has the same spatial size as `x`.
+ Thus the construction of `masks` is different from that in `forward`.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
+
+ - all_cls_scores_list (list[Tensor]): Classification scores \
+ for each scale level. Each is a 4D-tensor with shape \
+ [nb_dec, bs, num_query, cls_out_channels]. Note \
+ `cls_out_channels` should includes background.
+ - all_bbox_preds_list (list[Tensor]): Sigmoid regression \
+ outputs for each scale level. Each is a 4D-tensor with \
+ normalized coordinate format (cx, cy, w, h) and shape \
+ [nb_dec, bs, num_query, 4].
+ """
+ num_levels = len(feats)
+ img_metas_list = [img_metas for _ in range(num_levels)]
+ return multi_apply(self.forward_single_onnx, feats, img_metas_list)
+
+ def forward_single_onnx(self, x, img_metas):
+ """"Forward function for a single feature level with ONNX exportation.
+
+ Args:
+ x (Tensor): Input feature from backbone's single stage, shape
+ [bs, c, h, w].
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ all_cls_scores (Tensor): Outputs from the classification head,
+ shape [nb_dec, bs, num_query, cls_out_channels]. Note
+ cls_out_channels should includes background.
+ all_bbox_preds (Tensor): Sigmoid outputs from the regression
+ head with normalized coordinate format (cx, cy, w, h).
+ Shape [nb_dec, bs, num_query, 4].
+ """
+ # Note `img_shape` is not dynamically traceable to ONNX,
+ # since the related augmentation was done with numpy under
+ # CPU. Thus `masks` is directly created with zeros (valid tag)
+ # and the same spatial shape as `x`.
+ # The difference between torch and exported ONNX model may be
+ # ignored, since the same performance is achieved (e.g.
+ # 40.1 vs 40.1 for DETR)
+ batch_size = x.size(0)
+ h, w = x.size()[-2:]
+ masks = x.new_zeros((batch_size, h, w)) # [B,h,w]
+
+ x = self.input_proj(x)
+ # interpolate masks to have the same spatial shape with x
+ masks = F.interpolate(
+ masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
+ pos_embed = self.positional_encoding(masks)
+ outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
+ pos_embed)
+
+ all_cls_scores = self.fc_cls(outs_dec)
+ all_bbox_preds = self.fc_reg(self.activate(
+ self.reg_ffn(outs_dec))).sigmoid()
+ return all_cls_scores, all_bbox_preds
+
+ def onnx_export(self, all_cls_scores_list, all_bbox_preds_list, img_metas):
+ """Transform network outputs into bbox predictions, with ONNX
+ exportation.
+
+ Args:
+ all_cls_scores_list (list[Tensor]): Classification outputs
+ for each feature level. Each is a 4D-tensor with shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds_list (list[Tensor]): Sigmoid regression
+ outputs for each feature level. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ img_metas (list[dict]): Meta information of each image.
+
+ Returns:
+ tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
+ and class labels of shape [N, num_det].
+ """
+ assert len(img_metas) == 1, \
+ 'Only support one input image while in exporting to ONNX'
+
+ cls_scores = all_cls_scores_list[-1][-1]
+ bbox_preds = all_bbox_preds_list[-1][-1]
+
+ # Note `img_shape` is not dynamically traceable to ONNX,
+ # here `img_shape_for_onnx` (padded shape of image tensor)
+ # is used.
+ img_shape = img_metas[0]['img_shape_for_onnx']
+ max_per_img = self.test_cfg.get('max_per_img', self.num_query)
+ batch_size = cls_scores.size(0)
+ # `batch_index_offset` is used for the gather of concatenated tensor
+ batch_index_offset = torch.arange(batch_size).to(
+ cls_scores.device) * max_per_img
+ batch_index_offset = batch_index_offset.unsqueeze(1).expand(
+ batch_size, max_per_img)
+
+ # supports dynamical batch inference
+ if self.loss_cls.use_sigmoid:
+ cls_scores = cls_scores.sigmoid()
+ scores, indexes = cls_scores.view(batch_size, -1).topk(
+ max_per_img, dim=1)
+ det_labels = indexes % self.num_classes
+ bbox_index = indexes // self.num_classes
+ bbox_index = (bbox_index + batch_index_offset).view(-1)
+ bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
+ bbox_preds = bbox_preds.view(batch_size, -1, 4)
+ else:
+ scores, det_labels = F.softmax(
+ cls_scores, dim=-1)[..., :-1].max(-1)
+ scores, bbox_index = scores.topk(max_per_img, dim=1)
+ bbox_index = (bbox_index + batch_index_offset).view(-1)
+ bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
+ det_labels = det_labels.view(-1)[bbox_index]
+ bbox_preds = bbox_preds.view(batch_size, -1, 4)
+ det_labels = det_labels.view(batch_size, -1)
+
+ det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds)
+ # use `img_shape_tensor` for dynamically exporting to ONNX
+ img_shape_tensor = img_shape.flip(0).repeat(2) # [w,h,w,h]
+ img_shape_tensor = img_shape_tensor.unsqueeze(0).unsqueeze(0).expand(
+ batch_size, det_bboxes.size(1), 4)
+ det_bboxes = det_bboxes * img_shape_tensor
+ # dynamically clip bboxes
+ x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1)
+ from mmdet.core.export import dynamic_clip_for_onnx
+ x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, img_shape)
+ det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1)
+ det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1)
+
+ return det_bboxes, det_labels
diff --git a/mmdet/models/dense_heads/embedding_rpn_head.py b/mmdet/models/dense_heads/embedding_rpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..22060b964846298cae5a4625a0ffc32d9a139657
--- /dev/null
+++ b/mmdet/models/dense_heads/embedding_rpn_head.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.runner import BaseModule
+
+from mmdet.models.builder import HEADS
+from ...core import bbox_cxcywh_to_xyxy
+
+
+@HEADS.register_module()
+class EmbeddingRPNHead(BaseModule):
+ """RPNHead in the `Sparse R-CNN `_ .
+
+ Unlike traditional RPNHead, this module does not need FPN input, but just
+ decode `init_proposal_bboxes` and expand the first dimension of
+ `init_proposal_bboxes` and `init_proposal_features` to the batch_size.
+
+ Args:
+ num_proposals (int): Number of init_proposals. Default 100.
+ proposal_feature_channel (int): Channel number of
+ init_proposal_feature. Defaults to 256.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ num_proposals=100,
+ proposal_feature_channel=256,
+ init_cfg=None,
+ **kwargs):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ super(EmbeddingRPNHead, self).__init__(init_cfg)
+ self.num_proposals = num_proposals
+ self.proposal_feature_channel = proposal_feature_channel
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize a sparse set of proposal boxes and proposal features."""
+ self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4)
+ self.init_proposal_features = nn.Embedding(
+ self.num_proposals, self.proposal_feature_channel)
+
+ def init_weights(self):
+ """Initialize the init_proposal_bboxes as normalized.
+
+ [c_x, c_y, w, h], and we initialize it to the size of the entire
+ image.
+ """
+ super(EmbeddingRPNHead, self).init_weights()
+ nn.init.constant_(self.init_proposal_bboxes.weight[:, :2], 0.5)
+ nn.init.constant_(self.init_proposal_bboxes.weight[:, 2:], 1)
+
+ def _decode_init_proposals(self, imgs, img_metas):
+ """Decode init_proposal_bboxes according to the size of images and
+ expand dimension of init_proposal_features to batch_size.
+
+ Args:
+ imgs (list[Tensor]): List of FPN features.
+ img_metas (list[dict]): List of meta-information of
+ images. Need the img_shape to decode the init_proposals.
+
+ Returns:
+ Tuple(Tensor):
+
+ - proposals (Tensor): Decoded proposal bboxes,
+ has shape (batch_size, num_proposals, 4).
+ - init_proposal_features (Tensor): Expanded proposal
+ features, has shape
+ (batch_size, num_proposals, proposal_feature_channel).
+ - imgs_whwh (Tensor): Tensor with shape
+ (batch_size, 4), the dimension means
+ [img_width, img_height, img_width, img_height].
+ """
+ proposals = self.init_proposal_bboxes.weight.clone()
+ proposals = bbox_cxcywh_to_xyxy(proposals)
+ num_imgs = len(imgs[0])
+ imgs_whwh = []
+ for meta in img_metas:
+ h, w, _ = meta['img_shape']
+ imgs_whwh.append(imgs[0].new_tensor([[w, h, w, h]]))
+ imgs_whwh = torch.cat(imgs_whwh, dim=0)
+ imgs_whwh = imgs_whwh[:, None, :]
+
+ # imgs_whwh has shape (batch_size, 1, 4)
+ # The shape of proposals change from (num_proposals, 4)
+ # to (batch_size ,num_proposals, 4)
+ proposals = proposals * imgs_whwh
+
+ init_proposal_features = self.init_proposal_features.weight.clone()
+ init_proposal_features = init_proposal_features[None].expand(
+ num_imgs, *init_proposal_features.size())
+ return proposals, init_proposal_features, imgs_whwh
+
+ def forward_dummy(self, img, img_metas):
+ """Dummy forward function.
+
+ Used in flops calculation.
+ """
+ return self._decode_init_proposals(img, img_metas)
+
+ def forward_train(self, img, img_metas):
+ """Forward function in training stage."""
+ return self._decode_init_proposals(img, img_metas)
+
+ def simple_test_rpn(self, img, img_metas):
+ """Forward function in testing stage."""
+ return self._decode_init_proposals(img, img_metas)
+
+ def simple_test(self, img, img_metas):
+ """Forward function in testing stage."""
+ raise NotImplementedError
+
+ def aug_test_rpn(self, feats, img_metas):
+ raise NotImplementedError(
+ 'EmbeddingRPNHead does not support test-time augmentation')
diff --git a/mmdet/models/dense_heads/fcos_head.py b/mmdet/models/dense_heads/fcos_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d72fb56caa1599414d67c32445a6f6def44fefdf
--- /dev/null
+++ b/mmdet/models/dense_heads/fcos_head.py
@@ -0,0 +1,455 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import Scale
+from mmcv.runner import force_fp32
+
+from mmdet.core import multi_apply, reduce_mean
+from ..builder import HEADS, build_loss
+from .anchor_free_head import AnchorFreeHead
+
+INF = 1e8
+
+
+@HEADS.register_module()
+class FCOSHead(AnchorFreeHead):
+ """Anchor-free head used in `FCOS `_.
+
+ The FCOS head does not use anchor boxes. Instead bounding boxes are
+ predicted at each pixel and a centerness measure is used to suppress
+ low-quality predictions.
+ Here norm_on_bbox, centerness_on_reg, dcn_on_last_conv are training
+ tricks used in official repo, which will bring remarkable mAP gains
+ of up to 4.9. Please see https://github.com/tianzhi0549/FCOS for
+ more detail.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ strides (list[int] | list[tuple[int, int]]): Strides of points
+ in multiple feature levels. Default: (4, 8, 16, 32, 64).
+ regress_ranges (tuple[tuple[int, int]]): Regress range of multiple
+ level points.
+ center_sampling (bool): If true, use center sampling. Default: False.
+ center_sample_radius (float): Radius of center sampling. Default: 1.5.
+ norm_on_bbox (bool): If true, normalize the regression targets
+ with FPN strides. Default: False.
+ centerness_on_reg (bool): If true, position centerness on the
+ regress branch. Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042.
+ Default: False.
+ conv_bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias of conv will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ loss_centerness (dict): Config of centerness loss.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+
+ Example:
+ >>> self = FCOSHead(11, 7)
+ >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
+ >>> cls_score, bbox_pred, centerness = self.forward(feats)
+ >>> assert len(cls_score) == len(self.scales)
+ """ # noqa: E501
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512),
+ (512, INF)),
+ center_sampling=False,
+ center_sample_radius=1.5,
+ norm_on_bbox=False,
+ centerness_on_reg=False,
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox=dict(type='IoULoss', loss_weight=1.0),
+ loss_centerness=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='conv_cls',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ self.regress_ranges = regress_ranges
+ self.center_sampling = center_sampling
+ self.center_sample_radius = center_sample_radius
+ self.norm_on_bbox = norm_on_bbox
+ self.centerness_on_reg = centerness_on_reg
+ super().__init__(
+ num_classes,
+ in_channels,
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ norm_cfg=norm_cfg,
+ init_cfg=init_cfg,
+ **kwargs)
+ self.loss_centerness = build_loss(loss_centerness)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ super()._init_layers()
+ self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
+ self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple:
+ cls_scores (list[Tensor]): Box scores for each scale level, \
+ each is a 4D-tensor, the channel number is \
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each \
+ scale level, each is a 4D-tensor, the channel number is \
+ num_points * 4.
+ centernesses (list[Tensor]): centerness for each scale level, \
+ each is a 4D-tensor, the channel number is num_points * 1.
+ """
+ return multi_apply(self.forward_single, feats, self.scales,
+ self.strides)
+
+ def forward_single(self, x, scale, stride):
+ """Forward features of a single scale level.
+
+ Args:
+ x (Tensor): FPN feature maps of the specified stride.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+ stride (int): The corresponding stride for feature maps, only
+ used to normalize the bbox prediction when self.norm_on_bbox
+ is True.
+
+ Returns:
+ tuple: scores for each class, bbox predictions and centerness \
+ predictions of input feature maps.
+ """
+ cls_score, bbox_pred, cls_feat, reg_feat = super().forward_single(x)
+ if self.centerness_on_reg:
+ centerness = self.conv_centerness(reg_feat)
+ else:
+ centerness = self.conv_centerness(cls_feat)
+ # scale the bbox_pred of different level
+ # float to avoid overflow when enabling FP16
+ bbox_pred = scale(bbox_pred).float()
+ if self.norm_on_bbox:
+ # bbox_pred needed for gradient computation has been modified
+ # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace
+ # F.relu(bbox_pred) with bbox_pred.clamp(min=0)
+ bbox_pred = bbox_pred.clamp(min=0)
+ if not self.training:
+ bbox_pred *= stride
+ else:
+ bbox_pred = bbox_pred.exp()
+ return cls_score, bbox_pred, centerness
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level,
+ each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ centernesses (list[Tensor]): centerness for each scale level, each
+ is a 4D-tensor, the channel number is num_points * 1.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert len(cls_scores) == len(bbox_preds) == len(centernesses)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ all_level_points = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=bbox_preds[0].dtype,
+ device=bbox_preds[0].device)
+ labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
+ gt_labels)
+
+ num_imgs = cls_scores[0].size(0)
+ # flatten cls_scores, bbox_preds and centerness
+ flatten_cls_scores = [
+ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
+ for cls_score in cls_scores
+ ]
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ flatten_centerness = [
+ centerness.permute(0, 2, 3, 1).reshape(-1)
+ for centerness in centernesses
+ ]
+ flatten_cls_scores = torch.cat(flatten_cls_scores)
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds)
+ flatten_centerness = torch.cat(flatten_centerness)
+ flatten_labels = torch.cat(labels)
+ flatten_bbox_targets = torch.cat(bbox_targets)
+ # repeat points to align with bbox_preds
+ flatten_points = torch.cat(
+ [points.repeat(num_imgs, 1) for points in all_level_points])
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((flatten_labels >= 0)
+ & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
+ num_pos = torch.tensor(
+ len(pos_inds), dtype=torch.float, device=bbox_preds[0].device)
+ num_pos = max(reduce_mean(num_pos), 1.0)
+ loss_cls = self.loss_cls(
+ flatten_cls_scores, flatten_labels, avg_factor=num_pos)
+
+ pos_bbox_preds = flatten_bbox_preds[pos_inds]
+ pos_centerness = flatten_centerness[pos_inds]
+ pos_bbox_targets = flatten_bbox_targets[pos_inds]
+ pos_centerness_targets = self.centerness_target(pos_bbox_targets)
+ # centerness weighted iou loss
+ centerness_denorm = max(
+ reduce_mean(pos_centerness_targets.sum().detach()), 1e-6)
+
+ if len(pos_inds) > 0:
+ pos_points = flatten_points[pos_inds]
+ pos_decoded_bbox_preds = self.bbox_coder.decode(
+ pos_points, pos_bbox_preds)
+ pos_decoded_target_preds = self.bbox_coder.decode(
+ pos_points, pos_bbox_targets)
+ loss_bbox = self.loss_bbox(
+ pos_decoded_bbox_preds,
+ pos_decoded_target_preds,
+ weight=pos_centerness_targets,
+ avg_factor=centerness_denorm)
+ loss_centerness = self.loss_centerness(
+ pos_centerness, pos_centerness_targets, avg_factor=num_pos)
+ else:
+ loss_bbox = pos_bbox_preds.sum()
+ loss_centerness = pos_centerness.sum()
+
+ return dict(
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ loss_centerness=loss_centerness)
+
+ def get_targets(self, points, gt_bboxes_list, gt_labels_list):
+ """Compute regression, classification and centerness targets for points
+ in multiple images.
+
+ Args:
+ points (list[Tensor]): Points of each fpn level, each has shape
+ (num_points, 2).
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels_list (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+
+ Returns:
+ tuple:
+ concat_lvl_labels (list[Tensor]): Labels of each level. \
+ concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \
+ level.
+ """
+ assert len(points) == len(self.regress_ranges)
+ num_levels = len(points)
+ # expand regress ranges to align with points
+ expanded_regress_ranges = [
+ points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
+ points[i]) for i in range(num_levels)
+ ]
+ # concat all levels points and regress ranges
+ concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
+ concat_points = torch.cat(points, dim=0)
+
+ # the number of points per img, per lvl
+ num_points = [center.size(0) for center in points]
+
+ # get labels and bbox_targets of each image
+ labels_list, bbox_targets_list = multi_apply(
+ self._get_target_single,
+ gt_bboxes_list,
+ gt_labels_list,
+ points=concat_points,
+ regress_ranges=concat_regress_ranges,
+ num_points_per_lvl=num_points)
+
+ # split to per img, per level
+ labels_list = [labels.split(num_points, 0) for labels in labels_list]
+ bbox_targets_list = [
+ bbox_targets.split(num_points, 0)
+ for bbox_targets in bbox_targets_list
+ ]
+
+ # concat per level image
+ concat_lvl_labels = []
+ concat_lvl_bbox_targets = []
+ for i in range(num_levels):
+ concat_lvl_labels.append(
+ torch.cat([labels[i] for labels in labels_list]))
+ bbox_targets = torch.cat(
+ [bbox_targets[i] for bbox_targets in bbox_targets_list])
+ if self.norm_on_bbox:
+ bbox_targets = bbox_targets / self.strides[i]
+ concat_lvl_bbox_targets.append(bbox_targets)
+ return concat_lvl_labels, concat_lvl_bbox_targets
+
+ def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges,
+ num_points_per_lvl):
+ """Compute regression and classification targets for a single image."""
+ num_points = points.size(0)
+ num_gts = gt_labels.size(0)
+ if num_gts == 0:
+ return gt_labels.new_full((num_points,), self.num_classes), \
+ gt_bboxes.new_zeros((num_points, 4))
+
+ areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
+ gt_bboxes[:, 3] - gt_bboxes[:, 1])
+ # TODO: figure out why these two are different
+ # areas = areas[None].expand(num_points, num_gts)
+ areas = areas[None].repeat(num_points, 1)
+ regress_ranges = regress_ranges[:, None, :].expand(
+ num_points, num_gts, 2)
+ gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
+ xs, ys = points[:, 0], points[:, 1]
+ xs = xs[:, None].expand(num_points, num_gts)
+ ys = ys[:, None].expand(num_points, num_gts)
+
+ left = xs - gt_bboxes[..., 0]
+ right = gt_bboxes[..., 2] - xs
+ top = ys - gt_bboxes[..., 1]
+ bottom = gt_bboxes[..., 3] - ys
+ bbox_targets = torch.stack((left, top, right, bottom), -1)
+
+ if self.center_sampling:
+ # condition1: inside a `center bbox`
+ radius = self.center_sample_radius
+ center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
+ center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
+ center_gts = torch.zeros_like(gt_bboxes)
+ stride = center_xs.new_zeros(center_xs.shape)
+
+ # project the points on current lvl back to the `original` sizes
+ lvl_begin = 0
+ for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
+ lvl_end = lvl_begin + num_points_lvl
+ stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
+ lvl_begin = lvl_end
+
+ x_mins = center_xs - stride
+ y_mins = center_ys - stride
+ x_maxs = center_xs + stride
+ y_maxs = center_ys + stride
+ center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
+ x_mins, gt_bboxes[..., 0])
+ center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
+ y_mins, gt_bboxes[..., 1])
+ center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
+ gt_bboxes[..., 2], x_maxs)
+ center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
+ gt_bboxes[..., 3], y_maxs)
+
+ cb_dist_left = xs - center_gts[..., 0]
+ cb_dist_right = center_gts[..., 2] - xs
+ cb_dist_top = ys - center_gts[..., 1]
+ cb_dist_bottom = center_gts[..., 3] - ys
+ center_bbox = torch.stack(
+ (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
+ inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
+ else:
+ # condition1: inside a gt bbox
+ inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
+
+ # condition2: limit the regression range for each location
+ max_regress_distance = bbox_targets.max(-1)[0]
+ inside_regress_range = (
+ (max_regress_distance >= regress_ranges[..., 0])
+ & (max_regress_distance <= regress_ranges[..., 1]))
+
+ # if there are still more than one objects for a location,
+ # we choose the one with minimal area
+ areas[inside_gt_bbox_mask == 0] = INF
+ areas[inside_regress_range == 0] = INF
+ min_area, min_area_inds = areas.min(dim=1)
+
+ labels = gt_labels[min_area_inds]
+ labels[min_area == INF] = self.num_classes # set as BG
+ bbox_targets = bbox_targets[range(num_points), min_area_inds]
+
+ return labels, bbox_targets
+
+ def centerness_target(self, pos_bbox_targets):
+ """Compute centerness targets.
+
+ Args:
+ pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape
+ (num_pos, 4)
+
+ Returns:
+ Tensor: Centerness target.
+ """
+ # only calculate pos centerness targets, otherwise there may be nan
+ left_right = pos_bbox_targets[:, [0, 2]]
+ top_bottom = pos_bbox_targets[:, [1, 3]]
+ if len(left_right) == 0:
+ centerness_targets = left_right[..., 0]
+ else:
+ centerness_targets = (
+ left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
+ top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
+ return torch.sqrt(centerness_targets)
+
+ def _get_points_single(self,
+ featmap_size,
+ stride,
+ dtype,
+ device,
+ flatten=False):
+ """Get points according to feature map size.
+
+ This function will be deprecated soon.
+ """
+ warnings.warn(
+ '`_get_points_single` in `FCOSHead` will be '
+ 'deprecated soon, we support a multi level point generator now'
+ 'you can get points of a single level feature map '
+ 'with `self.prior_generator.single_level_grid_priors` ')
+
+ y, x = super()._get_points_single(featmap_size, stride, dtype, device)
+ points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
+ dim=-1) + stride // 2
+ return points
diff --git a/mmdet/models/dense_heads/fovea_head.py b/mmdet/models/dense_heads/fovea_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8be7fc94c767005da5d31d201dcc55fb760b5c53
--- /dev/null
+++ b/mmdet/models/dense_heads/fovea_head.py
@@ -0,0 +1,385 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.ops import DeformConv2d
+from mmcv.runner import BaseModule
+
+from mmdet.core import multi_apply
+from mmdet.core.utils import filter_scores_and_topk
+from ..builder import HEADS
+from .anchor_free_head import AnchorFreeHead
+
+INF = 1e8
+
+
+class FeatureAlign(BaseModule):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ deform_groups=4,
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.1,
+ override=dict(
+ type='Normal', name='conv_adaption', std=0.01))):
+ super(FeatureAlign, self).__init__(init_cfg)
+ offset_channels = kernel_size * kernel_size * 2
+ self.conv_offset = nn.Conv2d(
+ 4, deform_groups * offset_channels, 1, bias=False)
+ self.conv_adaption = DeformConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ deform_groups=deform_groups)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x, shape):
+ offset = self.conv_offset(shape)
+ x = self.relu(self.conv_adaption(x, offset))
+ return x
+
+
+@HEADS.register_module()
+class FoveaHead(AnchorFreeHead):
+ """FoveaBox: Beyond Anchor-based Object Detector
+ https://arxiv.org/abs/1904.03797
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ base_edge_list=(16, 32, 64, 128, 256),
+ scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128,
+ 512)),
+ sigma=0.4,
+ with_deform=False,
+ deform_groups=4,
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='conv_cls',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ self.base_edge_list = base_edge_list
+ self.scale_ranges = scale_ranges
+ self.sigma = sigma
+ self.with_deform = with_deform
+ self.deform_groups = deform_groups
+ super().__init__(num_classes, in_channels, init_cfg=init_cfg, **kwargs)
+
+ def _init_layers(self):
+ # box branch
+ super()._init_reg_convs()
+ self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+
+ # cls branch
+ if not self.with_deform:
+ super()._init_cls_convs()
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ else:
+ self.cls_convs = nn.ModuleList()
+ self.cls_convs.append(
+ ConvModule(
+ self.feat_channels, (self.feat_channels * 4),
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.norm_cfg is None))
+ self.cls_convs.append(
+ ConvModule((self.feat_channels * 4), (self.feat_channels * 4),
+ 1,
+ stride=1,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.norm_cfg is None))
+ self.feature_adaption = FeatureAlign(
+ self.feat_channels,
+ self.feat_channels,
+ kernel_size=3,
+ deform_groups=self.deform_groups)
+ self.conv_cls = nn.Conv2d(
+ int(self.feat_channels * 4),
+ self.cls_out_channels,
+ 3,
+ padding=1)
+
+ def forward_single(self, x):
+ cls_feat = x
+ reg_feat = x
+ for reg_layer in self.reg_convs:
+ reg_feat = reg_layer(reg_feat)
+ bbox_pred = self.conv_reg(reg_feat)
+ if self.with_deform:
+ cls_feat = self.feature_adaption(cls_feat, bbox_pred.exp())
+ for cls_layer in self.cls_convs:
+ cls_feat = cls_layer(cls_feat)
+ cls_score = self.conv_cls(cls_feat)
+ return cls_score, bbox_pred
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bbox_list,
+ gt_label_list,
+ img_metas,
+ gt_bboxes_ignore=None):
+ assert len(cls_scores) == len(bbox_preds)
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ points = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=bbox_preds[0].dtype,
+ device=bbox_preds[0].device)
+ num_imgs = cls_scores[0].size(0)
+ flatten_cls_scores = [
+ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
+ for cls_score in cls_scores
+ ]
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ flatten_cls_scores = torch.cat(flatten_cls_scores)
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds)
+ flatten_labels, flatten_bbox_targets = self.get_targets(
+ gt_bbox_list, gt_label_list, featmap_sizes, points)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ pos_inds = ((flatten_labels >= 0)
+ & (flatten_labels < self.num_classes)).nonzero().view(-1)
+ num_pos = len(pos_inds)
+
+ loss_cls = self.loss_cls(
+ flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs)
+ if num_pos > 0:
+ pos_bbox_preds = flatten_bbox_preds[pos_inds]
+ pos_bbox_targets = flatten_bbox_targets[pos_inds]
+ pos_weights = pos_bbox_targets.new_zeros(
+ pos_bbox_targets.size()) + 1.0
+ loss_bbox = self.loss_bbox(
+ pos_bbox_preds,
+ pos_bbox_targets,
+ pos_weights,
+ avg_factor=num_pos)
+ else:
+ loss_bbox = torch.tensor(
+ 0,
+ dtype=flatten_bbox_preds.dtype,
+ device=flatten_bbox_preds.device)
+ return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
+
+ def get_targets(self, gt_bbox_list, gt_label_list, featmap_sizes, points):
+ label_list, bbox_target_list = multi_apply(
+ self._get_target_single,
+ gt_bbox_list,
+ gt_label_list,
+ featmap_size_list=featmap_sizes,
+ point_list=points)
+ flatten_labels = [
+ torch.cat([
+ labels_level_img.flatten() for labels_level_img in labels_level
+ ]) for labels_level in zip(*label_list)
+ ]
+ flatten_bbox_targets = [
+ torch.cat([
+ bbox_targets_level_img.reshape(-1, 4)
+ for bbox_targets_level_img in bbox_targets_level
+ ]) for bbox_targets_level in zip(*bbox_target_list)
+ ]
+ flatten_labels = torch.cat(flatten_labels)
+ flatten_bbox_targets = torch.cat(flatten_bbox_targets)
+ return flatten_labels, flatten_bbox_targets
+
+ def _get_target_single(self,
+ gt_bboxes_raw,
+ gt_labels_raw,
+ featmap_size_list=None,
+ point_list=None):
+
+ gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
+ (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
+ label_list = []
+ bbox_target_list = []
+ # for each pyramid, find the cls and box target
+ for base_len, (lower_bound, upper_bound), stride, featmap_size, \
+ points in zip(self.base_edge_list, self.scale_ranges,
+ self.strides, featmap_size_list, point_list):
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ points = points.view(*featmap_size, 2)
+ x, y = points[..., 0], points[..., 1]
+ labels = gt_labels_raw.new_zeros(featmap_size) + self.num_classes
+ bbox_targets = gt_bboxes_raw.new(featmap_size[0], featmap_size[1],
+ 4) + 1
+ # scale assignment
+ hit_indices = ((gt_areas >= lower_bound) &
+ (gt_areas <= upper_bound)).nonzero().flatten()
+ if len(hit_indices) == 0:
+ label_list.append(labels)
+ bbox_target_list.append(torch.log(bbox_targets))
+ continue
+ _, hit_index_order = torch.sort(-gt_areas[hit_indices])
+ hit_indices = hit_indices[hit_index_order]
+ gt_bboxes = gt_bboxes_raw[hit_indices, :] / stride
+ gt_labels = gt_labels_raw[hit_indices]
+ half_w = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0])
+ half_h = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1])
+ # valid fovea area: left, right, top, down
+ pos_left = torch.ceil(
+ gt_bboxes[:, 0] + (1 - self.sigma) * half_w - 0.5).long(). \
+ clamp(0, featmap_size[1] - 1)
+ pos_right = torch.floor(
+ gt_bboxes[:, 0] + (1 + self.sigma) * half_w - 0.5).long(). \
+ clamp(0, featmap_size[1] - 1)
+ pos_top = torch.ceil(
+ gt_bboxes[:, 1] + (1 - self.sigma) * half_h - 0.5).long(). \
+ clamp(0, featmap_size[0] - 1)
+ pos_down = torch.floor(
+ gt_bboxes[:, 1] + (1 + self.sigma) * half_h - 0.5).long(). \
+ clamp(0, featmap_size[0] - 1)
+ for px1, py1, px2, py2, label, (gt_x1, gt_y1, gt_x2, gt_y2) in \
+ zip(pos_left, pos_top, pos_right, pos_down, gt_labels,
+ gt_bboxes_raw[hit_indices, :]):
+ labels[py1:py2 + 1, px1:px2 + 1] = label
+ bbox_targets[py1:py2 + 1, px1:px2 + 1, 0] = \
+ (x[py1:py2 + 1, px1:px2 + 1] - gt_x1) / base_len
+ bbox_targets[py1:py2 + 1, px1:px2 + 1, 1] = \
+ (y[py1:py2 + 1, px1:px2 + 1] - gt_y1) / base_len
+ bbox_targets[py1:py2 + 1, px1:px2 + 1, 2] = \
+ (gt_x2 - x[py1:py2 + 1, px1:px2 + 1]) / base_len
+ bbox_targets[py1:py2 + 1, px1:px2 + 1, 3] = \
+ (gt_y2 - y[py1:py2 + 1, px1:px2 + 1]) / base_len
+ bbox_targets = bbox_targets.clamp(min=1. / 16, max=16.)
+ label_list.append(labels)
+ bbox_target_list.append(torch.log(bbox_targets))
+ return label_list, bbox_target_list
+
+ # Same as base_dense_head/_get_bboxes_single except self._bbox_decode
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ score_factor_list,
+ mlvl_priors,
+ img_meta,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform outputs of a single image into bbox predictions.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_priors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has shape
+ (num_priors * 4, H, W).
+ score_factor_list (list[Tensor]): Score factor from all scale
+ levels of a single image. Fovea head does not need this value.
+ mlvl_priors (list[Tensor]): Each element in the list is
+ the priors of a single level in feature pyramid, has shape
+ (num_priors, 2).
+ img_meta (dict): Image meta info.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_score_list) == len(bbox_pred_list)
+ img_shape = img_meta['img_shape']
+ nms_pre = cfg.get('nms_pre', -1)
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_labels = []
+ for level_idx, (cls_score, bbox_pred, stride, base_len, priors) in \
+ enumerate(zip(cls_score_list, bbox_pred_list, self.strides,
+ self.base_edge_list, mlvl_priors)):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+
+ scores = cls_score.permute(1, 2, 0).reshape(
+ -1, self.cls_out_channels).sigmoid()
+
+ # After https://github.com/open-mmlab/mmdetection/pull/6268/,
+ # this operation keeps fewer bboxes under the same `nms_pre`.
+ # There is no difference in performance for most models. If you
+ # find a slight drop in performance, you can set a larger
+ # `nms_pre` than before.
+ results = filter_scores_and_topk(
+ scores, cfg.score_thr, nms_pre,
+ dict(bbox_pred=bbox_pred, priors=priors))
+ scores, labels, _, filtered_results = results
+
+ bbox_pred = filtered_results['bbox_pred']
+ priors = filtered_results['priors']
+
+ bboxes = self._bbox_decode(priors, bbox_pred, base_len, img_shape)
+
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_labels.append(labels)
+
+ return self._bbox_post_process(mlvl_scores, mlvl_labels, mlvl_bboxes,
+ img_meta['scale_factor'], cfg, rescale,
+ with_nms)
+
+ def _bbox_decode(self, priors, bbox_pred, base_len, max_shape):
+ bbox_pred = bbox_pred.exp()
+
+ y = priors[:, 1]
+ x = priors[:, 0]
+ x1 = (x - base_len * bbox_pred[:, 0]). \
+ clamp(min=0, max=max_shape[1] - 1)
+ y1 = (y - base_len * bbox_pred[:, 1]). \
+ clamp(min=0, max=max_shape[0] - 1)
+ x2 = (x + base_len * bbox_pred[:, 2]). \
+ clamp(min=0, max=max_shape[1] - 1)
+ y2 = (y + base_len * bbox_pred[:, 3]). \
+ clamp(min=0, max=max_shape[0] - 1)
+ decoded_bboxes = torch.stack([x1, y1, x2, y2], -1)
+ return decoded_bboxes
+
+ def _get_points_single(self, *args, **kwargs):
+ """Get points according to feature map size.
+
+ This function will be deprecated soon.
+ """
+ warnings.warn(
+ '`_get_points_single` in `FoveaHead` will be '
+ 'deprecated soon, we support a multi level point generator now'
+ 'you can get points of a single level feature map '
+ 'with `self.prior_generator.single_level_grid_priors` ')
+ y, x = super()._get_points_single(*args, **kwargs)
+ return y + 0.5, x + 0.5
diff --git a/mmdet/models/dense_heads/free_anchor_retina_head.py b/mmdet/models/dense_heads/free_anchor_retina_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3acd25ecba414b691b2b00a6bc30faa580dadebc
--- /dev/null
+++ b/mmdet/models/dense_heads/free_anchor_retina_head.py
@@ -0,0 +1,272 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn.functional as F
+
+from mmdet.core import bbox_overlaps
+from ..builder import HEADS
+from .retina_head import RetinaHead
+
+EPS = 1e-12
+
+
+@HEADS.register_module()
+class FreeAnchorRetinaHead(RetinaHead):
+ """FreeAnchor RetinaHead used in https://arxiv.org/abs/1909.02466.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): Number of conv layers in cls and reg tower.
+ Default: 4.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: norm_cfg=dict(type='GN', num_groups=32,
+ requires_grad=True).
+ pre_anchor_topk (int): Number of boxes that be token in each bag.
+ bbox_thr (float): The threshold of the saturated linear function. It is
+ usually the same with the IoU threshold used in NMS.
+ gamma (float): Gamma parameter in focal loss.
+ alpha (float): Alpha parameter in focal loss.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ pre_anchor_topk=50,
+ bbox_thr=0.6,
+ gamma=2.0,
+ alpha=0.5,
+ **kwargs):
+ super(FreeAnchorRetinaHead,
+ self).__init__(num_classes, in_channels, stacked_convs, conv_cfg,
+ norm_cfg, **kwargs)
+
+ self.pre_anchor_topk = pre_anchor_topk
+ self.bbox_thr = bbox_thr
+ self.gamma = gamma
+ self.alpha = alpha
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+ device = cls_scores[0].device
+ anchor_list, _ = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ anchors = [torch.cat(anchor) for anchor in anchor_list]
+
+ # concatenate each level
+ cls_scores = [
+ cls.permute(0, 2, 3,
+ 1).reshape(cls.size(0), -1, self.cls_out_channels)
+ for cls in cls_scores
+ ]
+ bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ cls_scores = torch.cat(cls_scores, dim=1)
+ bbox_preds = torch.cat(bbox_preds, dim=1)
+
+ cls_prob = torch.sigmoid(cls_scores)
+ box_prob = []
+ num_pos = 0
+ positive_losses = []
+ for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_,
+ bbox_preds_) in enumerate(
+ zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds)):
+
+ with torch.no_grad():
+ if len(gt_bboxes_) == 0:
+ image_box_prob = torch.zeros(
+ anchors_.size(0),
+ self.cls_out_channels).type_as(bbox_preds_)
+ else:
+ # box_localization: a_{j}^{loc}, shape: [j, 4]
+ pred_boxes = self.bbox_coder.decode(anchors_, bbox_preds_)
+
+ # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
+ object_box_iou = bbox_overlaps(gt_bboxes_, pred_boxes)
+
+ # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
+ t1 = self.bbox_thr
+ t2 = object_box_iou.max(
+ dim=1, keepdim=True).values.clamp(min=t1 + 1e-12)
+ object_box_prob = ((object_box_iou - t1) /
+ (t2 - t1)).clamp(
+ min=0, max=1)
+
+ # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
+ num_obj = gt_labels_.size(0)
+ indices = torch.stack([
+ torch.arange(num_obj).type_as(gt_labels_), gt_labels_
+ ],
+ dim=0)
+ object_cls_box_prob = torch.sparse_coo_tensor(
+ indices, object_box_prob)
+
+ # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
+ """
+ from "start" to "end" implement:
+ image_box_iou = torch.sparse.max(object_cls_box_prob,
+ dim=0).t()
+
+ """
+ # start
+ box_cls_prob = torch.sparse.sum(
+ object_cls_box_prob, dim=0).to_dense()
+
+ indices = torch.nonzero(box_cls_prob, as_tuple=False).t_()
+ if indices.numel() == 0:
+ image_box_prob = torch.zeros(
+ anchors_.size(0),
+ self.cls_out_channels).type_as(object_box_prob)
+ else:
+ nonzero_box_prob = torch.where(
+ (gt_labels_.unsqueeze(dim=-1) == indices[0]),
+ object_box_prob[:, indices[1]],
+ torch.tensor([
+ 0
+ ]).type_as(object_box_prob)).max(dim=0).values
+
+ # upmap to shape [j, c]
+ image_box_prob = torch.sparse_coo_tensor(
+ indices.flip([0]),
+ nonzero_box_prob,
+ size=(anchors_.size(0),
+ self.cls_out_channels)).to_dense()
+ # end
+
+ box_prob.append(image_box_prob)
+
+ # construct bags for objects
+ match_quality_matrix = bbox_overlaps(gt_bboxes_, anchors_)
+ _, matched = torch.topk(
+ match_quality_matrix,
+ self.pre_anchor_topk,
+ dim=1,
+ sorted=False)
+ del match_quality_matrix
+
+ # matched_cls_prob: P_{ij}^{cls}
+ matched_cls_prob = torch.gather(
+ cls_prob_[matched], 2,
+ gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
+ 1)).squeeze(2)
+
+ # matched_box_prob: P_{ij}^{loc}
+ matched_anchors = anchors_[matched]
+ matched_object_targets = self.bbox_coder.encode(
+ matched_anchors,
+ gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors))
+ loss_bbox = self.loss_bbox(
+ bbox_preds_[matched],
+ matched_object_targets,
+ reduction_override='none').sum(-1)
+ matched_box_prob = torch.exp(-loss_bbox)
+
+ # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
+ num_pos += len(gt_bboxes_)
+ positive_losses.append(
+ self.positive_bag_loss(matched_cls_prob, matched_box_prob))
+ positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)
+
+ # box_prob: P{a_{j} \in A_{+}}
+ box_prob = torch.stack(box_prob, dim=0)
+
+ # negative_loss:
+ # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
+ negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max(
+ 1, num_pos * self.pre_anchor_topk)
+
+ # avoid the absence of gradients in regression subnet
+ # when no ground-truth in a batch
+ if num_pos == 0:
+ positive_loss = bbox_preds.sum() * 0
+
+ losses = {
+ 'positive_bag_loss': positive_loss,
+ 'negative_bag_loss': negative_loss
+ }
+ return losses
+
+ def positive_bag_loss(self, matched_cls_prob, matched_box_prob):
+ """Compute positive bag loss.
+
+ :math:`-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )`.
+
+ :math:`P_{ij}^{cls}`: matched_cls_prob, classification probability of matched samples.
+
+ :math:`P_{ij}^{loc}`: matched_box_prob, box probability of matched samples.
+
+ Args:
+ matched_cls_prob (Tensor): Classification probability of matched
+ samples in shape (num_gt, pre_anchor_topk).
+ matched_box_prob (Tensor): BBox probability of matched samples,
+ in shape (num_gt, pre_anchor_topk).
+
+ Returns:
+ Tensor: Positive bag loss in shape (num_gt,).
+ """ # noqa: E501, W605
+ # bag_prob = Mean-max(matched_prob)
+ matched_prob = matched_cls_prob * matched_box_prob
+ weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None)
+ weight /= weight.sum(dim=1).unsqueeze(dim=-1)
+ bag_prob = (weight * matched_prob).sum(dim=1)
+ # positive_bag_loss = -self.alpha * log(bag_prob)
+ return self.alpha * F.binary_cross_entropy(
+ bag_prob, torch.ones_like(bag_prob), reduction='none')
+
+ def negative_bag_loss(self, cls_prob, box_prob):
+ """Compute negative bag loss.
+
+ :math:`FL((1 - P_{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}))`.
+
+ :math:`P_{a_{j} \in A_{+}}`: Box_probability of matched samples.
+
+ :math:`P_{j}^{bg}`: Classification probability of negative samples.
+
+ Args:
+ cls_prob (Tensor): Classification probability, in shape
+ (num_img, num_anchors, num_classes).
+ box_prob (Tensor): Box probability, in shape
+ (num_img, num_anchors, num_classes).
+
+ Returns:
+ Tensor: Negative bag loss in shape (num_img, num_anchors, num_classes).
+ """ # noqa: E501, W605
+ prob = cls_prob * (1 - box_prob)
+ # There are some cases when neg_prob = 0.
+ # This will cause the neg_prob.log() to be inf without clamp.
+ prob = prob.clamp(min=EPS, max=1 - EPS)
+ negative_bag_loss = prob**self.gamma * F.binary_cross_entropy(
+ prob, torch.zeros_like(prob), reduction='none')
+ return (1 - self.alpha) * negative_bag_loss
diff --git a/mmdet/models/dense_heads/fsaf_head.py b/mmdet/models/dense_heads/fsaf_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d2b78796948bec17a44624106d9022ae2be3e6c
--- /dev/null
+++ b/mmdet/models/dense_heads/fsaf_head.py
@@ -0,0 +1,433 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, images_to_levels, multi_apply,
+ unmap)
+from ..builder import HEADS
+from ..losses.accuracy import accuracy
+from ..losses.utils import weight_reduce_loss
+from .retina_head import RetinaHead
+
+
+@HEADS.register_module()
+class FSAFHead(RetinaHead):
+ """Anchor-free head used in `FSAF `_.
+
+ The head contains two subnetworks. The first classifies anchor boxes and
+ the second regresses deltas for the anchors (num_anchors is 1 for anchor-
+ free methods)
+
+ Args:
+ *args: Same as its base class in :class:`RetinaHead`
+ score_threshold (float, optional): The score_threshold to calculate
+ positive recall. If given, prediction scores lower than this value
+ is counted as incorrect prediction. Default to None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ **kwargs: Same as its base class in :class:`RetinaHead`
+
+ Example:
+ >>> import torch
+ >>> self = FSAFHead(11, 7)
+ >>> x = torch.rand(1, 7, 32, 32)
+ >>> cls_score, bbox_pred = self.forward_single(x)
+ >>> # Each anchor predicts a score for each class except background
+ >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
+ >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
+ >>> assert cls_per_anchor == self.num_classes
+ >>> assert box_per_anchor == 4
+ """
+
+ def __init__(self, *args, score_threshold=None, init_cfg=None, **kwargs):
+ # The positive bias in self.retina_reg conv is to prevent predicted \
+ # bbox with 0 area
+ if init_cfg is None:
+ init_cfg = dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=[
+ dict(
+ type='Normal',
+ name='retina_cls',
+ std=0.01,
+ bias_prob=0.01),
+ dict(
+ type='Normal', name='retina_reg', std=0.01, bias=0.25)
+ ])
+ super().__init__(*args, init_cfg=init_cfg, **kwargs)
+ self.score_threshold = score_threshold
+
+ def forward_single(self, x):
+ """Forward feature map of a single scale level.
+
+ Args:
+ x (Tensor): Feature map of a single scale level.
+
+ Returns:
+ tuple (Tensor):
+ cls_score (Tensor): Box scores for each scale level
+ Has shape (N, num_points * num_classes, H, W).
+ bbox_pred (Tensor): Box energies / deltas for each scale
+ level with shape (N, num_points * 4, H, W).
+ """
+ cls_score, bbox_pred = super().forward_single(x)
+ # relu: TBLR encoder only accepts positive bbox_pred
+ return cls_score, self.relu(bbox_pred)
+
+ def _get_targets_single(self,
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ Most of the codes are the same with the base class
+ :obj: `AnchorHead`, except that it also collects and returns
+ the matched gt index in the image (from 0 to num_gt-1). If the
+ anchor bbox is not matched to any gt, the corresponding value in
+ pos_gt_inds is -1.
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # Assign gt and sample anchors
+ anchors = flat_anchors[inside_flags.type(torch.bool), :]
+ assign_result = self.assigner.assign(
+ anchors, gt_bboxes, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros((num_valid_anchors, label_channels),
+ dtype=torch.float)
+ pos_gt_inds = anchors.new_full((num_valid_anchors, ),
+ -1,
+ dtype=torch.long)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ if len(pos_inds) > 0:
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, both
+ # the predicted boxes and regression targets should be with
+ # absolute coordinate format.
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ # The assigned gt_index for each anchor. (0-based)
+ pos_gt_inds[pos_inds] = sampling_result.pos_assigned_gt_inds
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # shadowed_labels is a tensor composed of tuples
+ # (anchor_inds, class_label) that indicate those anchors lying in the
+ # outer region of a gt or overlapped by another gt with a smaller
+ # area.
+ #
+ # Therefore, only the shadowed labels are ignored for loss calculation.
+ # the key `shadowed_labels` is defined in :obj:`CenterRegionAssigner`
+ shadowed_labels = assign_result.get_extra_property('shadowed_labels')
+ if shadowed_labels is not None and shadowed_labels.numel():
+ if len(shadowed_labels.shape) == 2:
+ idx_, label_ = shadowed_labels[:, 0], shadowed_labels[:, 1]
+ assert (labels[idx_] != label_).all(), \
+ 'One label cannot be both positive and ignored'
+ label_weights[idx_, label_] = 0
+ else:
+ label_weights[shadowed_labels] = 0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ labels = unmap(labels, num_total_anchors, inside_flags)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+ pos_gt_inds = unmap(
+ pos_gt_inds, num_total_anchors, inside_flags, fill=-1)
+
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds, sampling_result, pos_gt_inds)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_points * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_points * 4, H, W).
+ gt_bboxes (list[Tensor]): each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ for i in range(len(bbox_preds)): # loop over fpn level
+ # avoid 0 area of the predicted bbox
+ bbox_preds[i] = bbox_preds[i].clamp(min=1e-4)
+ # TODO: It may directly use the base-class loss function.
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+ batch_size = len(gt_bboxes)
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg,
+ pos_assigned_gt_inds_list) = cls_reg_targets
+
+ num_gts = np.array(list(map(len, gt_labels)))
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(anchor_list)):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+
+ # `pos_assigned_gt_inds_list` (length: fpn_levels) stores the assigned
+ # gt index of each anchor bbox in each fpn level.
+ cum_num_gts = list(np.cumsum(num_gts)) # length of batch_size
+ for i, assign in enumerate(pos_assigned_gt_inds_list):
+ # loop over fpn levels
+ for j in range(1, batch_size):
+ # loop over batch size
+ # Convert gt indices in each img to those in the batch
+ assign[j][assign[j] >= 0] += int(cum_num_gts[j - 1])
+ pos_assigned_gt_inds_list[i] = assign.flatten()
+ labels_list[i] = labels_list[i].flatten()
+ num_gts = sum(map(len, gt_labels)) # total number of gt in the batch
+ # The unique label index of each gt in the batch
+ label_sequence = torch.arange(num_gts, device=device)
+ # Collect the average loss of each gt in each level
+ with torch.no_grad():
+ loss_levels, = multi_apply(
+ self.collect_loss_level_single,
+ losses_cls,
+ losses_bbox,
+ pos_assigned_gt_inds_list,
+ labels_seq=label_sequence)
+ # Shape: (fpn_levels, num_gts). Loss of each gt at each fpn level
+ loss_levels = torch.stack(loss_levels, dim=0)
+ # Locate the best fpn level for loss back-propagation
+ if loss_levels.numel() == 0: # zero gt
+ argmin = loss_levels.new_empty((num_gts, ), dtype=torch.long)
+ else:
+ _, argmin = loss_levels.min(dim=0)
+
+ # Reweight the loss of each (anchor, label) pair, so that only those
+ # at the best gt level are back-propagated.
+ losses_cls, losses_bbox, pos_inds = multi_apply(
+ self.reweight_loss_single,
+ losses_cls,
+ losses_bbox,
+ pos_assigned_gt_inds_list,
+ labels_list,
+ list(range(len(losses_cls))),
+ min_levels=argmin)
+ num_pos = torch.cat(pos_inds, 0).sum().float()
+ pos_recall = self.calculate_pos_recall(cls_scores, labels_list,
+ pos_inds)
+
+ if num_pos == 0: # No gt
+ avg_factor = num_pos + float(num_total_neg)
+ else:
+ avg_factor = num_pos
+ for i in range(len(losses_cls)):
+ losses_cls[i] /= avg_factor
+ losses_bbox[i] /= avg_factor
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox=losses_bbox,
+ num_pos=num_pos / batch_size,
+ pos_recall=pos_recall)
+
+ def calculate_pos_recall(self, cls_scores, labels_list, pos_inds):
+ """Calculate positive recall with score threshold.
+
+ Args:
+ cls_scores (list[Tensor]): Classification scores at all fpn levels.
+ Each tensor is in shape (N, num_classes * num_anchors, H, W)
+ labels_list (list[Tensor]): The label that each anchor is assigned
+ to. Shape (N * H * W * num_anchors, )
+ pos_inds (list[Tensor]): List of bool tensors indicating whether
+ the anchor is assigned to a positive label.
+ Shape (N * H * W * num_anchors, )
+
+ Returns:
+ Tensor: A single float number indicating the positive recall.
+ """
+ with torch.no_grad():
+ num_class = self.num_classes
+ scores = [
+ cls.permute(0, 2, 3, 1).reshape(-1, num_class)[pos]
+ for cls, pos in zip(cls_scores, pos_inds)
+ ]
+ labels = [
+ label.reshape(-1)[pos]
+ for label, pos in zip(labels_list, pos_inds)
+ ]
+ scores = torch.cat(scores, dim=0)
+ labels = torch.cat(labels, dim=0)
+ if self.use_sigmoid_cls:
+ scores = scores.sigmoid()
+ else:
+ scores = scores.softmax(dim=1)
+
+ return accuracy(scores, labels, thresh=self.score_threshold)
+
+ def collect_loss_level_single(self, cls_loss, reg_loss, assigned_gt_inds,
+ labels_seq):
+ """Get the average loss in each FPN level w.r.t. each gt label.
+
+ Args:
+ cls_loss (Tensor): Classification loss of each feature map pixel,
+ shape (num_anchor, num_class)
+ reg_loss (Tensor): Regression loss of each feature map pixel,
+ shape (num_anchor, 4)
+ assigned_gt_inds (Tensor): It indicates which gt the prior is
+ assigned to (0-based, -1: no assignment). shape (num_anchor),
+ labels_seq: The rank of labels. shape (num_gt)
+
+ Returns:
+ shape: (num_gt), average loss of each gt in this level
+ """
+ if len(reg_loss.shape) == 2: # iou loss has shape (num_prior, 4)
+ reg_loss = reg_loss.sum(dim=-1) # sum loss in tblr dims
+ if len(cls_loss.shape) == 2:
+ cls_loss = cls_loss.sum(dim=-1) # sum loss in class dims
+ loss = cls_loss + reg_loss
+ assert loss.size(0) == assigned_gt_inds.size(0)
+ # Default loss value is 1e6 for a layer where no anchor is positive
+ # to ensure it will not be chosen to back-propagate gradient
+ losses_ = loss.new_full(labels_seq.shape, 1e6)
+ for i, l in enumerate(labels_seq):
+ match = assigned_gt_inds == l
+ if match.any():
+ losses_[i] = loss[match].mean()
+ return losses_,
+
+ def reweight_loss_single(self, cls_loss, reg_loss, assigned_gt_inds,
+ labels, level, min_levels):
+ """Reweight loss values at each level.
+
+ Reassign loss values at each level by masking those where the
+ pre-calculated loss is too large. Then return the reduced losses.
+
+ Args:
+ cls_loss (Tensor): Element-wise classification loss.
+ Shape: (num_anchors, num_classes)
+ reg_loss (Tensor): Element-wise regression loss.
+ Shape: (num_anchors, 4)
+ assigned_gt_inds (Tensor): The gt indices that each anchor bbox
+ is assigned to. -1 denotes a negative anchor, otherwise it is the
+ gt index (0-based). Shape: (num_anchors, ),
+ labels (Tensor): Label assigned to anchors. Shape: (num_anchors, ).
+ level (int): The current level index in the pyramid
+ (0-4 for RetinaNet)
+ min_levels (Tensor): The best-matching level for each gt.
+ Shape: (num_gts, ),
+
+ Returns:
+ tuple:
+ - cls_loss: Reduced corrected classification loss. Scalar.
+ - reg_loss: Reduced corrected regression loss. Scalar.
+ - pos_flags (Tensor): Corrected bool tensor indicating the
+ final positive anchors. Shape: (num_anchors, ).
+ """
+ loc_weight = torch.ones_like(reg_loss)
+ cls_weight = torch.ones_like(cls_loss)
+ pos_flags = assigned_gt_inds >= 0 # positive pixel flag
+ pos_indices = torch.nonzero(pos_flags, as_tuple=False).flatten()
+
+ if pos_flags.any(): # pos pixels exist
+ pos_assigned_gt_inds = assigned_gt_inds[pos_flags]
+ zeroing_indices = (min_levels[pos_assigned_gt_inds] != level)
+ neg_indices = pos_indices[zeroing_indices]
+
+ if neg_indices.numel():
+ pos_flags[neg_indices] = 0
+ loc_weight[neg_indices] = 0
+ # Only the weight corresponding to the label is
+ # zeroed out if not selected
+ zeroing_labels = labels[neg_indices]
+ assert (zeroing_labels >= 0).all()
+ cls_weight[neg_indices, zeroing_labels] = 0
+
+ # Weighted loss for both cls and reg loss
+ cls_loss = weight_reduce_loss(cls_loss, cls_weight, reduction='sum')
+ reg_loss = weight_reduce_loss(reg_loss, loc_weight, reduction='sum')
+
+ return cls_loss, reg_loss, pos_flags
diff --git a/mmdet/models/dense_heads/ga_retina_head.py b/mmdet/models/dense_heads/ga_retina_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d9e874c2bfdd07b408d148110eb4dd85c3a9069
--- /dev/null
+++ b/mmdet/models/dense_heads/ga_retina_head.py
@@ -0,0 +1,113 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.ops import MaskedConv2d
+
+from ..builder import HEADS
+from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
+
+
+@HEADS.register_module()
+class GARetinaHead(GuidedAnchorHead):
+ """Guided-Anchor-based RetinaNet head."""
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ init_cfg=None,
+ **kwargs):
+ if init_cfg is None:
+ init_cfg = dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=[
+ dict(
+ type='Normal',
+ name='conv_loc',
+ std=0.01,
+ bias_prob=0.01),
+ dict(
+ type='Normal',
+ name='retina_cls',
+ std=0.01,
+ bias_prob=0.01)
+ ])
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ super(GARetinaHead, self).__init__(
+ num_classes, in_channels, init_cfg=init_cfg, **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+
+ self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1)
+ self.conv_shape = nn.Conv2d(self.feat_channels, self.num_anchors * 2,
+ 1)
+ self.feature_adaption_cls = FeatureAdaption(
+ self.feat_channels,
+ self.feat_channels,
+ kernel_size=3,
+ deform_groups=self.deform_groups)
+ self.feature_adaption_reg = FeatureAdaption(
+ self.feat_channels,
+ self.feat_channels,
+ kernel_size=3,
+ deform_groups=self.deform_groups)
+ self.retina_cls = MaskedConv2d(
+ self.feat_channels,
+ self.num_base_priors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.retina_reg = MaskedConv2d(
+ self.feat_channels, self.num_base_priors * 4, 3, padding=1)
+
+ def forward_single(self, x):
+ """Forward feature map of a single scale level."""
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+
+ loc_pred = self.conv_loc(cls_feat)
+ shape_pred = self.conv_shape(reg_feat)
+
+ cls_feat = self.feature_adaption_cls(cls_feat, shape_pred)
+ reg_feat = self.feature_adaption_reg(reg_feat, shape_pred)
+
+ if not self.training:
+ mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
+ else:
+ mask = None
+ cls_score = self.retina_cls(cls_feat, mask)
+ bbox_pred = self.retina_reg(reg_feat, mask)
+ return cls_score, bbox_pred, shape_pred, loc_pred
diff --git a/mmdet/models/dense_heads/ga_rpn_head.py b/mmdet/models/dense_heads/ga_rpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4123c8b3f56f29f94668920d77b7db75ae78d8a2
--- /dev/null
+++ b/mmdet/models/dense_heads/ga_rpn_head.py
@@ -0,0 +1,177 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv import ConfigDict
+from mmcv.ops import nms
+
+from ..builder import HEADS
+from .guided_anchor_head import GuidedAnchorHead
+
+
+@HEADS.register_module()
+class GARPNHead(GuidedAnchorHead):
+ """Guided-Anchor-based RPN head."""
+
+ def __init__(self,
+ in_channels,
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='conv_loc',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ super(GARPNHead, self).__init__(
+ 1, in_channels, init_cfg=init_cfg, **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.rpn_conv = nn.Conv2d(
+ self.in_channels, self.feat_channels, 3, padding=1)
+ super(GARPNHead, self)._init_layers()
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level."""
+
+ x = self.rpn_conv(x)
+ x = F.relu(x, inplace=True)
+ (cls_score, bbox_pred, shape_pred,
+ loc_pred) = super(GARPNHead, self).forward_single(x)
+ return cls_score, bbox_pred, shape_pred, loc_pred
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ shape_preds,
+ loc_preds,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore=None):
+ losses = super(GARPNHead, self).loss(
+ cls_scores,
+ bbox_preds,
+ shape_preds,
+ loc_preds,
+ gt_bboxes,
+ None,
+ img_metas,
+ gt_bboxes_ignore=gt_bboxes_ignore)
+ return dict(
+ loss_rpn_cls=losses['loss_cls'],
+ loss_rpn_bbox=losses['loss_bbox'],
+ loss_anchor_shape=losses['loss_shape'],
+ loss_anchor_loc=losses['loss_loc'])
+
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchors,
+ mlvl_masks,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ cfg = self.test_cfg if cfg is None else cfg
+
+ cfg = copy.deepcopy(cfg)
+
+ # deprecate arguments warning
+ if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
+ warnings.warn(
+ 'In rpn_proposal or test_cfg, '
+ 'nms_thr has been moved to a dict named nms as '
+ 'iou_threshold, max_num has been renamed as max_per_img, '
+ 'name of original arguments and the way to specify '
+ 'iou_threshold of NMS will be deprecated.')
+ if 'nms' not in cfg:
+ cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
+ if 'max_num' in cfg:
+ if 'max_per_img' in cfg:
+ assert cfg.max_num == cfg.max_per_img, f'You ' \
+ f'set max_num and max_per_img at the same time, ' \
+ f'but get {cfg.max_num} ' \
+ f'and {cfg.max_per_img} respectively' \
+ 'Please delete max_num which will be deprecated.'
+ else:
+ cfg.max_per_img = cfg.max_num
+ if 'nms_thr' in cfg:
+ assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
+ f'iou_threshold in nms and ' \
+ f'nms_thr at the same time, but get ' \
+ f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
+ f' respectively. Please delete the ' \
+ f'nms_thr which will be deprecated.'
+
+ assert cfg.nms.get('type', 'nms') == 'nms', 'GARPNHead only support ' \
+ 'naive nms.'
+
+ mlvl_proposals = []
+ for idx in range(len(cls_scores)):
+ rpn_cls_score = cls_scores[idx]
+ rpn_bbox_pred = bbox_preds[idx]
+ anchors = mlvl_anchors[idx]
+ mask = mlvl_masks[idx]
+ assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
+ # if no location is kept, end.
+ if mask.sum() == 0:
+ continue
+ rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
+ if self.use_sigmoid_cls:
+ rpn_cls_score = rpn_cls_score.reshape(-1)
+ scores = rpn_cls_score.sigmoid()
+ else:
+ rpn_cls_score = rpn_cls_score.reshape(-1, 2)
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ scores = rpn_cls_score.softmax(dim=1)[:, :-1]
+ # filter scores, bbox_pred w.r.t. mask.
+ # anchors are filtered in get_anchors() beforehand.
+ scores = scores[mask]
+ rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1,
+ 4)[mask, :]
+ if scores.dim() == 0:
+ rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0)
+ anchors = anchors.unsqueeze(0)
+ scores = scores.unsqueeze(0)
+ # filter anchors, bbox_pred, scores w.r.t. scores
+ if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
+ _, topk_inds = scores.topk(cfg.nms_pre)
+ rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
+ anchors = anchors[topk_inds, :]
+ scores = scores[topk_inds]
+ # get proposals w.r.t. anchors and rpn_bbox_pred
+ proposals = self.bbox_coder.decode(
+ anchors, rpn_bbox_pred, max_shape=img_shape)
+ # filter out too small bboxes
+ if cfg.min_bbox_size >= 0:
+ w = proposals[:, 2] - proposals[:, 0]
+ h = proposals[:, 3] - proposals[:, 1]
+ valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
+ if not valid_mask.all():
+ proposals = proposals[valid_mask]
+ scores = scores[valid_mask]
+
+ # NMS in current level
+ proposals, _ = nms(proposals, scores, cfg.nms.iou_threshold)
+ proposals = proposals[:cfg.nms_post, :]
+ mlvl_proposals.append(proposals)
+ proposals = torch.cat(mlvl_proposals, 0)
+ if cfg.get('nms_across_levels', False):
+ # NMS across multi levels
+ proposals, _ = nms(proposals[:, :4], proposals[:, -1],
+ cfg.nms.iou_threshold)
+ proposals = proposals[:cfg.max_per_img, :]
+ else:
+ scores = proposals[:, 4]
+ num = min(cfg.max_per_img, proposals.shape[0])
+ _, topk_inds = scores.topk(num)
+ proposals = proposals[topk_inds, :]
+ return proposals
diff --git a/mmdet/models/dense_heads/gfl_head.py b/mmdet/models/dense_heads/gfl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..12eb89db8c9c9336955d7ef40d6636d122537908
--- /dev/null
+++ b/mmdet/models/dense_heads/gfl_head.py
@@ -0,0 +1,648 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, Scale
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, bbox_overlaps, build_assigner,
+ build_sampler, images_to_levels, multi_apply,
+ reduce_mean, unmap)
+from mmdet.core.utils import filter_scores_and_topk
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+
+class Integral(nn.Module):
+ """A fixed layer for calculating integral result from distribution.
+
+ This layer calculates the target location by :math: `sum{P(y_i) * y_i}`,
+ P(y_i) denotes the softmax vector that represents the discrete distribution
+ y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
+
+ Args:
+ reg_max (int): The maximal value of the discrete set. Default: 16. You
+ may want to reset it according to your new dataset or related
+ settings.
+ """
+
+ def __init__(self, reg_max=16):
+ super(Integral, self).__init__()
+ self.reg_max = reg_max
+ self.register_buffer('project',
+ torch.linspace(0, self.reg_max, self.reg_max + 1))
+
+ def forward(self, x):
+ """Forward feature from the regression head to get integral result of
+ bounding box location.
+
+ Args:
+ x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
+ n is self.reg_max.
+
+ Returns:
+ x (Tensor): Integral result of box locations, i.e., distance
+ offsets from the box center in four directions, shape (N, 4).
+ """
+ x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
+ x = F.linear(x, self.project.type_as(x)).reshape(-1, 4)
+ return x
+
+
+@HEADS.register_module()
+class GFLHead(AnchorHead):
+ """Generalized Focal Loss: Learning Qualified and Distributed Bounding
+ Boxes for Dense Object Detection.
+
+ GFL head structure is similar with ATSS, however GFL uses
+ 1) joint representation for classification and localization quality, and
+ 2) flexible General distribution for bounding box locations,
+ which are supervised by
+ Quality Focal Loss (QFL) and Distribution Focal Loss (DFL), respectively
+
+ https://arxiv.org/abs/2006.04388
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): Number of conv layers in cls and reg tower.
+ Default: 4.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='GN', num_groups=32, requires_grad=True).
+ loss_qfl (dict): Config of Quality Focal Loss (QFL).
+ bbox_coder (dict): Config of bbox coder. Defaults
+ 'DistancePointBBoxCoder'.
+ reg_max (int): Max value of integral set :math: `{0, ..., reg_max}`
+ in QFL setting. Default: 16.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Example:
+ >>> self = GFLHead(11, 7)
+ >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
+ >>> cls_quality_score, bbox_pred = self.forward(feats)
+ >>> assert len(cls_quality_score) == len(self.scales)
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25),
+ bbox_coder=dict(type='DistancePointBBoxCoder'),
+ reg_max=16,
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='gfl_cls',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.reg_max = reg_max
+ super(GFLHead, self).__init__(
+ num_classes,
+ in_channels,
+ bbox_coder=bbox_coder,
+ init_cfg=init_cfg,
+ **kwargs)
+
+ self.sampling = False
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # SSD sampling=False so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.integral = Integral(self.reg_max)
+ self.loss_dfl = build_loss(loss_dfl)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ assert self.num_anchors == 1, 'anchor free version'
+ self.gfl_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ self.gfl_reg = nn.Conv2d(
+ self.feat_channels, 4 * (self.reg_max + 1), 3, padding=1)
+ self.scales = nn.ModuleList(
+ [Scale(1.0) for _ in self.prior_generator.strides])
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ cls_scores (list[Tensor]): Classification and quality (IoU)
+ joint scores for all scale levels, each is a 4D-tensor,
+ the channel number is num_classes.
+ bbox_preds (list[Tensor]): Box distribution logits for all
+ scale levels, each is a 4D-tensor, the channel number is
+ 4*(n+1), n is max value of integral set.
+ """
+ return multi_apply(self.forward_single, feats, self.scales)
+
+ def forward_single(self, x, scale):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls and quality joint scores for a single
+ scale level the channel number is num_classes.
+ bbox_pred (Tensor): Box distribution logits for a single scale
+ level, the channel number is 4*(n+1), n is max value of
+ integral set.
+ """
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.gfl_cls(cls_feat)
+ bbox_pred = scale(self.gfl_reg(reg_feat)).float()
+ return cls_score, bbox_pred
+
+ def anchor_center(self, anchors):
+ """Get anchor centers from anchors.
+
+ Args:
+ anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format.
+
+ Returns:
+ Tensor: Anchor centers with shape (N, 2), "xy" format.
+ """
+ anchors_cx = (anchors[..., 2] + anchors[..., 0]) / 2
+ anchors_cy = (anchors[..., 3] + anchors[..., 1]) / 2
+ return torch.stack([anchors_cx, anchors_cy], dim=-1)
+
+ def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
+ bbox_targets, stride, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ cls_score (Tensor): Cls and quality joint scores for each scale
+ level has shape (N, num_classes, H, W).
+ bbox_pred (Tensor): Box distribution logits for each scale
+ level with shape (N, 4*(n+1), H, W), n is max value of integral
+ set.
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor
+ weight shape (N, num_total_anchors, 4).
+ stride (tuple): Stride in this scale level.
+ num_total_samples (int): Number of positive samples that is
+ reduced over all GPUs.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert stride[0] == stride[1], 'h stride is not equal to w stride!'
+ anchors = anchors.reshape(-1, 4)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ bbox_pred = bbox_pred.permute(0, 2, 3,
+ 1).reshape(-1, 4 * (self.reg_max + 1))
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((labels >= 0)
+ & (labels < bg_class_ind)).nonzero().squeeze(1)
+ score = label_weights.new_zeros(labels.shape)
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = bbox_targets[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_anchors = anchors[pos_inds]
+ pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
+
+ weight_targets = cls_score.detach().sigmoid()
+ weight_targets = weight_targets.max(dim=1)[0][pos_inds]
+ pos_bbox_pred_corners = self.integral(pos_bbox_pred)
+ pos_decode_bbox_pred = self.bbox_coder.decode(
+ pos_anchor_centers, pos_bbox_pred_corners)
+ pos_decode_bbox_targets = pos_bbox_targets / stride[0]
+ score[pos_inds] = bbox_overlaps(
+ pos_decode_bbox_pred.detach(),
+ pos_decode_bbox_targets,
+ is_aligned=True)
+ pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
+ target_corners = self.bbox_coder.encode(pos_anchor_centers,
+ pos_decode_bbox_targets,
+ self.reg_max).reshape(-1)
+
+ # regression loss
+ loss_bbox = self.loss_bbox(
+ pos_decode_bbox_pred,
+ pos_decode_bbox_targets,
+ weight=weight_targets,
+ avg_factor=1.0)
+
+ # dfl loss
+ loss_dfl = self.loss_dfl(
+ pred_corners,
+ target_corners,
+ weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
+ avg_factor=4.0)
+ else:
+ loss_bbox = bbox_pred.sum() * 0
+ loss_dfl = bbox_pred.sum() * 0
+ weight_targets = bbox_pred.new_tensor(0)
+
+ # cls (qfl) loss
+ loss_cls = self.loss_cls(
+ cls_score, (labels, score),
+ weight=label_weights,
+ avg_factor=num_total_samples)
+
+ return loss_cls, loss_bbox, loss_dfl, weight_targets.sum()
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Cls and quality scores for each scale
+ level has shape (N, num_classes, H, W).
+ bbox_preds (list[Tensor]): Box distribution logits for each scale
+ level with shape (N, 4*(n+1), H, W), n is max value of integral
+ set.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ losses_cls, losses_bbox, losses_dfl,\
+ avg_factor = multi_apply(
+ self.loss_single,
+ anchor_list,
+ cls_scores,
+ bbox_preds,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ self.prior_generator.strides,
+ num_total_samples=num_total_samples)
+
+ avg_factor = sum(avg_factor)
+ avg_factor = reduce_mean(avg_factor).clamp_(min=1).item()
+ losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox))
+ losses_dfl = list(map(lambda x: x / avg_factor, losses_dfl))
+ return dict(
+ loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dfl=losses_dfl)
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ score_factor_list,
+ mlvl_priors,
+ img_meta,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform outputs of a single image into bbox predictions.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_priors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has shape
+ (num_priors * 4, H, W).
+ score_factor_list (list[Tensor]): Score factor from all scale
+ levels of a single image. GFL head does not need this value.
+ mlvl_priors (list[Tensor]): Each element in the list is
+ the priors of a single level in feature pyramid, has shape
+ (num_priors, 4).
+ img_meta (dict): Image meta info.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ img_shape = img_meta['img_shape']
+ nms_pre = cfg.get('nms_pre', -1)
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_labels = []
+ for level_idx, (cls_score, bbox_pred, stride, priors) in enumerate(
+ zip(cls_score_list, bbox_pred_list,
+ self.prior_generator.strides, mlvl_priors)):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ assert stride[0] == stride[1]
+
+ bbox_pred = bbox_pred.permute(1, 2, 0)
+ bbox_pred = self.integral(bbox_pred) * stride[0]
+
+ scores = cls_score.permute(1, 2, 0).reshape(
+ -1, self.cls_out_channels).sigmoid()
+
+ # After https://github.com/open-mmlab/mmdetection/pull/6268/,
+ # this operation keeps fewer bboxes under the same `nms_pre`.
+ # There is no difference in performance for most models. If you
+ # find a slight drop in performance, you can set a larger
+ # `nms_pre` than before.
+ results = filter_scores_and_topk(
+ scores, cfg.score_thr, nms_pre,
+ dict(bbox_pred=bbox_pred, priors=priors))
+ scores, labels, _, filtered_results = results
+
+ bbox_pred = filtered_results['bbox_pred']
+ priors = filtered_results['priors']
+
+ bboxes = self.bbox_coder.decode(
+ self.anchor_center(priors), bbox_pred, max_shape=img_shape)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_labels.append(labels)
+
+ return self._bbox_post_process(
+ mlvl_scores,
+ mlvl_labels,
+ mlvl_bboxes,
+ img_meta['scale_factor'],
+ cfg,
+ rescale=rescale,
+ with_nms=with_nms)
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Get targets for GFL head.
+
+ This method is almost the same as `AnchorHead.get_targets()`. Besides
+ returning the targets as the parent method does, it also returns the
+ anchors as the first element of the returned tuple.
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ num_level_anchors_list = [num_level_anchors] * num_imgs
+
+ # concat all level anchors and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ anchor_list[i] = torch.cat(anchor_list[i])
+ valid_flag_list[i] = torch.cat(valid_flag_list[i])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ anchors_list = images_to_levels(all_anchors, num_level_anchors)
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors)
+ return (anchors_list, labels_list, label_weights_list,
+ bbox_targets_list, bbox_weights_list, num_total_pos,
+ num_total_neg)
+
+ def _get_target_single(self,
+ flat_anchors,
+ valid_flags,
+ num_level_anchors,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression, classification targets for anchors in a single
+ image.
+
+ Args:
+ flat_anchors (Tensor): Multi-level anchors of the image, which are
+ concatenated into a single tensor of shape (num_anchors, 4)
+ valid_flags (Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_anchors,).
+ num_level_anchors Tensor): Number of anchors of each scale level.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ img_meta (dict): Meta info of the image.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: N is the number of total anchors in the image.
+ anchors (Tensor): All anchors in the image with shape (N, 4).
+ labels (Tensor): Labels of all anchors in the image with shape
+ (N,).
+ label_weights (Tensor): Label weights of all anchor in the
+ image with shape (N,).
+ bbox_targets (Tensor): BBox targets of all anchors in the
+ image with shape (N, 4).
+ bbox_weights (Tensor): BBox weights of all anchors in the
+ image with shape (N, 4).
+ pos_inds (Tensor): Indices of positive anchor with shape
+ (num_pos,).
+ neg_inds (Tensor): Indices of negative anchor with shape
+ (num_neg,).
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+
+ num_level_anchors_inside = self.get_num_level_anchors_inside(
+ num_level_anchors, inside_flags)
+ assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
+ gt_bboxes, gt_bboxes_ignore,
+ gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ anchors = unmap(anchors, num_total_anchors, inside_flags)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags, fill=self.num_classes)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (anchors, labels, label_weights, bbox_targets, bbox_weights,
+ pos_inds, neg_inds)
+
+ def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
+ split_inside_flags = torch.split(inside_flags, num_level_anchors)
+ num_level_anchors_inside = [
+ int(flags.sum()) for flags in split_inside_flags
+ ]
+ return num_level_anchors_inside
diff --git a/mmdet/models/dense_heads/guided_anchor_head.py b/mmdet/models/dense_heads/guided_anchor_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..53e8cd8a750287ca60b33a5cdcb9ce2b02e4c2e3
--- /dev/null
+++ b/mmdet/models/dense_heads/guided_anchor_head.py
@@ -0,0 +1,868 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv.ops import DeformConv2d, MaskedConv2d
+from mmcv.runner import BaseModule, force_fp32
+
+from mmdet.core import (anchor_inside_flags, build_assigner, build_bbox_coder,
+ build_prior_generator, build_sampler, calc_region,
+ images_to_levels, multi_apply, multiclass_nms, unmap)
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+
+class FeatureAdaption(BaseModule):
+ """Feature Adaption Module.
+
+ Feature Adaption Module is implemented based on DCN v1.
+ It uses anchor shape prediction rather than feature map to
+ predict offsets of deform conv layer.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ out_channels (int): Number of channels in the output feature map.
+ kernel_size (int): Deformable conv kernel size.
+ deform_groups (int): Deformable conv group size.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ deform_groups=4,
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.1,
+ override=dict(
+ type='Normal', name='conv_adaption', std=0.01))):
+ super(FeatureAdaption, self).__init__(init_cfg)
+ offset_channels = kernel_size * kernel_size * 2
+ self.conv_offset = nn.Conv2d(
+ 2, deform_groups * offset_channels, 1, bias=False)
+ self.conv_adaption = DeformConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ deform_groups=deform_groups)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x, shape):
+ offset = self.conv_offset(shape.detach())
+ x = self.relu(self.conv_adaption(x, offset))
+ return x
+
+
+@HEADS.register_module()
+class GuidedAnchorHead(AnchorHead):
+ """Guided-Anchor-based head (GA-RPN, GA-RetinaNet, etc.).
+
+ This GuidedAnchorHead will predict high-quality feature guided
+ anchors and locations where anchors will be kept in inference.
+ There are mainly 3 categories of bounding-boxes.
+
+ - Sampled 9 pairs for target assignment. (approxes)
+ - The square boxes where the predicted anchors are based on. (squares)
+ - Guided anchors.
+
+ Please refer to https://arxiv.org/abs/1901.03278 for more details.
+
+ Args:
+ num_classes (int): Number of classes.
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels.
+ approx_anchor_generator (dict): Config dict for approx generator
+ square_anchor_generator (dict): Config dict for square generator
+ anchor_coder (dict): Config dict for anchor coder
+ bbox_coder (dict): Config dict for bbox coder
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ deform_groups: (int): Group number of DCN in
+ FeatureAdaption module.
+ loc_filter_thr (float): Threshold to filter out unconcerned regions.
+ loss_loc (dict): Config of location loss.
+ loss_shape (dict): Config of anchor shape loss.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of bbox regression loss.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(
+ self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ approx_anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=8,
+ scales_per_octave=3,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ square_anchor_generator=dict(
+ type='AnchorGenerator',
+ ratios=[1.0],
+ scales=[8],
+ strides=[4, 8, 16, 32, 64]),
+ anchor_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]
+ ),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]
+ ),
+ reg_decoded_bbox=False,
+ deform_groups=4,
+ loc_filter_thr=0.01,
+ train_cfg=None,
+ test_cfg=None,
+ loss_loc=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0),
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+ loss_weight=1.0),
+ init_cfg=dict(type='Normal', layer='Conv2d', std=0.01,
+ override=dict(type='Normal',
+ name='conv_loc',
+ std=0.01,
+ bias_prob=0.01))): # yapf: disable
+ super(AnchorHead, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.feat_channels = feat_channels
+ self.deform_groups = deform_groups
+ self.loc_filter_thr = loc_filter_thr
+
+ # build approx_anchor_generator and square_anchor_generator
+ assert (approx_anchor_generator['octave_base_scale'] ==
+ square_anchor_generator['scales'][0])
+ assert (approx_anchor_generator['strides'] ==
+ square_anchor_generator['strides'])
+ self.approx_anchor_generator = build_prior_generator(
+ approx_anchor_generator)
+ self.square_anchor_generator = build_prior_generator(
+ square_anchor_generator)
+ self.approxs_per_octave = self.approx_anchor_generator \
+ .num_base_priors[0]
+
+ self.reg_decoded_bbox = reg_decoded_bbox
+
+ # one anchor per location
+ self.num_base_priors = self.square_anchor_generator.num_base_priors[0]
+
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ self.loc_focal_loss = loss_loc['type'] in ['FocalLoss']
+ self.sampling = loss_cls['type'] not in ['FocalLoss']
+ self.ga_sampling = train_cfg is not None and hasattr(
+ train_cfg, 'ga_sampler')
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = self.num_classes
+ else:
+ self.cls_out_channels = self.num_classes + 1
+
+ # build bbox_coder
+ self.anchor_coder = build_bbox_coder(anchor_coder)
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+
+ # build losses
+ self.loss_loc = build_loss(loss_loc)
+ self.loss_shape = build_loss(loss_shape)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.ga_assigner = build_assigner(self.train_cfg.ga_assigner)
+ if self.ga_sampling:
+ ga_sampler_cfg = self.train_cfg.ga_sampler
+ else:
+ ga_sampler_cfg = dict(type='PseudoSampler')
+ self.ga_sampler = build_sampler(ga_sampler_cfg, context=self)
+
+ self.fp16_enabled = False
+
+ self._init_layers()
+
+ @property
+ def num_anchors(self):
+ warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
+ 'please use "num_base_priors" instead')
+ return self.square_anchor_generator.num_base_priors[0]
+
+ def _init_layers(self):
+ self.relu = nn.ReLU(inplace=True)
+ self.conv_loc = nn.Conv2d(self.in_channels, 1, 1)
+ self.conv_shape = nn.Conv2d(self.in_channels, self.num_base_priors * 2,
+ 1)
+ self.feature_adaption = FeatureAdaption(
+ self.in_channels,
+ self.feat_channels,
+ kernel_size=3,
+ deform_groups=self.deform_groups)
+ self.conv_cls = MaskedConv2d(
+ self.feat_channels, self.num_base_priors * self.cls_out_channels,
+ 1)
+ self.conv_reg = MaskedConv2d(self.feat_channels,
+ self.num_base_priors * 4, 1)
+
+ def forward_single(self, x):
+ loc_pred = self.conv_loc(x)
+ shape_pred = self.conv_shape(x)
+ x = self.feature_adaption(x, shape_pred)
+ # masked conv is only used during inference for speed-up
+ if not self.training:
+ mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
+ else:
+ mask = None
+ cls_score = self.conv_cls(x, mask)
+ bbox_pred = self.conv_reg(x, mask)
+ return cls_score, bbox_pred, shape_pred, loc_pred
+
+ def forward(self, feats):
+ return multi_apply(self.forward_single, feats)
+
+ def get_sampled_approxs(self, featmap_sizes, img_metas, device='cuda'):
+ """Get sampled approxs and inside flags according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+ device (torch.device | str): device for returned tensors
+
+ Returns:
+ tuple: approxes of each image, inside flags of each image
+ """
+ num_imgs = len(img_metas)
+
+ # since feature map sizes of all images are the same, we only compute
+ # approxes for one time
+ multi_level_approxs = self.approx_anchor_generator.grid_priors(
+ featmap_sizes, device=device)
+ approxs_list = [multi_level_approxs for _ in range(num_imgs)]
+
+ # for each image, we compute inside flags of multi level approxes
+ inside_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = []
+ multi_level_approxs = approxs_list[img_id]
+
+ # obtain valid flags for each approx first
+ multi_level_approx_flags = self.approx_anchor_generator \
+ .valid_flags(featmap_sizes,
+ img_meta['pad_shape'],
+ device=device)
+
+ for i, flags in enumerate(multi_level_approx_flags):
+ approxs = multi_level_approxs[i]
+ inside_flags_list = []
+ for i in range(self.approxs_per_octave):
+ split_valid_flags = flags[i::self.approxs_per_octave]
+ split_approxs = approxs[i::self.approxs_per_octave, :]
+ inside_flags = anchor_inside_flags(
+ split_approxs, split_valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ inside_flags_list.append(inside_flags)
+ # inside_flag for a position is true if any anchor in this
+ # position is true
+ inside_flags = (
+ torch.stack(inside_flags_list, 0).sum(dim=0) > 0)
+ multi_level_flags.append(inside_flags)
+ inside_flag_list.append(multi_level_flags)
+ return approxs_list, inside_flag_list
+
+ def get_anchors(self,
+ featmap_sizes,
+ shape_preds,
+ loc_preds,
+ img_metas,
+ use_loc_filter=False,
+ device='cuda'):
+ """Get squares according to feature map sizes and guided anchors.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ shape_preds (list[tensor]): Multi-level shape predictions.
+ loc_preds (list[tensor]): Multi-level location predictions.
+ img_metas (list[dict]): Image meta info.
+ use_loc_filter (bool): Use loc filter or not.
+ device (torch.device | str): device for returned tensors
+
+ Returns:
+ tuple: square approxs of each image, guided anchors of each image,
+ loc masks of each image
+ """
+ num_imgs = len(img_metas)
+ num_levels = len(featmap_sizes)
+
+ # since feature map sizes of all images are the same, we only compute
+ # squares for one time
+ multi_level_squares = self.square_anchor_generator.grid_priors(
+ featmap_sizes, device=device)
+ squares_list = [multi_level_squares for _ in range(num_imgs)]
+
+ # for each image, we compute multi level guided anchors
+ guided_anchors_list = []
+ loc_mask_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_guided_anchors = []
+ multi_level_loc_mask = []
+ for i in range(num_levels):
+ squares = squares_list[img_id][i]
+ shape_pred = shape_preds[i][img_id]
+ loc_pred = loc_preds[i][img_id]
+ guided_anchors, loc_mask = self._get_guided_anchors_single(
+ squares,
+ shape_pred,
+ loc_pred,
+ use_loc_filter=use_loc_filter)
+ multi_level_guided_anchors.append(guided_anchors)
+ multi_level_loc_mask.append(loc_mask)
+ guided_anchors_list.append(multi_level_guided_anchors)
+ loc_mask_list.append(multi_level_loc_mask)
+ return squares_list, guided_anchors_list, loc_mask_list
+
+ def _get_guided_anchors_single(self,
+ squares,
+ shape_pred,
+ loc_pred,
+ use_loc_filter=False):
+ """Get guided anchors and loc masks for a single level.
+
+ Args:
+ square (tensor): Squares of a single level.
+ shape_pred (tensor): Shape predictions of a single level.
+ loc_pred (tensor): Loc predictions of a single level.
+ use_loc_filter (list[tensor]): Use loc filter or not.
+
+ Returns:
+ tuple: guided anchors, location masks
+ """
+ # calculate location filtering mask
+ loc_pred = loc_pred.sigmoid().detach()
+ if use_loc_filter:
+ loc_mask = loc_pred >= self.loc_filter_thr
+ else:
+ loc_mask = loc_pred >= 0.0
+ mask = loc_mask.permute(1, 2, 0).expand(-1, -1, self.num_base_priors)
+ mask = mask.contiguous().view(-1)
+ # calculate guided anchors
+ squares = squares[mask]
+ anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view(
+ -1, 2).detach()[mask]
+ bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
+ bbox_deltas[:, 2:] = anchor_deltas
+ guided_anchors = self.anchor_coder.decode(
+ squares, bbox_deltas, wh_ratio_clip=1e-6)
+ return guided_anchors, mask
+
+ def ga_loc_targets(self, gt_bboxes_list, featmap_sizes):
+ """Compute location targets for guided anchoring.
+
+ Each feature map is divided into positive, negative and ignore regions.
+ - positive regions: target 1, weight 1
+ - ignore regions: target 0, weight 0
+ - negative regions: target 0, weight 0.1
+
+ Args:
+ gt_bboxes_list (list[Tensor]): Gt bboxes of each image.
+ featmap_sizes (list[tuple]): Multi level sizes of each feature
+ maps.
+
+ Returns:
+ tuple
+ """
+ anchor_scale = self.approx_anchor_generator.octave_base_scale
+ anchor_strides = self.approx_anchor_generator.strides
+ # Currently only supports same stride in x and y direction.
+ for stride in anchor_strides:
+ assert (stride[0] == stride[1])
+ anchor_strides = [stride[0] for stride in anchor_strides]
+
+ center_ratio = self.train_cfg.center_ratio
+ ignore_ratio = self.train_cfg.ignore_ratio
+ img_per_gpu = len(gt_bboxes_list)
+ num_lvls = len(featmap_sizes)
+ r1 = (1 - center_ratio) / 2
+ r2 = (1 - ignore_ratio) / 2
+ all_loc_targets = []
+ all_loc_weights = []
+ all_ignore_map = []
+ for lvl_id in range(num_lvls):
+ h, w = featmap_sizes[lvl_id]
+ loc_targets = torch.zeros(
+ img_per_gpu,
+ 1,
+ h,
+ w,
+ device=gt_bboxes_list[0].device,
+ dtype=torch.float32)
+ loc_weights = torch.full_like(loc_targets, -1)
+ ignore_map = torch.zeros_like(loc_targets)
+ all_loc_targets.append(loc_targets)
+ all_loc_weights.append(loc_weights)
+ all_ignore_map.append(ignore_map)
+ for img_id in range(img_per_gpu):
+ gt_bboxes = gt_bboxes_list[img_id]
+ scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
+ (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
+ min_anchor_size = scale.new_full(
+ (1, ), float(anchor_scale * anchor_strides[0]))
+ # assign gt bboxes to different feature levels w.r.t. their scales
+ target_lvls = torch.floor(
+ torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
+ target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
+ for gt_id in range(gt_bboxes.size(0)):
+ lvl = target_lvls[gt_id].item()
+ # rescaled to corresponding feature map
+ gt_ = gt_bboxes[gt_id, :4] / anchor_strides[lvl]
+ # calculate ignore regions
+ ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
+ gt_, r2, featmap_sizes[lvl])
+ # calculate positive (center) regions
+ ctr_x1, ctr_y1, ctr_x2, ctr_y2 = calc_region(
+ gt_, r1, featmap_sizes[lvl])
+ all_loc_targets[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
+ ctr_x1:ctr_x2 + 1] = 1
+ all_loc_weights[lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
+ ignore_x1:ignore_x2 + 1] = 0
+ all_loc_weights[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
+ ctr_x1:ctr_x2 + 1] = 1
+ # calculate ignore map on nearby low level feature
+ if lvl > 0:
+ d_lvl = lvl - 1
+ # rescaled to corresponding feature map
+ gt_ = gt_bboxes[gt_id, :4] / anchor_strides[d_lvl]
+ ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
+ gt_, r2, featmap_sizes[d_lvl])
+ all_ignore_map[d_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
+ ignore_x1:ignore_x2 + 1] = 1
+ # calculate ignore map on nearby high level feature
+ if lvl < num_lvls - 1:
+ u_lvl = lvl + 1
+ # rescaled to corresponding feature map
+ gt_ = gt_bboxes[gt_id, :4] / anchor_strides[u_lvl]
+ ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
+ gt_, r2, featmap_sizes[u_lvl])
+ all_ignore_map[u_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
+ ignore_x1:ignore_x2 + 1] = 1
+ for lvl_id in range(num_lvls):
+ # ignore negative regions w.r.t. ignore map
+ all_loc_weights[lvl_id][(all_loc_weights[lvl_id] < 0)
+ & (all_ignore_map[lvl_id] > 0)] = 0
+ # set negative regions with weight 0.1
+ all_loc_weights[lvl_id][all_loc_weights[lvl_id] < 0] = 0.1
+ # loc average factor to balance loss
+ loc_avg_factor = sum(
+ [t.size(0) * t.size(-1) * t.size(-2)
+ for t in all_loc_targets]) / 200
+ return all_loc_targets, all_loc_weights, loc_avg_factor
+
+ def _ga_shape_target_single(self,
+ flat_approxs,
+ inside_flags,
+ flat_squares,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ img_meta,
+ unmap_outputs=True):
+ """Compute guided anchoring targets.
+
+ This function returns sampled anchors and gt bboxes directly
+ rather than calculates regression targets.
+
+ Args:
+ flat_approxs (Tensor): flat approxs of a single image,
+ shape (n, 4)
+ inside_flags (Tensor): inside flags of a single image,
+ shape (n, ).
+ flat_squares (Tensor): flat squares of a single image,
+ shape (approxs_per_octave * n, 4)
+ gt_bboxes (Tensor): Ground truth bboxes of a single image.
+ img_meta (dict): Meta info of a single image.
+ approxs_per_octave (int): number of approxs per octave
+ cfg (dict): RPN train configs.
+ unmap_outputs (bool): unmap outputs or not.
+
+ Returns:
+ tuple
+ """
+ if not inside_flags.any():
+ return (None, ) * 5
+ # assign gt and sample anchors
+ expand_inside_flags = inside_flags[:, None].expand(
+ -1, self.approxs_per_octave).reshape(-1)
+ approxs = flat_approxs[expand_inside_flags, :]
+ squares = flat_squares[inside_flags, :]
+
+ assign_result = self.ga_assigner.assign(approxs, squares,
+ self.approxs_per_octave,
+ gt_bboxes, gt_bboxes_ignore)
+ sampling_result = self.ga_sampler.sample(assign_result, squares,
+ gt_bboxes)
+
+ bbox_anchors = torch.zeros_like(squares)
+ bbox_gts = torch.zeros_like(squares)
+ bbox_weights = torch.zeros_like(squares)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ bbox_anchors[pos_inds, :] = sampling_result.pos_bboxes
+ bbox_gts[pos_inds, :] = sampling_result.pos_gt_bboxes
+ bbox_weights[pos_inds, :] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_squares.size(0)
+ bbox_anchors = unmap(bbox_anchors, num_total_anchors, inside_flags)
+ bbox_gts = unmap(bbox_gts, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (bbox_anchors, bbox_gts, bbox_weights, pos_inds, neg_inds)
+
+ def ga_shape_targets(self,
+ approx_list,
+ inside_flag_list,
+ square_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ unmap_outputs=True):
+ """Compute guided anchoring targets.
+
+ Args:
+ approx_list (list[list]): Multi level approxs of each image.
+ inside_flag_list (list[list]): Multi level inside flags of each
+ image.
+ square_list (list[list]): Multi level squares of each image.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): ignore list of gt bboxes.
+ unmap_outputs (bool): unmap outputs or not.
+
+ Returns:
+ tuple
+ """
+ num_imgs = len(img_metas)
+ assert len(approx_list) == len(inside_flag_list) == len(
+ square_list) == num_imgs
+ # anchor number of multi levels
+ num_level_squares = [squares.size(0) for squares in square_list[0]]
+ # concat all level anchors and flags to a single tensor
+ inside_flag_flat_list = []
+ approx_flat_list = []
+ square_flat_list = []
+ for i in range(num_imgs):
+ assert len(square_list[i]) == len(inside_flag_list[i])
+ inside_flag_flat_list.append(torch.cat(inside_flag_list[i]))
+ approx_flat_list.append(torch.cat(approx_list[i]))
+ square_flat_list.append(torch.cat(square_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ (all_bbox_anchors, all_bbox_gts, all_bbox_weights, pos_inds_list,
+ neg_inds_list) = multi_apply(
+ self._ga_shape_target_single,
+ approx_flat_list,
+ inside_flag_flat_list,
+ square_flat_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ img_metas,
+ unmap_outputs=unmap_outputs)
+ # no valid anchors
+ if any([bbox_anchors is None for bbox_anchors in all_bbox_anchors]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ bbox_anchors_list = images_to_levels(all_bbox_anchors,
+ num_level_squares)
+ bbox_gts_list = images_to_levels(all_bbox_gts, num_level_squares)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_squares)
+ return (bbox_anchors_list, bbox_gts_list, bbox_weights_list,
+ num_total_pos, num_total_neg)
+
+ def loss_shape_single(self, shape_pred, bbox_anchors, bbox_gts,
+ anchor_weights, anchor_total_num):
+ shape_pred = shape_pred.permute(0, 2, 3, 1).contiguous().view(-1, 2)
+ bbox_anchors = bbox_anchors.contiguous().view(-1, 4)
+ bbox_gts = bbox_gts.contiguous().view(-1, 4)
+ anchor_weights = anchor_weights.contiguous().view(-1, 4)
+ bbox_deltas = bbox_anchors.new_full(bbox_anchors.size(), 0)
+ bbox_deltas[:, 2:] += shape_pred
+ # filter out negative samples to speed-up weighted_bounded_iou_loss
+ inds = torch.nonzero(
+ anchor_weights[:, 0] > 0, as_tuple=False).squeeze(1)
+ bbox_deltas_ = bbox_deltas[inds]
+ bbox_anchors_ = bbox_anchors[inds]
+ bbox_gts_ = bbox_gts[inds]
+ anchor_weights_ = anchor_weights[inds]
+ pred_anchors_ = self.anchor_coder.decode(
+ bbox_anchors_, bbox_deltas_, wh_ratio_clip=1e-6)
+ loss_shape = self.loss_shape(
+ pred_anchors_,
+ bbox_gts_,
+ anchor_weights_,
+ avg_factor=anchor_total_num)
+ return loss_shape
+
+ def loss_loc_single(self, loc_pred, loc_target, loc_weight,
+ loc_avg_factor):
+ loss_loc = self.loss_loc(
+ loc_pred.reshape(-1, 1),
+ loc_target.reshape(-1).long(),
+ loc_weight.reshape(-1),
+ avg_factor=loc_avg_factor)
+ return loss_loc
+
+ @force_fp32(
+ apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ shape_preds,
+ loc_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.approx_anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ # get loc targets
+ loc_targets, loc_weights, loc_avg_factor = self.ga_loc_targets(
+ gt_bboxes, featmap_sizes)
+
+ # get sampled approxes
+ approxs_list, inside_flag_list = self.get_sampled_approxs(
+ featmap_sizes, img_metas, device=device)
+ # get squares and guided anchors
+ squares_list, guided_anchors_list, _ = self.get_anchors(
+ featmap_sizes, shape_preds, loc_preds, img_metas, device=device)
+
+ # get shape targets
+ shape_targets = self.ga_shape_targets(approxs_list, inside_flag_list,
+ squares_list, gt_bboxes,
+ img_metas)
+ if shape_targets is None:
+ return None
+ (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
+ anchor_bg_num) = shape_targets
+ anchor_total_num = (
+ anchor_fg_num if not self.ga_sampling else anchor_fg_num +
+ anchor_bg_num)
+
+ # get anchor targets
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ guided_anchors_list,
+ inside_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+
+ # anchor number of multi levels
+ num_level_anchors = [
+ anchors.size(0) for anchors in guided_anchors_list[0]
+ ]
+ # concat all level anchors to a single tensor
+ concat_anchor_list = []
+ for i in range(len(guided_anchors_list)):
+ concat_anchor_list.append(torch.cat(guided_anchors_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+
+ # get classification and bbox regression losses
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+
+ # get anchor location loss
+ losses_loc = []
+ for i in range(len(loc_preds)):
+ loss_loc = self.loss_loc_single(
+ loc_preds[i],
+ loc_targets[i],
+ loc_weights[i],
+ loc_avg_factor=loc_avg_factor)
+ losses_loc.append(loss_loc)
+
+ # get anchor shape loss
+ losses_shape = []
+ for i in range(len(shape_preds)):
+ loss_shape = self.loss_shape_single(
+ shape_preds[i],
+ bbox_anchors_list[i],
+ bbox_gts_list[i],
+ anchor_weights_list[i],
+ anchor_total_num=anchor_total_num)
+ losses_shape.append(loss_shape)
+
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox=losses_bbox,
+ loss_shape=losses_shape,
+ loss_loc=losses_loc)
+
+ @force_fp32(
+ apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ shape_preds,
+ loc_preds,
+ img_metas,
+ cfg=None,
+ rescale=False):
+ assert len(cls_scores) == len(bbox_preds) == len(shape_preds) == len(
+ loc_preds)
+ num_levels = len(cls_scores)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ device = cls_scores[0].device
+ # get guided anchors
+ _, guided_anchors, loc_masks = self.get_anchors(
+ featmap_sizes,
+ shape_preds,
+ loc_preds,
+ img_metas,
+ use_loc_filter=not self.training,
+ device=device)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds[i][img_id].detach() for i in range(num_levels)
+ ]
+ guided_anchor_list = [
+ guided_anchors[img_id][i].detach() for i in range(num_levels)
+ ]
+ loc_mask_list = [
+ loc_masks[img_id][i].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
+ guided_anchor_list,
+ loc_mask_list, img_shape,
+ scale_factor, cfg, rescale)
+ result_list.append(proposals)
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchors,
+ mlvl_masks,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds,
+ mlvl_anchors,
+ mlvl_masks):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ # if no location is kept, end.
+ if mask.sum() == 0:
+ continue
+ # reshape scores and bbox_pred
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ # filter scores, bbox_pred w.r.t. mask.
+ # anchors are filtered in get_anchors() beforehand.
+ scores = scores[mask, :]
+ bbox_pred = bbox_pred[mask, :]
+ if scores.dim() == 0:
+ anchors = anchors.unsqueeze(0)
+ scores = scores.unsqueeze(0)
+ bbox_pred = bbox_pred.unsqueeze(0)
+ # filter anchors, bbox_pred, scores w.r.t. scores
+ nms_pre = cfg.get('nms_pre', -1)
+ if nms_pre > 0 and scores.shape[0] > nms_pre:
+ if self.use_sigmoid_cls:
+ max_scores, _ = scores.max(dim=1)
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ max_scores, _ = scores[:, :-1].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ anchors = anchors[topk_inds, :]
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ bboxes = self.bbox_coder.decode(
+ anchors, bbox_pred, max_shape=img_shape)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ if self.use_sigmoid_cls:
+ # Add a dummy background class to the backend when using sigmoid
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+ mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
+ # multi class NMS
+ det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ return det_bboxes, det_labels
diff --git a/mmdet/models/dense_heads/lad_head.py b/mmdet/models/dense_heads/lad_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..85273bcb24308dd6f47c8d47362164a6f1393e1e
--- /dev/null
+++ b/mmdet/models/dense_heads/lad_head.py
@@ -0,0 +1,232 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.core import bbox_overlaps, multi_apply
+from ..builder import HEADS
+from .paa_head import PAAHead, levels_to_images
+
+
+@HEADS.register_module()
+class LADHead(PAAHead):
+ """Label Assignment Head from the paper: `Improving Object Detection by
+ Label Assignment Distillation `_"""
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds'))
+ def get_label_assignment(self,
+ cls_scores,
+ bbox_preds,
+ iou_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Get label assignment (from teacher).
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level.
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ iou_preds (list[Tensor]): iou_preds for each scale
+ level with shape (N, num_anchors * 1, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
+ boxes can be ignored when are computing the loss.
+
+ Returns:
+ tuple: Returns a tuple containing label assignment variables.
+
+ - labels (Tensor): Labels of all anchors, each with
+ shape (num_anchors,).
+ - labels_weight (Tensor): Label weights of all anchor.
+ each with shape (num_anchors,).
+ - bboxes_target (Tensor): BBox targets of all anchors.
+ each with shape (num_anchors, 4).
+ - bboxes_weight (Tensor): BBox weights of all anchors.
+ each with shape (num_anchors, 4).
+ - pos_inds_flatten (Tensor): Contains all index of positive
+ sample in all anchor.
+ - pos_anchors (Tensor): Positive anchors.
+ - num_pos (int): Number of positive anchors.
+ """
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ )
+ (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds,
+ pos_gt_index) = cls_reg_targets
+ cls_scores = levels_to_images(cls_scores)
+ cls_scores = [
+ item.reshape(-1, self.cls_out_channels) for item in cls_scores
+ ]
+ bbox_preds = levels_to_images(bbox_preds)
+ bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
+ pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list,
+ cls_scores, bbox_preds, labels,
+ labels_weight, bboxes_target,
+ bboxes_weight, pos_inds)
+
+ with torch.no_grad():
+ reassign_labels, reassign_label_weight, \
+ reassign_bbox_weights, num_pos = multi_apply(
+ self.paa_reassign,
+ pos_losses_list,
+ labels,
+ labels_weight,
+ bboxes_weight,
+ pos_inds,
+ pos_gt_index,
+ anchor_list)
+ num_pos = sum(num_pos)
+ # convert all tensor list to a flatten tensor
+ labels = torch.cat(reassign_labels, 0).view(-1)
+ flatten_anchors = torch.cat(
+ [torch.cat(item, 0) for item in anchor_list])
+ labels_weight = torch.cat(reassign_label_weight, 0).view(-1)
+ bboxes_target = torch.cat(bboxes_target,
+ 0).view(-1, bboxes_target[0].size(-1))
+
+ pos_inds_flatten = ((labels >= 0)
+ &
+ (labels < self.num_classes)).nonzero().reshape(-1)
+
+ if num_pos:
+ pos_anchors = flatten_anchors[pos_inds_flatten]
+ else:
+ pos_anchors = None
+
+ label_assignment_results = (labels, labels_weight, bboxes_target,
+ bboxes_weight, pos_inds_flatten,
+ pos_anchors, num_pos)
+ return label_assignment_results
+
+ def forward_train(self,
+ x,
+ label_assignment_results,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ **kwargs):
+ """Forward train with the available label assignment (student receives
+ from teacher).
+
+ Args:
+ x (list[Tensor]): Features from FPN.
+ label_assignment_results (tuple): As the outputs defined in the
+ function `self.get_label_assignment`.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+
+ Returns:
+ losses: (dict[str, Tensor]): A dictionary of loss components.
+ """
+ outs = self(x)
+ if gt_labels is None:
+ loss_inputs = outs + (gt_bboxes, img_metas)
+ else:
+ loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
+ losses = self.loss(
+ *loss_inputs,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ label_assignment_results=label_assignment_results)
+ return losses
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ iou_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None,
+ label_assignment_results=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ iou_preds (list[Tensor]): iou_preds for each scale
+ level with shape (N, num_anchors * 1, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
+ boxes can be ignored when are computing the loss.
+ label_assignment_results (tuple): As the outputs defined in the
+ function `self.get_label_assignment`.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss gmm_assignment.
+ """
+
+ (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds_flatten,
+ pos_anchors, num_pos) = label_assignment_results
+
+ cls_scores = levels_to_images(cls_scores)
+ cls_scores = [
+ item.reshape(-1, self.cls_out_channels) for item in cls_scores
+ ]
+ bbox_preds = levels_to_images(bbox_preds)
+ bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
+ iou_preds = levels_to_images(iou_preds)
+ iou_preds = [item.reshape(-1, 1) for item in iou_preds]
+
+ # convert all tensor list to a flatten tensor
+ cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
+ bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1))
+ iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1))
+
+ losses_cls = self.loss_cls(
+ cls_scores,
+ labels,
+ labels_weight,
+ avg_factor=max(num_pos, len(img_metas))) # avoid num_pos=0
+ if num_pos:
+ pos_bbox_pred = self.bbox_coder.decode(
+ pos_anchors, bbox_preds[pos_inds_flatten])
+ pos_bbox_target = bboxes_target[pos_inds_flatten]
+ iou_target = bbox_overlaps(
+ pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True)
+ losses_iou = self.loss_centerness(
+ iou_preds[pos_inds_flatten],
+ iou_target.unsqueeze(-1),
+ avg_factor=num_pos)
+ losses_bbox = self.loss_bbox(
+ pos_bbox_pred, pos_bbox_target, avg_factor=num_pos)
+
+ else:
+ losses_iou = iou_preds.sum() * 0
+ losses_bbox = bbox_preds.sum() * 0
+
+ return dict(
+ loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou)
diff --git a/mmdet/models/dense_heads/ld_head.py b/mmdet/models/dense_heads/ld_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5a945fe2cc9c6f42f9fdc64e278ccdc27bd9e55
--- /dev/null
+++ b/mmdet/models/dense_heads/ld_head.py
@@ -0,0 +1,261 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.core import bbox_overlaps, multi_apply, reduce_mean
+from ..builder import HEADS, build_loss
+from .gfl_head import GFLHead
+
+
+@HEADS.register_module()
+class LDHead(GFLHead):
+ """Localization distillation Head. (Short description)
+
+ It utilizes the learned bbox distributions to transfer the localization
+ dark knowledge from teacher to student. Original paper: `Localization
+ Distillation for Object Detection. `_
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ loss_ld (dict): Config of Localization Distillation Loss (LD),
+ T is the temperature for distillation.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ loss_ld=dict(
+ type='LocalizationDistillationLoss',
+ loss_weight=0.25,
+ T=10),
+ **kwargs):
+
+ super(LDHead, self).__init__(num_classes, in_channels, **kwargs)
+ self.loss_ld = build_loss(loss_ld)
+
+ def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
+ bbox_targets, stride, soft_targets, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ cls_score (Tensor): Cls and quality joint scores for each scale
+ level has shape (N, num_classes, H, W).
+ bbox_pred (Tensor): Box distribution logits for each scale
+ level with shape (N, 4*(n+1), H, W), n is max value of integral
+ set.
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor
+ weight shape (N, num_total_anchors, 4).
+ stride (tuple): Stride in this scale level.
+ num_total_samples (int): Number of positive samples that is
+ reduced over all GPUs.
+
+ Returns:
+ dict[tuple, Tensor]: Loss components and weight targets.
+ """
+ assert stride[0] == stride[1], 'h stride is not equal to w stride!'
+ anchors = anchors.reshape(-1, 4)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ bbox_pred = bbox_pred.permute(0, 2, 3,
+ 1).reshape(-1, 4 * (self.reg_max + 1))
+ soft_targets = soft_targets.permute(0, 2, 3,
+ 1).reshape(-1,
+ 4 * (self.reg_max + 1))
+
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((labels >= 0)
+ & (labels < bg_class_ind)).nonzero().squeeze(1)
+ score = label_weights.new_zeros(labels.shape)
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = bbox_targets[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_anchors = anchors[pos_inds]
+ pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
+
+ weight_targets = cls_score.detach().sigmoid()
+ weight_targets = weight_targets.max(dim=1)[0][pos_inds]
+ pos_bbox_pred_corners = self.integral(pos_bbox_pred)
+ pos_decode_bbox_pred = self.bbox_coder.decode(
+ pos_anchor_centers, pos_bbox_pred_corners)
+ pos_decode_bbox_targets = pos_bbox_targets / stride[0]
+ score[pos_inds] = bbox_overlaps(
+ pos_decode_bbox_pred.detach(),
+ pos_decode_bbox_targets,
+ is_aligned=True)
+ pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
+ pos_soft_targets = soft_targets[pos_inds]
+ soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1)
+
+ target_corners = self.bbox_coder.encode(pos_anchor_centers,
+ pos_decode_bbox_targets,
+ self.reg_max).reshape(-1)
+
+ # regression loss
+ loss_bbox = self.loss_bbox(
+ pos_decode_bbox_pred,
+ pos_decode_bbox_targets,
+ weight=weight_targets,
+ avg_factor=1.0)
+
+ # dfl loss
+ loss_dfl = self.loss_dfl(
+ pred_corners,
+ target_corners,
+ weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
+ avg_factor=4.0)
+
+ # ld loss
+ loss_ld = self.loss_ld(
+ pred_corners,
+ soft_corners,
+ weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
+ avg_factor=4.0)
+
+ else:
+ loss_ld = bbox_pred.sum() * 0
+ loss_bbox = bbox_pred.sum() * 0
+ loss_dfl = bbox_pred.sum() * 0
+ weight_targets = bbox_pred.new_tensor(0)
+
+ # cls (qfl) loss
+ loss_cls = self.loss_cls(
+ cls_score, (labels, score),
+ weight=label_weights,
+ avg_factor=num_total_samples)
+
+ return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
+
+ def forward_train(self,
+ x,
+ out_teacher,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ proposal_cfg=None,
+ **kwargs):
+ """
+ Args:
+ x (list[Tensor]): Features from FPN.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ proposal_cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used
+
+ Returns:
+ tuple[dict, list]: The loss components and proposals of each image.
+
+ - losses (dict[str, Tensor]): A dictionary of loss components.
+ - proposal_list (list[Tensor]): Proposals of each image.
+ """
+ outs = self(x)
+ soft_target = out_teacher[1]
+ if gt_labels is None:
+ loss_inputs = outs + (gt_bboxes, soft_target, img_metas)
+ else:
+ loss_inputs = outs + (gt_bboxes, gt_labels, soft_target, img_metas)
+ losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+ if proposal_cfg is None:
+ return losses
+ else:
+ proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
+ return losses, proposal_list
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ soft_target,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Cls and quality scores for each scale
+ level has shape (N, num_classes, H, W).
+ bbox_preds (list[Tensor]): Box distribution logits for each scale
+ level with shape (N, 4*(n+1), H, W), n is max value of integral
+ set.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ losses_cls, losses_bbox, losses_dfl, losses_ld, \
+ avg_factor = multi_apply(
+ self.loss_single,
+ anchor_list,
+ cls_scores,
+ bbox_preds,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ self.prior_generator.strides,
+ soft_target,
+ num_total_samples=num_total_samples)
+
+ avg_factor = sum(avg_factor) + 1e-6
+ avg_factor = reduce_mean(avg_factor).item()
+ losses_bbox = [x / avg_factor for x in losses_bbox]
+ losses_dfl = [x / avg_factor for x in losses_dfl]
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox=losses_bbox,
+ loss_dfl=losses_dfl,
+ loss_ld=losses_ld)
diff --git a/mmdet/models/dense_heads/mask2former_head.py b/mmdet/models/dense_heads/mask2former_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..59047bdbb7939ba4fe7bcbdb0d0b165e408ed7be
--- /dev/null
+++ b/mmdet/models/dense_heads/mask2former_head.py
@@ -0,0 +1,430 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
+from mmcv.cnn.bricks.transformer import (build_positional_encoding,
+ build_transformer_layer_sequence)
+from mmcv.ops import point_sample
+from mmcv.runner import ModuleList
+
+from mmdet.core import build_assigner, build_sampler, reduce_mean
+from mmdet.models.utils import get_uncertain_point_coords_with_randomness
+from ..builder import HEADS, build_loss
+from .anchor_free_head import AnchorFreeHead
+from .maskformer_head import MaskFormerHead
+
+
+@HEADS.register_module()
+class Mask2FormerHead(MaskFormerHead):
+ """Implements the Mask2Former head.
+
+ See `Masked-attention Mask Transformer for Universal Image
+ Segmentation `_ for details.
+
+ Args:
+ in_channels (list[int]): Number of channels in the input feature map.
+ feat_channels (int): Number of channels for features.
+ out_channels (int): Number of channels for output.
+ num_things_classes (int): Number of things.
+ num_stuff_classes (int): Number of stuff.
+ num_queries (int): Number of query in Transformer decoder.
+ pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
+ decoder. Defaults to None.
+ enforce_decoder_input_project (bool, optional): Whether to add
+ a layer to change the embed_dim of tranformer encoder in
+ pixel decoder to the embed_dim of transformer decoder.
+ Defaults to False.
+ transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer decoder. Defaults to None.
+ positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer decoder position encoding. Defaults to None.
+ loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
+ loss. Defaults to None.
+ loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
+ Defaults to None.
+ loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
+ Defaults to None.
+ train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
+ Mask2Former head.
+ test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
+ Mask2Former head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ feat_channels,
+ out_channels,
+ num_things_classes=80,
+ num_stuff_classes=53,
+ num_queries=100,
+ num_transformer_feat_level=3,
+ pixel_decoder=None,
+ enforce_decoder_input_project=False,
+ transformer_decoder=None,
+ positional_encoding=None,
+ loss_cls=None,
+ loss_mask=None,
+ loss_dice=None,
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=None,
+ **kwargs):
+ super(AnchorFreeHead, self).__init__(init_cfg)
+ self.num_things_classes = num_things_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.num_classes = self.num_things_classes + self.num_stuff_classes
+ self.num_queries = num_queries
+ self.num_transformer_feat_level = num_transformer_feat_level
+ self.num_heads = transformer_decoder.transformerlayers.\
+ attn_cfgs.num_heads
+ self.num_transformer_decoder_layers = transformer_decoder.num_layers
+ assert pixel_decoder.encoder.transformerlayers.\
+ attn_cfgs.num_levels == num_transformer_feat_level
+ pixel_decoder_ = copy.deepcopy(pixel_decoder)
+ pixel_decoder_.update(
+ in_channels=in_channels,
+ feat_channels=feat_channels,
+ out_channels=out_channels)
+ self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
+ self.transformer_decoder = build_transformer_layer_sequence(
+ transformer_decoder)
+ self.decoder_embed_dims = self.transformer_decoder.embed_dims
+
+ self.decoder_input_projs = ModuleList()
+ # from low resolution to high resolution
+ for _ in range(num_transformer_feat_level):
+ if (self.decoder_embed_dims != feat_channels
+ or enforce_decoder_input_project):
+ self.decoder_input_projs.append(
+ Conv2d(
+ feat_channels, self.decoder_embed_dims, kernel_size=1))
+ else:
+ self.decoder_input_projs.append(nn.Identity())
+ self.decoder_positional_encoding = build_positional_encoding(
+ positional_encoding)
+ self.query_embed = nn.Embedding(self.num_queries, feat_channels)
+ self.query_feat = nn.Embedding(self.num_queries, feat_channels)
+ # from low resolution to high resolution
+ self.level_embed = nn.Embedding(self.num_transformer_feat_level,
+ feat_channels)
+
+ self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
+ self.mask_embed = nn.Sequential(
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
+ nn.Linear(feat_channels, out_channels))
+
+ self.test_cfg = test_cfg
+ self.train_cfg = train_cfg
+ if train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ self.sampler = build_sampler(self.train_cfg.sampler, context=self)
+ self.num_points = self.train_cfg.get('num_points', 12544)
+ self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
+ self.importance_sample_ratio = self.train_cfg.get(
+ 'importance_sample_ratio', 0.75)
+
+ self.class_weight = loss_cls.class_weight
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_mask = build_loss(loss_mask)
+ self.loss_dice = build_loss(loss_dice)
+
+ def init_weights(self):
+ for m in self.decoder_input_projs:
+ if isinstance(m, Conv2d):
+ caffe2_xavier_init(m, bias=0)
+
+ self.pixel_decoder.init_weights()
+
+ for p in self.transformer_decoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_normal_(p)
+
+ def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks,
+ img_metas):
+ """Compute classification and mask targets for one image.
+
+ Args:
+ cls_score (Tensor): Mask score logits from a single decoder layer
+ for one image. Shape (num_queries, cls_out_channels).
+ mask_pred (Tensor): Mask logits for a single decoder layer for one
+ image. Shape (num_queries, h, w).
+ gt_labels (Tensor): Ground truth class indices for one image with
+ shape (num_gts, ).
+ gt_masks (Tensor): Ground truth mask for each image, each with
+ shape (num_gts, h, w).
+ img_metas (dict): Image informtation.
+
+ Returns:
+ tuple[Tensor]: A tuple containing the following for one image.
+
+ - labels (Tensor): Labels of each image. \
+ shape (num_queries, ).
+ - label_weights (Tensor): Label weights of each image. \
+ shape (num_queries, ).
+ - mask_targets (Tensor): Mask targets of each image. \
+ shape (num_queries, h, w).
+ - mask_weights (Tensor): Mask weights of each image. \
+ shape (num_queries, ).
+ - pos_inds (Tensor): Sampled positive indices for each \
+ image.
+ - neg_inds (Tensor): Sampled negative indices for each \
+ image.
+ """
+ # sample points
+ num_queries = cls_score.shape[0]
+ num_gts = gt_labels.shape[0]
+
+ point_coords = torch.rand((1, self.num_points, 2),
+ device=cls_score.device)
+ # shape (num_queries, num_points)
+ mask_points_pred = point_sample(
+ mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
+ 1)).squeeze(1)
+ # shape (num_gts, num_points)
+ gt_points_masks = point_sample(
+ gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
+ 1)).squeeze(1)
+
+ # assign and sample
+ assign_result = self.assigner.assign(cls_score, mask_points_pred,
+ gt_labels, gt_points_masks,
+ img_metas)
+ sampling_result = self.sampler.sample(assign_result, mask_pred,
+ gt_masks)
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ # label target
+ labels = gt_labels.new_full((self.num_queries, ),
+ self.num_classes,
+ dtype=torch.long)
+ labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
+ label_weights = gt_labels.new_ones((self.num_queries, ))
+
+ # mask target
+ mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
+ mask_weights = mask_pred.new_zeros((self.num_queries, ))
+ mask_weights[pos_inds] = 1.0
+
+ return (labels, label_weights, mask_targets, mask_weights, pos_inds,
+ neg_inds)
+
+ def loss_single(self, cls_scores, mask_preds, gt_labels_list,
+ gt_masks_list, img_metas):
+ """Loss function for outputs from a single decoder layer.
+
+ Args:
+ cls_scores (Tensor): Mask score logits from a single decoder layer
+ for all images. Shape (batch_size, num_queries,
+ cls_out_channels). Note `cls_out_channels` should includes
+ background.
+ mask_preds (Tensor): Mask logits for a pixel decoder for all
+ images. Shape (batch_size, num_queries, h, w).
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image, each with shape (num_gts, ).
+ gt_masks_list (list[Tensor]): Ground truth mask for each image,
+ each with shape (num_gts, h, w).
+ img_metas (list[dict]): List of image meta information.
+
+ Returns:
+ tuple[Tensor]: Loss components for outputs from a single \
+ decoder layer.
+ """
+ num_imgs = cls_scores.size(0)
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
+ (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
+ num_total_pos,
+ num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list,
+ gt_labels_list, gt_masks_list,
+ img_metas)
+ # shape (batch_size, num_queries)
+ labels = torch.stack(labels_list, dim=0)
+ # shape (batch_size, num_queries)
+ label_weights = torch.stack(label_weights_list, dim=0)
+ # shape (num_total_gts, h, w)
+ mask_targets = torch.cat(mask_targets_list, dim=0)
+ # shape (batch_size, num_queries)
+ mask_weights = torch.stack(mask_weights_list, dim=0)
+
+ # classfication loss
+ # shape (batch_size * num_queries, )
+ cls_scores = cls_scores.flatten(0, 1)
+ labels = labels.flatten(0, 1)
+ label_weights = label_weights.flatten(0, 1)
+
+ class_weight = cls_scores.new_tensor(self.class_weight)
+ loss_cls = self.loss_cls(
+ cls_scores,
+ labels,
+ label_weights,
+ avg_factor=class_weight[labels].sum())
+
+ num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
+ num_total_masks = max(num_total_masks, 1)
+
+ # extract positive ones
+ # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
+ mask_preds = mask_preds[mask_weights > 0]
+
+ if mask_targets.shape[0] == 0:
+ # zero match
+ loss_dice = mask_preds.sum()
+ loss_mask = mask_preds.sum()
+ return loss_cls, loss_mask, loss_dice
+
+ with torch.no_grad():
+ points_coords = get_uncertain_point_coords_with_randomness(
+ mask_preds.unsqueeze(1), None, self.num_points,
+ self.oversample_ratio, self.importance_sample_ratio)
+ # shape (num_total_gts, h, w) -> (num_total_gts, num_points)
+ mask_point_targets = point_sample(
+ mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
+ # shape (num_queries, h, w) -> (num_queries, num_points)
+ mask_point_preds = point_sample(
+ mask_preds.unsqueeze(1), points_coords).squeeze(1)
+
+ # dice loss
+ loss_dice = self.loss_dice(
+ mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
+
+ # mask loss
+ # shape (num_queries, num_points) -> (num_queries * num_points, )
+ mask_point_preds = mask_point_preds.reshape(-1)
+ # shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
+ mask_point_targets = mask_point_targets.reshape(-1)
+ loss_mask = self.loss_mask(
+ mask_point_preds,
+ mask_point_targets,
+ avg_factor=num_total_masks * self.num_points)
+
+ return loss_cls, loss_mask, loss_dice
+
+ def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
+ """Forward for head part which is called after every decoder layer.
+
+ Args:
+ decoder_out (Tensor): in shape (num_queries, batch_size, c).
+ mask_feature (Tensor): in shape (batch_size, c, h, w).
+ attn_mask_target_size (tuple[int, int]): target attention
+ mask size.
+
+ Returns:
+ tuple: A tuple contain three elements.
+
+ - cls_pred (Tensor): Classification scores in shape \
+ (batch_size, num_queries, cls_out_channels). \
+ Note `cls_out_channels` should includes background.
+ - mask_pred (Tensor): Mask scores in shape \
+ (batch_size, num_queries,h, w).
+ - attn_mask (Tensor): Attention mask in shape \
+ (batch_size * num_heads, num_queries, h, w).
+ """
+ decoder_out = self.transformer_decoder.post_norm(decoder_out)
+ decoder_out = decoder_out.transpose(0, 1)
+ # shape (batch_size, num_queries, c)
+ cls_pred = self.cls_embed(decoder_out)
+ # shape (batch_size, num_queries, c)
+ mask_embed = self.mask_embed(decoder_out)
+ # shape (batch_size, num_queries, h, w)
+ mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature)
+ attn_mask = F.interpolate(
+ mask_pred,
+ attn_mask_target_size,
+ mode='bilinear',
+ align_corners=False)
+ # shape (batch_size, num_queries, h, w) ->
+ # (batch_size * num_head, num_queries, h*w)
+ attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat(
+ (1, self.num_heads, 1, 1)).flatten(0, 1)
+ attn_mask = attn_mask.sigmoid() < 0.5
+ attn_mask = attn_mask.detach()
+
+ return cls_pred, mask_pred, attn_mask
+
+ def forward(self, feats, img_metas):
+ """Forward function.
+
+ Args:
+ feats (list[Tensor]): Multi scale Features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple: A tuple contains two elements.
+
+ - cls_pred_list (list[Tensor)]: Classification logits \
+ for each decoder layer. Each is a 3D-tensor with shape \
+ (batch_size, num_queries, cls_out_channels). \
+ Note `cls_out_channels` should includes background.
+ - mask_pred_list (list[Tensor]): Mask logits for each \
+ decoder layer. Each with shape (batch_size, num_queries, \
+ h, w).
+ """
+ batch_size = len(img_metas)
+ mask_features, multi_scale_memorys = self.pixel_decoder(feats)
+ # multi_scale_memorys (from low resolution to high resolution)
+ decoder_inputs = []
+ decoder_positional_encodings = []
+ for i in range(self.num_transformer_feat_level):
+ decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
+ # shape (batch_size, c, h, w) -> (h*w, batch_size, c)
+ decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
+ level_embed = self.level_embed.weight[i].view(1, 1, -1)
+ decoder_input = decoder_input + level_embed
+ # shape (batch_size, c, h, w) -> (h*w, batch_size, c)
+ mask = decoder_input.new_zeros(
+ (batch_size, ) + multi_scale_memorys[i].shape[-2:],
+ dtype=torch.bool)
+ decoder_positional_encoding = self.decoder_positional_encoding(
+ mask)
+ decoder_positional_encoding = decoder_positional_encoding.flatten(
+ 2).permute(2, 0, 1)
+ decoder_inputs.append(decoder_input)
+ decoder_positional_encodings.append(decoder_positional_encoding)
+ # shape (num_queries, c) -> (num_queries, batch_size, c)
+ query_feat = self.query_feat.weight.unsqueeze(1).repeat(
+ (1, batch_size, 1))
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(
+ (1, batch_size, 1))
+
+ cls_pred_list = []
+ mask_pred_list = []
+ cls_pred, mask_pred, attn_mask = self.forward_head(
+ query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
+ cls_pred_list.append(cls_pred)
+ mask_pred_list.append(mask_pred)
+
+ for i in range(self.num_transformer_decoder_layers):
+ level_idx = i % self.num_transformer_feat_level
+ # if a mask is all True(all background), then set it all False.
+ attn_mask[torch.where(
+ attn_mask.sum(-1) == attn_mask.shape[-1])] = False
+
+ # cross_attn + self_attn
+ layer = self.transformer_decoder.layers[i]
+ attn_masks = [attn_mask, None]
+ query_feat = layer(
+ query=query_feat,
+ key=decoder_inputs[level_idx],
+ value=decoder_inputs[level_idx],
+ query_pos=query_embed,
+ key_pos=decoder_positional_encodings[level_idx],
+ attn_masks=attn_masks,
+ query_key_padding_mask=None,
+ # here we do not apply masking on padded region
+ key_padding_mask=None)
+ cls_pred, mask_pred, attn_mask = self.forward_head(
+ query_feat, mask_features, multi_scale_memorys[
+ (i + 1) % self.num_transformer_feat_level].shape[-2:])
+
+ cls_pred_list.append(cls_pred)
+ mask_pred_list.append(mask_pred)
+
+ return cls_pred_list, mask_pred_list
diff --git a/mmdet/models/dense_heads/maskformer_head.py b/mmdet/models/dense_heads/maskformer_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..566dc074059ef770892d2916e7c44fa54b0f8758
--- /dev/null
+++ b/mmdet/models/dense_heads/maskformer_head.py
@@ -0,0 +1,556 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
+from mmcv.cnn.bricks.transformer import (build_positional_encoding,
+ build_transformer_layer_sequence)
+from mmcv.runner import force_fp32
+
+from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
+from mmdet.models.utils import preprocess_panoptic_gt
+from ..builder import HEADS, build_loss
+from .anchor_free_head import AnchorFreeHead
+
+
+@HEADS.register_module()
+class MaskFormerHead(AnchorFreeHead):
+ """Implements the MaskFormer head.
+
+ See `Per-Pixel Classification is Not All You Need for Semantic
+ Segmentation `_ for details.
+
+ Args:
+ in_channels (list[int]): Number of channels in the input feature map.
+ feat_channels (int): Number of channels for feature.
+ out_channels (int): Number of channels for output.
+ num_things_classes (int): Number of things.
+ num_stuff_classes (int): Number of stuff.
+ num_queries (int): Number of query in Transformer.
+ pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
+ decoder. Defaults to None.
+ enforce_decoder_input_project (bool, optional): Whether to add a layer
+ to change the embed_dim of tranformer encoder in pixel decoder to
+ the embed_dim of transformer decoder. Defaults to False.
+ transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer decoder. Defaults to None.
+ positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer decoder position encoding. Defaults to None.
+ loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
+ loss. Defaults to `CrossEntropyLoss`.
+ loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
+ Defaults to `FocalLoss`.
+ loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
+ Defaults to `DiceLoss`.
+ train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
+ Maskformer head.
+ test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of Maskformer
+ head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_channels,
+ feat_channels,
+ out_channels,
+ num_things_classes=80,
+ num_stuff_classes=53,
+ num_queries=100,
+ pixel_decoder=None,
+ enforce_decoder_input_project=False,
+ transformer_decoder=None,
+ positional_encoding=None,
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0,
+ class_weight=[1.0] * 133 + [0.1]),
+ loss_mask=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=20.0),
+ loss_dice=dict(
+ type='DiceLoss',
+ use_sigmoid=True,
+ activate=True,
+ naive_dice=True,
+ loss_weight=1.0),
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=None,
+ **kwargs):
+ super(AnchorFreeHead, self).__init__(init_cfg)
+ self.num_things_classes = num_things_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.num_classes = self.num_things_classes + self.num_stuff_classes
+ self.num_queries = num_queries
+
+ pixel_decoder.update(
+ in_channels=in_channels,
+ feat_channels=feat_channels,
+ out_channels=out_channels)
+ self.pixel_decoder = build_plugin_layer(pixel_decoder)[1]
+ self.transformer_decoder = build_transformer_layer_sequence(
+ transformer_decoder)
+ self.decoder_embed_dims = self.transformer_decoder.embed_dims
+ pixel_decoder_type = pixel_decoder.get('type')
+ if pixel_decoder_type == 'PixelDecoder' and (
+ self.decoder_embed_dims != in_channels[-1]
+ or enforce_decoder_input_project):
+ self.decoder_input_proj = Conv2d(
+ in_channels[-1], self.decoder_embed_dims, kernel_size=1)
+ else:
+ self.decoder_input_proj = nn.Identity()
+ self.decoder_pe = build_positional_encoding(positional_encoding)
+ self.query_embed = nn.Embedding(self.num_queries, out_channels)
+
+ self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
+ self.mask_embed = nn.Sequential(
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
+ nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
+ nn.Linear(feat_channels, out_channels))
+
+ self.test_cfg = test_cfg
+ self.train_cfg = train_cfg
+ if train_cfg:
+ self.assigner = build_assigner(train_cfg.get('assigner', None))
+ self.sampler = build_sampler(
+ train_cfg.get('sampler', None), context=self)
+
+ self.class_weight = loss_cls.get('class_weight', None)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_mask = build_loss(loss_mask)
+ self.loss_dice = build_loss(loss_dice)
+
+ def init_weights(self):
+ if isinstance(self.decoder_input_proj, Conv2d):
+ caffe2_xavier_init(self.decoder_input_proj, bias=0)
+
+ self.pixel_decoder.init_weights()
+
+ for p in self.transformer_decoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def preprocess_gt(self, gt_labels_list, gt_masks_list, gt_semantic_segs,
+ img_metas):
+ """Preprocess the ground truth for all images.
+
+ Args:
+ gt_labels_list (list[Tensor]): Each is ground truth
+ labels of each bbox, with shape (num_gts, ).
+ gt_masks_list (list[BitmapMasks]): Each is ground truth
+ masks of each instances of a image, shape
+ (num_gts, h, w).
+ gt_semantic_seg (Tensor | None): Ground truth of semantic
+ segmentation with the shape (batch_size, n, h, w).
+ [0, num_thing_class - 1] means things,
+ [num_thing_class, num_class-1] means stuff,
+ 255 means VOID. It's None when training instance segmentation.
+ img_metas (list[dict]): List of image meta information.
+
+ Returns:
+ tuple: a tuple containing the following targets.
+ - labels (list[Tensor]): Ground truth class indices\
+ for all images. Each with shape (n, ), n is the sum of\
+ number of stuff type and number of instance in a image.
+ - masks (list[Tensor]): Ground truth mask for each\
+ image, each with shape (n, h, w).
+ """
+ num_things_list = [self.num_things_classes] * len(gt_labels_list)
+ num_stuff_list = [self.num_stuff_classes] * len(gt_labels_list)
+ if gt_semantic_segs is None:
+ gt_semantic_segs = [None] * len(gt_labels_list)
+
+ targets = multi_apply(preprocess_panoptic_gt, gt_labels_list,
+ gt_masks_list, gt_semantic_segs, num_things_list,
+ num_stuff_list, img_metas)
+ labels, masks = targets
+ return labels, masks
+
+ def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list,
+ gt_masks_list, img_metas):
+ """Compute classification and mask targets for all images for a decoder
+ layer.
+
+ Args:
+ cls_scores_list (list[Tensor]): Mask score logits from a single
+ decoder layer for all images. Each with shape (num_queries,
+ cls_out_channels).
+ mask_preds_list (list[Tensor]): Mask logits from a single decoder
+ layer for all images. Each with shape (num_queries, h, w).
+ gt_labels_list (list[Tensor]): Ground truth class indices for all
+ images. Each with shape (n, ), n is the sum of number of stuff
+ type and number of instance in a image.
+ gt_masks_list (list[Tensor]): Ground truth mask for each image,
+ each with shape (n, h, w).
+ img_metas (list[dict]): List of image meta information.
+
+ Returns:
+ tuple[list[Tensor]]: a tuple containing the following targets.
+ - labels_list (list[Tensor]): Labels of all images.\
+ Each with shape (num_queries, ).
+ - label_weights_list (list[Tensor]): Label weights\
+ of all images. Each with shape (num_queries, ).
+ - mask_targets_list (list[Tensor]): Mask targets of\
+ all images. Each with shape (num_queries, h, w).
+ - mask_weights_list (list[Tensor]): Mask weights of\
+ all images. Each with shape (num_queries, ).
+ - num_total_pos (int): Number of positive samples in\
+ all images.
+ - num_total_neg (int): Number of negative samples in\
+ all images.
+ """
+ (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
+ pos_inds_list,
+ neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list,
+ mask_preds_list, gt_labels_list,
+ gt_masks_list, img_metas)
+
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+ return (labels_list, label_weights_list, mask_targets_list,
+ mask_weights_list, num_total_pos, num_total_neg)
+
+ def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks,
+ img_metas):
+ """Compute classification and mask targets for one image.
+
+ Args:
+ cls_score (Tensor): Mask score logits from a single decoder layer
+ for one image. Shape (num_queries, cls_out_channels).
+ mask_pred (Tensor): Mask logits for a single decoder layer for one
+ image. Shape (num_queries, h, w).
+ gt_labels (Tensor): Ground truth class indices for one image with
+ shape (n, ). n is the sum of number of stuff type and number
+ of instance in a image.
+ gt_masks (Tensor): Ground truth mask for each image, each with
+ shape (n, h, w).
+ img_metas (dict): Image informtation.
+
+ Returns:
+ tuple[Tensor]: a tuple containing the following for one image.
+ - labels (Tensor): Labels of each image.
+ shape (num_queries, ).
+ - label_weights (Tensor): Label weights of each image.
+ shape (num_queries, ).
+ - mask_targets (Tensor): Mask targets of each image.
+ shape (num_queries, h, w).
+ - mask_weights (Tensor): Mask weights of each image.
+ shape (num_queries, ).
+ - pos_inds (Tensor): Sampled positive indices for each image.
+ - neg_inds (Tensor): Sampled negative indices for each image.
+ """
+ target_shape = mask_pred.shape[-2:]
+ if gt_masks.shape[0] > 0:
+ gt_masks_downsampled = F.interpolate(
+ gt_masks.unsqueeze(1).float(), target_shape,
+ mode='nearest').squeeze(1).long()
+ else:
+ gt_masks_downsampled = gt_masks
+
+ # assign and sample
+ assign_result = self.assigner.assign(cls_score, mask_pred, gt_labels,
+ gt_masks_downsampled, img_metas)
+ sampling_result = self.sampler.sample(assign_result, mask_pred,
+ gt_masks)
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ # label target
+ labels = gt_labels.new_full((self.num_queries, ),
+ self.num_classes,
+ dtype=torch.long)
+ labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
+ label_weights = gt_labels.new_ones(self.num_queries)
+
+ # mask target
+ mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
+ mask_weights = mask_pred.new_zeros((self.num_queries, ))
+ mask_weights[pos_inds] = 1.0
+
+ return (labels, label_weights, mask_targets, mask_weights, pos_inds,
+ neg_inds)
+
+ @force_fp32(apply_to=('all_cls_scores', 'all_mask_preds'))
+ def loss(self, all_cls_scores, all_mask_preds, gt_labels_list,
+ gt_masks_list, img_metas):
+ """Loss function.
+
+ Args:
+ all_cls_scores (Tensor): Classification scores for all decoder
+ layers with shape (num_decoder, batch_size, num_queries,
+ cls_out_channels). Note `cls_out_channels` should includes
+ background.
+ all_mask_preds (Tensor): Mask scores for all decoder layers with
+ shape (num_decoder, batch_size, num_queries, h, w).
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (n, ). n is the sum of number of stuff type
+ and number of instance in a image.
+ gt_masks_list (list[Tensor]): Ground truth mask for each image with
+ shape (n, h, w).
+ img_metas (list[dict]): List of image meta information.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_dec_layers = len(all_cls_scores)
+ all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+ all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
+ losses_cls, losses_mask, losses_dice = multi_apply(
+ self.loss_single, all_cls_scores, all_mask_preds,
+ all_gt_labels_list, all_gt_masks_list, img_metas_list)
+
+ loss_dict = dict()
+ # loss from the last decoder layer
+ loss_dict['loss_cls'] = losses_cls[-1]
+ loss_dict['loss_mask'] = losses_mask[-1]
+ loss_dict['loss_dice'] = losses_dice[-1]
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for loss_cls_i, loss_mask_i, loss_dice_i in zip(
+ losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
+ loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
+ loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i
+ loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i
+ num_dec_layer += 1
+ return loss_dict
+
+ def loss_single(self, cls_scores, mask_preds, gt_labels_list,
+ gt_masks_list, img_metas):
+ """Loss function for outputs from a single decoder layer.
+
+ Args:
+ cls_scores (Tensor): Mask score logits from a single decoder layer
+ for all images. Shape (batch_size, num_queries,
+ cls_out_channels). Note `cls_out_channels` should includes
+ background.
+ mask_preds (Tensor): Mask logits for a pixel decoder for all
+ images. Shape (batch_size, num_queries, h, w).
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image, each with shape (n, ). n is the sum of number of stuff
+ types and number of instances in a image.
+ gt_masks_list (list[Tensor]): Ground truth mask for each image,
+ each with shape (n, h, w).
+ img_metas (list[dict]): List of image meta information.
+
+ Returns:
+ tuple[Tensor]: Loss components for outputs from a single decoder\
+ layer.
+ """
+ num_imgs = cls_scores.size(0)
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
+
+ (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
+ num_total_pos,
+ num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list,
+ gt_labels_list, gt_masks_list,
+ img_metas)
+ # shape (batch_size, num_queries)
+ labels = torch.stack(labels_list, dim=0)
+ # shape (batch_size, num_queries)
+ label_weights = torch.stack(label_weights_list, dim=0)
+ # shape (num_total_gts, h, w)
+ mask_targets = torch.cat(mask_targets_list, dim=0)
+ # shape (batch_size, num_queries)
+ mask_weights = torch.stack(mask_weights_list, dim=0)
+
+ # classfication loss
+ # shape (batch_size * num_queries, )
+ cls_scores = cls_scores.flatten(0, 1)
+ labels = labels.flatten(0, 1)
+ label_weights = label_weights.flatten(0, 1)
+
+ class_weight = cls_scores.new_tensor(self.class_weight)
+ loss_cls = self.loss_cls(
+ cls_scores,
+ labels,
+ label_weights,
+ avg_factor=class_weight[labels].sum())
+
+ num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
+ num_total_masks = max(num_total_masks, 1)
+
+ # extract positive ones
+ # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
+ mask_preds = mask_preds[mask_weights > 0]
+ target_shape = mask_targets.shape[-2:]
+
+ if mask_targets.shape[0] == 0:
+ # zero match
+ loss_dice = mask_preds.sum()
+ loss_mask = mask_preds.sum()
+ return loss_cls, loss_mask, loss_dice
+
+ # upsample to shape of target
+ # shape (num_total_gts, h, w)
+ mask_preds = F.interpolate(
+ mask_preds.unsqueeze(1),
+ target_shape,
+ mode='bilinear',
+ align_corners=False).squeeze(1)
+
+ # dice loss
+ loss_dice = self.loss_dice(
+ mask_preds, mask_targets, avg_factor=num_total_masks)
+
+ # mask loss
+ # FocalLoss support input of shape (n, num_class)
+ h, w = mask_preds.shape[-2:]
+ # shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1)
+ mask_preds = mask_preds.reshape(-1, 1)
+ # shape (num_total_gts, h, w) -> (num_total_gts * h * w)
+ mask_targets = mask_targets.reshape(-1)
+ # target is (1 - mask_targets) !!!
+ loss_mask = self.loss_mask(
+ mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w)
+
+ return loss_cls, loss_mask, loss_dice
+
+ def forward(self, feats, img_metas):
+ """Forward function.
+
+ Args:
+ feats (list[Tensor]): Features from the upstream network, each
+ is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple: a tuple contains two elements.
+ - all_cls_scores (Tensor): Classification scores for each\
+ scale level. Each is a 4D-tensor with shape\
+ (num_decoder, batch_size, num_queries, cls_out_channels).\
+ Note `cls_out_channels` should includes background.
+ - all_mask_preds (Tensor): Mask scores for each decoder\
+ layer. Each with shape (num_decoder, batch_size,\
+ num_queries, h, w).
+ """
+ batch_size = len(img_metas)
+ input_img_h, input_img_w = img_metas[0]['batch_input_shape']
+ padding_mask = feats[-1].new_ones(
+ (batch_size, input_img_h, input_img_w), dtype=torch.float32)
+ for i in range(batch_size):
+ img_h, img_w, _ = img_metas[i]['img_shape']
+ padding_mask[i, :img_h, :img_w] = 0
+ padding_mask = F.interpolate(
+ padding_mask.unsqueeze(1),
+ size=feats[-1].shape[-2:],
+ mode='nearest').to(torch.bool).squeeze(1)
+ # when backbone is swin, memory is output of last stage of swin.
+ # when backbone is r50, memory is output of tranformer encoder.
+ mask_features, memory = self.pixel_decoder(feats, img_metas)
+ pos_embed = self.decoder_pe(padding_mask)
+ memory = self.decoder_input_proj(memory)
+ # shape (batch_size, c, h, w) -> (h*w, batch_size, c)
+ memory = memory.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ # shape (batch_size, h * w)
+ padding_mask = padding_mask.flatten(1)
+ # shape = (num_queries, embed_dims)
+ query_embed = self.query_embed.weight
+ # shape = (num_queries, batch_size, embed_dims)
+ query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1)
+ target = torch.zeros_like(query_embed)
+ # shape (num_decoder, num_queries, batch_size, embed_dims)
+ out_dec = self.transformer_decoder(
+ query=target,
+ key=memory,
+ value=memory,
+ key_pos=pos_embed,
+ query_pos=query_embed,
+ key_padding_mask=padding_mask)
+ # shape (num_decoder, batch_size, num_queries, embed_dims)
+ out_dec = out_dec.transpose(1, 2)
+
+ # cls_scores
+ all_cls_scores = self.cls_embed(out_dec)
+
+ # mask_preds
+ mask_embed = self.mask_embed(out_dec)
+ all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed,
+ mask_features)
+
+ return all_cls_scores, all_mask_preds
+
+ def forward_train(self,
+ feats,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_masks,
+ gt_semantic_seg,
+ gt_bboxes_ignore=None):
+ """Forward function for training mode.
+
+ Args:
+ feats (list[Tensor]): Multi-level features from the upstream
+ network, each is a 4D-tensor.
+ img_metas (list[Dict]): List of image information.
+ gt_bboxes (list[Tensor]): Each element is ground truth bboxes of
+ the image, shape (num_gts, 4). Not used here.
+ gt_labels (list[Tensor]): Each element is ground truth labels of
+ each box, shape (num_gts,).
+ gt_masks (list[BitmapMasks]): Each element is masks of instances
+ of a image, shape (num_gts, h, w).
+ gt_semantic_seg (list[tensor] | None): Each element is the ground
+ truth of semantic segmentation with the shape (N, H, W).
+ [0, num_thing_class - 1] means things,
+ [num_thing_class, num_class-1] means stuff,
+ 255 means VOID. It's None when training instance segmentation.
+ gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be
+ ignored. Defaults to None.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # not consider ignoring bboxes
+ assert gt_bboxes_ignore is None
+
+ # forward
+ all_cls_scores, all_mask_preds = self(feats, img_metas)
+
+ # preprocess ground truth
+ gt_labels, gt_masks = self.preprocess_gt(gt_labels, gt_masks,
+ gt_semantic_seg, img_metas)
+
+ # loss
+ losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks,
+ img_metas)
+
+ return losses
+
+ def simple_test(self, feats, img_metas, **kwargs):
+ """Test without augmentaton.
+
+ Args:
+ feats (list[Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple: A tuple contains two tensors.
+
+ - mask_cls_results (Tensor): Mask classification logits,\
+ shape (batch_size, num_queries, cls_out_channels).
+ Note `cls_out_channels` should includes background.
+ - mask_pred_results (Tensor): Mask logits, shape \
+ (batch_size, num_queries, h, w).
+ """
+ all_cls_scores, all_mask_preds = self(feats, img_metas)
+ mask_cls_results = all_cls_scores[-1]
+ mask_pred_results = all_mask_preds[-1]
+
+ # upsample masks
+ img_shape = img_metas[0]['batch_input_shape']
+ mask_pred_results = F.interpolate(
+ mask_pred_results,
+ size=(img_shape[0], img_shape[1]),
+ mode='bilinear',
+ align_corners=False)
+
+ return mask_cls_results, mask_pred_results
diff --git a/mmdet/models/dense_heads/nasfcos_head.py b/mmdet/models/dense_heads/nasfcos_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..380c912c763445a32acad3be6da965966cd9ae53
--- /dev/null
+++ b/mmdet/models/dense_heads/nasfcos_head.py
@@ -0,0 +1,80 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule, Scale
+
+from mmdet.models.dense_heads.fcos_head import FCOSHead
+from ..builder import HEADS
+
+
+@HEADS.register_module()
+class NASFCOSHead(FCOSHead):
+ """Anchor-free head used in `NASFCOS `_.
+
+ It is quite similar with FCOS head, except for the searched structure of
+ classification branch and bbox regression branch, where a structure of
+ "dconv3x3, conv3x3, dconv3x3, conv1x1" is utilized instead.
+ """
+
+ def __init__(self, *args, init_cfg=None, **kwargs):
+ if init_cfg is None:
+ init_cfg = [
+ dict(type='Caffe2Xavier', layer=['ConvModule', 'Conv2d']),
+ dict(
+ type='Normal',
+ std=0.01,
+ override=[
+ dict(name='conv_reg'),
+ dict(name='conv_centerness'),
+ dict(
+ name='conv_cls',
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01)
+ ]),
+ ]
+ super(NASFCOSHead, self).__init__(*args, init_cfg=init_cfg, **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ dconv3x3_config = dict(
+ type='DCNv2',
+ kernel_size=3,
+ use_bias=True,
+ deform_groups=2,
+ padding=1)
+ conv3x3_config = dict(type='Conv', kernel_size=3, padding=1)
+ conv1x1_config = dict(type='Conv', kernel_size=1)
+
+ self.arch_config = [
+ dconv3x3_config, conv3x3_config, dconv3x3_config, conv1x1_config
+ ]
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i, op_ in enumerate(self.arch_config):
+ op = copy.deepcopy(op_)
+ chn = self.in_channels if i == 0 else self.feat_channels
+ assert isinstance(op, dict)
+ use_bias = op.pop('use_bias', False)
+ padding = op.pop('padding', 0)
+ kernel_size = op.pop('kernel_size')
+ module = ConvModule(
+ chn,
+ self.feat_channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ norm_cfg=self.norm_cfg,
+ bias=use_bias,
+ conv_cfg=op)
+
+ self.cls_convs.append(copy.deepcopy(module))
+ self.reg_convs.append(copy.deepcopy(module))
+
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+ self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
+
+ self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
diff --git a/mmdet/models/dense_heads/paa_head.py b/mmdet/models/dense_heads/paa_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d79b5b9f40778fb775b76919cadc80579fa00ba0
--- /dev/null
+++ b/mmdet/models/dense_heads/paa_head.py
@@ -0,0 +1,756 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.core import multi_apply, multiclass_nms
+from mmdet.core.bbox.iou_calculators import bbox_overlaps
+from mmdet.models import HEADS
+from mmdet.models.dense_heads import ATSSHead
+
+EPS = 1e-12
+try:
+ import sklearn.mixture as skm
+except ImportError:
+ skm = None
+
+
+def levels_to_images(mlvl_tensor):
+ """Concat multi-level feature maps by image.
+
+ [feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
+ Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
+ (N, H*W , C), then split the element to N elements with shape (H*W, C), and
+ concat elements in same image of all level along first dimension.
+
+ Args:
+ mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from
+ corresponding level. Each element is of shape (N, C, H, W)
+
+ Returns:
+ list[torch.Tensor]: A list that contains N tensors and each tensor is
+ of shape (num_elements, C)
+ """
+ batch_size = mlvl_tensor[0].size(0)
+ batch_list = [[] for _ in range(batch_size)]
+ channels = mlvl_tensor[0].size(1)
+ for t in mlvl_tensor:
+ t = t.permute(0, 2, 3, 1)
+ t = t.view(batch_size, -1, channels).contiguous()
+ for img in range(batch_size):
+ batch_list[img].append(t[img])
+ return [torch.cat(item, 0) for item in batch_list]
+
+
+@HEADS.register_module()
+class PAAHead(ATSSHead):
+ """Head of PAAAssignment: Probabilistic Anchor Assignment with IoU
+ Prediction for Object Detection.
+
+ Code is modified from the `official github repo
+ `_.
+
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ topk (int): Select topk samples with smallest loss in
+ each level.
+ score_voting (bool): Whether to use score voting in post-process.
+ covariance_type : String describing the type of covariance parameters
+ to be used in :class:`sklearn.mixture.GaussianMixture`.
+ It must be one of:
+
+ - 'full': each component has its own general covariance matrix
+ - 'tied': all components share the same general covariance matrix
+ - 'diag': each component has its own diagonal covariance matrix
+ - 'spherical': each component has its own single variance
+ Default: 'diag'. From 'full' to 'spherical', the gmm fitting
+ process is faster yet the performance could be influenced. For most
+ cases, 'diag' should be a good choice.
+ """
+
+ def __init__(self,
+ *args,
+ topk=9,
+ score_voting=True,
+ covariance_type='diag',
+ **kwargs):
+ # topk used in paa reassign process
+ self.topk = topk
+ self.with_score_voting = score_voting
+ self.covariance_type = covariance_type
+ super(PAAHead, self).__init__(*args, **kwargs)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ iou_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ iou_preds (list[Tensor]): iou_preds for each scale
+ level with shape (N, num_anchors * 1, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
+ boxes can be ignored when are computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss gmm_assignment.
+ """
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ )
+ (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds,
+ pos_gt_index) = cls_reg_targets
+ cls_scores = levels_to_images(cls_scores)
+ cls_scores = [
+ item.reshape(-1, self.cls_out_channels) for item in cls_scores
+ ]
+ bbox_preds = levels_to_images(bbox_preds)
+ bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
+ iou_preds = levels_to_images(iou_preds)
+ iou_preds = [item.reshape(-1, 1) for item in iou_preds]
+ pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list,
+ cls_scores, bbox_preds, labels,
+ labels_weight, bboxes_target,
+ bboxes_weight, pos_inds)
+
+ with torch.no_grad():
+ reassign_labels, reassign_label_weight, \
+ reassign_bbox_weights, num_pos = multi_apply(
+ self.paa_reassign,
+ pos_losses_list,
+ labels,
+ labels_weight,
+ bboxes_weight,
+ pos_inds,
+ pos_gt_index,
+ anchor_list)
+ num_pos = sum(num_pos)
+ # convert all tensor list to a flatten tensor
+ cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
+ bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1))
+ iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1))
+ labels = torch.cat(reassign_labels, 0).view(-1)
+ flatten_anchors = torch.cat(
+ [torch.cat(item, 0) for item in anchor_list])
+ labels_weight = torch.cat(reassign_label_weight, 0).view(-1)
+ bboxes_target = torch.cat(bboxes_target,
+ 0).view(-1, bboxes_target[0].size(-1))
+
+ pos_inds_flatten = ((labels >= 0)
+ &
+ (labels < self.num_classes)).nonzero().reshape(-1)
+
+ losses_cls = self.loss_cls(
+ cls_scores,
+ labels,
+ labels_weight,
+ avg_factor=max(num_pos, len(img_metas))) # avoid num_pos=0
+ if num_pos:
+ pos_bbox_pred = self.bbox_coder.decode(
+ flatten_anchors[pos_inds_flatten],
+ bbox_preds[pos_inds_flatten])
+ pos_bbox_target = bboxes_target[pos_inds_flatten]
+ iou_target = bbox_overlaps(
+ pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True)
+ losses_iou = self.loss_centerness(
+ iou_preds[pos_inds_flatten],
+ iou_target.unsqueeze(-1),
+ avg_factor=num_pos)
+ losses_bbox = self.loss_bbox(
+ pos_bbox_pred,
+ pos_bbox_target,
+ iou_target.clamp(min=EPS),
+ avg_factor=iou_target.sum())
+ else:
+ losses_iou = iou_preds.sum() * 0
+ losses_bbox = bbox_preds.sum() * 0
+
+ return dict(
+ loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou)
+
+ def get_pos_loss(self, anchors, cls_score, bbox_pred, label, label_weight,
+ bbox_target, bbox_weight, pos_inds):
+ """Calculate loss of all potential positive samples obtained from first
+ match process.
+
+ Args:
+ anchors (list[Tensor]): Anchors of each scale.
+ cls_score (Tensor): Box scores of single image with shape
+ (num_anchors, num_classes)
+ bbox_pred (Tensor): Box energies / deltas of single image
+ with shape (num_anchors, 4)
+ label (Tensor): classification target of each anchor with
+ shape (num_anchors,)
+ label_weight (Tensor): Classification loss weight of each
+ anchor with shape (num_anchors).
+ bbox_target (dict): Regression target of each anchor with
+ shape (num_anchors, 4).
+ bbox_weight (Tensor): Bbox weight of each anchor with shape
+ (num_anchors, 4).
+ pos_inds (Tensor): Index of all positive samples got from
+ first assign process.
+
+ Returns:
+ Tensor: Losses of all positive samples in single image.
+ """
+ if not len(pos_inds):
+ return cls_score.new([]),
+ anchors_all_level = torch.cat(anchors, 0)
+ pos_scores = cls_score[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_label = label[pos_inds]
+ pos_label_weight = label_weight[pos_inds]
+ pos_bbox_target = bbox_target[pos_inds]
+ pos_bbox_weight = bbox_weight[pos_inds]
+ pos_anchors = anchors_all_level[pos_inds]
+ pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
+
+ # to keep loss dimension
+ loss_cls = self.loss_cls(
+ pos_scores,
+ pos_label,
+ pos_label_weight,
+ avg_factor=1.0,
+ reduction_override='none')
+
+ loss_bbox = self.loss_bbox(
+ pos_bbox_pred,
+ pos_bbox_target,
+ pos_bbox_weight,
+ avg_factor=1.0, # keep same loss weight before reassign
+ reduction_override='none')
+
+ loss_cls = loss_cls.sum(-1)
+ pos_loss = loss_bbox + loss_cls
+ return pos_loss,
+
+ def paa_reassign(self, pos_losses, label, label_weight, bbox_weight,
+ pos_inds, pos_gt_inds, anchors):
+ """Fit loss to GMM distribution and separate positive, ignore, negative
+ samples again with GMM model.
+
+ Args:
+ pos_losses (Tensor): Losses of all positive samples in
+ single image.
+ label (Tensor): classification target of each anchor with
+ shape (num_anchors,)
+ label_weight (Tensor): Classification loss weight of each
+ anchor with shape (num_anchors).
+ bbox_weight (Tensor): Bbox weight of each anchor with shape
+ (num_anchors, 4).
+ pos_inds (Tensor): Index of all positive samples got from
+ first assign process.
+ pos_gt_inds (Tensor): Gt_index of all positive samples got
+ from first assign process.
+ anchors (list[Tensor]): Anchors of each scale.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - label (Tensor): classification target of each anchor after
+ paa assign, with shape (num_anchors,)
+ - label_weight (Tensor): Classification loss weight of each
+ anchor after paa assign, with shape (num_anchors).
+ - bbox_weight (Tensor): Bbox weight of each anchor with shape
+ (num_anchors, 4).
+ - num_pos (int): The number of positive samples after paa
+ assign.
+ """
+ if not len(pos_inds):
+ return label, label_weight, bbox_weight, 0
+ label = label.clone()
+ label_weight = label_weight.clone()
+ bbox_weight = bbox_weight.clone()
+ num_gt = pos_gt_inds.max() + 1
+ num_level = len(anchors)
+ num_anchors_each_level = [item.size(0) for item in anchors]
+ num_anchors_each_level.insert(0, 0)
+ inds_level_interval = np.cumsum(num_anchors_each_level)
+ pos_level_mask = []
+ for i in range(num_level):
+ mask = (pos_inds >= inds_level_interval[i]) & (
+ pos_inds < inds_level_interval[i + 1])
+ pos_level_mask.append(mask)
+ pos_inds_after_paa = [label.new_tensor([])]
+ ignore_inds_after_paa = [label.new_tensor([])]
+ for gt_ind in range(num_gt):
+ pos_inds_gmm = []
+ pos_loss_gmm = []
+ gt_mask = pos_gt_inds == gt_ind
+ for level in range(num_level):
+ level_mask = pos_level_mask[level]
+ level_gt_mask = level_mask & gt_mask
+ value, topk_inds = pos_losses[level_gt_mask].topk(
+ min(level_gt_mask.sum(), self.topk), largest=False)
+ pos_inds_gmm.append(pos_inds[level_gt_mask][topk_inds])
+ pos_loss_gmm.append(value)
+ pos_inds_gmm = torch.cat(pos_inds_gmm)
+ pos_loss_gmm = torch.cat(pos_loss_gmm)
+ # fix gmm need at least two sample
+ if len(pos_inds_gmm) < 2:
+ continue
+ device = pos_inds_gmm.device
+ pos_loss_gmm, sort_inds = pos_loss_gmm.sort()
+ pos_inds_gmm = pos_inds_gmm[sort_inds]
+ pos_loss_gmm = pos_loss_gmm.view(-1, 1).cpu().numpy()
+ min_loss, max_loss = pos_loss_gmm.min(), pos_loss_gmm.max()
+ means_init = np.array([min_loss, max_loss]).reshape(2, 1)
+ weights_init = np.array([0.5, 0.5])
+ precisions_init = np.array([1.0, 1.0]).reshape(2, 1, 1) # full
+ if self.covariance_type == 'spherical':
+ precisions_init = precisions_init.reshape(2)
+ elif self.covariance_type == 'diag':
+ precisions_init = precisions_init.reshape(2, 1)
+ elif self.covariance_type == 'tied':
+ precisions_init = np.array([[1.0]])
+ if skm is None:
+ raise ImportError('Please run "pip install sklearn" '
+ 'to install sklearn first.')
+ gmm = skm.GaussianMixture(
+ 2,
+ weights_init=weights_init,
+ means_init=means_init,
+ precisions_init=precisions_init,
+ covariance_type=self.covariance_type)
+ gmm.fit(pos_loss_gmm)
+ gmm_assignment = gmm.predict(pos_loss_gmm)
+ scores = gmm.score_samples(pos_loss_gmm)
+ gmm_assignment = torch.from_numpy(gmm_assignment).to(device)
+ scores = torch.from_numpy(scores).to(device)
+
+ pos_inds_temp, ignore_inds_temp = self.gmm_separation_scheme(
+ gmm_assignment, scores, pos_inds_gmm)
+ pos_inds_after_paa.append(pos_inds_temp)
+ ignore_inds_after_paa.append(ignore_inds_temp)
+
+ pos_inds_after_paa = torch.cat(pos_inds_after_paa)
+ ignore_inds_after_paa = torch.cat(ignore_inds_after_paa)
+ reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_paa).all(1)
+ reassign_ids = pos_inds[reassign_mask]
+ label[reassign_ids] = self.num_classes
+ label_weight[ignore_inds_after_paa] = 0
+ bbox_weight[reassign_ids] = 0
+ num_pos = len(pos_inds_after_paa)
+ return label, label_weight, bbox_weight, num_pos
+
+ def gmm_separation_scheme(self, gmm_assignment, scores, pos_inds_gmm):
+ """A general separation scheme for gmm model.
+
+ It separates a GMM distribution of candidate samples into three
+ parts, 0 1 and uncertain areas, and you can implement other
+ separation schemes by rewriting this function.
+
+ Args:
+ gmm_assignment (Tensor): The prediction of GMM which is of shape
+ (num_samples,). The 0/1 value indicates the distribution
+ that each sample comes from.
+ scores (Tensor): The probability of sample coming from the
+ fit GMM distribution. The tensor is of shape (num_samples,).
+ pos_inds_gmm (Tensor): All the indexes of samples which are used
+ to fit GMM model. The tensor is of shape (num_samples,)
+
+ Returns:
+ tuple[Tensor]: The indices of positive and ignored samples.
+
+ - pos_inds_temp (Tensor): Indices of positive samples.
+ - ignore_inds_temp (Tensor): Indices of ignore samples.
+ """
+ # The implementation is (c) in Fig.3 in origin paper instead of (b).
+ # You can refer to issues such as
+ # https://github.com/kkhoot/PAA/issues/8 and
+ # https://github.com/kkhoot/PAA/issues/9.
+ fgs = gmm_assignment == 0
+ pos_inds_temp = fgs.new_tensor([], dtype=torch.long)
+ ignore_inds_temp = fgs.new_tensor([], dtype=torch.long)
+ if fgs.nonzero().numel():
+ _, pos_thr_ind = scores[fgs].topk(1)
+ pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1]
+ ignore_inds_temp = pos_inds_gmm.new_tensor([])
+ return pos_inds_temp, ignore_inds_temp
+
+ def get_targets(
+ self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True,
+ ):
+ """Get targets for PAA head.
+
+ This method is almost the same as `AnchorHead.get_targets()`. We direct
+ return the results from _get_targets_single instead map it to levels
+ by images_to_levels function.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, 4).
+ valid_flag_list (list[list[Tensor]]): Multi level valid flags of
+ each image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - labels (list[Tensor]): Labels of all anchors, each with
+ shape (num_anchors,).
+ - label_weights (list[Tensor]): Label weights of all anchor.
+ each with shape (num_anchors,).
+ - bbox_targets (list[Tensor]): BBox targets of all anchors.
+ each with shape (num_anchors, 4).
+ - bbox_weights (list[Tensor]): BBox weights of all anchors.
+ each with shape (num_anchors, 4).
+ - pos_inds (list[Tensor]): Contains all index of positive
+ sample in all anchor.
+ - gt_inds (list[Tensor]): Contains all gt_index of positive
+ sample in all anchor.
+ """
+
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+ concat_anchor_list = []
+ concat_valid_flag_list = []
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ results = multi_apply(
+ self._get_targets_single,
+ concat_anchor_list,
+ concat_valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+
+ (labels, label_weights, bbox_targets, bbox_weights, valid_pos_inds,
+ valid_neg_inds, sampling_result) = results
+
+ # Due to valid flag of anchors, we have to calculate the real pos_inds
+ # in origin anchor set.
+ pos_inds = []
+ for i, single_labels in enumerate(labels):
+ pos_mask = (0 <= single_labels) & (
+ single_labels < self.num_classes)
+ pos_inds.append(pos_mask.nonzero().view(-1))
+
+ gt_inds = [item.pos_assigned_gt_inds for item in sampling_result]
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ gt_inds)
+
+ def _get_targets_single(self,
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ This method is same as `AnchorHead._get_targets_single()`.
+ """
+ assert unmap_outputs, 'We must map outputs back to the original' \
+ 'set of anchors in PAAhead'
+ return super(ATSSHead, self)._get_targets_single(
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ score_factors=None,
+ img_metas=None,
+ cfg=None,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ assert with_nms, 'PAA only supports "with_nms=True" now and it ' \
+ 'means PAAHead does not support ' \
+ 'test-time augmentation'
+ return super(ATSSHead, self).get_bboxes(cls_scores, bbox_preds,
+ score_factors, img_metas, cfg,
+ rescale, with_nms, **kwargs)
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ score_factor_list,
+ mlvl_priors,
+ img_meta,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform outputs of a single image into bbox predictions.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_priors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has shape
+ (num_priors * 4, H, W).
+ score_factor_list (list[Tensor]): Score factors from all scale
+ levels of a single image, each item has shape
+ (num_priors * 1, H, W).
+ mlvl_priors (list[Tensor]): Each element in the list is
+ the priors of a single level in feature pyramid, has shape
+ (num_priors, 4).
+ img_meta (dict): Image meta info.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ img_shape = img_meta['img_shape']
+ nms_pre = cfg.get('nms_pre', -1)
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_score_factors = []
+ for level_idx, (cls_score, bbox_pred, score_factor, priors) in \
+ enumerate(zip(cls_score_list, bbox_pred_list,
+ score_factor_list, mlvl_priors)):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+
+ scores = cls_score.permute(1, 2, 0).reshape(
+ -1, self.cls_out_channels).sigmoid()
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid()
+
+ if 0 < nms_pre < scores.shape[0]:
+ max_scores, _ = (scores *
+ score_factor[:, None]).sqrt().max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ priors = priors[topk_inds, :]
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ score_factor = score_factor[topk_inds]
+
+ bboxes = self.bbox_coder.decode(
+ priors, bbox_pred, max_shape=img_shape)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_score_factors.append(score_factor)
+
+ return self._bbox_post_process(mlvl_scores, mlvl_bboxes,
+ img_meta['scale_factor'], cfg, rescale,
+ with_nms, mlvl_score_factors, **kwargs)
+
+ def _bbox_post_process(self,
+ mlvl_scores,
+ mlvl_bboxes,
+ scale_factor,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ mlvl_score_factors=None,
+ **kwargs):
+ """bbox post-processing method.
+
+ The boxes would be rescaled to the original image scale and do
+ the nms operation. Usually with_nms is False is used for aug test.
+
+ Args:
+ mlvl_scores (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_bboxes, num_class).
+ mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale
+ levels of a single image, each item has shape (num_bboxes, 4).
+ scale_factor (ndarray, optional): Scale factor of the image arange
+ as (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ mlvl_score_factors (list[Tensor], optional): Score factor from
+ all scale levels of a single image, each item has shape
+ (num_bboxes, ). Default: None.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ # Add a dummy background class to the backend when using sigmoid
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+ mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
+
+ mlvl_iou_preds = torch.cat(mlvl_score_factors)
+ mlvl_nms_scores = (mlvl_scores * mlvl_iou_preds[:, None]).sqrt()
+ det_bboxes, det_labels = multiclass_nms(
+ mlvl_bboxes,
+ mlvl_nms_scores,
+ cfg.score_thr,
+ cfg.nms,
+ cfg.max_per_img,
+ score_factors=None)
+ if self.with_score_voting and len(det_bboxes) > 0:
+ det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels,
+ mlvl_bboxes,
+ mlvl_nms_scores,
+ cfg.score_thr)
+
+ return det_bboxes, det_labels
+
+ def score_voting(self, det_bboxes, det_labels, mlvl_bboxes,
+ mlvl_nms_scores, score_thr):
+ """Implementation of score voting method works on each remaining boxes
+ after NMS procedure.
+
+ Args:
+ det_bboxes (Tensor): Remaining boxes after NMS procedure,
+ with shape (k, 5), each dimension means
+ (x1, y1, x2, y2, score).
+ det_labels (Tensor): The label of remaining boxes, with shape
+ (k, 1),Labels are 0-based.
+ mlvl_bboxes (Tensor): All boxes before the NMS procedure,
+ with shape (num_anchors,4).
+ mlvl_nms_scores (Tensor): The scores of all boxes which is used
+ in the NMS procedure, with shape (num_anchors, num_class)
+ score_thr (float): The score threshold of bboxes.
+
+ Returns:
+ tuple: Usually returns a tuple containing voting results.
+
+ - det_bboxes_voted (Tensor): Remaining boxes after
+ score voting procedure, with shape (k, 5), each
+ dimension means (x1, y1, x2, y2, score).
+ - det_labels_voted (Tensor): Label of remaining bboxes
+ after voting, with shape (num_anchors,).
+ """
+ candidate_mask = mlvl_nms_scores > score_thr
+ candidate_mask_nonzeros = candidate_mask.nonzero(as_tuple=False)
+ candidate_inds = candidate_mask_nonzeros[:, 0]
+ candidate_labels = candidate_mask_nonzeros[:, 1]
+ candidate_bboxes = mlvl_bboxes[candidate_inds]
+ candidate_scores = mlvl_nms_scores[candidate_mask]
+ det_bboxes_voted = []
+ det_labels_voted = []
+ for cls in range(self.cls_out_channels):
+ candidate_cls_mask = candidate_labels == cls
+ if not candidate_cls_mask.any():
+ continue
+ candidate_cls_scores = candidate_scores[candidate_cls_mask]
+ candidate_cls_bboxes = candidate_bboxes[candidate_cls_mask]
+ det_cls_mask = det_labels == cls
+ det_cls_bboxes = det_bboxes[det_cls_mask].view(
+ -1, det_bboxes.size(-1))
+ det_candidate_ious = bbox_overlaps(det_cls_bboxes[:, :4],
+ candidate_cls_bboxes)
+ for det_ind in range(len(det_cls_bboxes)):
+ single_det_ious = det_candidate_ious[det_ind]
+ pos_ious_mask = single_det_ious > 0.01
+ pos_ious = single_det_ious[pos_ious_mask]
+ pos_bboxes = candidate_cls_bboxes[pos_ious_mask]
+ pos_scores = candidate_cls_scores[pos_ious_mask]
+ pis = (torch.exp(-(1 - pos_ious)**2 / 0.025) *
+ pos_scores)[:, None]
+ voted_box = torch.sum(
+ pis * pos_bboxes, dim=0) / torch.sum(
+ pis, dim=0)
+ voted_score = det_cls_bboxes[det_ind][-1:][None, :]
+ det_bboxes_voted.append(
+ torch.cat((voted_box[None, :], voted_score), dim=1))
+ det_labels_voted.append(cls)
+
+ det_bboxes_voted = torch.cat(det_bboxes_voted, dim=0)
+ det_labels_voted = det_labels.new_tensor(det_labels_voted)
+ return det_bboxes_voted, det_labels_voted
diff --git a/mmdet/models/dense_heads/pisa_retinanet_head.py b/mmdet/models/dense_heads/pisa_retinanet_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8654ef453a849f038f68c78df64b4fdc4b26549b
--- /dev/null
+++ b/mmdet/models/dense_heads/pisa_retinanet_head.py
@@ -0,0 +1,155 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.core import images_to_levels
+from ..builder import HEADS
+from ..losses import carl_loss, isr_p
+from .retina_head import RetinaHead
+
+
+@HEADS.register_module()
+class PISARetinaHead(RetinaHead):
+ """PISA Retinanet Head.
+
+ The head owns the same structure with Retinanet Head, but differs in two
+ aspects:
+ 1. Importance-based Sample Reweighting Positive (ISR-P) is applied to
+ change the positive loss weights.
+ 2. Classification-aware regression loss is adopted as a third loss.
+ """
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image
+ with shape (num_obj, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each image
+ with shape (num_obj, 4).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor]): Ignored gt bboxes of each image.
+ Default: None.
+
+ Returns:
+ dict: Loss dict, comprise classification loss, regression loss and
+ carl loss.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ return_sampling_results=True)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg, sampling_results_list) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(anchor_list)):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+
+ num_imgs = len(img_metas)
+ flatten_cls_scores = [
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, label_channels)
+ for cls_score in cls_scores
+ ]
+ flatten_cls_scores = torch.cat(
+ flatten_cls_scores, dim=1).reshape(-1,
+ flatten_cls_scores[0].size(-1))
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ flatten_bbox_preds = torch.cat(
+ flatten_bbox_preds, dim=1).view(-1, flatten_bbox_preds[0].size(-1))
+ flatten_labels = torch.cat(labels_list, dim=1).reshape(-1)
+ flatten_label_weights = torch.cat(
+ label_weights_list, dim=1).reshape(-1)
+ flatten_anchors = torch.cat(all_anchor_list, dim=1).reshape(-1, 4)
+ flatten_bbox_targets = torch.cat(
+ bbox_targets_list, dim=1).reshape(-1, 4)
+ flatten_bbox_weights = torch.cat(
+ bbox_weights_list, dim=1).reshape(-1, 4)
+
+ # Apply ISR-P
+ isr_cfg = self.train_cfg.get('isr', None)
+ if isr_cfg is not None:
+ all_targets = (flatten_labels, flatten_label_weights,
+ flatten_bbox_targets, flatten_bbox_weights)
+ with torch.no_grad():
+ all_targets = isr_p(
+ flatten_cls_scores,
+ flatten_bbox_preds,
+ all_targets,
+ flatten_anchors,
+ sampling_results_list,
+ bbox_coder=self.bbox_coder,
+ loss_cls=self.loss_cls,
+ num_class=self.num_classes,
+ **self.train_cfg.isr)
+ (flatten_labels, flatten_label_weights, flatten_bbox_targets,
+ flatten_bbox_weights) = all_targets
+
+ # For convenience we compute loss once instead separating by fpn level,
+ # so that we don't need to separate the weights by level again.
+ # The result should be the same
+ losses_cls = self.loss_cls(
+ flatten_cls_scores,
+ flatten_labels,
+ flatten_label_weights,
+ avg_factor=num_total_samples)
+ losses_bbox = self.loss_bbox(
+ flatten_bbox_preds,
+ flatten_bbox_targets,
+ flatten_bbox_weights,
+ avg_factor=num_total_samples)
+ loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
+
+ # CARL Loss
+ carl_cfg = self.train_cfg.get('carl', None)
+ if carl_cfg is not None:
+ loss_carl = carl_loss(
+ flatten_cls_scores,
+ flatten_labels,
+ flatten_bbox_preds,
+ flatten_bbox_targets,
+ self.loss_bbox,
+ **self.train_cfg.carl,
+ avg_factor=num_total_pos,
+ sigmoid=True,
+ num_class=self.num_classes)
+ loss_dict.update(loss_carl)
+
+ return loss_dict
diff --git a/mmdet/models/dense_heads/pisa_ssd_head.py b/mmdet/models/dense_heads/pisa_ssd_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..86b67abe932262c7f0177a34cb94ea43a12ac5d4
--- /dev/null
+++ b/mmdet/models/dense_heads/pisa_ssd_head.py
@@ -0,0 +1,140 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.core import multi_apply
+from ..builder import HEADS
+from ..losses import CrossEntropyLoss, SmoothL1Loss, carl_loss, isr_p
+from .ssd_head import SSDHead
+
+
+# TODO: add loss evaluator for SSD
+@HEADS.register_module()
+class PISASSDHead(SSDHead):
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image
+ with shape (num_obj, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each image
+ with shape (num_obj, 4).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor]): Ignored gt bboxes of each image.
+ Default: None.
+
+ Returns:
+ dict: Loss dict, comprise classification loss regression loss and
+ carl loss.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=1,
+ unmap_outputs=False,
+ return_sampling_results=True)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg, sampling_results_list) = cls_reg_targets
+
+ num_images = len(img_metas)
+ all_cls_scores = torch.cat([
+ s.permute(0, 2, 3, 1).reshape(
+ num_images, -1, self.cls_out_channels) for s in cls_scores
+ ], 1)
+ all_labels = torch.cat(labels_list, -1).view(num_images, -1)
+ all_label_weights = torch.cat(label_weights_list,
+ -1).view(num_images, -1)
+ all_bbox_preds = torch.cat([
+ b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
+ for b in bbox_preds
+ ], -2)
+ all_bbox_targets = torch.cat(bbox_targets_list,
+ -2).view(num_images, -1, 4)
+ all_bbox_weights = torch.cat(bbox_weights_list,
+ -2).view(num_images, -1, 4)
+
+ # concat all level anchors to a single tensor
+ all_anchors = []
+ for i in range(num_images):
+ all_anchors.append(torch.cat(anchor_list[i]))
+
+ isr_cfg = self.train_cfg.get('isr', None)
+ all_targets = (all_labels.view(-1), all_label_weights.view(-1),
+ all_bbox_targets.view(-1,
+ 4), all_bbox_weights.view(-1, 4))
+ # apply ISR-P
+ if isr_cfg is not None:
+ all_targets = isr_p(
+ all_cls_scores.view(-1, all_cls_scores.size(-1)),
+ all_bbox_preds.view(-1, 4),
+ all_targets,
+ torch.cat(all_anchors),
+ sampling_results_list,
+ loss_cls=CrossEntropyLoss(),
+ bbox_coder=self.bbox_coder,
+ **self.train_cfg.isr,
+ num_class=self.num_classes)
+ (new_labels, new_label_weights, new_bbox_targets,
+ new_bbox_weights) = all_targets
+ all_labels = new_labels.view(all_labels.shape)
+ all_label_weights = new_label_weights.view(all_label_weights.shape)
+ all_bbox_targets = new_bbox_targets.view(all_bbox_targets.shape)
+ all_bbox_weights = new_bbox_weights.view(all_bbox_weights.shape)
+
+ # add CARL loss
+ carl_loss_cfg = self.train_cfg.get('carl', None)
+ if carl_loss_cfg is not None:
+ loss_carl = carl_loss(
+ all_cls_scores.view(-1, all_cls_scores.size(-1)),
+ all_targets[0],
+ all_bbox_preds.view(-1, 4),
+ all_targets[2],
+ SmoothL1Loss(beta=1.),
+ **self.train_cfg.carl,
+ avg_factor=num_total_pos,
+ num_class=self.num_classes)
+
+ # check NaN and Inf
+ assert torch.isfinite(all_cls_scores).all().item(), \
+ 'classification scores become infinite or NaN!'
+ assert torch.isfinite(all_bbox_preds).all().item(), \
+ 'bbox predications become infinite or NaN!'
+
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ all_cls_scores,
+ all_bbox_preds,
+ all_anchors,
+ all_labels,
+ all_label_weights,
+ all_bbox_targets,
+ all_bbox_weights,
+ num_total_samples=num_total_pos)
+ loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
+ if carl_loss_cfg is not None:
+ loss_dict.update(loss_carl)
+ return loss_dict
diff --git a/mmdet/models/dense_heads/reppoints_head.py b/mmdet/models/dense_heads/reppoints_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7204141db43a3754031bc175c87876a2d7df3e5
--- /dev/null
+++ b/mmdet/models/dense_heads/reppoints_head.py
@@ -0,0 +1,764 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.ops import DeformConv2d
+
+from mmdet.core import (build_assigner, build_sampler, images_to_levels,
+ multi_apply, unmap)
+from mmdet.core.anchor.point_generator import MlvlPointGenerator
+from mmdet.core.utils import filter_scores_and_topk
+from ..builder import HEADS, build_loss
+from .anchor_free_head import AnchorFreeHead
+
+
+@HEADS.register_module()
+class RepPointsHead(AnchorFreeHead):
+ """RepPoint head.
+
+ Args:
+ point_feat_channels (int): Number of channels of points features.
+ gradient_mul (float): The multiplier to gradients from
+ points refinement and recognition.
+ point_strides (Iterable): points strides.
+ point_base_scale (int): bbox scale for assigning labels.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox_init (dict): Config of initial points loss.
+ loss_bbox_refine (dict): Config of points loss in refinement.
+ use_grid_points (bool): If we use bounding box representation, the
+ reppoints is represented as grid points on the bounding box.
+ center_init (bool): Whether to use center point assignment.
+ transform_method (str): The methods to transform RepPoints to bbox.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ point_feat_channels=256,
+ num_points=9,
+ gradient_mul=0.1,
+ point_strides=[8, 16, 32, 64, 128],
+ point_base_scale=4,
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_init=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
+ loss_bbox_refine=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ use_grid_points=False,
+ center_init=True,
+ transform_method='moment',
+ moment_mul=0.01,
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='reppoints_cls_out',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ self.num_points = num_points
+ self.point_feat_channels = point_feat_channels
+ self.use_grid_points = use_grid_points
+ self.center_init = center_init
+
+ # we use deform conv to extract points features
+ self.dcn_kernel = int(np.sqrt(num_points))
+ self.dcn_pad = int((self.dcn_kernel - 1) / 2)
+ assert self.dcn_kernel * self.dcn_kernel == num_points, \
+ 'The points number should be a square number.'
+ assert self.dcn_kernel % 2 == 1, \
+ 'The points number should be an odd square number.'
+ dcn_base = np.arange(-self.dcn_pad,
+ self.dcn_pad + 1).astype(np.float64)
+ dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
+ dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
+ dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
+ (-1))
+ self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
+
+ super().__init__(
+ num_classes,
+ in_channels,
+ loss_cls=loss_cls,
+ init_cfg=init_cfg,
+ **kwargs)
+
+ self.gradient_mul = gradient_mul
+ self.point_base_scale = point_base_scale
+ self.point_strides = point_strides
+ self.prior_generator = MlvlPointGenerator(
+ self.point_strides, offset=0.)
+
+ self.sampling = loss_cls['type'] not in ['FocalLoss']
+ if self.train_cfg:
+ self.init_assigner = build_assigner(self.train_cfg.init.assigner)
+ self.refine_assigner = build_assigner(
+ self.train_cfg.refine.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.transform_method = transform_method
+ if self.transform_method == 'moment':
+ self.moment_transfer = nn.Parameter(
+ data=torch.zeros(2), requires_grad=True)
+ self.moment_mul = moment_mul
+
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = self.num_classes
+ else:
+ self.cls_out_channels = self.num_classes + 1
+ self.loss_bbox_init = build_loss(loss_bbox_init)
+ self.loss_bbox_refine = build_loss(loss_bbox_refine)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points
+ self.reppoints_cls_conv = DeformConv2d(self.feat_channels,
+ self.point_feat_channels,
+ self.dcn_kernel, 1,
+ self.dcn_pad)
+ self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
+ self.cls_out_channels, 1, 1, 0)
+ self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
+ self.point_feat_channels, 3,
+ 1, 1)
+ self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
+ pts_out_dim, 1, 1, 0)
+ self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels,
+ self.point_feat_channels,
+ self.dcn_kernel, 1,
+ self.dcn_pad)
+ self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
+ pts_out_dim, 1, 1, 0)
+
+ def points2bbox(self, pts, y_first=True):
+ """Converting the points set into bounding box.
+
+ :param pts: the input points sets (fields), each points
+ set (fields) is represented as 2n scalar.
+ :param y_first: if y_first=True, the point set is represented as
+ [y1, x1, y2, x2 ... yn, xn], otherwise the point set is
+ represented as [x1, y1, x2, y2 ... xn, yn].
+ :return: each points set is converting to a bbox [x1, y1, x2, y2].
+ """
+ pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
+ pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1,
+ ...]
+ pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0,
+ ...]
+ if self.transform_method == 'minmax':
+ bbox_left = pts_x.min(dim=1, keepdim=True)[0]
+ bbox_right = pts_x.max(dim=1, keepdim=True)[0]
+ bbox_up = pts_y.min(dim=1, keepdim=True)[0]
+ bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
+ bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
+ dim=1)
+ elif self.transform_method == 'partial_minmax':
+ pts_y = pts_y[:, :4, ...]
+ pts_x = pts_x[:, :4, ...]
+ bbox_left = pts_x.min(dim=1, keepdim=True)[0]
+ bbox_right = pts_x.max(dim=1, keepdim=True)[0]
+ bbox_up = pts_y.min(dim=1, keepdim=True)[0]
+ bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
+ bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
+ dim=1)
+ elif self.transform_method == 'moment':
+ pts_y_mean = pts_y.mean(dim=1, keepdim=True)
+ pts_x_mean = pts_x.mean(dim=1, keepdim=True)
+ pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
+ pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
+ moment_transfer = (self.moment_transfer * self.moment_mul) + (
+ self.moment_transfer.detach() * (1 - self.moment_mul))
+ moment_width_transfer = moment_transfer[0]
+ moment_height_transfer = moment_transfer[1]
+ half_width = pts_x_std * torch.exp(moment_width_transfer)
+ half_height = pts_y_std * torch.exp(moment_height_transfer)
+ bbox = torch.cat([
+ pts_x_mean - half_width, pts_y_mean - half_height,
+ pts_x_mean + half_width, pts_y_mean + half_height
+ ],
+ dim=1)
+ else:
+ raise NotImplementedError
+ return bbox
+
+ def gen_grid_from_reg(self, reg, previous_boxes):
+ """Base on the previous bboxes and regression values, we compute the
+ regressed bboxes and generate the grids on the bboxes.
+
+ :param reg: the regression value to previous bboxes.
+ :param previous_boxes: previous bboxes.
+ :return: generate grids on the regressed bboxes.
+ """
+ b, _, h, w = reg.shape
+ bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2.
+ bwh = (previous_boxes[:, 2:, ...] -
+ previous_boxes[:, :2, ...]).clamp(min=1e-6)
+ grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp(
+ reg[:, 2:, ...])
+ grid_wh = bwh * torch.exp(reg[:, 2:, ...])
+ grid_left = grid_topleft[:, [0], ...]
+ grid_top = grid_topleft[:, [1], ...]
+ grid_width = grid_wh[:, [0], ...]
+ grid_height = grid_wh[:, [1], ...]
+ intervel = torch.linspace(0., 1., self.dcn_kernel).view(
+ 1, self.dcn_kernel, 1, 1).type_as(reg)
+ grid_x = grid_left + grid_width * intervel
+ grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
+ grid_x = grid_x.view(b, -1, h, w)
+ grid_y = grid_top + grid_height * intervel
+ grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
+ grid_y = grid_y.view(b, -1, h, w)
+ grid_yx = torch.stack([grid_y, grid_x], dim=2)
+ grid_yx = grid_yx.view(b, -1, h, w)
+ regressed_bbox = torch.cat([
+ grid_left, grid_top, grid_left + grid_width, grid_top + grid_height
+ ], 1)
+ return grid_yx, regressed_bbox
+
+ def forward(self, feats):
+ return multi_apply(self.forward_single, feats)
+
+ def forward_single(self, x):
+ """Forward feature map of a single FPN level."""
+ dcn_base_offset = self.dcn_base_offset.type_as(x)
+ # If we use center_init, the initial reppoints is from center points.
+ # If we use bounding bbox representation, the initial reppoints is
+ # from regular grid placed on a pre-defined bbox.
+ if self.use_grid_points or not self.center_init:
+ scale = self.point_base_scale / 2
+ points_init = dcn_base_offset / dcn_base_offset.max() * scale
+ bbox_init = x.new_tensor([-scale, -scale, scale,
+ scale]).view(1, 4, 1, 1)
+ else:
+ points_init = 0
+ cls_feat = x
+ pts_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ pts_feat = reg_conv(pts_feat)
+ # initialize reppoints
+ pts_out_init = self.reppoints_pts_init_out(
+ self.relu(self.reppoints_pts_init_conv(pts_feat)))
+ if self.use_grid_points:
+ pts_out_init, bbox_out_init = self.gen_grid_from_reg(
+ pts_out_init, bbox_init.detach())
+ else:
+ pts_out_init = pts_out_init + points_init
+ # refine and classify reppoints
+ pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach(
+ ) + self.gradient_mul * pts_out_init
+ dcn_offset = pts_out_init_grad_mul - dcn_base_offset
+ cls_out = self.reppoints_cls_out(
+ self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
+ pts_out_refine = self.reppoints_pts_refine_out(
+ self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
+ if self.use_grid_points:
+ pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(
+ pts_out_refine, bbox_out_init.detach())
+ else:
+ pts_out_refine = pts_out_refine + pts_out_init.detach()
+
+ if self.training:
+ return cls_out, pts_out_init, pts_out_refine
+ else:
+ return cls_out, self.points2bbox(pts_out_refine)
+
+ def get_points(self, featmap_sizes, img_metas, device):
+ """Get points according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+
+ Returns:
+ tuple: points of each image, valid flags of each image
+ """
+ num_imgs = len(img_metas)
+
+ # since feature map sizes of all images are the same, we only compute
+ # points center for one time
+ multi_level_points = self.prior_generator.grid_priors(
+ featmap_sizes, device=device, with_stride=True)
+ points_list = [[point.clone() for point in multi_level_points]
+ for _ in range(num_imgs)]
+
+ # for each image, we compute valid flags of multi level grids
+ valid_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = self.prior_generator.valid_flags(
+ featmap_sizes, img_meta['pad_shape'])
+ valid_flag_list.append(multi_level_flags)
+
+ return points_list, valid_flag_list
+
+ def centers_to_bboxes(self, point_list):
+ """Get bboxes according to center points.
+
+ Only used in :class:`MaxIoUAssigner`.
+ """
+ bbox_list = []
+ for i_img, point in enumerate(point_list):
+ bbox = []
+ for i_lvl in range(len(self.point_strides)):
+ scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
+ bbox_shift = torch.Tensor([-scale, -scale, scale,
+ scale]).view(1, 4).type_as(point[0])
+ bbox_center = torch.cat(
+ [point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
+ bbox.append(bbox_center + bbox_shift)
+ bbox_list.append(bbox)
+ return bbox_list
+
+ def offset_to_pts(self, center_list, pred_list):
+ """Change from point offset to point coordinate."""
+ pts_list = []
+ for i_lvl in range(len(self.point_strides)):
+ pts_lvl = []
+ for i_img in range(len(center_list)):
+ pts_center = center_list[i_img][i_lvl][:, :2].repeat(
+ 1, self.num_points)
+ pts_shift = pred_list[i_lvl][i_img]
+ yx_pts_shift = pts_shift.permute(1, 2, 0).view(
+ -1, 2 * self.num_points)
+ y_pts_shift = yx_pts_shift[..., 0::2]
+ x_pts_shift = yx_pts_shift[..., 1::2]
+ xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
+ xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
+ pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
+ pts_lvl.append(pts)
+ pts_lvl = torch.stack(pts_lvl, 0)
+ pts_list.append(pts_lvl)
+ return pts_list
+
+ def _point_target_single(self,
+ flat_proposals,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ stage='init',
+ unmap_outputs=True):
+ inside_flags = valid_flags
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample proposals
+ proposals = flat_proposals[inside_flags, :]
+
+ if stage == 'init':
+ assigner = self.init_assigner
+ pos_weight = self.train_cfg.init.pos_weight
+ else:
+ assigner = self.refine_assigner
+ pos_weight = self.train_cfg.refine.pos_weight
+ assign_result = assigner.assign(proposals, gt_bboxes, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+ sampling_result = self.sampler.sample(assign_result, proposals,
+ gt_bboxes)
+
+ num_valid_proposals = proposals.shape[0]
+ bbox_gt = proposals.new_zeros([num_valid_proposals, 4])
+ pos_proposals = torch.zeros_like(proposals)
+ proposals_weights = proposals.new_zeros([num_valid_proposals, 4])
+ labels = proposals.new_full((num_valid_proposals, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = proposals.new_zeros(
+ num_valid_proposals, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ pos_gt_bboxes = sampling_result.pos_gt_bboxes
+ bbox_gt[pos_inds, :] = pos_gt_bboxes
+ pos_proposals[pos_inds, :] = proposals[pos_inds, :]
+ proposals_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of proposals
+ if unmap_outputs:
+ num_total_proposals = flat_proposals.size(0)
+ labels = unmap(labels, num_total_proposals, inside_flags)
+ label_weights = unmap(label_weights, num_total_proposals,
+ inside_flags)
+ bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
+ pos_proposals = unmap(pos_proposals, num_total_proposals,
+ inside_flags)
+ proposals_weights = unmap(proposals_weights, num_total_proposals,
+ inside_flags)
+
+ return (labels, label_weights, bbox_gt, pos_proposals,
+ proposals_weights, pos_inds, neg_inds)
+
+ def get_targets(self,
+ proposals_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ stage='init',
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute corresponding GT box and classification targets for
+ proposals.
+
+ Args:
+ proposals_list (list[list]): Multi level points/bboxes of each
+ image.
+ valid_flag_list (list[list]): Multi level valid flags of each
+ image.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_bboxes_list (list[Tensor]): Ground truth labels of each box.
+ stage (str): `init` or `refine`. Generate target for init stage or
+ refine stage
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple:
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each level. # noqa: E501
+ - bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
+ - proposal_list (list[Tensor]): Proposals(points/bboxes) of each level. # noqa: E501
+ - proposal_weights_list (list[Tensor]): Proposal weights of each level. # noqa: E501
+ - num_total_pos (int): Number of positive samples in all images. # noqa: E501
+ - num_total_neg (int): Number of negative samples in all images. # noqa: E501
+ """
+ assert stage in ['init', 'refine']
+ num_imgs = len(img_metas)
+ assert len(proposals_list) == len(valid_flag_list) == num_imgs
+
+ # points number of multi levels
+ num_level_proposals = [points.size(0) for points in proposals_list[0]]
+
+ # concat all level points and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(proposals_list[i]) == len(valid_flag_list[i])
+ proposals_list[i] = torch.cat(proposals_list[i])
+ valid_flag_list[i] = torch.cat(valid_flag_list[i])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_labels, all_label_weights, all_bbox_gt, all_proposals,
+ all_proposal_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._point_target_single,
+ proposals_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ stage=stage,
+ unmap_outputs=unmap_outputs)
+ # no valid points
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled points of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ labels_list = images_to_levels(all_labels, num_level_proposals)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_proposals)
+ bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
+ proposals_list = images_to_levels(all_proposals, num_level_proposals)
+ proposal_weights_list = images_to_levels(all_proposal_weights,
+ num_level_proposals)
+ return (labels_list, label_weights_list, bbox_gt_list, proposals_list,
+ proposal_weights_list, num_total_pos, num_total_neg)
+
+ def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels,
+ label_weights, bbox_gt_init, bbox_weights_init,
+ bbox_gt_refine, bbox_weights_refine, stride,
+ num_total_samples_init, num_total_samples_refine):
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ cls_score = cls_score.contiguous()
+ loss_cls = self.loss_cls(
+ cls_score,
+ labels,
+ label_weights,
+ avg_factor=num_total_samples_refine)
+
+ # points loss
+ bbox_gt_init = bbox_gt_init.reshape(-1, 4)
+ bbox_weights_init = bbox_weights_init.reshape(-1, 4)
+ bbox_pred_init = self.points2bbox(
+ pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
+ bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
+ bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
+ bbox_pred_refine = self.points2bbox(
+ pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
+ normalize_term = self.point_base_scale * stride
+ loss_pts_init = self.loss_bbox_init(
+ bbox_pred_init / normalize_term,
+ bbox_gt_init / normalize_term,
+ bbox_weights_init,
+ avg_factor=num_total_samples_init)
+ loss_pts_refine = self.loss_bbox_refine(
+ bbox_pred_refine / normalize_term,
+ bbox_gt_refine / normalize_term,
+ bbox_weights_refine,
+ avg_factor=num_total_samples_refine)
+ return loss_cls, loss_pts_init, loss_pts_refine
+
+ def loss(self,
+ cls_scores,
+ pts_preds_init,
+ pts_preds_refine,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ device = cls_scores[0].device
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ # target for initial stage
+ center_list, valid_flag_list = self.get_points(featmap_sizes,
+ img_metas, device)
+ pts_coordinate_preds_init = self.offset_to_pts(center_list,
+ pts_preds_init)
+ if self.train_cfg.init.assigner['type'] == 'PointAssigner':
+ # Assign target for center list
+ candidate_list = center_list
+ else:
+ # transform center list to bbox list and
+ # assign target for bbox list
+ bbox_list = self.centers_to_bboxes(center_list)
+ candidate_list = bbox_list
+ cls_reg_targets_init = self.get_targets(
+ candidate_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ stage='init',
+ label_channels=label_channels)
+ (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
+ num_total_pos_init, num_total_neg_init) = cls_reg_targets_init
+ num_total_samples_init = (
+ num_total_pos_init +
+ num_total_neg_init if self.sampling else num_total_pos_init)
+
+ # target for refinement stage
+ center_list, valid_flag_list = self.get_points(featmap_sizes,
+ img_metas, device)
+ pts_coordinate_preds_refine = self.offset_to_pts(
+ center_list, pts_preds_refine)
+ bbox_list = []
+ for i_img, center in enumerate(center_list):
+ bbox = []
+ for i_lvl in range(len(pts_preds_refine)):
+ bbox_preds_init = self.points2bbox(
+ pts_preds_init[i_lvl].detach())
+ bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
+ bbox_center = torch.cat(
+ [center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1)
+ bbox.append(bbox_center +
+ bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4))
+ bbox_list.append(bbox)
+ cls_reg_targets_refine = self.get_targets(
+ bbox_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ stage='refine',
+ label_channels=label_channels)
+ (labels_list, label_weights_list, bbox_gt_list_refine,
+ candidate_list_refine, bbox_weights_list_refine, num_total_pos_refine,
+ num_total_neg_refine) = cls_reg_targets_refine
+ num_total_samples_refine = (
+ num_total_pos_refine +
+ num_total_neg_refine if self.sampling else num_total_pos_refine)
+
+ # compute loss
+ losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
+ self.loss_single,
+ cls_scores,
+ pts_coordinate_preds_init,
+ pts_coordinate_preds_refine,
+ labels_list,
+ label_weights_list,
+ bbox_gt_list_init,
+ bbox_weights_list_init,
+ bbox_gt_list_refine,
+ bbox_weights_list_refine,
+ self.point_strides,
+ num_total_samples_init=num_total_samples_init,
+ num_total_samples_refine=num_total_samples_refine)
+ loss_dict_all = {
+ 'loss_cls': losses_cls,
+ 'loss_pts_init': losses_pts_init,
+ 'loss_pts_refine': losses_pts_refine
+ }
+ return loss_dict_all
+
+ # Same as base_dense_head/_get_bboxes_single except self._bbox_decode
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ score_factor_list,
+ mlvl_priors,
+ img_meta,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform outputs of a single image into bbox predictions.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_priors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has shape
+ (num_priors * 4, H, W).
+ score_factor_list (list[Tensor]): Score factor from all scale
+ levels of a single image. RepPoints head does not need
+ this value.
+ mlvl_priors (list[Tensor]): Each element in the list is
+ the priors of a single level in feature pyramid, has shape
+ (num_priors, 2).
+ img_meta (dict): Image meta info.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_score_list) == len(bbox_pred_list)
+ img_shape = img_meta['img_shape']
+ nms_pre = cfg.get('nms_pre', -1)
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_labels = []
+ for level_idx, (cls_score, bbox_pred, priors) in enumerate(
+ zip(cls_score_list, bbox_pred_list, mlvl_priors)):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)[:, :-1]
+
+ # After https://github.com/open-mmlab/mmdetection/pull/6268/,
+ # this operation keeps fewer bboxes under the same `nms_pre`.
+ # There is no difference in performance for most models. If you
+ # find a slight drop in performance, you can set a larger
+ # `nms_pre` than before.
+ results = filter_scores_and_topk(
+ scores, cfg.score_thr, nms_pre,
+ dict(bbox_pred=bbox_pred, priors=priors))
+ scores, labels, _, filtered_results = results
+
+ bbox_pred = filtered_results['bbox_pred']
+ priors = filtered_results['priors']
+
+ bboxes = self._bbox_decode(priors, bbox_pred,
+ self.point_strides[level_idx],
+ img_shape)
+
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_labels.append(labels)
+
+ return self._bbox_post_process(
+ mlvl_scores,
+ mlvl_labels,
+ mlvl_bboxes,
+ img_meta['scale_factor'],
+ cfg,
+ rescale=rescale,
+ with_nms=with_nms)
+
+ def _bbox_decode(self, points, bbox_pred, stride, max_shape):
+ bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
+ bboxes = bbox_pred * stride + bbox_pos_center
+ x1 = bboxes[:, 0].clamp(min=0, max=max_shape[1])
+ y1 = bboxes[:, 1].clamp(min=0, max=max_shape[0])
+ x2 = bboxes[:, 2].clamp(min=0, max=max_shape[1])
+ y2 = bboxes[:, 3].clamp(min=0, max=max_shape[0])
+ decoded_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
+ return decoded_bboxes
diff --git a/mmdet/models/dense_heads/retina_head.py b/mmdet/models/dense_heads/retina_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a48720c2ee88c47c9602d6e49b3b4f60a129e380
--- /dev/null
+++ b/mmdet/models/dense_heads/retina_head.py
@@ -0,0 +1,115 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module()
+class RetinaHead(AnchorHead):
+ r"""An anchor-based head used in `RetinaNet
+ `_.
+
+ The head contains two subnetworks. The first classifies anchor boxes and
+ the second regresses deltas for the anchors.
+
+ Example:
+ >>> import torch
+ >>> self = RetinaHead(11, 7)
+ >>> x = torch.rand(1, 7, 32, 32)
+ >>> cls_score, bbox_pred = self.forward_single(x)
+ >>> # Each anchor predicts a score for each class except background
+ >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
+ >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
+ >>> assert cls_per_anchor == (self.num_classes)
+ >>> assert box_per_anchor == 4
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=4,
+ scales_per_octave=3,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64, 128]),
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='retina_cls',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ super(RetinaHead, self).__init__(
+ num_classes,
+ in_channels,
+ anchor_generator=anchor_generator,
+ init_cfg=init_cfg,
+ **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.retina_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.retina_reg = nn.Conv2d(
+ self.feat_channels, self.num_base_priors * 4, 3, padding=1)
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls scores for a single scale level
+ the channels number is num_anchors * num_classes.
+ bbox_pred (Tensor): Box energies / deltas for a single scale
+ level, the channels number is num_anchors * 4.
+ """
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.retina_cls(cls_feat)
+ bbox_pred = self.retina_reg(reg_feat)
+ return cls_score, bbox_pred
diff --git a/mmdet/models/dense_heads/retina_sepbn_head.py b/mmdet/models/dense_heads/retina_sepbn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b385c61816fd24d091589635ad0211d73b8fdd9f
--- /dev/null
+++ b/mmdet/models/dense_heads/retina_sepbn_head.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
+
+from ..builder import HEADS
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module()
+class RetinaSepBNHead(AnchorHead):
+ """"RetinaHead with separate BN.
+
+ In RetinaHead, conv/norm layers are shared across different FPN levels,
+ while in RetinaSepBNHead, conv layers are shared across different FPN
+ levels, but BN layers are separated.
+ """
+
+ def __init__(self,
+ num_classes,
+ num_ins,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ init_cfg=None,
+ **kwargs):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.num_ins = num_ins
+ super(RetinaSepBNHead, self).__init__(
+ num_classes, in_channels, init_cfg=init_cfg, **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.num_ins):
+ cls_convs = nn.ModuleList()
+ reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.cls_convs.append(cls_convs)
+ self.reg_convs.append(reg_convs)
+ for i in range(self.stacked_convs):
+ for j in range(1, self.num_ins):
+ self.cls_convs[j][i].conv = self.cls_convs[0][i].conv
+ self.reg_convs[j][i].conv = self.reg_convs[0][i].conv
+ self.retina_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.retina_reg = nn.Conv2d(
+ self.feat_channels, self.num_base_priors * 4, 3, padding=1)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ super(RetinaSepBNHead, self).init_weights()
+ for m in self.cls_convs[0]:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs[0]:
+ normal_init(m.conv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.retina_cls, std=0.01, bias=bias_cls)
+ normal_init(self.retina_reg, std=0.01)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * 4.
+ """
+ cls_scores = []
+ bbox_preds = []
+ for i, x in enumerate(feats):
+ cls_feat = feats[i]
+ reg_feat = feats[i]
+ for cls_conv in self.cls_convs[i]:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs[i]:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.retina_cls(cls_feat)
+ bbox_pred = self.retina_reg(reg_feat)
+ cls_scores.append(cls_score)
+ bbox_preds.append(bbox_pred)
+ return cls_scores, bbox_preds
diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..54cd39a213e4da120435e972addd40553d880a20
--- /dev/null
+++ b/mmdet/models/dense_heads/rpn_head.py
@@ -0,0 +1,265 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.ops import batched_nms
+
+from ..builder import HEADS
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module()
+class RPNHead(AnchorHead):
+ """RPN head.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ num_convs (int): Number of convolution layers in the head. Default 1.
+ """ # noqa: W605
+
+ def __init__(self,
+ in_channels,
+ init_cfg=dict(type='Normal', layer='Conv2d', std=0.01),
+ num_convs=1,
+ **kwargs):
+ self.num_convs = num_convs
+ super(RPNHead, self).__init__(
+ 1, in_channels, init_cfg=init_cfg, **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ if self.num_convs > 1:
+ rpn_convs = []
+ for i in range(self.num_convs):
+ if i == 0:
+ in_channels = self.in_channels
+ else:
+ in_channels = self.feat_channels
+ # use ``inplace=False`` to avoid error: one of the variables
+ # needed for gradient computation has been modified by an
+ # inplace operation.
+ rpn_convs.append(
+ ConvModule(
+ in_channels,
+ self.feat_channels,
+ 3,
+ padding=1,
+ inplace=False))
+ self.rpn_conv = nn.Sequential(*rpn_convs)
+ else:
+ self.rpn_conv = nn.Conv2d(
+ self.in_channels, self.feat_channels, 3, padding=1)
+ self.rpn_cls = nn.Conv2d(self.feat_channels,
+ self.num_base_priors * self.cls_out_channels,
+ 1)
+ self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_base_priors * 4,
+ 1)
+
+ def forward_single(self, x):
+ """Forward feature map of a single scale level."""
+ x = self.rpn_conv(x)
+ x = F.relu(x, inplace=False)
+ rpn_cls_score = self.rpn_cls(x)
+ rpn_bbox_pred = self.rpn_reg(x)
+ return rpn_cls_score, rpn_bbox_pred
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ losses = super(RPNHead, self).loss(
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ None,
+ img_metas,
+ gt_bboxes_ignore=gt_bboxes_ignore)
+ return dict(
+ loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ score_factor_list,
+ mlvl_anchors,
+ img_meta,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform outputs of a single image into bbox predictions.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_anchors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has
+ shape (num_anchors * 4, H, W).
+ score_factor_list (list[Tensor]): Score factor from all scale
+ levels of a single image. RPN head does not need this value.
+ mlvl_anchors (list[Tensor]): Anchors of all scale level
+ each item has shape (num_anchors, 4).
+ img_meta (dict): Image meta info.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the
+ 5-th column is a score between 0 and 1.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ cfg = copy.deepcopy(cfg)
+ img_shape = img_meta['img_shape']
+
+ # bboxes from different level should be independent during NMS,
+ # level_ids are used as labels for batched NMS to separate them
+ level_ids = []
+ mlvl_scores = []
+ mlvl_bbox_preds = []
+ mlvl_valid_anchors = []
+ nms_pre = cfg.get('nms_pre', -1)
+ for level_idx in range(len(cls_score_list)):
+ rpn_cls_score = cls_score_list[level_idx]
+ rpn_bbox_pred = bbox_pred_list[level_idx]
+ assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
+ rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
+ if self.use_sigmoid_cls:
+ rpn_cls_score = rpn_cls_score.reshape(-1)
+ scores = rpn_cls_score.sigmoid()
+ else:
+ rpn_cls_score = rpn_cls_score.reshape(-1, 2)
+ # We set FG labels to [0, num_class-1] and BG label to
+ # num_class in RPN head since mmdet v2.5, which is unified to
+ # be consistent with other head since mmdet v2.0. In mmdet v2.0
+ # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
+ scores = rpn_cls_score.softmax(dim=1)[:, 0]
+ rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+
+ anchors = mlvl_anchors[level_idx]
+ if 0 < nms_pre < scores.shape[0]:
+ # sort is faster than topk
+ # _, topk_inds = scores.topk(cfg.nms_pre)
+ ranked_scores, rank_inds = scores.sort(descending=True)
+ topk_inds = rank_inds[:nms_pre]
+ scores = ranked_scores[:nms_pre]
+ rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
+ anchors = anchors[topk_inds, :]
+
+ mlvl_scores.append(scores)
+ mlvl_bbox_preds.append(rpn_bbox_pred)
+ mlvl_valid_anchors.append(anchors)
+ level_ids.append(
+ scores.new_full((scores.size(0), ),
+ level_idx,
+ dtype=torch.long))
+
+ return self._bbox_post_process(mlvl_scores, mlvl_bbox_preds,
+ mlvl_valid_anchors, level_ids, cfg,
+ img_shape)
+
+ def _bbox_post_process(self, mlvl_scores, mlvl_bboxes, mlvl_valid_anchors,
+ level_ids, cfg, img_shape, **kwargs):
+ """bbox post-processing method.
+
+ Do the nms operation for bboxes in same level.
+
+ Args:
+ mlvl_scores (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_bboxes, ).
+ mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale
+ levels of a single image, each item has shape (num_bboxes, 4).
+ mlvl_valid_anchors (list[Tensor]): Anchors of all scale level
+ each item has shape (num_bboxes, 4).
+ level_ids (list[Tensor]): Indexes from all scale levels of a
+ single image, each item has shape (num_bboxes, ).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, `self.test_cfg` would be used.
+ img_shape (tuple(int)): The shape of model's input image.
+
+ Returns:
+ Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the
+ 5-th column is a score between 0 and 1.
+ """
+ scores = torch.cat(mlvl_scores)
+ anchors = torch.cat(mlvl_valid_anchors)
+ rpn_bbox_pred = torch.cat(mlvl_bboxes)
+ proposals = self.bbox_coder.decode(
+ anchors, rpn_bbox_pred, max_shape=img_shape)
+ ids = torch.cat(level_ids)
+
+ if cfg.min_bbox_size >= 0:
+ w = proposals[:, 2] - proposals[:, 0]
+ h = proposals[:, 3] - proposals[:, 1]
+ valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
+ if not valid_mask.all():
+ proposals = proposals[valid_mask]
+ scores = scores[valid_mask]
+ ids = ids[valid_mask]
+
+ if proposals.numel() > 0:
+ dets, _ = batched_nms(proposals, scores, ids, cfg.nms)
+ else:
+ return proposals.new_zeros(0, 5)
+
+ return dets[:cfg.max_per_img]
+
+ def onnx_export(self, x, img_metas):
+ """Test without augmentation.
+
+ Args:
+ x (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ img_metas (list[dict]): Meta info of each image.
+ Returns:
+ Tensor: dets of shape [N, num_det, 5].
+ """
+ cls_scores, bbox_preds = self(x)
+
+ assert len(cls_scores) == len(bbox_preds)
+
+ batch_bboxes, batch_scores = super(RPNHead, self).onnx_export(
+ cls_scores, bbox_preds, img_metas=img_metas, with_nms=False)
+ # Use ONNX::NonMaxSuppression in deployment
+ from mmdet.core.export import add_dummy_nms_for_onnx
+ cfg = copy.deepcopy(self.test_cfg)
+ score_threshold = cfg.nms.get('score_thr', 0.0)
+ nms_pre = cfg.get('deploy_nms_pre', -1)
+ # Different from the normal forward doing NMS level by level,
+ # we do NMS across all levels when exporting ONNX.
+ dets, _ = add_dummy_nms_for_onnx(batch_bboxes, batch_scores,
+ cfg.max_per_img,
+ cfg.nms.iou_threshold,
+ score_threshold, nms_pre,
+ cfg.max_per_img)
+ return dets
diff --git a/mmdet/models/dense_heads/sabl_retina_head.py b/mmdet/models/dense_heads/sabl_retina_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fede7109dfcb36ab4e43df3da6900cef6a6a1c8
--- /dev/null
+++ b/mmdet/models/dense_heads/sabl_retina_head.py
@@ -0,0 +1,630 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import force_fp32
+
+from mmdet.core import (build_assigner, build_bbox_coder,
+ build_prior_generator, build_sampler, images_to_levels,
+ multi_apply, unmap)
+from mmdet.core.utils import filter_scores_and_topk
+from ..builder import HEADS, build_loss
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+from .guided_anchor_head import GuidedAnchorHead
+
+
+@HEADS.register_module()
+class SABLRetinaHead(BaseDenseHead, BBoxTestMixin):
+ """Side-Aware Boundary Localization (SABL) for RetinaNet.
+
+ The anchor generation, assigning and sampling in SABLRetinaHead
+ are the same as GuidedAnchorHead for guided anchoring.
+
+ Please refer to https://arxiv.org/abs/1912.04260 for more details.
+
+ Args:
+ num_classes (int): Number of classes.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): Number of Convs for classification \
+ and regression branches. Defaults to 4.
+ feat_channels (int): Number of hidden channels. \
+ Defaults to 256.
+ approx_anchor_generator (dict): Config dict for approx generator.
+ square_anchor_generator (dict): Config dict for square generator.
+ conv_cfg (dict): Config dict for ConvModule. Defaults to None.
+ norm_cfg (dict): Config dict for Norm Layer. Defaults to None.
+ bbox_coder (dict): Config dict for bbox coder.
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ train_cfg (dict): Training config of SABLRetinaHead.
+ test_cfg (dict): Testing config of SABLRetinaHead.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox_cls (dict): Config of classification loss for bbox branch.
+ loss_bbox_reg (dict): Config of regression loss for bbox branch.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ feat_channels=256,
+ approx_anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=4,
+ scales_per_octave=3,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64, 128]),
+ square_anchor_generator=dict(
+ type='AnchorGenerator',
+ ratios=[1.0],
+ scales=[4],
+ strides=[8, 16, 32, 64, 128]),
+ conv_cfg=None,
+ norm_cfg=None,
+ bbox_coder=dict(
+ type='BucketingBBoxCoder',
+ num_buckets=14,
+ scale_factor=3.0),
+ reg_decoded_bbox=False,
+ train_cfg=None,
+ test_cfg=None,
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.5),
+ loss_bbox_reg=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5),
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='retina_cls',
+ std=0.01,
+ bias_prob=0.01))):
+ super(SABLRetinaHead, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.feat_channels = feat_channels
+ self.num_buckets = bbox_coder['num_buckets']
+ self.side_num = int(np.ceil(self.num_buckets / 2))
+
+ assert (approx_anchor_generator['octave_base_scale'] ==
+ square_anchor_generator['scales'][0])
+ assert (approx_anchor_generator['strides'] ==
+ square_anchor_generator['strides'])
+
+ self.approx_anchor_generator = build_prior_generator(
+ approx_anchor_generator)
+ self.square_anchor_generator = build_prior_generator(
+ square_anchor_generator)
+ self.approxs_per_octave = (
+ self.approx_anchor_generator.num_base_priors[0])
+
+ # one anchor per location
+ self.num_base_priors = self.square_anchor_generator.num_base_priors[0]
+
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.reg_decoded_bbox = reg_decoded_bbox
+
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ self.sampling = loss_cls['type'] not in [
+ 'FocalLoss', 'GHMC', 'QualityFocalLoss'
+ ]
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = num_classes
+ else:
+ self.cls_out_channels = num_classes + 1
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox_cls = build_loss(loss_bbox_cls)
+ self.loss_bbox_reg = build_loss(loss_bbox_reg)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.fp16_enabled = False
+ self._init_layers()
+
+ @property
+ def num_anchors(self):
+ warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
+ 'please use "num_base_priors" instead')
+ return self.square_anchor_generator.num_base_priors[0]
+
+ def _init_layers(self):
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.retina_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ self.retina_bbox_reg = nn.Conv2d(
+ self.feat_channels, self.side_num * 4, 3, padding=1)
+ self.retina_bbox_cls = nn.Conv2d(
+ self.feat_channels, self.side_num * 4, 3, padding=1)
+
+ def forward_single(self, x):
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.retina_cls(cls_feat)
+ bbox_cls_pred = self.retina_bbox_cls(reg_feat)
+ bbox_reg_pred = self.retina_bbox_reg(reg_feat)
+ bbox_pred = (bbox_cls_pred, bbox_reg_pred)
+ return cls_score, bbox_pred
+
+ def forward(self, feats):
+ return multi_apply(self.forward_single, feats)
+
+ def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
+ """Get squares according to feature map sizes and guided anchors.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+ device (torch.device | str): device for returned tensors
+
+ Returns:
+ tuple: square approxs of each image
+ """
+ num_imgs = len(img_metas)
+
+ # since feature map sizes of all images are the same, we only compute
+ # squares for one time
+ multi_level_squares = self.square_anchor_generator.grid_priors(
+ featmap_sizes, device=device)
+ squares_list = [multi_level_squares for _ in range(num_imgs)]
+
+ return squares_list
+
+ def get_target(self,
+ approx_list,
+ inside_flag_list,
+ square_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=None,
+ sampling=True,
+ unmap_outputs=True):
+ """Compute bucketing targets.
+ Args:
+ approx_list (list[list]): Multi level approxs of each image.
+ inside_flag_list (list[list]): Multi level inside flags of each
+ image.
+ square_list (list[list]): Multi level squares of each image.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): ignore list of gt bboxes.
+ gt_bboxes_list (list[Tensor]): Gt bboxes of each image.
+ label_channels (int): Channel of label.
+ sampling (bool): Sample Anchors or not.
+ unmap_outputs (bool): unmap outputs or not.
+
+ Returns:
+ tuple: Returns a tuple containing learning targets.
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each \
+ level.
+ - bbox_cls_targets_list (list[Tensor]): BBox cls targets of \
+ each level.
+ - bbox_cls_weights_list (list[Tensor]): BBox cls weights of \
+ each level.
+ - bbox_reg_targets_list (list[Tensor]): BBox reg targets of \
+ each level.
+ - bbox_reg_weights_list (list[Tensor]): BBox reg weights of \
+ each level.
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+ """
+ num_imgs = len(img_metas)
+ assert len(approx_list) == len(inside_flag_list) == len(
+ square_list) == num_imgs
+ # anchor number of multi levels
+ num_level_squares = [squares.size(0) for squares in square_list[0]]
+ # concat all level anchors and flags to a single tensor
+ inside_flag_flat_list = []
+ approx_flat_list = []
+ square_flat_list = []
+ for i in range(num_imgs):
+ assert len(square_list[i]) == len(inside_flag_list[i])
+ inside_flag_flat_list.append(torch.cat(inside_flag_list[i]))
+ approx_flat_list.append(torch.cat(approx_list[i]))
+ square_flat_list.append(torch.cat(square_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_labels, all_label_weights, all_bbox_cls_targets,
+ all_bbox_cls_weights, all_bbox_reg_targets, all_bbox_reg_weights,
+ pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ approx_flat_list,
+ inside_flag_flat_list,
+ square_flat_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ sampling=sampling,
+ unmap_outputs=unmap_outputs)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ labels_list = images_to_levels(all_labels, num_level_squares)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_squares)
+ bbox_cls_targets_list = images_to_levels(all_bbox_cls_targets,
+ num_level_squares)
+ bbox_cls_weights_list = images_to_levels(all_bbox_cls_weights,
+ num_level_squares)
+ bbox_reg_targets_list = images_to_levels(all_bbox_reg_targets,
+ num_level_squares)
+ bbox_reg_weights_list = images_to_levels(all_bbox_reg_weights,
+ num_level_squares)
+ return (labels_list, label_weights_list, bbox_cls_targets_list,
+ bbox_cls_weights_list, bbox_reg_targets_list,
+ bbox_reg_weights_list, num_total_pos, num_total_neg)
+
+ def _get_target_single(self,
+ flat_approxs,
+ inside_flags,
+ flat_squares,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=None,
+ sampling=True,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ Args:
+ flat_approxs (Tensor): flat approxs of a single image,
+ shape (n, 4)
+ inside_flags (Tensor): inside flags of a single image,
+ shape (n, ).
+ flat_squares (Tensor): flat squares of a single image,
+ shape (approxs_per_octave * n, 4)
+ gt_bboxes (Tensor): Ground truth bboxes of a single image, \
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ img_meta (dict): Meta info of the image.
+ label_channels (int): Channel of label.
+ sampling (bool): Sample Anchors or not.
+ unmap_outputs (bool): unmap outputs or not.
+
+ Returns:
+ tuple:
+
+ - labels_list (Tensor): Labels in a single image
+ - label_weights (Tensor): Label weights in a single image
+ - bbox_cls_targets (Tensor): BBox cls targets in a single image
+ - bbox_cls_weights (Tensor): BBox cls weights in a single image
+ - bbox_reg_targets (Tensor): BBox reg targets in a single image
+ - bbox_reg_weights (Tensor): BBox reg weights in a single image
+ - num_total_pos (int): Number of positive samples \
+ in a single image
+ - num_total_neg (int): Number of negative samples \
+ in a single image
+ """
+ if not inside_flags.any():
+ return (None, ) * 8
+ # assign gt and sample anchors
+ expand_inside_flags = inside_flags[:, None].expand(
+ -1, self.approxs_per_octave).reshape(-1)
+ approxs = flat_approxs[expand_inside_flags, :]
+ squares = flat_squares[inside_flags, :]
+
+ assign_result = self.assigner.assign(approxs, squares,
+ self.approxs_per_octave,
+ gt_bboxes, gt_bboxes_ignore)
+ sampling_result = self.sampler.sample(assign_result, squares,
+ gt_bboxes)
+
+ num_valid_squares = squares.shape[0]
+ bbox_cls_targets = squares.new_zeros(
+ (num_valid_squares, self.side_num * 4))
+ bbox_cls_weights = squares.new_zeros(
+ (num_valid_squares, self.side_num * 4))
+ bbox_reg_targets = squares.new_zeros(
+ (num_valid_squares, self.side_num * 4))
+ bbox_reg_weights = squares.new_zeros(
+ (num_valid_squares, self.side_num * 4))
+ labels = squares.new_full((num_valid_squares, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = squares.new_zeros(num_valid_squares, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ (pos_bbox_reg_targets, pos_bbox_reg_weights, pos_bbox_cls_targets,
+ pos_bbox_cls_weights) = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+
+ bbox_cls_targets[pos_inds, :] = pos_bbox_cls_targets
+ bbox_reg_targets[pos_inds, :] = pos_bbox_reg_targets
+ bbox_cls_weights[pos_inds, :] = pos_bbox_cls_weights
+ bbox_reg_weights[pos_inds, :] = pos_bbox_reg_weights
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_squares.size(0)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags, fill=self.num_classes)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_cls_targets = unmap(bbox_cls_targets, num_total_anchors,
+ inside_flags)
+ bbox_cls_weights = unmap(bbox_cls_weights, num_total_anchors,
+ inside_flags)
+ bbox_reg_targets = unmap(bbox_reg_targets, num_total_anchors,
+ inside_flags)
+ bbox_reg_weights = unmap(bbox_reg_weights, num_total_anchors,
+ inside_flags)
+ return (labels, label_weights, bbox_cls_targets, bbox_cls_weights,
+ bbox_reg_targets, bbox_reg_weights, pos_inds, neg_inds)
+
+ def loss_single(self, cls_score, bbox_pred, labels, label_weights,
+ bbox_cls_targets, bbox_cls_weights, bbox_reg_targets,
+ bbox_reg_weights, num_total_samples):
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+ # regression loss
+ bbox_cls_targets = bbox_cls_targets.reshape(-1, self.side_num * 4)
+ bbox_cls_weights = bbox_cls_weights.reshape(-1, self.side_num * 4)
+ bbox_reg_targets = bbox_reg_targets.reshape(-1, self.side_num * 4)
+ bbox_reg_weights = bbox_reg_weights.reshape(-1, self.side_num * 4)
+ (bbox_cls_pred, bbox_reg_pred) = bbox_pred
+ bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape(
+ -1, self.side_num * 4)
+ bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape(
+ -1, self.side_num * 4)
+ loss_bbox_cls = self.loss_bbox_cls(
+ bbox_cls_pred,
+ bbox_cls_targets.long(),
+ bbox_cls_weights,
+ avg_factor=num_total_samples * 4 * self.side_num)
+ loss_bbox_reg = self.loss_bbox_reg(
+ bbox_reg_pred,
+ bbox_reg_targets,
+ bbox_reg_weights,
+ avg_factor=num_total_samples * 4 * self.bbox_coder.offset_topk)
+ return loss_cls, loss_bbox_cls, loss_bbox_reg
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.approx_anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ # get sampled approxes
+ approxs_list, inside_flag_list = GuidedAnchorHead.get_sampled_approxs(
+ self, featmap_sizes, img_metas, device=device)
+
+ square_list = self.get_anchors(featmap_sizes, img_metas, device=device)
+
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = self.get_target(
+ approxs_list,
+ inside_flag_list,
+ square_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ sampling=self.sampling)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_cls_targets_list,
+ bbox_cls_weights_list, bbox_reg_targets_list, bbox_reg_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+ losses_cls, losses_bbox_cls, losses_bbox_reg = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ labels_list,
+ label_weights_list,
+ bbox_cls_targets_list,
+ bbox_cls_weights_list,
+ bbox_reg_targets_list,
+ bbox_reg_weights_list,
+ num_total_samples=num_total_samples)
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox_cls=losses_bbox_cls,
+ loss_bbox_reg=losses_bbox_reg)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ img_metas,
+ cfg=None,
+ rescale=False):
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+
+ device = cls_scores[0].device
+ mlvl_anchors = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_cls_pred_list = [
+ bbox_preds[i][0][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_reg_pred_list = [
+ bbox_preds[i][1][img_id].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(
+ cls_score_list, bbox_cls_pred_list, bbox_reg_pred_list,
+ mlvl_anchors[img_id], img_shape, scale_factor, cfg, rescale)
+ result_list.append(proposals)
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_cls_preds,
+ bbox_reg_preds,
+ mlvl_anchors,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ cfg = self.test_cfg if cfg is None else cfg
+ nms_pre = cfg.get('nms_pre', -1)
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_confids = []
+ mlvl_labels = []
+ assert len(cls_scores) == len(bbox_cls_preds) == len(
+ bbox_reg_preds) == len(mlvl_anchors)
+ for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip(
+ cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_cls_pred.size(
+ )[-2:] == bbox_reg_pred.size()[-2::]
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)[:, :-1]
+ bbox_cls_pred = bbox_cls_pred.permute(1, 2, 0).reshape(
+ -1, self.side_num * 4)
+ bbox_reg_pred = bbox_reg_pred.permute(1, 2, 0).reshape(
+ -1, self.side_num * 4)
+
+ # After https://github.com/open-mmlab/mmdetection/pull/6268/,
+ # this operation keeps fewer bboxes under the same `nms_pre`.
+ # There is no difference in performance for most models. If you
+ # find a slight drop in performance, you can set a larger
+ # `nms_pre` than before.
+ results = filter_scores_and_topk(
+ scores, cfg.score_thr, nms_pre,
+ dict(
+ anchors=anchors,
+ bbox_cls_pred=bbox_cls_pred,
+ bbox_reg_pred=bbox_reg_pred))
+ scores, labels, _, filtered_results = results
+
+ anchors = filtered_results['anchors']
+ bbox_cls_pred = filtered_results['bbox_cls_pred']
+ bbox_reg_pred = filtered_results['bbox_reg_pred']
+
+ bbox_preds = [
+ bbox_cls_pred.contiguous(),
+ bbox_reg_pred.contiguous()
+ ]
+ bboxes, confids = self.bbox_coder.decode(
+ anchors.contiguous(), bbox_preds, max_shape=img_shape)
+
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_confids.append(confids)
+ mlvl_labels.append(labels)
+ return self._bbox_post_process(mlvl_scores, mlvl_labels, mlvl_bboxes,
+ scale_factor, cfg, rescale, True,
+ mlvl_confids)
diff --git a/mmdet/models/dense_heads/solo_head.py b/mmdet/models/dense_heads/solo_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e89aacb420af4f5df11183e656e04c87f3dc8fe4
--- /dev/null
+++ b/mmdet/models/dense_heads/solo_head.py
@@ -0,0 +1,1197 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+
+from mmdet.core import InstanceData, mask_matrix_nms, multi_apply
+from mmdet.core.utils import center_of_mass, generate_coordinate
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.utils.misc import floordiv
+from .base_mask_head import BaseMaskHead
+
+
+@HEADS.register_module()
+class SOLOHead(BaseMaskHead):
+ """SOLO mask head used in `SOLO: Segmenting Objects by Locations.
+
+ `_
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels. Used in child classes.
+ Default: 256.
+ stacked_convs (int): Number of stacking convs of the head.
+ Default: 4.
+ strides (tuple): Downsample factor of each feature map.
+ scale_ranges (tuple[tuple[int, int]]): Area range of multiple
+ level masks, in the format [(min1, max1), (min2, max2), ...].
+ A range of (16, 64) means the area range between (16, 64).
+ pos_scale (float): Constant scale factor to control the center region.
+ num_grids (list[int]): Divided image into a uniform grids, each
+ feature map has a different grid value. The number of output
+ channels is grid ** 2. Default: [40, 36, 24, 16, 12].
+ cls_down_index (int): The index of downsample operation in
+ classification branch. Default: 0.
+ loss_mask (dict): Config of mask loss.
+ loss_cls (dict): Config of classification loss.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: norm_cfg=dict(type='GN', num_groups=32,
+ requires_grad=True).
+ train_cfg (dict): Training config of head.
+ test_cfg (dict): Testing config of head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(
+ self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ stacked_convs=4,
+ strides=(4, 8, 16, 32, 64),
+ scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
+ pos_scale=0.2,
+ num_grids=[40, 36, 24, 16, 12],
+ cls_down_index=0,
+ loss_mask=None,
+ loss_cls=None,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=[
+ dict(type='Normal', layer='Conv2d', std=0.01),
+ dict(
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01,
+ override=dict(name='conv_mask_list')),
+ dict(
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01,
+ override=dict(name='conv_cls'))
+ ],
+ ):
+ super(SOLOHead, self).__init__(init_cfg)
+ self.num_classes = num_classes
+ self.cls_out_channels = self.num_classes
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.stacked_convs = stacked_convs
+ self.strides = strides
+ self.num_grids = num_grids
+ # number of FPN feats
+ self.num_levels = len(strides)
+ assert self.num_levels == len(scale_ranges) == len(num_grids)
+ self.scale_ranges = scale_ranges
+ self.pos_scale = pos_scale
+
+ self.cls_down_index = cls_down_index
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_mask = build_loss(loss_mask)
+ self.norm_cfg = norm_cfg
+ self.init_cfg = init_cfg
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self._init_layers()
+
+ def _init_layers(self):
+ self.mask_convs = nn.ModuleList()
+ self.cls_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels + 2 if i == 0 else self.feat_channels
+ self.mask_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=self.norm_cfg))
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=self.norm_cfg))
+ self.conv_mask_list = nn.ModuleList()
+ for num_grid in self.num_grids:
+ self.conv_mask_list.append(
+ nn.Conv2d(self.feat_channels, num_grid**2, 1))
+
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+
+ def resize_feats(self, feats):
+ """Downsample the first feat and upsample last feat in feats."""
+ out = []
+ for i in range(len(feats)):
+ if i == 0:
+ out.append(
+ F.interpolate(
+ feats[0],
+ size=feats[i + 1].shape[-2:],
+ mode='bilinear',
+ align_corners=False))
+ elif i == len(feats) - 1:
+ out.append(
+ F.interpolate(
+ feats[i],
+ size=feats[i - 1].shape[-2:],
+ mode='bilinear',
+ align_corners=False))
+ else:
+ out.append(feats[i])
+ return out
+
+ def forward(self, feats):
+ assert len(feats) == self.num_levels
+ feats = self.resize_feats(feats)
+ mlvl_mask_preds = []
+ mlvl_cls_preds = []
+ for i in range(self.num_levels):
+ x = feats[i]
+ mask_feat = x
+ cls_feat = x
+ # generate and concat the coordinate
+ coord_feat = generate_coordinate(mask_feat.size(),
+ mask_feat.device)
+ mask_feat = torch.cat([mask_feat, coord_feat], 1)
+
+ for mask_layer in (self.mask_convs):
+ mask_feat = mask_layer(mask_feat)
+
+ mask_feat = F.interpolate(
+ mask_feat, scale_factor=2, mode='bilinear')
+ mask_pred = self.conv_mask_list[i](mask_feat)
+
+ # cls branch
+ for j, cls_layer in enumerate(self.cls_convs):
+ if j == self.cls_down_index:
+ num_grid = self.num_grids[i]
+ cls_feat = F.interpolate(
+ cls_feat, size=num_grid, mode='bilinear')
+ cls_feat = cls_layer(cls_feat)
+
+ cls_pred = self.conv_cls(cls_feat)
+
+ if not self.training:
+ feat_wh = feats[0].size()[-2:]
+ upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
+ mask_pred = F.interpolate(
+ mask_pred.sigmoid(), size=upsampled_size, mode='bilinear')
+ cls_pred = cls_pred.sigmoid()
+ # get local maximum
+ local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
+ keep_mask = local_max[:, :, :-1, :-1] == cls_pred
+ cls_pred = cls_pred * keep_mask
+
+ mlvl_mask_preds.append(mask_pred)
+ mlvl_cls_preds.append(cls_pred)
+ return mlvl_mask_preds, mlvl_cls_preds
+
+ def loss(self,
+ mlvl_mask_preds,
+ mlvl_cls_preds,
+ gt_labels,
+ gt_masks,
+ img_metas,
+ gt_bboxes=None,
+ **kwargs):
+ """Calculate the loss of total batch.
+
+ Args:
+ mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
+ Each element in the list has shape
+ (batch_size, num_grids**2 ,h ,w).
+ mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
+ in the list has shape
+ (batch_size, num_classes, num_grids ,num_grids).
+ gt_labels (list[Tensor]): Labels of multiple images.
+ gt_masks (list[Tensor]): Ground truth masks of multiple images.
+ Each has shape (num_instances, h, w).
+ img_metas (list[dict]): Meta information of multiple images.
+ gt_bboxes (list[Tensor]): Ground truth bboxes of multiple
+ images. Default: None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_levels = self.num_levels
+ num_imgs = len(gt_labels)
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds]
+
+ # `BoolTensor` in `pos_masks` represent
+ # whether the corresponding point is
+ # positive
+ pos_mask_targets, labels, pos_masks = multi_apply(
+ self._get_targets_single,
+ gt_bboxes,
+ gt_labels,
+ gt_masks,
+ featmap_sizes=featmap_sizes)
+
+ # change from the outside list meaning multi images
+ # to the outside list meaning multi levels
+ mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
+ mlvl_pos_mask_preds = [[] for _ in range(num_levels)]
+ mlvl_pos_masks = [[] for _ in range(num_levels)]
+ mlvl_labels = [[] for _ in range(num_levels)]
+ for img_id in range(num_imgs):
+ assert num_levels == len(pos_mask_targets[img_id])
+ for lvl in range(num_levels):
+ mlvl_pos_mask_targets[lvl].append(
+ pos_mask_targets[img_id][lvl])
+ mlvl_pos_mask_preds[lvl].append(
+ mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...])
+ mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten())
+ mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
+
+ # cat multiple image
+ temp_mlvl_cls_preds = []
+ for lvl in range(num_levels):
+ mlvl_pos_mask_targets[lvl] = torch.cat(
+ mlvl_pos_mask_targets[lvl], dim=0)
+ mlvl_pos_mask_preds[lvl] = torch.cat(
+ mlvl_pos_mask_preds[lvl], dim=0)
+ mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0)
+ mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
+ temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
+ 0, 2, 3, 1).reshape(-1, self.cls_out_channels))
+
+ num_pos = sum(item.sum() for item in mlvl_pos_masks)
+ # dice loss
+ loss_mask = []
+ for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets):
+ if pred.size()[0] == 0:
+ loss_mask.append(pred.sum().unsqueeze(0))
+ continue
+ loss_mask.append(
+ self.loss_mask(pred, target, reduction_override='none'))
+ if num_pos > 0:
+ loss_mask = torch.cat(loss_mask).sum() / num_pos
+ else:
+ loss_mask = torch.cat(loss_mask).mean()
+
+ flatten_labels = torch.cat(mlvl_labels)
+ flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
+ loss_cls = self.loss_cls(
+ flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
+ return dict(loss_mask=loss_mask, loss_cls=loss_cls)
+
+ def _get_targets_single(self,
+ gt_bboxes,
+ gt_labels,
+ gt_masks,
+ featmap_sizes=None):
+ """Compute targets for predictions of single image.
+
+ Args:
+ gt_bboxes (Tensor): Ground truth bbox of each instance,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth label of each instance,
+ shape (num_gts,).
+ gt_masks (Tensor): Ground truth mask of each instance,
+ shape (num_gts, h, w).
+ featmap_sizes (list[:obj:`torch.size`]): Size of each
+ feature map from feature pyramid, each element
+ means (feat_h, feat_w). Default: None.
+
+ Returns:
+ Tuple: Usually returns a tuple containing targets for predictions.
+
+ - mlvl_pos_mask_targets (list[Tensor]): Each element represent
+ the binary mask targets for positive points in this
+ level, has shape (num_pos, out_h, out_w).
+ - mlvl_labels (list[Tensor]): Each element is
+ classification labels for all
+ points in this level, has shape
+ (num_grid, num_grid).
+ - mlvl_pos_masks (list[Tensor]): Each element is
+ a `BoolTensor` to represent whether the
+ corresponding point in single level
+ is positive, has shape (num_grid **2).
+ """
+ device = gt_labels.device
+ gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
+ (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
+
+ mlvl_pos_mask_targets = []
+ mlvl_labels = []
+ mlvl_pos_masks = []
+ for (lower_bound, upper_bound), stride, featmap_size, num_grid \
+ in zip(self.scale_ranges, self.strides,
+ featmap_sizes, self.num_grids):
+
+ mask_target = torch.zeros(
+ [num_grid**2, featmap_size[0], featmap_size[1]],
+ dtype=torch.uint8,
+ device=device)
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ labels = torch.zeros([num_grid, num_grid],
+ dtype=torch.int64,
+ device=device) + self.num_classes
+ pos_mask = torch.zeros([num_grid**2],
+ dtype=torch.bool,
+ device=device)
+
+ gt_inds = ((gt_areas >= lower_bound) &
+ (gt_areas <= upper_bound)).nonzero().flatten()
+ if len(gt_inds) == 0:
+ mlvl_pos_mask_targets.append(
+ mask_target.new_zeros(0, featmap_size[0], featmap_size[1]))
+ mlvl_labels.append(labels)
+ mlvl_pos_masks.append(pos_mask)
+ continue
+ hit_gt_bboxes = gt_bboxes[gt_inds]
+ hit_gt_labels = gt_labels[gt_inds]
+ hit_gt_masks = gt_masks[gt_inds, ...]
+
+ pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] -
+ hit_gt_bboxes[:, 0]) * self.pos_scale
+ pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] -
+ hit_gt_bboxes[:, 1]) * self.pos_scale
+
+ # Make sure hit_gt_masks has a value
+ valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0
+ output_stride = stride / 2
+
+ for gt_mask, gt_label, pos_h_range, pos_w_range, \
+ valid_mask_flag in \
+ zip(hit_gt_masks, hit_gt_labels, pos_h_ranges,
+ pos_w_ranges, valid_mask_flags):
+ if not valid_mask_flag:
+ continue
+ upsampled_size = (featmap_sizes[0][0] * 4,
+ featmap_sizes[0][1] * 4)
+ center_h, center_w = center_of_mass(gt_mask)
+
+ coord_w = int(
+ floordiv((center_w / upsampled_size[1]), (1. / num_grid),
+ rounding_mode='trunc'))
+ coord_h = int(
+ floordiv((center_h / upsampled_size[0]), (1. / num_grid),
+ rounding_mode='trunc'))
+
+ # left, top, right, down
+ top_box = max(
+ 0,
+ int(
+ floordiv(
+ (center_h - pos_h_range) / upsampled_size[0],
+ (1. / num_grid),
+ rounding_mode='trunc')))
+ down_box = min(
+ num_grid - 1,
+ int(
+ floordiv(
+ (center_h + pos_h_range) / upsampled_size[0],
+ (1. / num_grid),
+ rounding_mode='trunc')))
+ left_box = max(
+ 0,
+ int(
+ floordiv(
+ (center_w - pos_w_range) / upsampled_size[1],
+ (1. / num_grid),
+ rounding_mode='trunc')))
+ right_box = min(
+ num_grid - 1,
+ int(
+ floordiv(
+ (center_w + pos_w_range) / upsampled_size[1],
+ (1. / num_grid),
+ rounding_mode='trunc')))
+
+ top = max(top_box, coord_h - 1)
+ down = min(down_box, coord_h + 1)
+ left = max(coord_w - 1, left_box)
+ right = min(right_box, coord_w + 1)
+
+ labels[top:(down + 1), left:(right + 1)] = gt_label
+ # ins
+ gt_mask = np.uint8(gt_mask.cpu().numpy())
+ # Follow the original implementation, F.interpolate is
+ # different from cv2 and opencv
+ gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride)
+ gt_mask = torch.from_numpy(gt_mask).to(device=device)
+
+ for i in range(top, down + 1):
+ for j in range(left, right + 1):
+ index = int(i * num_grid + j)
+ mask_target[index, :gt_mask.shape[0], :gt_mask.
+ shape[1]] = gt_mask
+ pos_mask[index] = True
+ mlvl_pos_mask_targets.append(mask_target[pos_mask])
+ mlvl_labels.append(labels)
+ mlvl_pos_masks.append(pos_mask)
+ return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks
+
+ def get_results(self, mlvl_mask_preds, mlvl_cls_scores, img_metas,
+ **kwargs):
+ """Get multi-image mask results.
+
+ Args:
+ mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
+ Each element in the list has shape
+ (batch_size, num_grids**2 ,h ,w).
+ mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
+ in the list has shape
+ (batch_size, num_classes, num_grids ,num_grids).
+ img_metas (list[dict]): Meta information of all images.
+
+ Returns:
+ list[:obj:`InstanceData`]: Processed results of multiple
+ images.Each :obj:`InstanceData` usually contains
+ following keys.
+
+ - scores (Tensor): Classification scores, has shape
+ (num_instance,).
+ - labels (Tensor): Has shape (num_instances,).
+ - masks (Tensor): Processed mask results, has
+ shape (num_instances, h, w).
+ """
+ mlvl_cls_scores = [
+ item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
+ ]
+ assert len(mlvl_mask_preds) == len(mlvl_cls_scores)
+ num_levels = len(mlvl_cls_scores)
+
+ results_list = []
+ for img_id in range(len(img_metas)):
+ cls_pred_list = [
+ mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels)
+ for lvl in range(num_levels)
+ ]
+ mask_pred_list = [
+ mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels)
+ ]
+
+ cls_pred_list = torch.cat(cls_pred_list, dim=0)
+ mask_pred_list = torch.cat(mask_pred_list, dim=0)
+
+ results = self._get_results_single(
+ cls_pred_list, mask_pred_list, img_meta=img_metas[img_id])
+ results_list.append(results)
+
+ return results_list
+
+ def _get_results_single(self, cls_scores, mask_preds, img_meta, cfg=None):
+ """Get processed mask related results of single image.
+
+ Args:
+ cls_scores (Tensor): Classification score of all points
+ in single image, has shape (num_points, num_classes).
+ mask_preds (Tensor): Mask prediction of all points in
+ single image, has shape (num_points, feat_h, feat_w).
+ img_meta (dict): Meta information of corresponding image.
+ cfg (dict, optional): Config used in test phase.
+ Default: None.
+
+ Returns:
+ :obj:`InstanceData`: Processed results of single image.
+ it usually contains following keys.
+
+ - scores (Tensor): Classification scores, has shape
+ (num_instance,).
+ - labels (Tensor): Has shape (num_instances,).
+ - masks (Tensor): Processed mask results, has
+ shape (num_instances, h, w).
+ """
+
+ def empty_results(results, cls_scores):
+ """Generate a empty results."""
+ results.scores = cls_scores.new_ones(0)
+ results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2])
+ results.labels = cls_scores.new_ones(0)
+ return results
+
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_scores) == len(mask_preds)
+ results = InstanceData(img_meta)
+
+ featmap_size = mask_preds.size()[-2:]
+
+ img_shape = results.img_shape
+ ori_shape = results.ori_shape
+
+ h, w, _ = img_shape
+ upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
+
+ score_mask = (cls_scores > cfg.score_thr)
+ cls_scores = cls_scores[score_mask]
+ if len(cls_scores) == 0:
+ return empty_results(results, cls_scores)
+
+ inds = score_mask.nonzero()
+ cls_labels = inds[:, 1]
+
+ # Filter the mask mask with an area is smaller than
+ # stride of corresponding feature level
+ lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
+ strides = cls_scores.new_ones(lvl_interval[-1])
+ strides[:lvl_interval[0]] *= self.strides[0]
+ for lvl in range(1, self.num_levels):
+ strides[lvl_interval[lvl -
+ 1]:lvl_interval[lvl]] *= self.strides[lvl]
+ strides = strides[inds[:, 0]]
+ mask_preds = mask_preds[inds[:, 0]]
+
+ masks = mask_preds > cfg.mask_thr
+ sum_masks = masks.sum((1, 2)).float()
+ keep = sum_masks > strides
+ if keep.sum() == 0:
+ return empty_results(results, cls_scores)
+ masks = masks[keep]
+ mask_preds = mask_preds[keep]
+ sum_masks = sum_masks[keep]
+ cls_scores = cls_scores[keep]
+ cls_labels = cls_labels[keep]
+
+ # maskness.
+ mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
+ cls_scores *= mask_scores
+
+ scores, labels, _, keep_inds = mask_matrix_nms(
+ masks,
+ cls_labels,
+ cls_scores,
+ mask_area=sum_masks,
+ nms_pre=cfg.nms_pre,
+ max_num=cfg.max_per_img,
+ kernel=cfg.kernel,
+ sigma=cfg.sigma,
+ filter_thr=cfg.filter_thr)
+ mask_preds = mask_preds[keep_inds]
+ mask_preds = F.interpolate(
+ mask_preds.unsqueeze(0), size=upsampled_size,
+ mode='bilinear')[:, :, :h, :w]
+ mask_preds = F.interpolate(
+ mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0)
+ masks = mask_preds > cfg.mask_thr
+
+ results.masks = masks
+ results.labels = labels
+ results.scores = scores
+
+ return results
+
+
+@HEADS.register_module()
+class DecoupledSOLOHead(SOLOHead):
+ """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations.
+
+ `_
+
+ Args:
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ *args,
+ init_cfg=[
+ dict(type='Normal', layer='Conv2d', std=0.01),
+ dict(
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01,
+ override=dict(name='conv_mask_list_x')),
+ dict(
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01,
+ override=dict(name='conv_mask_list_y')),
+ dict(
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01,
+ override=dict(name='conv_cls'))
+ ],
+ **kwargs):
+ super(DecoupledSOLOHead, self).__init__(
+ *args, init_cfg=init_cfg, **kwargs)
+
+ def _init_layers(self):
+ self.mask_convs_x = nn.ModuleList()
+ self.mask_convs_y = nn.ModuleList()
+ self.cls_convs = nn.ModuleList()
+
+ for i in range(self.stacked_convs):
+ chn = self.in_channels + 1 if i == 0 else self.feat_channels
+ self.mask_convs_x.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=self.norm_cfg))
+ self.mask_convs_y.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=self.norm_cfg))
+
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_cfg=self.norm_cfg))
+
+ self.conv_mask_list_x = nn.ModuleList()
+ self.conv_mask_list_y = nn.ModuleList()
+ for num_grid in self.num_grids:
+ self.conv_mask_list_x.append(
+ nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
+ self.conv_mask_list_y.append(
+ nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+
+ def forward(self, feats):
+ assert len(feats) == self.num_levels
+ feats = self.resize_feats(feats)
+ mask_preds_x = []
+ mask_preds_y = []
+ cls_preds = []
+ for i in range(self.num_levels):
+ x = feats[i]
+ mask_feat = x
+ cls_feat = x
+ # generate and concat the coordinate
+ coord_feat = generate_coordinate(mask_feat.size(),
+ mask_feat.device)
+ mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1)
+ mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1)
+
+ for mask_layer_x, mask_layer_y in \
+ zip(self.mask_convs_x, self.mask_convs_y):
+ mask_feat_x = mask_layer_x(mask_feat_x)
+ mask_feat_y = mask_layer_y(mask_feat_y)
+
+ mask_feat_x = F.interpolate(
+ mask_feat_x, scale_factor=2, mode='bilinear')
+ mask_feat_y = F.interpolate(
+ mask_feat_y, scale_factor=2, mode='bilinear')
+
+ mask_pred_x = self.conv_mask_list_x[i](mask_feat_x)
+ mask_pred_y = self.conv_mask_list_y[i](mask_feat_y)
+
+ # cls branch
+ for j, cls_layer in enumerate(self.cls_convs):
+ if j == self.cls_down_index:
+ num_grid = self.num_grids[i]
+ cls_feat = F.interpolate(
+ cls_feat, size=num_grid, mode='bilinear')
+ cls_feat = cls_layer(cls_feat)
+
+ cls_pred = self.conv_cls(cls_feat)
+
+ if not self.training:
+ feat_wh = feats[0].size()[-2:]
+ upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
+ mask_pred_x = F.interpolate(
+ mask_pred_x.sigmoid(),
+ size=upsampled_size,
+ mode='bilinear')
+ mask_pred_y = F.interpolate(
+ mask_pred_y.sigmoid(),
+ size=upsampled_size,
+ mode='bilinear')
+ cls_pred = cls_pred.sigmoid()
+ # get local maximum
+ local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
+ keep_mask = local_max[:, :, :-1, :-1] == cls_pred
+ cls_pred = cls_pred * keep_mask
+
+ mask_preds_x.append(mask_pred_x)
+ mask_preds_y.append(mask_pred_y)
+ cls_preds.append(cls_pred)
+ return mask_preds_x, mask_preds_y, cls_preds
+
+ def loss(self,
+ mlvl_mask_preds_x,
+ mlvl_mask_preds_y,
+ mlvl_cls_preds,
+ gt_labels,
+ gt_masks,
+ img_metas,
+ gt_bboxes=None,
+ **kwargs):
+ """Calculate the loss of total batch.
+
+ Args:
+ mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
+ from x branch. Each element in the list has shape
+ (batch_size, num_grids ,h ,w).
+ mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
+ from y branch. Each element in the list has shape
+ (batch_size, num_grids ,h ,w).
+ mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
+ in the list has shape
+ (batch_size, num_classes, num_grids ,num_grids).
+ gt_labels (list[Tensor]): Labels of multiple images.
+ gt_masks (list[Tensor]): Ground truth masks of multiple images.
+ Each has shape (num_instances, h, w).
+ img_metas (list[dict]): Meta information of multiple images.
+ gt_bboxes (list[Tensor]): Ground truth bboxes of multiple
+ images. Default: None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_levels = self.num_levels
+ num_imgs = len(gt_labels)
+ featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x]
+
+ pos_mask_targets, labels, \
+ xy_pos_indexes = \
+ multi_apply(self._get_targets_single,
+ gt_bboxes,
+ gt_labels,
+ gt_masks,
+ featmap_sizes=featmap_sizes)
+
+ # change from the outside list meaning multi images
+ # to the outside list meaning multi levels
+ mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
+ mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)]
+ mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)]
+ mlvl_labels = [[] for _ in range(num_levels)]
+ for img_id in range(num_imgs):
+
+ for lvl in range(num_levels):
+ mlvl_pos_mask_targets[lvl].append(
+ pos_mask_targets[img_id][lvl])
+ mlvl_pos_mask_preds_x[lvl].append(
+ mlvl_mask_preds_x[lvl][img_id,
+ xy_pos_indexes[img_id][lvl][:, 1]])
+ mlvl_pos_mask_preds_y[lvl].append(
+ mlvl_mask_preds_y[lvl][img_id,
+ xy_pos_indexes[img_id][lvl][:, 0]])
+ mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
+
+ # cat multiple image
+ temp_mlvl_cls_preds = []
+ for lvl in range(num_levels):
+ mlvl_pos_mask_targets[lvl] = torch.cat(
+ mlvl_pos_mask_targets[lvl], dim=0)
+ mlvl_pos_mask_preds_x[lvl] = torch.cat(
+ mlvl_pos_mask_preds_x[lvl], dim=0)
+ mlvl_pos_mask_preds_y[lvl] = torch.cat(
+ mlvl_pos_mask_preds_y[lvl], dim=0)
+ mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
+ temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
+ 0, 2, 3, 1).reshape(-1, self.cls_out_channels))
+
+ num_pos = 0.
+ # dice loss
+ loss_mask = []
+ for pred_x, pred_y, target in \
+ zip(mlvl_pos_mask_preds_x,
+ mlvl_pos_mask_preds_y, mlvl_pos_mask_targets):
+ num_masks = pred_x.size(0)
+ if num_masks == 0:
+ # make sure can get grad
+ loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0))
+ continue
+ num_pos += num_masks
+ pred_mask = pred_y.sigmoid() * pred_x.sigmoid()
+ loss_mask.append(
+ self.loss_mask(pred_mask, target, reduction_override='none'))
+ if num_pos > 0:
+ loss_mask = torch.cat(loss_mask).sum() / num_pos
+ else:
+ loss_mask = torch.cat(loss_mask).mean()
+
+ # cate
+ flatten_labels = torch.cat(mlvl_labels)
+ flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
+
+ loss_cls = self.loss_cls(
+ flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
+ return dict(loss_mask=loss_mask, loss_cls=loss_cls)
+
+ def _get_targets_single(self,
+ gt_bboxes,
+ gt_labels,
+ gt_masks,
+ featmap_sizes=None):
+ """Compute targets for predictions of single image.
+
+ Args:
+ gt_bboxes (Tensor): Ground truth bbox of each instance,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth label of each instance,
+ shape (num_gts,).
+ gt_masks (Tensor): Ground truth mask of each instance,
+ shape (num_gts, h, w).
+ featmap_sizes (list[:obj:`torch.size`]): Size of each
+ feature map from feature pyramid, each element
+ means (feat_h, feat_w). Default: None.
+
+ Returns:
+ Tuple: Usually returns a tuple containing targets for predictions.
+
+ - mlvl_pos_mask_targets (list[Tensor]): Each element represent
+ the binary mask targets for positive points in this
+ level, has shape (num_pos, out_h, out_w).
+ - mlvl_labels (list[Tensor]): Each element is
+ classification labels for all
+ points in this level, has shape
+ (num_grid, num_grid).
+ - mlvl_xy_pos_indexes (list[Tensor]): Each element
+ in the list contains the index of positive samples in
+ corresponding level, has shape (num_pos, 2), last
+ dimension 2 present (index_x, index_y).
+ """
+ mlvl_pos_mask_targets, mlvl_labels, \
+ mlvl_pos_masks = \
+ super()._get_targets_single(gt_bboxes, gt_labels, gt_masks,
+ featmap_sizes=featmap_sizes)
+
+ mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero()
+ for item in mlvl_labels]
+
+ return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes
+
+ def get_results(self,
+ mlvl_mask_preds_x,
+ mlvl_mask_preds_y,
+ mlvl_cls_scores,
+ img_metas,
+ rescale=None,
+ **kwargs):
+ """Get multi-image mask results.
+
+ Args:
+ mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
+ from x branch. Each element in the list has shape
+ (batch_size, num_grids ,h ,w).
+ mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
+ from y branch. Each element in the list has shape
+ (batch_size, num_grids ,h ,w).
+ mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
+ in the list has shape
+ (batch_size, num_classes ,num_grids ,num_grids).
+ img_metas (list[dict]): Meta information of all images.
+
+ Returns:
+ list[:obj:`InstanceData`]: Processed results of multiple
+ images.Each :obj:`InstanceData` usually contains
+ following keys.
+
+ - scores (Tensor): Classification scores, has shape
+ (num_instance,).
+ - labels (Tensor): Has shape (num_instances,).
+ - masks (Tensor): Processed mask results, has
+ shape (num_instances, h, w).
+ """
+ mlvl_cls_scores = [
+ item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
+ ]
+ assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores)
+ num_levels = len(mlvl_cls_scores)
+
+ results_list = []
+ for img_id in range(len(img_metas)):
+ cls_pred_list = [
+ mlvl_cls_scores[i][img_id].view(
+ -1, self.cls_out_channels).detach()
+ for i in range(num_levels)
+ ]
+ mask_pred_list_x = [
+ mlvl_mask_preds_x[i][img_id] for i in range(num_levels)
+ ]
+ mask_pred_list_y = [
+ mlvl_mask_preds_y[i][img_id] for i in range(num_levels)
+ ]
+
+ cls_pred_list = torch.cat(cls_pred_list, dim=0)
+ mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0)
+ mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0)
+
+ results = self._get_results_single(
+ cls_pred_list,
+ mask_pred_list_x,
+ mask_pred_list_y,
+ img_meta=img_metas[img_id],
+ cfg=self.test_cfg)
+ results_list.append(results)
+ return results_list
+
+ def _get_results_single(self, cls_scores, mask_preds_x, mask_preds_y,
+ img_meta, cfg):
+ """Get processed mask related results of single image.
+
+ Args:
+ cls_scores (Tensor): Classification score of all points
+ in single image, has shape (num_points, num_classes).
+ mask_preds_x (Tensor): Mask prediction of x branch of
+ all points in single image, has shape
+ (sum_num_grids, feat_h, feat_w).
+ mask_preds_y (Tensor): Mask prediction of y branch of
+ all points in single image, has shape
+ (sum_num_grids, feat_h, feat_w).
+ img_meta (dict): Meta information of corresponding image.
+ cfg (dict): Config used in test phase.
+
+ Returns:
+ :obj:`InstanceData`: Processed results of single image.
+ it usually contains following keys.
+
+ - scores (Tensor): Classification scores, has shape
+ (num_instance,).
+ - labels (Tensor): Has shape (num_instances,).
+ - masks (Tensor): Processed mask results, has
+ shape (num_instances, h, w).
+ """
+
+ def empty_results(results, cls_scores):
+ """Generate a empty results."""
+ results.scores = cls_scores.new_ones(0)
+ results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2])
+ results.labels = cls_scores.new_ones(0)
+ return results
+
+ cfg = self.test_cfg if cfg is None else cfg
+
+ results = InstanceData(img_meta)
+ img_shape = results.img_shape
+ ori_shape = results.ori_shape
+ h, w, _ = img_shape
+ featmap_size = mask_preds_x.size()[-2:]
+ upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
+
+ score_mask = (cls_scores > cfg.score_thr)
+ cls_scores = cls_scores[score_mask]
+ inds = score_mask.nonzero()
+ lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0)
+ num_all_points = lvl_interval[-1]
+ lvl_start_index = inds.new_ones(num_all_points)
+ num_grids = inds.new_ones(num_all_points)
+ seg_size = inds.new_tensor(self.num_grids).cumsum(0)
+ mask_lvl_start_index = inds.new_ones(num_all_points)
+ strides = inds.new_ones(num_all_points)
+
+ lvl_start_index[:lvl_interval[0]] *= 0
+ mask_lvl_start_index[:lvl_interval[0]] *= 0
+ num_grids[:lvl_interval[0]] *= self.num_grids[0]
+ strides[:lvl_interval[0]] *= self.strides[0]
+
+ for lvl in range(1, self.num_levels):
+ lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
+ lvl_interval[lvl - 1]
+ mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
+ seg_size[lvl - 1]
+ num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
+ self.num_grids[lvl]
+ strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
+ self.strides[lvl]
+
+ lvl_start_index = lvl_start_index[inds[:, 0]]
+ mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]]
+ num_grids = num_grids[inds[:, 0]]
+ strides = strides[inds[:, 0]]
+
+ y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids
+ x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids
+ y_inds = mask_lvl_start_index + y_lvl_offset
+ x_inds = mask_lvl_start_index + x_lvl_offset
+
+ cls_labels = inds[:, 1]
+ mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...]
+
+ masks = mask_preds > cfg.mask_thr
+ sum_masks = masks.sum((1, 2)).float()
+ keep = sum_masks > strides
+ if keep.sum() == 0:
+ return empty_results(results, cls_scores)
+
+ masks = masks[keep]
+ mask_preds = mask_preds[keep]
+ sum_masks = sum_masks[keep]
+ cls_scores = cls_scores[keep]
+ cls_labels = cls_labels[keep]
+
+ # maskness.
+ mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
+ cls_scores *= mask_scores
+
+ scores, labels, _, keep_inds = mask_matrix_nms(
+ masks,
+ cls_labels,
+ cls_scores,
+ mask_area=sum_masks,
+ nms_pre=cfg.nms_pre,
+ max_num=cfg.max_per_img,
+ kernel=cfg.kernel,
+ sigma=cfg.sigma,
+ filter_thr=cfg.filter_thr)
+ mask_preds = mask_preds[keep_inds]
+ mask_preds = F.interpolate(
+ mask_preds.unsqueeze(0), size=upsampled_size,
+ mode='bilinear')[:, :, :h, :w]
+ mask_preds = F.interpolate(
+ mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0)
+ masks = mask_preds > cfg.mask_thr
+
+ results.masks = masks
+ results.labels = labels
+ results.scores = scores
+
+ return results
+
+
+@HEADS.register_module()
+class DecoupledSOLOLightHead(DecoupledSOLOHead):
+ """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by
+ Locations `_
+
+ Args:
+ with_dcn (bool): Whether use dcn in mask_convs and cls_convs,
+ default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ *args,
+ dcn_cfg=None,
+ init_cfg=[
+ dict(type='Normal', layer='Conv2d', std=0.01),
+ dict(
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01,
+ override=dict(name='conv_mask_list_x')),
+ dict(
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01,
+ override=dict(name='conv_mask_list_y')),
+ dict(
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01,
+ override=dict(name='conv_cls'))
+ ],
+ **kwargs):
+ assert dcn_cfg is None or isinstance(dcn_cfg, dict)
+ self.dcn_cfg = dcn_cfg
+ super(DecoupledSOLOLightHead, self).__init__(
+ *args, init_cfg=init_cfg, **kwargs)
+
+ def _init_layers(self):
+ self.mask_convs = nn.ModuleList()
+ self.cls_convs = nn.ModuleList()
+
+ for i in range(self.stacked_convs):
+ if self.dcn_cfg is not None\
+ and i == self.stacked_convs - 1:
+ conv_cfg = self.dcn_cfg
+ else:
+ conv_cfg = None
+
+ chn = self.in_channels + 2 if i == 0 else self.feat_channels
+ self.mask_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg))
+
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg))
+
+ self.conv_mask_list_x = nn.ModuleList()
+ self.conv_mask_list_y = nn.ModuleList()
+ for num_grid in self.num_grids:
+ self.conv_mask_list_x.append(
+ nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
+ self.conv_mask_list_y.append(
+ nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+
+ def forward(self, feats):
+ assert len(feats) == self.num_levels
+ feats = self.resize_feats(feats)
+ mask_preds_x = []
+ mask_preds_y = []
+ cls_preds = []
+ for i in range(self.num_levels):
+ x = feats[i]
+ mask_feat = x
+ cls_feat = x
+ # generate and concat the coordinate
+ coord_feat = generate_coordinate(mask_feat.size(),
+ mask_feat.device)
+ mask_feat = torch.cat([mask_feat, coord_feat], 1)
+
+ for mask_layer in self.mask_convs:
+ mask_feat = mask_layer(mask_feat)
+
+ mask_feat = F.interpolate(
+ mask_feat, scale_factor=2, mode='bilinear')
+
+ mask_pred_x = self.conv_mask_list_x[i](mask_feat)
+ mask_pred_y = self.conv_mask_list_y[i](mask_feat)
+
+ # cls branch
+ for j, cls_layer in enumerate(self.cls_convs):
+ if j == self.cls_down_index:
+ num_grid = self.num_grids[i]
+ cls_feat = F.interpolate(
+ cls_feat, size=num_grid, mode='bilinear')
+ cls_feat = cls_layer(cls_feat)
+
+ cls_pred = self.conv_cls(cls_feat)
+
+ if not self.training:
+ feat_wh = feats[0].size()[-2:]
+ upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
+ mask_pred_x = F.interpolate(
+ mask_pred_x.sigmoid(),
+ size=upsampled_size,
+ mode='bilinear')
+ mask_pred_y = F.interpolate(
+ mask_pred_y.sigmoid(),
+ size=upsampled_size,
+ mode='bilinear')
+ cls_pred = cls_pred.sigmoid()
+ # get local maximum
+ local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
+ keep_mask = local_max[:, :, :-1, :-1] == cls_pred
+ cls_pred = cls_pred * keep_mask
+
+ mask_preds_x.append(mask_pred_x)
+ mask_preds_y.append(mask_pred_y)
+ cls_preds.append(cls_pred)
+ return mask_preds_x, mask_preds_y, cls_preds
diff --git a/mmdet/models/dense_heads/solov2_head.py b/mmdet/models/dense_heads/solov2_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..975306c29722e057853c5253cf44608954e25af2
--- /dev/null
+++ b/mmdet/models/dense_heads/solov2_head.py
@@ -0,0 +1,766 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import mmcv
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule, auto_fp16, force_fp32
+
+from mmdet.core import InstanceData, mask_matrix_nms, multi_apply
+from mmdet.core.utils import center_of_mass, generate_coordinate
+from mmdet.models.builder import HEADS
+from mmdet.utils.misc import floordiv
+from .solo_head import SOLOHead
+
+
+class MaskFeatModule(BaseModule):
+ """SOLOv2 mask feature map branch used in `SOLOv2: Dynamic and Fast
+ Instance Segmentation. `_
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels of the mask feature
+ map branch.
+ start_level (int): The starting feature map level from RPN that
+ will be used to predict the mask feature map.
+ end_level (int): The ending feature map level from rpn that
+ will be used to predict the mask feature map.
+ out_channels (int): Number of output channels of the mask feature
+ map branch. This is the channel count of the mask
+ feature map that to be dynamically convolved with the predicted
+ kernel.
+ mask_stride (int): Downsample factor of the mask feature map output.
+ Default: 4.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ feat_channels,
+ start_level,
+ end_level,
+ out_channels,
+ mask_stride=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ init_cfg=[dict(type='Normal', layer='Conv2d', std=0.01)]):
+ super().__init__(init_cfg=init_cfg)
+
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.start_level = start_level
+ self.end_level = end_level
+ self.mask_stride = mask_stride
+ assert start_level >= 0 and end_level >= start_level
+ self.out_channels = out_channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self._init_layers()
+ self.fp16_enabled = False
+
+ def _init_layers(self):
+ self.convs_all_levels = nn.ModuleList()
+ for i in range(self.start_level, self.end_level + 1):
+ convs_per_level = nn.Sequential()
+ if i == 0:
+ convs_per_level.add_module(
+ f'conv{i}',
+ ConvModule(
+ self.in_channels,
+ self.feat_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=False))
+ self.convs_all_levels.append(convs_per_level)
+ continue
+
+ for j in range(i):
+ if j == 0:
+ if i == self.end_level:
+ chn = self.in_channels + 2
+ else:
+ chn = self.in_channels
+ convs_per_level.add_module(
+ f'conv{j}',
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=False))
+ convs_per_level.add_module(
+ f'upsample{j}',
+ nn.Upsample(
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=False))
+ continue
+
+ convs_per_level.add_module(
+ f'conv{j}',
+ ConvModule(
+ self.feat_channels,
+ self.feat_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=False))
+ convs_per_level.add_module(
+ f'upsample{j}',
+ nn.Upsample(
+ scale_factor=2, mode='bilinear', align_corners=False))
+
+ self.convs_all_levels.append(convs_per_level)
+
+ self.conv_pred = ConvModule(
+ self.feat_channels,
+ self.out_channels,
+ 1,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+
+ @auto_fp16()
+ def forward(self, feats):
+ inputs = feats[self.start_level:self.end_level + 1]
+ assert len(inputs) == (self.end_level - self.start_level + 1)
+ feature_add_all_level = self.convs_all_levels[0](inputs[0])
+ for i in range(1, len(inputs)):
+ input_p = inputs[i]
+ if i == len(inputs) - 1:
+ coord_feat = generate_coordinate(input_p.size(),
+ input_p.device)
+ input_p = torch.cat([input_p, coord_feat], 1)
+
+ # fix runtime error of "+=" inplace operation in PyTorch 1.10
+ feature_add_all_level = feature_add_all_level + \
+ self.convs_all_levels[i](input_p)
+
+ feature_pred = self.conv_pred(feature_add_all_level)
+ return feature_pred
+
+
+@HEADS.register_module()
+class SOLOV2Head(SOLOHead):
+ """SOLOv2 mask head used in `SOLOv2: Dynamic and Fast Instance
+ Segmentation. `_
+
+ Args:
+ mask_feature_head (dict): Config of SOLOv2MaskFeatHead.
+ dynamic_conv_size (int): Dynamic Conv kernel size. Default: 1.
+ dcn_cfg (dict): Dcn conv configurations in kernel_convs and cls_conv.
+ default: None.
+ dcn_apply_to_all_conv (bool): Whether to use dcn in every layer of
+ kernel_convs and cls_convs, or only the last layer. It shall be set
+ `True` for the normal version of SOLOv2 and `False` for the
+ light-weight version. default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ *args,
+ mask_feature_head,
+ dynamic_conv_size=1,
+ dcn_cfg=None,
+ dcn_apply_to_all_conv=True,
+ init_cfg=[
+ dict(type='Normal', layer='Conv2d', std=0.01),
+ dict(
+ type='Normal',
+ std=0.01,
+ bias_prob=0.01,
+ override=dict(name='conv_cls'))
+ ],
+ **kwargs):
+ assert dcn_cfg is None or isinstance(dcn_cfg, dict)
+ self.dcn_cfg = dcn_cfg
+ self.with_dcn = dcn_cfg is not None
+ self.dcn_apply_to_all_conv = dcn_apply_to_all_conv
+ self.dynamic_conv_size = dynamic_conv_size
+ mask_out_channels = mask_feature_head.get('out_channels')
+ self.kernel_out_channels = \
+ mask_out_channels * self.dynamic_conv_size * self.dynamic_conv_size
+
+ super().__init__(*args, init_cfg=init_cfg, **kwargs)
+
+ # update the in_channels of mask_feature_head
+ if mask_feature_head.get('in_channels', None) is not None:
+ if mask_feature_head.in_channels != self.in_channels:
+ warnings.warn('The `in_channels` of SOLOv2MaskFeatHead and '
+ 'SOLOv2Head should be same, changing '
+ 'mask_feature_head.in_channels to '
+ f'{self.in_channels}')
+ mask_feature_head.update(in_channels=self.in_channels)
+ else:
+ mask_feature_head.update(in_channels=self.in_channels)
+
+ self.mask_feature_head = MaskFeatModule(**mask_feature_head)
+ self.mask_stride = self.mask_feature_head.mask_stride
+ self.fp16_enabled = False
+
+ def _init_layers(self):
+ self.cls_convs = nn.ModuleList()
+ self.kernel_convs = nn.ModuleList()
+ conv_cfg = None
+ for i in range(self.stacked_convs):
+ if self.with_dcn:
+ if self.dcn_apply_to_all_conv:
+ conv_cfg = self.dcn_cfg
+ elif i == self.stacked_convs - 1:
+ # light head
+ conv_cfg = self.dcn_cfg
+
+ chn = self.in_channels + 2 if i == 0 else self.feat_channels
+ self.kernel_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.norm_cfg is None))
+
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.norm_cfg is None))
+
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+
+ self.conv_kernel = nn.Conv2d(
+ self.feat_channels, self.kernel_out_channels, 3, padding=1)
+
+ @auto_fp16()
+ def forward(self, feats):
+ assert len(feats) == self.num_levels
+ mask_feats = self.mask_feature_head(feats)
+ feats = self.resize_feats(feats)
+ mlvl_kernel_preds = []
+ mlvl_cls_preds = []
+ for i in range(self.num_levels):
+ ins_kernel_feat = feats[i]
+ # ins branch
+ # concat coord
+ coord_feat = generate_coordinate(ins_kernel_feat.size(),
+ ins_kernel_feat.device)
+ ins_kernel_feat = torch.cat([ins_kernel_feat, coord_feat], 1)
+
+ # kernel branch
+ kernel_feat = ins_kernel_feat
+ kernel_feat = F.interpolate(
+ kernel_feat,
+ size=self.num_grids[i],
+ mode='bilinear',
+ align_corners=False)
+
+ cate_feat = kernel_feat[:, :-2, :, :]
+
+ kernel_feat = kernel_feat.contiguous()
+ for i, kernel_conv in enumerate(self.kernel_convs):
+ kernel_feat = kernel_conv(kernel_feat)
+ kernel_pred = self.conv_kernel(kernel_feat)
+
+ # cate branch
+ cate_feat = cate_feat.contiguous()
+ for i, cls_conv in enumerate(self.cls_convs):
+ cate_feat = cls_conv(cate_feat)
+ cate_pred = self.conv_cls(cate_feat)
+
+ mlvl_kernel_preds.append(kernel_pred)
+ mlvl_cls_preds.append(cate_pred)
+
+ return mlvl_kernel_preds, mlvl_cls_preds, mask_feats
+
+ def _get_targets_single(self,
+ gt_bboxes,
+ gt_labels,
+ gt_masks,
+ featmap_size=None):
+ """Compute targets for predictions of single image.
+
+ Args:
+ gt_bboxes (Tensor): Ground truth bbox of each instance,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth label of each instance,
+ shape (num_gts,).
+ gt_masks (Tensor): Ground truth mask of each instance,
+ shape (num_gts, h, w).
+ featmap_sizes (:obj:`torch.size`): Size of UNified mask
+ feature map used to generate instance segmentation
+ masks by dynamic convolution, each element means
+ (feat_h, feat_w). Default: None.
+
+ Returns:
+ Tuple: Usually returns a tuple containing targets for predictions.
+
+ - mlvl_pos_mask_targets (list[Tensor]): Each element represent
+ the binary mask targets for positive points in this
+ level, has shape (num_pos, out_h, out_w).
+ - mlvl_labels (list[Tensor]): Each element is
+ classification labels for all
+ points in this level, has shape
+ (num_grid, num_grid).
+ - mlvl_pos_masks (list[Tensor]): Each element is
+ a `BoolTensor` to represent whether the
+ corresponding point in single level
+ is positive, has shape (num_grid **2).
+ - mlvl_pos_indexes (list[list]): Each element
+ in the list contains the positive index in
+ corresponding level, has shape (num_pos).
+ """
+
+ device = gt_labels.device
+ gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
+ (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
+
+ mlvl_pos_mask_targets = []
+ mlvl_pos_indexes = []
+ mlvl_labels = []
+ mlvl_pos_masks = []
+ for (lower_bound, upper_bound), num_grid \
+ in zip(self.scale_ranges, self.num_grids):
+ mask_target = []
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ pos_index = []
+ labels = torch.zeros([num_grid, num_grid],
+ dtype=torch.int64,
+ device=device) + self.num_classes
+ pos_mask = torch.zeros([num_grid**2],
+ dtype=torch.bool,
+ device=device)
+
+ gt_inds = ((gt_areas >= lower_bound) &
+ (gt_areas <= upper_bound)).nonzero().flatten()
+ if len(gt_inds) == 0:
+ mlvl_pos_mask_targets.append(
+ torch.zeros([0, featmap_size[0], featmap_size[1]],
+ dtype=torch.uint8,
+ device=device))
+ mlvl_labels.append(labels)
+ mlvl_pos_masks.append(pos_mask)
+ mlvl_pos_indexes.append([])
+ continue
+ hit_gt_bboxes = gt_bboxes[gt_inds]
+ hit_gt_labels = gt_labels[gt_inds]
+ hit_gt_masks = gt_masks[gt_inds, ...]
+
+ pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] -
+ hit_gt_bboxes[:, 0]) * self.pos_scale
+ pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] -
+ hit_gt_bboxes[:, 1]) * self.pos_scale
+
+ # Make sure hit_gt_masks has a value
+ valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0
+
+ for gt_mask, gt_label, pos_h_range, pos_w_range, \
+ valid_mask_flag in \
+ zip(hit_gt_masks, hit_gt_labels, pos_h_ranges,
+ pos_w_ranges, valid_mask_flags):
+ if not valid_mask_flag:
+ continue
+ upsampled_size = (featmap_size[0] * self.mask_stride,
+ featmap_size[1] * self.mask_stride)
+ center_h, center_w = center_of_mass(gt_mask)
+
+ coord_w = int(
+ floordiv((center_w / upsampled_size[1]), (1. / num_grid),
+ rounding_mode='trunc'))
+ coord_h = int(
+ floordiv((center_h / upsampled_size[0]), (1. / num_grid),
+ rounding_mode='trunc'))
+
+ # left, top, right, down
+ top_box = max(
+ 0,
+ int(
+ floordiv(
+ (center_h - pos_h_range) / upsampled_size[0],
+ (1. / num_grid),
+ rounding_mode='trunc')))
+ down_box = min(
+ num_grid - 1,
+ int(
+ floordiv(
+ (center_h + pos_h_range) / upsampled_size[0],
+ (1. / num_grid),
+ rounding_mode='trunc')))
+ left_box = max(
+ 0,
+ int(
+ floordiv(
+ (center_w - pos_w_range) / upsampled_size[1],
+ (1. / num_grid),
+ rounding_mode='trunc')))
+ right_box = min(
+ num_grid - 1,
+ int(
+ floordiv(
+ (center_w + pos_w_range) / upsampled_size[1],
+ (1. / num_grid),
+ rounding_mode='trunc')))
+
+ top = max(top_box, coord_h - 1)
+ down = min(down_box, coord_h + 1)
+ left = max(coord_w - 1, left_box)
+ right = min(right_box, coord_w + 1)
+
+ labels[top:(down + 1), left:(right + 1)] = gt_label
+ # ins
+ gt_mask = np.uint8(gt_mask.cpu().numpy())
+ # Follow the original implementation, F.interpolate is
+ # different from cv2 and opencv
+ gt_mask = mmcv.imrescale(gt_mask, scale=1. / self.mask_stride)
+ gt_mask = torch.from_numpy(gt_mask).to(device=device)
+
+ for i in range(top, down + 1):
+ for j in range(left, right + 1):
+ index = int(i * num_grid + j)
+ this_mask_target = torch.zeros(
+ [featmap_size[0], featmap_size[1]],
+ dtype=torch.uint8,
+ device=device)
+ this_mask_target[:gt_mask.shape[0], :gt_mask.
+ shape[1]] = gt_mask
+ mask_target.append(this_mask_target)
+ pos_mask[index] = True
+ pos_index.append(index)
+ if len(mask_target) == 0:
+ mask_target = torch.zeros(
+ [0, featmap_size[0], featmap_size[1]],
+ dtype=torch.uint8,
+ device=device)
+ else:
+ mask_target = torch.stack(mask_target, 0)
+ mlvl_pos_mask_targets.append(mask_target)
+ mlvl_labels.append(labels)
+ mlvl_pos_masks.append(pos_mask)
+ mlvl_pos_indexes.append(pos_index)
+ return (mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks,
+ mlvl_pos_indexes)
+
+ @force_fp32(apply_to=('mlvl_kernel_preds', 'mlvl_cls_preds', 'mask_feats'))
+ def loss(self,
+ mlvl_kernel_preds,
+ mlvl_cls_preds,
+ mask_feats,
+ gt_labels,
+ gt_masks,
+ img_metas,
+ gt_bboxes=None,
+ **kwargs):
+ """Calculate the loss of total batch.
+
+ Args:
+ mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel
+ prediction. The kernel is used to generate instance
+ segmentation masks by dynamic convolution. Each element in the
+ list has shape
+ (batch_size, kernel_out_channels, num_grids, num_grids).
+ mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
+ in the list has shape
+ (batch_size, num_classes, num_grids, num_grids).
+ mask_feats (Tensor): Unified mask feature map used to generate
+ instance segmentation masks by dynamic convolution. Has shape
+ (batch_size, mask_out_channels, h, w).
+ gt_labels (list[Tensor]): Labels of multiple images.
+ gt_masks (list[Tensor]): Ground truth masks of multiple images.
+ Each has shape (num_instances, h, w).
+ img_metas (list[dict]): Meta information of multiple images.
+ gt_bboxes (list[Tensor]): Ground truth bboxes of multiple
+ images. Default: None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_size = mask_feats.size()[-2:]
+
+ pos_mask_targets, labels, pos_masks, pos_indexes = multi_apply(
+ self._get_targets_single,
+ gt_bboxes,
+ gt_labels,
+ gt_masks,
+ featmap_size=featmap_size)
+
+ mlvl_mask_targets = [
+ torch.cat(lvl_mask_targets, 0)
+ for lvl_mask_targets in zip(*pos_mask_targets)
+ ]
+
+ mlvl_pos_kernel_preds = []
+ for lvl_kernel_preds, lvl_pos_indexes in zip(mlvl_kernel_preds,
+ zip(*pos_indexes)):
+ lvl_pos_kernel_preds = []
+ for img_lvl_kernel_preds, img_lvl_pos_indexes in zip(
+ lvl_kernel_preds, lvl_pos_indexes):
+ img_lvl_pos_kernel_preds = img_lvl_kernel_preds.view(
+ img_lvl_kernel_preds.shape[0], -1)[:, img_lvl_pos_indexes]
+ lvl_pos_kernel_preds.append(img_lvl_pos_kernel_preds)
+ mlvl_pos_kernel_preds.append(lvl_pos_kernel_preds)
+
+ # make multilevel mlvl_mask_pred
+ mlvl_mask_preds = []
+ for lvl_pos_kernel_preds in mlvl_pos_kernel_preds:
+ lvl_mask_preds = []
+ for img_id, img_lvl_pos_kernel_pred in enumerate(
+ lvl_pos_kernel_preds):
+ if img_lvl_pos_kernel_pred.size()[-1] == 0:
+ continue
+ img_mask_feats = mask_feats[[img_id]]
+ h, w = img_mask_feats.shape[-2:]
+ num_kernel = img_lvl_pos_kernel_pred.shape[1]
+ img_lvl_mask_pred = F.conv2d(
+ img_mask_feats,
+ img_lvl_pos_kernel_pred.permute(1, 0).view(
+ num_kernel, -1, self.dynamic_conv_size,
+ self.dynamic_conv_size),
+ stride=1).view(-1, h, w)
+ lvl_mask_preds.append(img_lvl_mask_pred)
+ if len(lvl_mask_preds) == 0:
+ lvl_mask_preds = None
+ else:
+ lvl_mask_preds = torch.cat(lvl_mask_preds, 0)
+ mlvl_mask_preds.append(lvl_mask_preds)
+ # dice loss
+ num_pos = 0
+ for img_pos_masks in pos_masks:
+ for lvl_img_pos_masks in img_pos_masks:
+ num_pos += lvl_img_pos_masks.count_nonzero()
+
+ loss_mask = []
+ for lvl_mask_preds, lvl_mask_targets in zip(mlvl_mask_preds,
+ mlvl_mask_targets):
+ if lvl_mask_preds is None:
+ continue
+ loss_mask.append(
+ self.loss_mask(
+ lvl_mask_preds,
+ lvl_mask_targets,
+ reduction_override='none'))
+ if num_pos > 0:
+ loss_mask = torch.cat(loss_mask).sum() / num_pos
+ else:
+ loss_mask = mask_feats.sum() * 0
+
+ # cate
+ flatten_labels = [
+ torch.cat(
+ [img_lvl_labels.flatten() for img_lvl_labels in lvl_labels])
+ for lvl_labels in zip(*labels)
+ ]
+ flatten_labels = torch.cat(flatten_labels)
+
+ flatten_cls_preds = [
+ lvl_cls_preds.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
+ for lvl_cls_preds in mlvl_cls_preds
+ ]
+ flatten_cls_preds = torch.cat(flatten_cls_preds)
+
+ loss_cls = self.loss_cls(
+ flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
+ return dict(loss_mask=loss_mask, loss_cls=loss_cls)
+
+ @force_fp32(
+ apply_to=('mlvl_kernel_preds', 'mlvl_cls_scores', 'mask_feats'))
+ def get_results(self, mlvl_kernel_preds, mlvl_cls_scores, mask_feats,
+ img_metas, **kwargs):
+ """Get multi-image mask results.
+
+ Args:
+ mlvl_kernel_preds (list[Tensor]): Multi-level dynamic kernel
+ prediction. The kernel is used to generate instance
+ segmentation masks by dynamic convolution. Each element in the
+ list has shape
+ (batch_size, kernel_out_channels, num_grids, num_grids).
+ mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
+ in the list has shape
+ (batch_size, num_classes, num_grids, num_grids).
+ mask_feats (Tensor): Unified mask feature map used to generate
+ instance segmentation masks by dynamic convolution. Has shape
+ (batch_size, mask_out_channels, h, w).
+ img_metas (list[dict]): Meta information of all images.
+
+ Returns:
+ list[:obj:`InstanceData`]: Processed results of multiple
+ images.Each :obj:`InstanceData` usually contains
+ following keys.
+
+ - scores (Tensor): Classification scores, has shape
+ (num_instance,).
+ - labels (Tensor): Has shape (num_instances,).
+ - masks (Tensor): Processed mask results, has
+ shape (num_instances, h, w).
+ """
+ num_levels = len(mlvl_cls_scores)
+ assert len(mlvl_kernel_preds) == len(mlvl_cls_scores)
+
+ for lvl in range(num_levels):
+ cls_scores = mlvl_cls_scores[lvl]
+ cls_scores = cls_scores.sigmoid()
+ local_max = F.max_pool2d(cls_scores, 2, stride=1, padding=1)
+ keep_mask = local_max[:, :, :-1, :-1] == cls_scores
+ cls_scores = cls_scores * keep_mask
+ mlvl_cls_scores[lvl] = cls_scores.permute(0, 2, 3, 1)
+
+ result_list = []
+ for img_id in range(len(img_metas)):
+ img_cls_pred = [
+ mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels)
+ for lvl in range(num_levels)
+ ]
+ img_mask_feats = mask_feats[[img_id]]
+ img_kernel_pred = [
+ mlvl_kernel_preds[lvl][img_id].permute(1, 2, 0).view(
+ -1, self.kernel_out_channels) for lvl in range(num_levels)
+ ]
+ img_cls_pred = torch.cat(img_cls_pred, dim=0)
+ img_kernel_pred = torch.cat(img_kernel_pred, dim=0)
+ result = self._get_results_single(
+ img_kernel_pred,
+ img_cls_pred,
+ img_mask_feats,
+ img_meta=img_metas[img_id])
+ result_list.append(result)
+ return result_list
+
+ def _get_results_single(self,
+ kernel_preds,
+ cls_scores,
+ mask_feats,
+ img_meta,
+ cfg=None):
+ """Get processed mask related results of single image.
+
+ Args:
+ kernel_preds (Tensor): Dynamic kernel prediction of all points
+ in single image, has shape
+ (num_points, kernel_out_channels).
+ cls_scores (Tensor): Classification score of all points
+ in single image, has shape (num_points, num_classes).
+ mask_preds (Tensor): Mask prediction of all points in
+ single image, has shape (num_points, feat_h, feat_w).
+ img_meta (dict): Meta information of corresponding image.
+ cfg (dict, optional): Config used in test phase.
+ Default: None.
+
+ Returns:
+ :obj:`InstanceData`: Processed results of single image.
+ it usually contains following keys.
+ - scores (Tensor): Classification scores, has shape
+ (num_instance,).
+ - labels (Tensor): Has shape (num_instances,).
+ - masks (Tensor): Processed mask results, has
+ shape (num_instances, h, w).
+ """
+
+ def empty_results(results, cls_scores):
+ """Generate a empty results."""
+ results.scores = cls_scores.new_ones(0)
+ results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2])
+ results.labels = cls_scores.new_ones(0)
+ return results
+
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(kernel_preds) == len(cls_scores)
+ results = InstanceData(img_meta)
+
+ featmap_size = mask_feats.size()[-2:]
+
+ img_shape = results.img_shape
+ ori_shape = results.ori_shape
+
+ # overall info
+ h, w, _ = img_shape
+ upsampled_size = (featmap_size[0] * self.mask_stride,
+ featmap_size[1] * self.mask_stride)
+
+ # process.
+ score_mask = (cls_scores > cfg.score_thr)
+ cls_scores = cls_scores[score_mask]
+ if len(cls_scores) == 0:
+ return empty_results(results, cls_scores)
+
+ # cate_labels & kernel_preds
+ inds = score_mask.nonzero()
+ cls_labels = inds[:, 1]
+ kernel_preds = kernel_preds[inds[:, 0]]
+
+ # trans vector.
+ lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
+ strides = kernel_preds.new_ones(lvl_interval[-1])
+
+ strides[:lvl_interval[0]] *= self.strides[0]
+ for lvl in range(1, self.num_levels):
+ strides[lvl_interval[lvl -
+ 1]:lvl_interval[lvl]] *= self.strides[lvl]
+ strides = strides[inds[:, 0]]
+
+ # mask encoding.
+ kernel_preds = kernel_preds.view(
+ kernel_preds.size(0), -1, self.dynamic_conv_size,
+ self.dynamic_conv_size)
+ mask_preds = F.conv2d(
+ mask_feats, kernel_preds, stride=1).squeeze(0).sigmoid()
+ # mask.
+ masks = mask_preds > cfg.mask_thr
+ sum_masks = masks.sum((1, 2)).float()
+ keep = sum_masks > strides
+ if keep.sum() == 0:
+ return empty_results(results, cls_scores)
+ masks = masks[keep]
+ mask_preds = mask_preds[keep]
+ sum_masks = sum_masks[keep]
+ cls_scores = cls_scores[keep]
+ cls_labels = cls_labels[keep]
+
+ # maskness.
+ mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
+ cls_scores *= mask_scores
+
+ scores, labels, _, keep_inds = mask_matrix_nms(
+ masks,
+ cls_labels,
+ cls_scores,
+ mask_area=sum_masks,
+ nms_pre=cfg.nms_pre,
+ max_num=cfg.max_per_img,
+ kernel=cfg.kernel,
+ sigma=cfg.sigma,
+ filter_thr=cfg.filter_thr)
+ mask_preds = mask_preds[keep_inds]
+ mask_preds = F.interpolate(
+ mask_preds.unsqueeze(0),
+ size=upsampled_size,
+ mode='bilinear',
+ align_corners=False)[:, :, :h, :w]
+ mask_preds = F.interpolate(
+ mask_preds,
+ size=ori_shape[:2],
+ mode='bilinear',
+ align_corners=False).squeeze(0)
+ masks = mask_preds > cfg.mask_thr
+
+ results.masks = masks
+ results.labels = labels
+ results.scores = scores
+
+ return results
diff --git a/mmdet/models/dense_heads/ssd_head.py b/mmdet/models/dense_heads/ssd_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e362fd8016a0b0f7d0d371adb4fc39249ceb2f6a
--- /dev/null
+++ b/mmdet/models/dense_heads/ssd_head.py
@@ -0,0 +1,357 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+from mmcv.runner import force_fp32
+
+from mmdet.core import (build_assigner, build_bbox_coder,
+ build_prior_generator, build_sampler, multi_apply)
+from ..builder import HEADS
+from ..losses import smooth_l1_loss
+from .anchor_head import AnchorHead
+
+
+# TODO: add loss evaluator for SSD
+@HEADS.register_module()
+class SSDHead(AnchorHead):
+ """SSD head used in https://arxiv.org/abs/1512.02325.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): Number of conv layers in cls and reg tower.
+ Default: 0.
+ feat_channels (int): Number of hidden channels when stacked_convs
+ > 0. Default: 256.
+ use_depthwise (bool): Whether to use DepthwiseSeparableConv.
+ Default: False.
+ conv_cfg (dict): Dictionary to construct and config conv layer.
+ Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: None.
+ act_cfg (dict): Dictionary to construct and config activation layer.
+ Default: None.
+ anchor_generator (dict): Config dict for anchor generator
+ bbox_coder (dict): Config of bounding box coder.
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ train_cfg (dict): Training config of anchor head.
+ test_cfg (dict): Testing config of anchor head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes=80,
+ in_channels=(512, 1024, 512, 256, 256, 256),
+ stacked_convs=0,
+ feat_channels=256,
+ use_depthwise=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ anchor_generator=dict(
+ type='SSDAnchorGenerator',
+ scale_major=False,
+ input_size=300,
+ strides=[8, 16, 32, 64, 100, 300],
+ ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
+ basesize_ratio_range=(0.1, 0.9)),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ clip_border=True,
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0],
+ ),
+ reg_decoded_bbox=False,
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=dict(
+ type='Xavier',
+ layer='Conv2d',
+ distribution='uniform',
+ bias=0)):
+ super(AnchorHead, self).__init__(init_cfg)
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.stacked_convs = stacked_convs
+ self.feat_channels = feat_channels
+ self.use_depthwise = use_depthwise
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+
+ self.cls_out_channels = num_classes + 1 # add background class
+ self.prior_generator = build_prior_generator(anchor_generator)
+
+ # Usually the numbers of anchors for each level are the same
+ # except SSD detectors. So it is an int in the most dense
+ # heads but a list of int in SSDHead
+ self.num_base_priors = self.prior_generator.num_base_priors
+
+ self._init_layers()
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.reg_decoded_bbox = reg_decoded_bbox
+ self.use_sigmoid_cls = False
+ self.cls_focal_loss = False
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ # set sampling=False for archor_target
+ self.sampling = False
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # SSD sampling=False so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.fp16_enabled = False
+
+ @property
+ def num_anchors(self):
+ """
+ Returns:
+ list[int]: Number of base_anchors on each point of each level.
+ """
+ warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
+ 'please use "num_base_priors" instead')
+ return self.num_base_priors
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ # TODO: Use registry to choose ConvModule type
+ conv = DepthwiseSeparableConvModule \
+ if self.use_depthwise else ConvModule
+
+ for channel, num_base_priors in zip(self.in_channels,
+ self.num_base_priors):
+ cls_layers = []
+ reg_layers = []
+ in_channel = channel
+ # build stacked conv tower, not used in default ssd
+ for i in range(self.stacked_convs):
+ cls_layers.append(
+ conv(
+ in_channel,
+ self.feat_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ reg_layers.append(
+ conv(
+ in_channel,
+ self.feat_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ in_channel = self.feat_channels
+ # SSD-Lite head
+ if self.use_depthwise:
+ cls_layers.append(
+ ConvModule(
+ in_channel,
+ in_channel,
+ 3,
+ padding=1,
+ groups=in_channel,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ reg_layers.append(
+ ConvModule(
+ in_channel,
+ in_channel,
+ 3,
+ padding=1,
+ groups=in_channel,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ cls_layers.append(
+ nn.Conv2d(
+ in_channel,
+ num_base_priors * self.cls_out_channels,
+ kernel_size=1 if self.use_depthwise else 3,
+ padding=0 if self.use_depthwise else 1))
+ reg_layers.append(
+ nn.Conv2d(
+ in_channel,
+ num_base_priors * 4,
+ kernel_size=1 if self.use_depthwise else 3,
+ padding=0 if self.use_depthwise else 1))
+ self.cls_convs.append(nn.Sequential(*cls_layers))
+ self.reg_convs.append(nn.Sequential(*reg_layers))
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple:
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * 4.
+ """
+ cls_scores = []
+ bbox_preds = []
+ for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
+ self.cls_convs):
+ cls_scores.append(cls_conv(feat))
+ bbox_preds.append(reg_conv(feat))
+ return cls_scores, bbox_preds
+
+ def loss_single(self, cls_score, bbox_pred, anchor, labels, label_weights,
+ bbox_targets, bbox_weights, num_total_samples):
+ """Compute loss of a single image.
+
+ Args:
+ cls_score (Tensor): Box scores for eachimage
+ Has shape (num_total_anchors, num_classes).
+ bbox_pred (Tensor): Box energies / deltas for each image
+ level with shape (num_total_anchors, 4).
+ anchors (Tensor): Box reference for each scale level with shape
+ (num_total_anchors, 4).
+ labels (Tensor): Labels of each anchors with shape
+ (num_total_anchors,).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (num_total_anchors,)
+ bbox_targets (Tensor): BBox regression targets of each anchor
+ weight shape (num_total_anchors, 4).
+ bbox_weights (Tensor): BBox regression loss weights of each anchor
+ with shape (num_total_anchors, 4).
+ num_total_samples (int): If sampling, num total samples equal to
+ the number of total anchors; Otherwise, it is the number of
+ positive anchors.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ loss_cls_all = F.cross_entropy(
+ cls_score, labels, reduction='none') * label_weights
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(
+ as_tuple=False).reshape(-1)
+ neg_inds = (labels == self.num_classes).nonzero(
+ as_tuple=False).view(-1)
+
+ num_pos_samples = pos_inds.size(0)
+ num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples
+ if num_neg_samples > neg_inds.size(0):
+ num_neg_samples = neg_inds.size(0)
+ topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
+ loss_cls_pos = loss_cls_all[pos_inds].sum()
+ loss_cls_neg = topk_loss_cls_neg.sum()
+ loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
+
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ bbox_pred = self.bbox_coder.decode(anchor, bbox_pred)
+
+ loss_bbox = smooth_l1_loss(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ beta=self.train_cfg.smoothl1_beta,
+ avg_factor=num_total_samples)
+ return loss_cls[None], loss_bbox
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=1,
+ unmap_outputs=True)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+
+ num_images = len(img_metas)
+ all_cls_scores = torch.cat([
+ s.permute(0, 2, 3, 1).reshape(
+ num_images, -1, self.cls_out_channels) for s in cls_scores
+ ], 1)
+ all_labels = torch.cat(labels_list, -1).view(num_images, -1)
+ all_label_weights = torch.cat(label_weights_list,
+ -1).view(num_images, -1)
+ all_bbox_preds = torch.cat([
+ b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
+ for b in bbox_preds
+ ], -2)
+ all_bbox_targets = torch.cat(bbox_targets_list,
+ -2).view(num_images, -1, 4)
+ all_bbox_weights = torch.cat(bbox_weights_list,
+ -2).view(num_images, -1, 4)
+
+ # concat all level anchors to a single tensor
+ all_anchors = []
+ for i in range(num_images):
+ all_anchors.append(torch.cat(anchor_list[i]))
+
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ all_cls_scores,
+ all_bbox_preds,
+ all_anchors,
+ all_labels,
+ all_label_weights,
+ all_bbox_targets,
+ all_bbox_weights,
+ num_total_samples=num_total_pos)
+ return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
diff --git a/mmdet/models/dense_heads/tood_head.py b/mmdet/models/dense_heads/tood_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c64ebf7a8ce6d428e4e7f8cc60be06baed5752c9
--- /dev/null
+++ b/mmdet/models/dense_heads/tood_head.py
@@ -0,0 +1,778 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init
+from mmcv.ops import deform_conv2d
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, build_assigner, distance2bbox,
+ images_to_levels, multi_apply, reduce_mean, unmap)
+from mmdet.core.utils import filter_scores_and_topk
+from mmdet.models.utils import sigmoid_geometric_mean
+from ..builder import HEADS, build_loss
+from .atss_head import ATSSHead
+
+
+class TaskDecomposition(nn.Module):
+ """Task decomposition module in task-aligned predictor of TOOD.
+
+ Args:
+ feat_channels (int): Number of feature channels in TOOD head.
+ stacked_convs (int): Number of conv layers in TOOD head.
+ la_down_rate (int): Downsample rate of layer attention.
+ conv_cfg (dict): Config dict for convolution layer.
+ norm_cfg (dict): Config dict for normalization layer.
+ """
+
+ def __init__(self,
+ feat_channels,
+ stacked_convs,
+ la_down_rate=8,
+ conv_cfg=None,
+ norm_cfg=None):
+ super(TaskDecomposition, self).__init__()
+ self.feat_channels = feat_channels
+ self.stacked_convs = stacked_convs
+ self.in_channels = self.feat_channels * self.stacked_convs
+ self.norm_cfg = norm_cfg
+ self.layer_attention = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.in_channels // la_down_rate, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(
+ self.in_channels // la_down_rate,
+ self.stacked_convs,
+ 1,
+ padding=0), nn.Sigmoid())
+
+ self.reduction_conv = ConvModule(
+ self.in_channels,
+ self.feat_channels,
+ 1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ bias=norm_cfg is None)
+
+ def init_weights(self):
+ for m in self.layer_attention.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.001)
+ normal_init(self.reduction_conv.conv, std=0.01)
+
+ def forward(self, feat, avg_feat=None):
+ b, c, h, w = feat.shape
+ if avg_feat is None:
+ avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
+ weight = self.layer_attention(avg_feat)
+
+ # here we first compute the product between layer attention weight and
+ # conv weight, and then compute the convolution between new conv weight
+ # and feature map, in order to save memory and FLOPs.
+ conv_weight = weight.reshape(
+ b, 1, self.stacked_convs,
+ 1) * self.reduction_conv.conv.weight.reshape(
+ 1, self.feat_channels, self.stacked_convs, self.feat_channels)
+ conv_weight = conv_weight.reshape(b, self.feat_channels,
+ self.in_channels)
+ feat = feat.reshape(b, self.in_channels, h * w)
+ feat = torch.bmm(conv_weight, feat).reshape(b, self.feat_channels, h,
+ w)
+ if self.norm_cfg is not None:
+ feat = self.reduction_conv.norm(feat)
+ feat = self.reduction_conv.activate(feat)
+
+ return feat
+
+
+@HEADS.register_module()
+class TOODHead(ATSSHead):
+ """TOODHead used in `TOOD: Task-aligned One-stage Object Detection.
+
+ `_.
+
+ TOOD uses Task-aligned head (T-head) and is optimized by Task Alignment
+ Learning (TAL).
+
+ Args:
+ num_dcn (int): Number of deformable convolution in the head.
+ Default: 0.
+ anchor_type (str): If set to `anchor_free`, the head will use centers
+ to regress bboxes. If set to `anchor_based`, the head will
+ regress bboxes based on anchors. Default: `anchor_free`.
+ initial_loss_cls (dict): Config of initial loss.
+
+ Example:
+ >>> self = TOODHead(11, 7)
+ >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
+ >>> cls_score, bbox_pred = self.forward(feats)
+ >>> assert len(cls_score) == len(self.scales)
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ num_dcn=0,
+ anchor_type='anchor_free',
+ initial_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ activated=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ **kwargs):
+ assert anchor_type in ['anchor_free', 'anchor_based']
+ self.num_dcn = num_dcn
+ self.anchor_type = anchor_type
+ self.epoch = 0 # which would be update in SetEpochInfoHook!
+ super(TOODHead, self).__init__(num_classes, in_channels, **kwargs)
+
+ if self.train_cfg:
+ self.initial_epoch = self.train_cfg.initial_epoch
+ self.initial_assigner = build_assigner(
+ self.train_cfg.initial_assigner)
+ self.initial_loss_cls = build_loss(initial_loss_cls)
+ self.assigner = self.initial_assigner
+ self.alignment_assigner = build_assigner(self.train_cfg.assigner)
+ self.alpha = self.train_cfg.alpha
+ self.beta = self.train_cfg.beta
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.inter_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ if i < self.num_dcn:
+ conv_cfg = dict(type='DCNv2', deform_groups=4)
+ else:
+ conv_cfg = self.conv_cfg
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.inter_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg))
+
+ self.cls_decomp = TaskDecomposition(self.feat_channels,
+ self.stacked_convs,
+ self.stacked_convs * 8,
+ self.conv_cfg, self.norm_cfg)
+ self.reg_decomp = TaskDecomposition(self.feat_channels,
+ self.stacked_convs,
+ self.stacked_convs * 8,
+ self.conv_cfg, self.norm_cfg)
+
+ self.tood_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.tood_reg = nn.Conv2d(
+ self.feat_channels, self.num_base_priors * 4, 3, padding=1)
+
+ self.cls_prob_module = nn.Sequential(
+ nn.Conv2d(self.feat_channels * self.stacked_convs,
+ self.feat_channels // 4, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(self.feat_channels // 4, 1, 3, padding=1))
+ self.reg_offset_module = nn.Sequential(
+ nn.Conv2d(self.feat_channels * self.stacked_convs,
+ self.feat_channels // 4, 1), nn.ReLU(inplace=True),
+ nn.Conv2d(self.feat_channels // 4, 4 * 2, 3, padding=1))
+
+ self.scales = nn.ModuleList(
+ [Scale(1.0) for _ in self.prior_generator.strides])
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ bias_cls = bias_init_with_prob(0.01)
+ for m in self.inter_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.cls_prob_module:
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.01)
+ for m in self.reg_offset_module:
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, std=0.001)
+ normal_init(self.cls_prob_module[-1], std=0.01, bias=bias_cls)
+
+ self.cls_decomp.init_weights()
+ self.reg_decomp.init_weights()
+
+ normal_init(self.tood_cls, std=0.01, bias=bias_cls)
+ normal_init(self.tood_reg, std=0.01)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * num_classes.
+ bbox_preds (list[Tensor]): Decoded box for all scale levels,
+ each is a 4D-tensor, the channels number is
+ num_anchors * 4. In [tl_x, tl_y, br_x, br_y] format.
+ """
+ cls_scores = []
+ bbox_preds = []
+ for idx, (x, scale, stride) in enumerate(
+ zip(feats, self.scales, self.prior_generator.strides)):
+ b, c, h, w = x.shape
+ anchor = self.prior_generator.single_level_grid_priors(
+ (h, w), idx, device=x.device)
+ anchor = torch.cat([anchor for _ in range(b)])
+ # extract task interactive features
+ inter_feats = []
+ for inter_conv in self.inter_convs:
+ x = inter_conv(x)
+ inter_feats.append(x)
+ feat = torch.cat(inter_feats, 1)
+
+ # task decomposition
+ avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
+ cls_feat = self.cls_decomp(feat, avg_feat)
+ reg_feat = self.reg_decomp(feat, avg_feat)
+
+ # cls prediction and alignment
+ cls_logits = self.tood_cls(cls_feat)
+ cls_prob = self.cls_prob_module(feat)
+ cls_score = sigmoid_geometric_mean(cls_logits, cls_prob)
+
+ # reg prediction and alignment
+ if self.anchor_type == 'anchor_free':
+ reg_dist = scale(self.tood_reg(reg_feat).exp()).float()
+ reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4)
+ reg_bbox = distance2bbox(
+ self.anchor_center(anchor) / stride[0],
+ reg_dist).reshape(b, h, w, 4).permute(0, 3, 1,
+ 2) # (b, c, h, w)
+ elif self.anchor_type == 'anchor_based':
+ reg_dist = scale(self.tood_reg(reg_feat)).float()
+ reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4)
+ reg_bbox = self.bbox_coder.decode(anchor, reg_dist).reshape(
+ b, h, w, 4).permute(0, 3, 1, 2) / stride[0]
+ else:
+ raise NotImplementedError(
+ f'Unknown anchor type: {self.anchor_type}.'
+ f'Please use `anchor_free` or `anchor_based`.')
+ reg_offset = self.reg_offset_module(feat)
+ bbox_pred = self.deform_sampling(reg_bbox.contiguous(),
+ reg_offset.contiguous())
+
+ # After deform_sampling, some boxes will become invalid (The
+ # left-top point is at the right or bottom of the right-bottom
+ # point), which will make the GIoULoss negative.
+ invalid_bbox_idx = (bbox_pred[:, [0]] > bbox_pred[:, [2]]) | \
+ (bbox_pred[:, [1]] > bbox_pred[:, [3]])
+ invalid_bbox_idx = invalid_bbox_idx.expand_as(bbox_pred)
+ bbox_pred = torch.where(invalid_bbox_idx, reg_bbox, bbox_pred)
+
+ cls_scores.append(cls_score)
+ bbox_preds.append(bbox_pred)
+ return tuple(cls_scores), tuple(bbox_preds)
+
+ def deform_sampling(self, feat, offset):
+ """Sampling the feature x according to offset.
+
+ Args:
+ feat (Tensor): Feature
+ offset (Tensor): Spatial offset for feature sampling
+ """
+ # it is an equivalent implementation of bilinear interpolation
+ b, c, h, w = feat.shape
+ weight = feat.new_ones(c, 1, 1, 1)
+ y = deform_conv2d(feat, offset, weight, 1, 0, 1, c, c)
+ return y
+
+ def anchor_center(self, anchors):
+ """Get anchor centers from anchors.
+
+ Args:
+ anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format.
+
+ Returns:
+ Tensor: Anchor centers with shape (N, 2), "xy" format.
+ """
+ anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
+ anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
+ return torch.stack([anchors_cx, anchors_cy], dim=-1)
+
+ def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
+ bbox_targets, alignment_metrics, stride):
+ """Compute loss of a single scale level.
+
+ Args:
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ cls_score (Tensor): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W).
+ bbox_pred (Tensor): Decoded bboxes for each scale
+ level with shape (N, num_anchors * 4, H, W).
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors).
+ bbox_targets (Tensor): BBox regression targets of each anchor with
+ shape (N, num_total_anchors, 4).
+ alignment_metrics (Tensor): Alignment metrics with shape
+ (N, num_total_anchors).
+ stride (tuple[int]): Downsample stride of the feature map.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert stride[0] == stride[1], 'h stride is not equal to w stride!'
+ anchors = anchors.reshape(-1, 4)
+ cls_score = cls_score.permute(0, 2, 3, 1).reshape(
+ -1, self.cls_out_channels).contiguous()
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ labels = labels.reshape(-1)
+ alignment_metrics = alignment_metrics.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ targets = labels if self.epoch < self.initial_epoch else (
+ labels, alignment_metrics)
+ cls_loss_func = self.initial_loss_cls \
+ if self.epoch < self.initial_epoch else self.loss_cls
+
+ loss_cls = cls_loss_func(
+ cls_score, targets, label_weights, avg_factor=1.0)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((labels >= 0)
+ & (labels < bg_class_ind)).nonzero().squeeze(1)
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = bbox_targets[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_anchors = anchors[pos_inds]
+
+ pos_decode_bbox_pred = pos_bbox_pred
+ pos_decode_bbox_targets = pos_bbox_targets / stride[0]
+
+ # regression loss
+ pos_bbox_weight = self.centerness_target(
+ pos_anchors, pos_bbox_targets
+ ) if self.epoch < self.initial_epoch else alignment_metrics[
+ pos_inds]
+
+ loss_bbox = self.loss_bbox(
+ pos_decode_bbox_pred,
+ pos_decode_bbox_targets,
+ weight=pos_bbox_weight,
+ avg_factor=1.0)
+ else:
+ loss_bbox = bbox_pred.sum() * 0
+ pos_bbox_weight = bbox_targets.new_tensor(0.)
+
+ return loss_cls, loss_bbox, alignment_metrics.sum(
+ ), pos_bbox_weight.sum()
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Decoded box for each scale
+ level with shape (N, num_anchors * 4, H, W) in
+ [tl_x, tl_y, br_x, br_y] format.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_imgs = len(img_metas)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ flatten_cls_scores = torch.cat([
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
+ self.cls_out_channels)
+ for cls_score in cls_scores
+ ], 1)
+ flatten_bbox_preds = torch.cat([
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) * stride[0]
+ for bbox_pred, stride in zip(bbox_preds,
+ self.prior_generator.strides)
+ ], 1)
+
+ cls_reg_targets = self.get_targets(
+ flatten_cls_scores,
+ flatten_bbox_preds,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ alignment_metrics_list) = cls_reg_targets
+
+ losses_cls, losses_bbox,\
+ cls_avg_factors, bbox_avg_factors = multi_apply(
+ self.loss_single,
+ anchor_list,
+ cls_scores,
+ bbox_preds,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ alignment_metrics_list,
+ self.prior_generator.strides)
+
+ cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
+ losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
+
+ bbox_avg_factor = reduce_mean(
+ sum(bbox_avg_factors)).clamp_(min=1).item()
+ losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
+ return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ score_factor_list,
+ mlvl_priors,
+ img_meta,
+ cfg,
+ rescale=False,
+ with_nms=True,
+ **kwargs):
+ """Transform outputs of a single image into bbox predictions.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores from all scale
+ levels of a single image, each item has shape
+ (num_priors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas from
+ all scale levels of a single image, each item has shape
+ (num_priors * 4, H, W).
+ score_factor_list (list[Tensor]): Score factor from all scale
+ levels of a single image, each item has shape
+ (num_priors * 1, H, W).
+ mlvl_priors (list[Tensor]): Each element in the list is
+ the priors of a single level in feature pyramid. In all
+ anchor-based methods, it has shape (num_priors, 4). In
+ all anchor-free methods, it has shape (num_priors, 2)
+ when `with_stride=True`, otherwise it still has shape
+ (num_priors, 4).
+ img_meta (dict): Image meta info.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels. If with_nms
+ is False and mlvl_score_factor is None, return mlvl_bboxes and
+ mlvl_scores, else return mlvl_bboxes, mlvl_scores and
+ mlvl_score_factor. Usually with_nms is False is used for aug
+ test. If with_nms is True, then return the following format
+
+ - det_bboxes (Tensor): Predicted bboxes with shape \
+ [num_bboxes, 5], where the first 4 columns are bounding \
+ box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
+ column are scores between 0 and 1.
+ - det_labels (Tensor): Predicted labels of the corresponding \
+ box with shape [num_bboxes].
+ """
+
+ cfg = self.test_cfg if cfg is None else cfg
+ nms_pre = cfg.get('nms_pre', -1)
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_labels = []
+ for cls_score, bbox_pred, priors, stride in zip(
+ cls_score_list, bbox_pred_list, mlvl_priors,
+ self.prior_generator.strides):
+
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) * stride[0]
+ scores = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+
+ # After https://github.com/open-mmlab/mmdetection/pull/6268/,
+ # this operation keeps fewer bboxes under the same `nms_pre`.
+ # There is no difference in performance for most models. If you
+ # find a slight drop in performance, you can set a larger
+ # `nms_pre` than before.
+ results = filter_scores_and_topk(
+ scores, cfg.score_thr, nms_pre,
+ dict(bbox_pred=bbox_pred, priors=priors))
+ scores, labels, keep_idxs, filtered_results = results
+
+ bboxes = filtered_results['bbox_pred']
+
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_labels.append(labels)
+
+ return self._bbox_post_process(mlvl_scores, mlvl_labels, mlvl_bboxes,
+ img_meta['scale_factor'], cfg, rescale,
+ with_nms, None, **kwargs)
+
+ def get_targets(self,
+ cls_scores,
+ bbox_preds,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in
+ multiple images.
+
+ Args:
+ cls_scores (Tensor): Classification predictions of images,
+ a 3D-Tensor with shape [num_imgs, num_priors, num_classes].
+ bbox_preds (Tensor): Decoded bboxes predictions of one image,
+ a 3D-Tensor with shape [num_imgs, num_priors, 4] in [tl_x,
+ tl_y, br_x, br_y] format.
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, 4).
+ valid_flag_list (list[list[Tensor]]): Multi level valid flags of
+ each image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: a tuple containing learning targets.
+
+ - anchors_list (list[list[Tensor]]): Anchors of each level.
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each
+ level.
+ - bbox_targets_list (list[Tensor]): BBox targets of each level.
+ - norm_alignment_metrics_list (list[Tensor]): Normalized
+ alignment metrics of each level.
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ num_level_anchors_list = [num_level_anchors] * num_imgs
+
+ # concat all level anchors and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ anchor_list[i] = torch.cat(anchor_list[i])
+ valid_flag_list[i] = torch.cat(valid_flag_list[i])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ # anchor_list: list(b * [-1, 4])
+
+ if self.epoch < self.initial_epoch:
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ super()._get_target_single,
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ all_assign_metrics = [
+ weight[..., 0] for weight in all_bbox_weights
+ ]
+ else:
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_assign_metrics) = multi_apply(
+ self._get_target_single,
+ cls_scores,
+ bbox_preds,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+
+ # split targets to a list w.r.t. multiple levels
+ anchors_list = images_to_levels(all_anchors, num_level_anchors)
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ norm_alignment_metrics_list = images_to_levels(all_assign_metrics,
+ num_level_anchors)
+
+ return (anchors_list, labels_list, label_weights_list,
+ bbox_targets_list, norm_alignment_metrics_list)
+
+ def _get_target_single(self,
+ cls_scores,
+ bbox_preds,
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression, classification targets for anchors in a single
+ image.
+
+ Args:
+ cls_scores (list(Tensor)): Box scores for each image.
+ bbox_preds (list(Tensor)): Box energies / deltas for each image.
+ flat_anchors (Tensor): Multi-level anchors of the image, which are
+ concatenated into a single tensor of shape (num_anchors ,4)
+ valid_flags (Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_anchors,).
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ img_meta (dict): Meta info of the image.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: N is the number of total anchors in the image.
+ anchors (Tensor): All anchors in the image with shape (N, 4).
+ labels (Tensor): Labels of all anchors in the image with shape
+ (N,).
+ label_weights (Tensor): Label weights of all anchor in the
+ image with shape (N,).
+ bbox_targets (Tensor): BBox targets of all anchors in the
+ image with shape (N, 4).
+ norm_alignment_metrics (Tensor): Normalized alignment metrics
+ of all priors in the image with shape (N,).
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+ assign_result = self.alignment_assigner.assign(
+ cls_scores[inside_flags, :], bbox_preds[inside_flags, :], anchors,
+ gt_bboxes, gt_bboxes_ignore, gt_labels, self.alpha, self.beta)
+ assign_ious = assign_result.max_overlaps
+ assign_metrics = assign_result.assign_metrics
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+ norm_alignment_metrics = anchors.new_zeros(
+ num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ # point-based
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ class_assigned_gt_inds = torch.unique(
+ sampling_result.pos_assigned_gt_inds)
+ for gt_inds in class_assigned_gt_inds:
+ gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds ==
+ gt_inds]
+ pos_alignment_metrics = assign_metrics[gt_class_inds]
+ pos_ious = assign_ious[gt_class_inds]
+ pos_norm_alignment_metrics = pos_alignment_metrics / (
+ pos_alignment_metrics.max() + 10e-8) * pos_ious.max()
+ norm_alignment_metrics[gt_class_inds] = pos_norm_alignment_metrics
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ anchors = unmap(anchors, num_total_anchors, inside_flags)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags, fill=self.num_classes)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ norm_alignment_metrics = unmap(norm_alignment_metrics,
+ num_total_anchors, inside_flags)
+ return (anchors, labels, label_weights, bbox_targets,
+ norm_alignment_metrics)
diff --git a/mmdet/models/dense_heads/vfnet_head.py b/mmdet/models/dense_heads/vfnet_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba285e22e32f3764ffa86f06246ffd5d2fbdd03d
--- /dev/null
+++ b/mmdet/models/dense_heads/vfnet_head.py
@@ -0,0 +1,740 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, Scale
+from mmcv.ops import DeformConv2d
+from mmcv.runner import force_fp32
+
+from mmdet.core import (MlvlPointGenerator, bbox_overlaps, build_assigner,
+ build_prior_generator, build_sampler, multi_apply,
+ reduce_mean)
+from ..builder import HEADS, build_loss
+from .atss_head import ATSSHead
+from .fcos_head import FCOSHead
+
+INF = 1e8
+
+
+@HEADS.register_module()
+class VFNetHead(ATSSHead, FCOSHead):
+ """Head of `VarifocalNet (VFNet): An IoU-aware Dense Object
+ Detector.`_.
+
+ The VFNet predicts IoU-aware classification scores which mix the
+ object presence confidence and object localization accuracy as the
+ detection score. It is built on the FCOS architecture and uses ATSS
+ for defining positive/negative training examples. The VFNet is trained
+ with Varifocal Loss and empolys star-shaped deformable convolution to
+ extract features for a bbox.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ regress_ranges (tuple[tuple[int, int]]): Regress range of multiple
+ level points.
+ center_sampling (bool): If true, use center sampling. Default: False.
+ center_sample_radius (float): Radius of center sampling. Default: 1.5.
+ sync_num_pos (bool): If true, synchronize the number of positive
+ examples across GPUs. Default: True
+ gradient_mul (float): The multiplier to gradients from bbox refinement
+ and recognition. Default: 0.1.
+ bbox_norm_type (str): The bbox normalization type, 'reg_denom' or
+ 'stride'. Default: reg_denom
+ loss_cls_fl (dict): Config of focal loss.
+ use_vfl (bool): If true, use varifocal loss for training.
+ Default: True.
+ loss_cls (dict): Config of varifocal loss.
+ loss_bbox (dict): Config of localization loss, GIoU Loss.
+ loss_bbox (dict): Config of localization refinement loss, GIoU Loss.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: norm_cfg=dict(type='GN', num_groups=32,
+ requires_grad=True).
+ use_atss (bool): If true, use ATSS to define positive/negative
+ examples. Default: True.
+ anchor_generator (dict): Config of anchor generator for ATSS.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+
+ Example:
+ >>> self = VFNetHead(11, 7)
+ >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
+ >>> cls_score, bbox_pred, bbox_pred_refine= self.forward(feats)
+ >>> assert len(cls_score) == len(self.scales)
+ """ # noqa: E501
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512),
+ (512, INF)),
+ center_sampling=False,
+ center_sample_radius=1.5,
+ sync_num_pos=True,
+ gradient_mul=0.1,
+ bbox_norm_type='reg_denom',
+ loss_cls_fl=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ use_vfl=True,
+ loss_cls=dict(
+ type='VarifocalLoss',
+ use_sigmoid=True,
+ alpha=0.75,
+ gamma=2.0,
+ iou_weighted=True,
+ loss_weight=1.0),
+ loss_bbox=dict(type='GIoULoss', loss_weight=1.5),
+ loss_bbox_refine=dict(type='GIoULoss', loss_weight=2.0),
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ use_atss=True,
+ reg_decoded_bbox=True,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ ratios=[1.0],
+ octave_base_scale=8,
+ scales_per_octave=1,
+ center_offset=0.0,
+ strides=[8, 16, 32, 64, 128]),
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='vfnet_cls',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ # dcn base offsets, adapted from reppoints_head.py
+ self.num_dconv_points = 9
+ self.dcn_kernel = int(np.sqrt(self.num_dconv_points))
+ self.dcn_pad = int((self.dcn_kernel - 1) / 2)
+ dcn_base = np.arange(-self.dcn_pad,
+ self.dcn_pad + 1).astype(np.float64)
+ dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
+ dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
+ dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
+ (-1))
+ self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
+
+ super(FCOSHead, self).__init__(
+ num_classes,
+ in_channels,
+ norm_cfg=norm_cfg,
+ init_cfg=init_cfg,
+ **kwargs)
+ self.regress_ranges = regress_ranges
+ self.reg_denoms = [
+ regress_range[-1] for regress_range in regress_ranges
+ ]
+ self.reg_denoms[-1] = self.reg_denoms[-2] * 2
+ self.center_sampling = center_sampling
+ self.center_sample_radius = center_sample_radius
+ self.sync_num_pos = sync_num_pos
+ self.bbox_norm_type = bbox_norm_type
+ self.gradient_mul = gradient_mul
+ self.use_vfl = use_vfl
+ if self.use_vfl:
+ self.loss_cls = build_loss(loss_cls)
+ else:
+ self.loss_cls = build_loss(loss_cls_fl)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.loss_bbox_refine = build_loss(loss_bbox_refine)
+
+ # for getting ATSS targets
+ self.use_atss = use_atss
+ self.reg_decoded_bbox = reg_decoded_bbox
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+
+ self.anchor_center_offset = anchor_generator['center_offset']
+
+ self.num_base_priors = self.prior_generator.num_base_priors[0]
+
+ self.sampling = False
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ # only be used in `get_atss_targets` when `use_atss` is True
+ self.atss_prior_generator = build_prior_generator(anchor_generator)
+
+ self.fcos_prior_generator = MlvlPointGenerator(
+ anchor_generator['strides'],
+ self.anchor_center_offset if self.use_atss else 0.5)
+
+ # In order to reuse the `get_bboxes` in `BaseDenseHead.
+ # Only be used in testing phase.
+ self.prior_generator = self.fcos_prior_generator
+
+ @property
+ def num_anchors(self):
+ """
+ Returns:
+ int: Number of anchors on each point of feature map.
+ """
+ warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
+ 'please use "num_base_priors" instead')
+ return self.num_base_priors
+
+ @property
+ def anchor_generator(self):
+ warnings.warn('DeprecationWarning: anchor_generator is deprecated, '
+ 'please use "atss_prior_generator" instead')
+ return self.prior_generator
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ super(FCOSHead, self)._init_cls_convs()
+ super(FCOSHead, self)._init_reg_convs()
+ self.relu = nn.ReLU(inplace=True)
+ self.vfnet_reg_conv = ConvModule(
+ self.feat_channels,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias)
+ self.vfnet_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+ self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
+
+ self.vfnet_reg_refine_dconv = DeformConv2d(
+ self.feat_channels,
+ self.feat_channels,
+ self.dcn_kernel,
+ 1,
+ padding=self.dcn_pad)
+ self.vfnet_reg_refine = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+ self.scales_refine = nn.ModuleList([Scale(1.0) for _ in self.strides])
+
+ self.vfnet_cls_dconv = DeformConv2d(
+ self.feat_channels,
+ self.feat_channels,
+ self.dcn_kernel,
+ 1,
+ padding=self.dcn_pad)
+ self.vfnet_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple:
+ cls_scores (list[Tensor]): Box iou-aware scores for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box offsets for each
+ scale level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ bbox_preds_refine (list[Tensor]): Refined Box offsets for
+ each scale level, each is a 4D-tensor, the channel
+ number is num_points * 4.
+ """
+ return multi_apply(self.forward_single, feats, self.scales,
+ self.scales_refine, self.strides, self.reg_denoms)
+
+ def forward_single(self, x, scale, scale_refine, stride, reg_denom):
+ """Forward features of a single scale level.
+
+ Args:
+ x (Tensor): FPN feature maps of the specified stride.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+ scale_refine (:obj: `mmcv.cnn.Scale`): Learnable scale module to
+ resize the refined bbox prediction.
+ stride (int): The corresponding stride for feature maps,
+ used to normalize the bbox prediction when
+ bbox_norm_type = 'stride'.
+ reg_denom (int): The corresponding regression range for feature
+ maps, only used to normalize the bbox prediction when
+ bbox_norm_type = 'reg_denom'.
+
+ Returns:
+ tuple: iou-aware cls scores for each box, bbox predictions and
+ refined bbox predictions of input feature maps.
+ """
+ cls_feat = x
+ reg_feat = x
+
+ for cls_layer in self.cls_convs:
+ cls_feat = cls_layer(cls_feat)
+
+ for reg_layer in self.reg_convs:
+ reg_feat = reg_layer(reg_feat)
+
+ # predict the bbox_pred of different level
+ reg_feat_init = self.vfnet_reg_conv(reg_feat)
+ if self.bbox_norm_type == 'reg_denom':
+ bbox_pred = scale(
+ self.vfnet_reg(reg_feat_init)).float().exp() * reg_denom
+ elif self.bbox_norm_type == 'stride':
+ bbox_pred = scale(
+ self.vfnet_reg(reg_feat_init)).float().exp() * stride
+ else:
+ raise NotImplementedError
+
+ # compute star deformable convolution offsets
+ # converting dcn_offset to reg_feat.dtype thus VFNet can be
+ # trained with FP16
+ dcn_offset = self.star_dcn_offset(bbox_pred, self.gradient_mul,
+ stride).to(reg_feat.dtype)
+
+ # refine the bbox_pred
+ reg_feat = self.relu(self.vfnet_reg_refine_dconv(reg_feat, dcn_offset))
+ bbox_pred_refine = scale_refine(
+ self.vfnet_reg_refine(reg_feat)).float().exp()
+ bbox_pred_refine = bbox_pred_refine * bbox_pred.detach()
+
+ # predict the iou-aware cls score
+ cls_feat = self.relu(self.vfnet_cls_dconv(cls_feat, dcn_offset))
+ cls_score = self.vfnet_cls(cls_feat)
+
+ if self.training:
+ return cls_score, bbox_pred, bbox_pred_refine
+ else:
+ return cls_score, bbox_pred_refine
+
+ def star_dcn_offset(self, bbox_pred, gradient_mul, stride):
+ """Compute the star deformable conv offsets.
+
+ Args:
+ bbox_pred (Tensor): Predicted bbox distance offsets (l, r, t, b).
+ gradient_mul (float): Gradient multiplier.
+ stride (int): The corresponding stride for feature maps,
+ used to project the bbox onto the feature map.
+
+ Returns:
+ dcn_offsets (Tensor): The offsets for deformable convolution.
+ """
+ dcn_base_offset = self.dcn_base_offset.type_as(bbox_pred)
+ bbox_pred_grad_mul = (1 - gradient_mul) * bbox_pred.detach() + \
+ gradient_mul * bbox_pred
+ # map to the feature map scale
+ bbox_pred_grad_mul = bbox_pred_grad_mul / stride
+ N, C, H, W = bbox_pred.size()
+
+ x1 = bbox_pred_grad_mul[:, 0, :, :]
+ y1 = bbox_pred_grad_mul[:, 1, :, :]
+ x2 = bbox_pred_grad_mul[:, 2, :, :]
+ y2 = bbox_pred_grad_mul[:, 3, :, :]
+ bbox_pred_grad_mul_offset = bbox_pred.new_zeros(
+ N, 2 * self.num_dconv_points, H, W)
+ bbox_pred_grad_mul_offset[:, 0, :, :] = -1.0 * y1 # -y1
+ bbox_pred_grad_mul_offset[:, 1, :, :] = -1.0 * x1 # -x1
+ bbox_pred_grad_mul_offset[:, 2, :, :] = -1.0 * y1 # -y1
+ bbox_pred_grad_mul_offset[:, 4, :, :] = -1.0 * y1 # -y1
+ bbox_pred_grad_mul_offset[:, 5, :, :] = x2 # x2
+ bbox_pred_grad_mul_offset[:, 7, :, :] = -1.0 * x1 # -x1
+ bbox_pred_grad_mul_offset[:, 11, :, :] = x2 # x2
+ bbox_pred_grad_mul_offset[:, 12, :, :] = y2 # y2
+ bbox_pred_grad_mul_offset[:, 13, :, :] = -1.0 * x1 # -x1
+ bbox_pred_grad_mul_offset[:, 14, :, :] = y2 # y2
+ bbox_pred_grad_mul_offset[:, 16, :, :] = y2 # y2
+ bbox_pred_grad_mul_offset[:, 17, :, :] = x2 # x2
+ dcn_offset = bbox_pred_grad_mul_offset - dcn_base_offset
+
+ return dcn_offset
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'bbox_preds_refine'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ bbox_preds_refine,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box iou-aware scores for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box offsets for each
+ scale level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ bbox_preds_refine (list[Tensor]): Refined Box offsets for
+ each scale level, each is a 4D-tensor, the channel
+ number is num_points * 4.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ Default: None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert len(cls_scores) == len(bbox_preds) == len(bbox_preds_refine)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ all_level_points = self.fcos_prior_generator.grid_priors(
+ featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device)
+ labels, label_weights, bbox_targets, bbox_weights = self.get_targets(
+ cls_scores, all_level_points, gt_bboxes, gt_labels, img_metas,
+ gt_bboxes_ignore)
+
+ num_imgs = cls_scores[0].size(0)
+ # flatten cls_scores, bbox_preds and bbox_preds_refine
+ flatten_cls_scores = [
+ cls_score.permute(0, 2, 3,
+ 1).reshape(-1,
+ self.cls_out_channels).contiguous()
+ for cls_score in cls_scores
+ ]
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4).contiguous()
+ for bbox_pred in bbox_preds
+ ]
+ flatten_bbox_preds_refine = [
+ bbox_pred_refine.permute(0, 2, 3, 1).reshape(-1, 4).contiguous()
+ for bbox_pred_refine in bbox_preds_refine
+ ]
+ flatten_cls_scores = torch.cat(flatten_cls_scores)
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds)
+ flatten_bbox_preds_refine = torch.cat(flatten_bbox_preds_refine)
+ flatten_labels = torch.cat(labels)
+ flatten_bbox_targets = torch.cat(bbox_targets)
+ # repeat points to align with bbox_preds
+ flatten_points = torch.cat(
+ [points.repeat(num_imgs, 1) for points in all_level_points])
+
+ # FG cat_id: [0, num_classes - 1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = torch.where(
+ ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)) > 0)[0]
+ num_pos = len(pos_inds)
+
+ pos_bbox_preds = flatten_bbox_preds[pos_inds]
+ pos_bbox_preds_refine = flatten_bbox_preds_refine[pos_inds]
+ pos_labels = flatten_labels[pos_inds]
+
+ # sync num_pos across all gpus
+ if self.sync_num_pos:
+ num_pos_avg_per_gpu = reduce_mean(
+ pos_inds.new_tensor(num_pos).float()).item()
+ num_pos_avg_per_gpu = max(num_pos_avg_per_gpu, 1.0)
+ else:
+ num_pos_avg_per_gpu = num_pos
+
+ pos_bbox_targets = flatten_bbox_targets[pos_inds]
+ pos_points = flatten_points[pos_inds]
+
+ pos_decoded_bbox_preds = self.bbox_coder.decode(
+ pos_points, pos_bbox_preds)
+ pos_decoded_target_preds = self.bbox_coder.decode(
+ pos_points, pos_bbox_targets)
+ iou_targets_ini = bbox_overlaps(
+ pos_decoded_bbox_preds,
+ pos_decoded_target_preds.detach(),
+ is_aligned=True).clamp(min=1e-6)
+ bbox_weights_ini = iou_targets_ini.clone().detach()
+ bbox_avg_factor_ini = reduce_mean(
+ bbox_weights_ini.sum()).clamp_(min=1).item()
+
+ pos_decoded_bbox_preds_refine = \
+ self.bbox_coder.decode(pos_points, pos_bbox_preds_refine)
+ iou_targets_rf = bbox_overlaps(
+ pos_decoded_bbox_preds_refine,
+ pos_decoded_target_preds.detach(),
+ is_aligned=True).clamp(min=1e-6)
+ bbox_weights_rf = iou_targets_rf.clone().detach()
+ bbox_avg_factor_rf = reduce_mean(
+ bbox_weights_rf.sum()).clamp_(min=1).item()
+
+ if num_pos > 0:
+ loss_bbox = self.loss_bbox(
+ pos_decoded_bbox_preds,
+ pos_decoded_target_preds.detach(),
+ weight=bbox_weights_ini,
+ avg_factor=bbox_avg_factor_ini)
+
+ loss_bbox_refine = self.loss_bbox_refine(
+ pos_decoded_bbox_preds_refine,
+ pos_decoded_target_preds.detach(),
+ weight=bbox_weights_rf,
+ avg_factor=bbox_avg_factor_rf)
+
+ # build IoU-aware cls_score targets
+ if self.use_vfl:
+ pos_ious = iou_targets_rf.clone().detach()
+ cls_iou_targets = torch.zeros_like(flatten_cls_scores)
+ cls_iou_targets[pos_inds, pos_labels] = pos_ious
+ else:
+ loss_bbox = pos_bbox_preds.sum() * 0
+ loss_bbox_refine = pos_bbox_preds_refine.sum() * 0
+ if self.use_vfl:
+ cls_iou_targets = torch.zeros_like(flatten_cls_scores)
+
+ if self.use_vfl:
+ loss_cls = self.loss_cls(
+ flatten_cls_scores,
+ cls_iou_targets,
+ avg_factor=num_pos_avg_per_gpu)
+ else:
+ loss_cls = self.loss_cls(
+ flatten_cls_scores,
+ flatten_labels,
+ weight=label_weights,
+ avg_factor=num_pos_avg_per_gpu)
+
+ return dict(
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ loss_bbox_rf=loss_bbox_refine)
+
+ def get_targets(self, cls_scores, mlvl_points, gt_bboxes, gt_labels,
+ img_metas, gt_bboxes_ignore):
+ """A wrapper for computing ATSS and FCOS targets for points in multiple
+ images.
+
+ Args:
+ cls_scores (list[Tensor]): Box iou-aware scores for each scale
+ level with shape (N, num_points * num_classes, H, W).
+ mlvl_points (list[Tensor]): Points of each fpn level, each has
+ shape (num_points, 2).
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+
+ Returns:
+ tuple:
+ labels_list (list[Tensor]): Labels of each level.
+ label_weights (Tensor/None): Label weights of all levels.
+ bbox_targets_list (list[Tensor]): Regression targets of each
+ level, (l, t, r, b).
+ bbox_weights (Tensor/None): Bbox weights of all levels.
+ """
+ if self.use_atss:
+ return self.get_atss_targets(cls_scores, mlvl_points, gt_bboxes,
+ gt_labels, img_metas,
+ gt_bboxes_ignore)
+ else:
+ self.norm_on_bbox = False
+ return self.get_fcos_targets(mlvl_points, gt_bboxes, gt_labels)
+
+ def _get_target_single(self, *args, **kwargs):
+ """Avoid ambiguity in multiple inheritance."""
+ if self.use_atss:
+ return ATSSHead._get_target_single(self, *args, **kwargs)
+ else:
+ return FCOSHead._get_target_single(self, *args, **kwargs)
+
+ def get_fcos_targets(self, points, gt_bboxes_list, gt_labels_list):
+ """Compute FCOS regression and classification targets for points in
+ multiple images.
+
+ Args:
+ points (list[Tensor]): Points of each fpn level, each has shape
+ (num_points, 2).
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels_list (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+
+ Returns:
+ tuple:
+ labels (list[Tensor]): Labels of each level.
+ label_weights: None, to be compatible with ATSS targets.
+ bbox_targets (list[Tensor]): BBox targets of each level.
+ bbox_weights: None, to be compatible with ATSS targets.
+ """
+ labels, bbox_targets = FCOSHead.get_targets(self, points,
+ gt_bboxes_list,
+ gt_labels_list)
+ label_weights = None
+ bbox_weights = None
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
+ """Get anchors according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+ device (torch.device | str): Device for returned tensors
+
+ Returns:
+ tuple:
+ anchor_list (list[Tensor]): Anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ """
+ num_imgs = len(img_metas)
+
+ # since feature map sizes of all images are the same, we only compute
+ # anchors for one time
+ multi_level_anchors = self.atss_prior_generator.grid_priors(
+ featmap_sizes, device=device)
+ anchor_list = [multi_level_anchors for _ in range(num_imgs)]
+
+ # for each image, we compute valid flags of multi level anchors
+ valid_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = self.atss_prior_generator.valid_flags(
+ featmap_sizes, img_meta['pad_shape'], device=device)
+ valid_flag_list.append(multi_level_flags)
+
+ return anchor_list, valid_flag_list
+
+ def get_atss_targets(self,
+ cls_scores,
+ mlvl_points,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """A wrapper for computing ATSS targets for points in multiple images.
+
+ Args:
+ cls_scores (list[Tensor]): Box iou-aware scores for each scale
+ level with shape (N, num_points * num_classes, H, W).
+ mlvl_points (list[Tensor]): Points of each fpn level, each has
+ shape (num_points, 2).
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4). Default: None.
+
+ Returns:
+ tuple:
+ labels_list (list[Tensor]): Labels of each level.
+ label_weights (Tensor): Label weights of all levels.
+ bbox_targets_list (list[Tensor]): Regression targets of each
+ level, (l, t, r, b).
+ bbox_weights (Tensor): Bbox weights of all levels.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(
+ featmap_sizes
+ ) == self.atss_prior_generator.num_levels == \
+ self.fcos_prior_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = ATSSHead.get_targets(
+ self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ unmap_outputs=True)
+ if cls_reg_targets is None:
+ return None
+
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
+
+ bbox_targets_list = [
+ bbox_targets.reshape(-1, 4) for bbox_targets in bbox_targets_list
+ ]
+
+ num_imgs = len(img_metas)
+ # transform bbox_targets (x1, y1, x2, y2) into (l, t, r, b) format
+ bbox_targets_list = self.transform_bbox_targets(
+ bbox_targets_list, mlvl_points, num_imgs)
+
+ labels_list = [labels.reshape(-1) for labels in labels_list]
+ label_weights_list = [
+ label_weights.reshape(-1) for label_weights in label_weights_list
+ ]
+ bbox_weights_list = [
+ bbox_weights.reshape(-1) for bbox_weights in bbox_weights_list
+ ]
+ label_weights = torch.cat(label_weights_list)
+ bbox_weights = torch.cat(bbox_weights_list)
+ return labels_list, label_weights, bbox_targets_list, bbox_weights
+
+ def transform_bbox_targets(self, decoded_bboxes, mlvl_points, num_imgs):
+ """Transform bbox_targets (x1, y1, x2, y2) into (l, t, r, b) format.
+
+ Args:
+ decoded_bboxes (list[Tensor]): Regression targets of each level,
+ in the form of (x1, y1, x2, y2).
+ mlvl_points (list[Tensor]): Points of each fpn level, each has
+ shape (num_points, 2).
+ num_imgs (int): the number of images in a batch.
+
+ Returns:
+ bbox_targets (list[Tensor]): Regression targets of each level in
+ the form of (l, t, r, b).
+ """
+ # TODO: Re-implemented in Class PointCoder
+ assert len(decoded_bboxes) == len(mlvl_points)
+ num_levels = len(decoded_bboxes)
+ mlvl_points = [points.repeat(num_imgs, 1) for points in mlvl_points]
+ bbox_targets = []
+ for i in range(num_levels):
+ bbox_target = self.bbox_coder.encode(mlvl_points[i],
+ decoded_bboxes[i])
+ bbox_targets.append(bbox_target)
+
+ return bbox_targets
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """Override the method in the parent class to avoid changing para's
+ name."""
+ pass
+
+ def _get_points_single(self,
+ featmap_size,
+ stride,
+ dtype,
+ device,
+ flatten=False):
+ """Get points according to feature map size.
+
+ This function will be deprecated soon.
+ """
+
+ warnings.warn(
+ '`_get_points_single` in `VFNetHead` will be '
+ 'deprecated soon, we support a multi level point generator now'
+ 'you can get points of a single level feature map'
+ 'with `self.fcos_prior_generator.single_level_grid_priors` ')
+
+ h, w = featmap_size
+ x_range = torch.arange(
+ 0, w * stride, stride, dtype=dtype, device=device)
+ y_range = torch.arange(
+ 0, h * stride, stride, dtype=dtype, device=device)
+ y, x = torch.meshgrid(y_range, x_range)
+ # to be compatible with anchor points in ATSS
+ if self.use_atss:
+ points = torch.stack(
+ (x.reshape(-1), y.reshape(-1)), dim=-1) + \
+ stride * self.anchor_center_offset
+ else:
+ points = torch.stack(
+ (x.reshape(-1), y.reshape(-1)), dim=-1) + stride // 2
+ return points
diff --git a/mmdet/models/dense_heads/yolact_head.py b/mmdet/models/dense_heads/yolact_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f89a271baf2fd75eb63dc16e8343870fe640760
--- /dev/null
+++ b/mmdet/models/dense_heads/yolact_head.py
@@ -0,0 +1,1018 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule, ModuleList, force_fp32
+
+from mmdet.core import build_sampler, fast_nms, images_to_levels, multi_apply
+from mmdet.core.utils import select_single_mlvl
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module()
+class YOLACTHead(AnchorHead):
+ """YOLACT box head used in https://arxiv.org/abs/1904.02689.
+
+ Note that YOLACT head is a light version of RetinaNet head.
+ Four differences are described as follows:
+
+ 1. YOLACT box head has three-times fewer anchors.
+ 2. YOLACT box head shares the convs for box and cls branches.
+ 3. YOLACT box head uses OHEM instead of Focal loss.
+ 4. YOLACT box head predicts a set of mask coefficients for each box.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ anchor_generator (dict): Config dict for anchor generator
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ num_head_convs (int): Number of the conv layers shared by
+ box and cls branches.
+ num_protos (int): Number of the mask coefficients.
+ use_ohem (bool): If true, ``loss_single_OHEM`` will be used for
+ cls loss calculation. If false, ``loss_single`` will be used.
+ conv_cfg (dict): Dictionary to construct and config conv layer.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=3,
+ scales_per_octave=1,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64, 128]),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ reduction='none',
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=1.5),
+ num_head_convs=1,
+ num_protos=32,
+ use_ohem=True,
+ conv_cfg=None,
+ norm_cfg=None,
+ init_cfg=dict(
+ type='Xavier',
+ distribution='uniform',
+ bias=0,
+ layer='Conv2d'),
+ **kwargs):
+ self.num_head_convs = num_head_convs
+ self.num_protos = num_protos
+ self.use_ohem = use_ohem
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ super(YOLACTHead, self).__init__(
+ num_classes,
+ in_channels,
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ anchor_generator=anchor_generator,
+ init_cfg=init_cfg,
+ **kwargs)
+ if self.use_ohem:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.sampling = False
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.head_convs = ModuleList()
+ for i in range(self.num_head_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.head_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.conv_reg = nn.Conv2d(
+ self.feat_channels, self.num_base_priors * 4, 3, padding=1)
+ self.conv_coeff = nn.Conv2d(
+ self.feat_channels,
+ self.num_base_priors * self.num_protos,
+ 3,
+ padding=1)
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls scores for a single scale level \
+ the channels number is num_anchors * num_classes.
+ bbox_pred (Tensor): Box energies / deltas for a single scale \
+ level, the channels number is num_anchors * 4.
+ coeff_pred (Tensor): Mask coefficients for a single scale \
+ level, the channels number is num_anchors * num_protos.
+ """
+ for head_conv in self.head_convs:
+ x = head_conv(x)
+ cls_score = self.conv_cls(x)
+ bbox_pred = self.conv_reg(x)
+ coeff_pred = self.conv_coeff(x).tanh()
+ return cls_score, bbox_pred, coeff_pred
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """A combination of the func:``AnchorHead.loss`` and
+ func:``SSDHead.loss``.
+
+ When ``self.use_ohem == True``, it functions like ``SSDHead.loss``,
+ otherwise, it follows ``AnchorHead.loss``. Besides, it additionally
+ returns ``sampling_results``.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ tuple:
+ dict[str, Tensor]: A dictionary of loss components.
+ List[:obj:``SamplingResult``]: Sampler results for each image.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.prior_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ unmap_outputs=not self.use_ohem,
+ return_sampling_results=True)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg, sampling_results) = cls_reg_targets
+
+ if self.use_ohem:
+ num_images = len(img_metas)
+ all_cls_scores = torch.cat([
+ s.permute(0, 2, 3, 1).reshape(
+ num_images, -1, self.cls_out_channels) for s in cls_scores
+ ], 1)
+ all_labels = torch.cat(labels_list, -1).view(num_images, -1)
+ all_label_weights = torch.cat(label_weights_list,
+ -1).view(num_images, -1)
+ all_bbox_preds = torch.cat([
+ b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
+ for b in bbox_preds
+ ], -2)
+ all_bbox_targets = torch.cat(bbox_targets_list,
+ -2).view(num_images, -1, 4)
+ all_bbox_weights = torch.cat(bbox_weights_list,
+ -2).view(num_images, -1, 4)
+
+ # concat all level anchors to a single tensor
+ all_anchors = []
+ for i in range(num_images):
+ all_anchors.append(torch.cat(anchor_list[i]))
+
+ # check NaN and Inf
+ assert torch.isfinite(all_cls_scores).all().item(), \
+ 'classification scores become infinite or NaN!'
+ assert torch.isfinite(all_bbox_preds).all().item(), \
+ 'bbox predications become infinite or NaN!'
+
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single_OHEM,
+ all_cls_scores,
+ all_bbox_preds,
+ all_anchors,
+ all_labels,
+ all_label_weights,
+ all_bbox_targets,
+ all_bbox_weights,
+ num_total_samples=num_total_pos)
+ else:
+ num_total_samples = (
+ num_total_pos +
+ num_total_neg if self.sampling else num_total_pos)
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(anchor_list)):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+
+ return dict(
+ loss_cls=losses_cls, loss_bbox=losses_bbox), sampling_results
+
+ def loss_single_OHEM(self, cls_score, bbox_pred, anchors, labels,
+ label_weights, bbox_targets, bbox_weights,
+ num_total_samples):
+ """"See func:``SSDHead.loss``."""
+ loss_cls_all = self.loss_cls(cls_score, labels, label_weights)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(
+ as_tuple=False).reshape(-1)
+ neg_inds = (labels == self.num_classes).nonzero(
+ as_tuple=False).view(-1)
+
+ num_pos_samples = pos_inds.size(0)
+ if num_pos_samples == 0:
+ num_neg_samples = neg_inds.size(0)
+ else:
+ num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples
+ if num_neg_samples > neg_inds.size(0):
+ num_neg_samples = neg_inds.size(0)
+ topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
+ loss_cls_pos = loss_cls_all[pos_inds].sum()
+ loss_cls_neg = topk_loss_cls_neg.sum()
+ loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
+ loss_bbox = self.loss_bbox(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+ return loss_cls[None], loss_bbox
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'coeff_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ coeff_preds,
+ img_metas,
+ cfg=None,
+ rescale=False):
+ """"Similar to func:``AnchorHead.get_bboxes``, but additionally
+ processes coeff_preds.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ with shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ coeff_preds (list[Tensor]): Mask coefficients for each scale
+ level with shape (N, num_anchors * num_protos, H, W)
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+
+ Returns:
+ list[tuple[Tensor, Tensor, Tensor]]: Each item in result_list is
+ a 3-tuple. The first item is an (n, 5) tensor, where the
+ first 4 columns are bounding box positions
+ (tl_x, tl_y, br_x, br_y) and the 5-th column is a score
+ between 0 and 1. The second item is an (n,) tensor where each
+ item is the predicted class label of the corresponding box.
+ The third item is an (n, num_protos) tensor where each item
+ is the predicted mask coefficients of instance inside the
+ corresponding box.
+ """
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+
+ device = cls_scores[0].device
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_anchors = self.prior_generator.grid_priors(
+ featmap_sizes, device=device)
+
+ det_bboxes = []
+ det_labels = []
+ det_coeffs = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = select_single_mlvl(cls_scores, img_id)
+ bbox_pred_list = select_single_mlvl(bbox_preds, img_id)
+ coeff_pred_list = select_single_mlvl(coeff_preds, img_id)
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ bbox_res = self._get_bboxes_single(cls_score_list, bbox_pred_list,
+ coeff_pred_list, mlvl_anchors,
+ img_shape, scale_factor, cfg,
+ rescale)
+ det_bboxes.append(bbox_res[0])
+ det_labels.append(bbox_res[1])
+ det_coeffs.append(bbox_res[2])
+ return det_bboxes, det_labels, det_coeffs
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ coeff_preds_list,
+ mlvl_anchors,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ """"Similar to func:``AnchorHead._get_bboxes_single``, but additionally
+ processes coeff_preds_list and uses fast NMS instead of traditional
+ NMS.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores for a single scale level
+ Has shape (num_anchors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas for a single
+ scale level with shape (num_anchors * 4, H, W).
+ coeff_preds_list (list[Tensor]): Mask coefficients for a single
+ scale level with shape (num_anchors * num_protos, H, W).
+ mlvl_anchors (list[Tensor]): Box reference for a single scale level
+ with shape (num_total_anchors, 4).
+ img_shape (tuple[int]): Shape of the input image,
+ (height, width, 3).
+ scale_factor (ndarray): Scale factor of the image arange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+
+ Returns:
+ tuple[Tensor, Tensor, Tensor]: The first item is an (n, 5) tensor,
+ where the first 4 columns are bounding box positions
+ (tl_x, tl_y, br_x, br_y) and the 5-th column is a score between
+ 0 and 1. The second item is an (n,) tensor where each item is
+ the predicted class label of the corresponding box. The third
+ item is an (n, num_protos) tensor where each item is the
+ predicted mask coefficients of instance inside the
+ corresponding box.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
+ nms_pre = cfg.get('nms_pre', -1)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_coeffs = []
+ for cls_score, bbox_pred, coeff_pred, anchors in \
+ zip(cls_score_list, bbox_pred_list,
+ coeff_preds_list, mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ coeff_pred = coeff_pred.permute(1, 2,
+ 0).reshape(-1, self.num_protos)
+
+ if 0 < nms_pre < scores.shape[0]:
+ # Get maximum scores for foreground classes.
+ if self.use_sigmoid_cls:
+ max_scores, _ = scores.max(dim=1)
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ max_scores, _ = scores[:, :-1].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ anchors = anchors[topk_inds, :]
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ coeff_pred = coeff_pred[topk_inds, :]
+ bboxes = self.bbox_coder.decode(
+ anchors, bbox_pred, max_shape=img_shape)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_coeffs.append(coeff_pred)
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ mlvl_coeffs = torch.cat(mlvl_coeffs)
+ if self.use_sigmoid_cls:
+ # Add a dummy background class to the backend when using sigmoid
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+ mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
+ det_bboxes, det_labels, det_coeffs = fast_nms(mlvl_bboxes, mlvl_scores,
+ mlvl_coeffs,
+ cfg.score_thr,
+ cfg.iou_thr, cfg.top_k,
+ cfg.max_per_img)
+ return det_bboxes, det_labels, det_coeffs
+
+
+@HEADS.register_module()
+class YOLACTSegmHead(BaseModule):
+ """YOLACT segmentation head used in https://arxiv.org/abs/1904.02689.
+
+ Apply a semantic segmentation loss on feature space using layers that are
+ only evaluated during training to increase performance with no speed
+ penalty.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ num_classes (int): Number of categories excluding the background
+ category.
+ loss_segm (dict): Config of semantic segmentation loss.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels=256,
+ loss_segm=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ init_cfg=dict(
+ type='Xavier',
+ distribution='uniform',
+ override=dict(name='segm_conv'))):
+ super(YOLACTSegmHead, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.loss_segm = build_loss(loss_segm)
+ self._init_layers()
+ self.fp16_enabled = False
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.segm_conv = nn.Conv2d(
+ self.in_channels, self.num_classes, kernel_size=1)
+
+ def forward(self, x):
+ """Forward feature from the upstream network.
+
+ Args:
+ x (Tensor): Feature from the upstream network, which is
+ a 4D-tensor.
+
+ Returns:
+ Tensor: Predicted semantic segmentation map with shape
+ (N, num_classes, H, W).
+ """
+ return self.segm_conv(x)
+
+ @force_fp32(apply_to=('segm_pred', ))
+ def loss(self, segm_pred, gt_masks, gt_labels):
+ """Compute loss of the head.
+
+ Args:
+ segm_pred (list[Tensor]): Predicted semantic segmentation map
+ with shape (N, num_classes, H, W).
+ gt_masks (list[Tensor]): Ground truth masks for each image with
+ the same shape of the input image.
+ gt_labels (list[Tensor]): Class indices corresponding to each box.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ loss_segm = []
+ num_imgs, num_classes, mask_h, mask_w = segm_pred.size()
+ for idx in range(num_imgs):
+ cur_segm_pred = segm_pred[idx]
+ cur_gt_masks = gt_masks[idx].float()
+ cur_gt_labels = gt_labels[idx]
+ segm_targets = self.get_targets(cur_segm_pred, cur_gt_masks,
+ cur_gt_labels)
+ if segm_targets is None:
+ loss = self.loss_segm(cur_segm_pred,
+ torch.zeros_like(cur_segm_pred),
+ torch.zeros_like(cur_segm_pred))
+ else:
+ loss = self.loss_segm(
+ cur_segm_pred,
+ segm_targets,
+ avg_factor=num_imgs * mask_h * mask_w)
+ loss_segm.append(loss)
+ return dict(loss_segm=loss_segm)
+
+ def get_targets(self, segm_pred, gt_masks, gt_labels):
+ """Compute semantic segmentation targets for each image.
+
+ Args:
+ segm_pred (Tensor): Predicted semantic segmentation map
+ with shape (num_classes, H, W).
+ gt_masks (Tensor): Ground truth masks for each image with
+ the same shape of the input image.
+ gt_labels (Tensor): Class indices corresponding to each box.
+
+ Returns:
+ Tensor: Semantic segmentation targets with shape
+ (num_classes, H, W).
+ """
+ if gt_masks.size(0) == 0:
+ return None
+ num_classes, mask_h, mask_w = segm_pred.size()
+ with torch.no_grad():
+ downsampled_masks = F.interpolate(
+ gt_masks.unsqueeze(0), (mask_h, mask_w),
+ mode='bilinear',
+ align_corners=False).squeeze(0)
+ downsampled_masks = downsampled_masks.gt(0.5).float()
+ segm_targets = torch.zeros_like(segm_pred, requires_grad=False)
+ for obj_idx in range(downsampled_masks.size(0)):
+ segm_targets[gt_labels[obj_idx] - 1] = torch.max(
+ segm_targets[gt_labels[obj_idx] - 1],
+ downsampled_masks[obj_idx])
+ return segm_targets
+
+ def simple_test(self, feats, img_metas, rescale=False):
+ """Test function without test-time augmentation."""
+ raise NotImplementedError(
+ 'simple_test of YOLACTSegmHead is not implemented '
+ 'because this head is only evaluated during training')
+
+
+@HEADS.register_module()
+class YOLACTProtonet(BaseModule):
+ """YOLACT mask head used in https://arxiv.org/abs/1904.02689.
+
+ This head outputs the mask prototypes for YOLACT.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ proto_channels (tuple[int]): Output channels of protonet convs.
+ proto_kernel_sizes (tuple[int]): Kernel sizes of protonet convs.
+ include_last_relu (Bool): If keep the last relu of protonet.
+ num_protos (int): Number of prototypes.
+ num_classes (int): Number of categories excluding the background
+ category.
+ loss_mask_weight (float): Reweight the mask loss by this factor.
+ max_masks_to_train (int): Maximum number of masks to train for
+ each image.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels=256,
+ proto_channels=(256, 256, 256, None, 256, 32),
+ proto_kernel_sizes=(3, 3, 3, -2, 3, 1),
+ include_last_relu=True,
+ num_protos=32,
+ loss_mask_weight=1.0,
+ max_masks_to_train=100,
+ init_cfg=dict(
+ type='Xavier',
+ distribution='uniform',
+ override=dict(name='protonet'))):
+ super(YOLACTProtonet, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.proto_channels = proto_channels
+ self.proto_kernel_sizes = proto_kernel_sizes
+ self.include_last_relu = include_last_relu
+ self.protonet = self._init_layers()
+
+ self.loss_mask_weight = loss_mask_weight
+ self.num_protos = num_protos
+ self.num_classes = num_classes
+ self.max_masks_to_train = max_masks_to_train
+ self.fp16_enabled = False
+
+ def _init_layers(self):
+ """A helper function to take a config setting and turn it into a
+ network."""
+ # Possible patterns:
+ # ( 256, 3) -> conv
+ # ( 256,-2) -> deconv
+ # (None,-2) -> bilinear interpolate
+ in_channels = self.in_channels
+ protonets = ModuleList()
+ for num_channels, kernel_size in zip(self.proto_channels,
+ self.proto_kernel_sizes):
+ if kernel_size > 0:
+ layer = nn.Conv2d(
+ in_channels,
+ num_channels,
+ kernel_size,
+ padding=kernel_size // 2)
+ else:
+ if num_channels is None:
+ layer = InterpolateModule(
+ scale_factor=-kernel_size,
+ mode='bilinear',
+ align_corners=False)
+ else:
+ layer = nn.ConvTranspose2d(
+ in_channels,
+ num_channels,
+ -kernel_size,
+ padding=kernel_size // 2)
+ protonets.append(layer)
+ protonets.append(nn.ReLU(inplace=True))
+ in_channels = num_channels if num_channels is not None \
+ else in_channels
+ if not self.include_last_relu:
+ protonets = protonets[:-1]
+ return nn.Sequential(*protonets)
+
+ def forward_dummy(self, x):
+ prototypes = self.protonet(x)
+ return prototypes
+
+ def forward(self, x, coeff_pred, bboxes, img_meta, sampling_results=None):
+ """Forward feature from the upstream network to get prototypes and
+ linearly combine the prototypes, using masks coefficients, into
+ instance masks. Finally, crop the instance masks with given bboxes.
+
+ Args:
+ x (Tensor): Feature from the upstream network, which is
+ a 4D-tensor.
+ coeff_pred (list[Tensor]): Mask coefficients for each scale
+ level with shape (N, num_anchors * num_protos, H, W).
+ bboxes (list[Tensor]): Box used for cropping with shape
+ (N, num_anchors * 4, H, W). During training, they are
+ ground truth boxes. During testing, they are predicted
+ boxes.
+ img_meta (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ sampling_results (List[:obj:``SamplingResult``]): Sampler results
+ for each image.
+
+ Returns:
+ list[Tensor]: Predicted instance segmentation masks.
+ """
+ prototypes = self.protonet(x)
+ prototypes = prototypes.permute(0, 2, 3, 1).contiguous()
+
+ num_imgs = x.size(0)
+
+ # The reason for not using self.training is that
+ # val workflow will have a dimension mismatch error.
+ # Note that this writing method is very tricky.
+ # Fix https://github.com/open-mmlab/mmdetection/issues/5978
+ is_train_or_val_workflow = (coeff_pred[0].dim() == 4)
+
+ # Train or val workflow
+ if is_train_or_val_workflow:
+ coeff_pred_list = []
+ for coeff_pred_per_level in coeff_pred:
+ coeff_pred_per_level = \
+ coeff_pred_per_level.permute(
+ 0, 2, 3, 1).reshape(num_imgs, -1, self.num_protos)
+ coeff_pred_list.append(coeff_pred_per_level)
+ coeff_pred = torch.cat(coeff_pred_list, dim=1)
+
+ mask_pred_list = []
+ for idx in range(num_imgs):
+ cur_prototypes = prototypes[idx]
+ cur_coeff_pred = coeff_pred[idx]
+ cur_bboxes = bboxes[idx]
+ cur_img_meta = img_meta[idx]
+
+ # Testing state
+ if not is_train_or_val_workflow:
+ bboxes_for_cropping = cur_bboxes
+ else:
+ cur_sampling_results = sampling_results[idx]
+ pos_assigned_gt_inds = \
+ cur_sampling_results.pos_assigned_gt_inds
+ bboxes_for_cropping = cur_bboxes[pos_assigned_gt_inds].clone()
+ pos_inds = cur_sampling_results.pos_inds
+ cur_coeff_pred = cur_coeff_pred[pos_inds]
+
+ # Linearly combine the prototypes with the mask coefficients
+ mask_pred = cur_prototypes @ cur_coeff_pred.t()
+ mask_pred = torch.sigmoid(mask_pred)
+
+ h, w = cur_img_meta['img_shape'][:2]
+ bboxes_for_cropping[:, 0] /= w
+ bboxes_for_cropping[:, 1] /= h
+ bboxes_for_cropping[:, 2] /= w
+ bboxes_for_cropping[:, 3] /= h
+
+ mask_pred = self.crop(mask_pred, bboxes_for_cropping)
+ mask_pred = mask_pred.permute(2, 0, 1).contiguous()
+ mask_pred_list.append(mask_pred)
+ return mask_pred_list
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def loss(self, mask_pred, gt_masks, gt_bboxes, img_meta, sampling_results):
+ """Compute loss of the head.
+
+ Args:
+ mask_pred (list[Tensor]): Predicted prototypes with shape
+ (num_classes, H, W).
+ gt_masks (list[Tensor]): Ground truth masks for each image with
+ the same shape of the input image.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ img_meta (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ sampling_results (List[:obj:``SamplingResult``]): Sampler results
+ for each image.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ loss_mask = []
+ num_imgs = len(mask_pred)
+ total_pos = 0
+ for idx in range(num_imgs):
+ cur_mask_pred = mask_pred[idx]
+ cur_gt_masks = gt_masks[idx].float()
+ cur_gt_bboxes = gt_bboxes[idx]
+ cur_img_meta = img_meta[idx]
+ cur_sampling_results = sampling_results[idx]
+
+ pos_assigned_gt_inds = cur_sampling_results.pos_assigned_gt_inds
+ num_pos = pos_assigned_gt_inds.size(0)
+ # Since we're producing (near) full image masks,
+ # it'd take too much vram to backprop on every single mask.
+ # Thus we select only a subset.
+ if num_pos > self.max_masks_to_train:
+ perm = torch.randperm(num_pos)
+ select = perm[:self.max_masks_to_train]
+ cur_mask_pred = cur_mask_pred[select]
+ pos_assigned_gt_inds = pos_assigned_gt_inds[select]
+ num_pos = self.max_masks_to_train
+ total_pos += num_pos
+
+ gt_bboxes_for_reweight = cur_gt_bboxes[pos_assigned_gt_inds]
+
+ mask_targets = self.get_targets(cur_mask_pred, cur_gt_masks,
+ pos_assigned_gt_inds)
+ if num_pos == 0:
+ loss = cur_mask_pred.sum() * 0.
+ elif mask_targets is None:
+ loss = F.binary_cross_entropy(cur_mask_pred,
+ torch.zeros_like(cur_mask_pred),
+ torch.zeros_like(cur_mask_pred))
+ else:
+ cur_mask_pred = torch.clamp(cur_mask_pred, 0, 1)
+ loss = F.binary_cross_entropy(
+ cur_mask_pred, mask_targets,
+ reduction='none') * self.loss_mask_weight
+
+ h, w = cur_img_meta['img_shape'][:2]
+ gt_bboxes_width = (gt_bboxes_for_reweight[:, 2] -
+ gt_bboxes_for_reweight[:, 0]) / w
+ gt_bboxes_height = (gt_bboxes_for_reweight[:, 3] -
+ gt_bboxes_for_reweight[:, 1]) / h
+ loss = loss.mean(dim=(1,
+ 2)) / gt_bboxes_width / gt_bboxes_height
+ loss = torch.sum(loss)
+ loss_mask.append(loss)
+
+ if total_pos == 0:
+ total_pos += 1 # avoid nan
+ loss_mask = [x / total_pos for x in loss_mask]
+
+ return dict(loss_mask=loss_mask)
+
+ def get_targets(self, mask_pred, gt_masks, pos_assigned_gt_inds):
+ """Compute instance segmentation targets for each image.
+
+ Args:
+ mask_pred (Tensor): Predicted prototypes with shape
+ (num_classes, H, W).
+ gt_masks (Tensor): Ground truth masks for each image with
+ the same shape of the input image.
+ pos_assigned_gt_inds (Tensor): GT indices of the corresponding
+ positive samples.
+ Returns:
+ Tensor: Instance segmentation targets with shape
+ (num_instances, H, W).
+ """
+ if gt_masks.size(0) == 0:
+ return None
+ mask_h, mask_w = mask_pred.shape[-2:]
+ gt_masks = F.interpolate(
+ gt_masks.unsqueeze(0), (mask_h, mask_w),
+ mode='bilinear',
+ align_corners=False).squeeze(0)
+ gt_masks = gt_masks.gt(0.5).float()
+ mask_targets = gt_masks[pos_assigned_gt_inds]
+ return mask_targets
+
+ def get_seg_masks(self, mask_pred, label_pred, img_meta, rescale):
+ """Resize, binarize, and format the instance mask predictions.
+
+ Args:
+ mask_pred (Tensor): shape (N, H, W).
+ label_pred (Tensor): shape (N, ).
+ img_meta (dict): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If rescale is False, then returned masks will
+ fit the scale of imgs[0].
+ Returns:
+ list[ndarray]: Mask predictions grouped by their predicted classes.
+ """
+ ori_shape = img_meta['ori_shape']
+ scale_factor = img_meta['scale_factor']
+ if rescale:
+ img_h, img_w = ori_shape[:2]
+ else:
+ img_h = np.round(ori_shape[0] * scale_factor[1]).astype(np.int32)
+ img_w = np.round(ori_shape[1] * scale_factor[0]).astype(np.int32)
+
+ cls_segms = [[] for _ in range(self.num_classes)]
+ if mask_pred.size(0) == 0:
+ return cls_segms
+
+ mask_pred = F.interpolate(
+ mask_pred.unsqueeze(0), (img_h, img_w),
+ mode='bilinear',
+ align_corners=False).squeeze(0) > 0.5
+ mask_pred = mask_pred.cpu().numpy().astype(np.uint8)
+
+ for m, l in zip(mask_pred, label_pred):
+ cls_segms[l].append(m)
+ return cls_segms
+
+ def crop(self, masks, boxes, padding=1):
+ """Crop predicted masks by zeroing out everything not in the predicted
+ bbox.
+
+ Args:
+ masks (Tensor): shape [H, W, N].
+ boxes (Tensor): bbox coords in relative point form with
+ shape [N, 4].
+
+ Return:
+ Tensor: The cropped masks.
+ """
+ h, w, n = masks.size()
+ x1, x2 = self.sanitize_coordinates(
+ boxes[:, 0], boxes[:, 2], w, padding, cast=False)
+ y1, y2 = self.sanitize_coordinates(
+ boxes[:, 1], boxes[:, 3], h, padding, cast=False)
+
+ rows = torch.arange(
+ w, device=masks.device, dtype=x1.dtype).view(1, -1,
+ 1).expand(h, w, n)
+ cols = torch.arange(
+ h, device=masks.device, dtype=x1.dtype).view(-1, 1,
+ 1).expand(h, w, n)
+
+ masks_left = rows >= x1.view(1, 1, -1)
+ masks_right = rows < x2.view(1, 1, -1)
+ masks_up = cols >= y1.view(1, 1, -1)
+ masks_down = cols < y2.view(1, 1, -1)
+
+ crop_mask = masks_left * masks_right * masks_up * masks_down
+
+ return masks * crop_mask.float()
+
+ def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True):
+ """Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0,
+ and x2 <= image_size. Also converts from relative to absolute
+ coordinates and casts the results to long tensors.
+
+ Warning: this does things in-place behind the scenes so
+ copy if necessary.
+
+ Args:
+ _x1 (Tensor): shape (N, ).
+ _x2 (Tensor): shape (N, ).
+ img_size (int): Size of the input image.
+ padding (int): x1 >= padding, x2 <= image_size-padding.
+ cast (bool): If cast is false, the result won't be cast to longs.
+
+ Returns:
+ tuple:
+ x1 (Tensor): Sanitized _x1.
+ x2 (Tensor): Sanitized _x2.
+ """
+ x1 = x1 * img_size
+ x2 = x2 * img_size
+ if cast:
+ x1 = x1.long()
+ x2 = x2.long()
+ x1 = torch.min(x1, x2)
+ x2 = torch.max(x1, x2)
+ x1 = torch.clamp(x1 - padding, min=0)
+ x2 = torch.clamp(x2 + padding, max=img_size)
+ return x1, x2
+
+ def simple_test(self,
+ feats,
+ det_bboxes,
+ det_labels,
+ det_coeffs,
+ img_metas,
+ rescale=False):
+ """Test function without test-time augmentation.
+
+ Args:
+ feats (tuple[torch.Tensor]): Multi-level features from the
+ upstream network, each is a 4D-tensor.
+ det_bboxes (list[Tensor]): BBox results of each image. each
+ element is (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ det_labels (list[Tensor]): BBox results of each image. each
+ element is (n, ) tensor, each element represents the class
+ label of the corresponding box.
+ det_coeffs (list[Tensor]): BBox coefficient of each image. each
+ element is (n, m) tensor, m is vector length.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[list]: encoded masks. The c-th item in the outer list
+ corresponds to the c-th class. Given the c-th outer list, the
+ i-th item in that inner list is the mask for the i-th box with
+ class label c.
+ """
+ num_imgs = len(img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ segm_results = [[[] for _ in range(self.num_classes)]
+ for _ in range(num_imgs)]
+ else:
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i][:, :4]
+ for i in range(len(det_bboxes))
+ ]
+ mask_preds = self.forward(feats[0], det_coeffs, _bboxes, img_metas)
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append([[] for _ in range(self.num_classes)])
+ else:
+ segm_result = self.get_seg_masks(mask_preds[i],
+ det_labels[i],
+ img_metas[i], rescale)
+ segm_results.append(segm_result)
+ return segm_results
+
+
+class InterpolateModule(BaseModule):
+ """This is a module version of F.interpolate.
+
+ Any arguments you give it just get passed along for the ride.
+ """
+
+ def __init__(self, *args, init_cfg=None, **kwargs):
+ super().__init__(init_cfg)
+
+ self.args = args
+ self.kwargs = kwargs
+
+ def forward(self, x):
+ """Forward features from the upstream network."""
+ return F.interpolate(x, *self.args, **self.kwargs)
diff --git a/mmdet/models/dense_heads/yolo_head.py b/mmdet/models/dense_heads/yolo_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b446cb7eb24b6608ba217713a36c917dc4b93407
--- /dev/null
+++ b/mmdet/models/dense_heads/yolo_head.py
@@ -0,0 +1,621 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Copyright (c) 2019 Western Digital Corporation or its affiliates.
+
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (ConvModule, bias_init_with_prob, constant_init, is_norm,
+ normal_init)
+from mmcv.runner import force_fp32
+
+from mmdet.core import (build_assigner, build_bbox_coder,
+ build_prior_generator, build_sampler, images_to_levels,
+ multi_apply, multiclass_nms)
+from ..builder import HEADS, build_loss
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+
+
+@HEADS.register_module()
+class YOLOV3Head(BaseDenseHead, BBoxTestMixin):
+ """YOLOV3Head Paper link: https://arxiv.org/abs/1804.02767.
+
+ Args:
+ num_classes (int): The number of object classes (w/o background)
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (List[int]): The number of output channels per scale
+ before the final 1x1 layer. Default: (1024, 512, 256).
+ anchor_generator (dict): Config dict for anchor generator
+ bbox_coder (dict): Config of bounding box coder.
+ featmap_strides (List[int]): The stride of each scale.
+ Should be in descending order. Default: (32, 16, 8).
+ one_hot_smoother (float): Set a non-zero value to enable label-smooth
+ Default: 0.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ loss_cls (dict): Config of classification loss.
+ loss_conf (dict): Config of confidence loss.
+ loss_xy (dict): Config of xy coordinate loss.
+ loss_wh (dict): Config of wh coordinate loss.
+ train_cfg (dict): Training config of YOLOV3 head. Default: None.
+ test_cfg (dict): Testing config of YOLOV3 head. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ out_channels=(1024, 512, 256),
+ anchor_generator=dict(
+ type='YOLOAnchorGenerator',
+ base_sizes=[[(116, 90), (156, 198), (373, 326)],
+ [(30, 61), (62, 45), (59, 119)],
+ [(10, 13), (16, 30), (33, 23)]],
+ strides=[32, 16, 8]),
+ bbox_coder=dict(type='YOLOBBoxCoder'),
+ featmap_strides=[32, 16, 8],
+ one_hot_smoother=0.,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_conf=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_xy=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_wh=dict(type='MSELoss', loss_weight=1.0),
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=dict(
+ type='Normal', std=0.01,
+ override=dict(name='convs_pred'))):
+ super(YOLOV3Head, self).__init__(init_cfg)
+ # Check params
+ assert (len(in_channels) == len(out_channels) == len(featmap_strides))
+
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.featmap_strides = featmap_strides
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ if hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.fp16_enabled = False
+
+ self.one_hot_smoother = one_hot_smoother
+
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+
+ self.prior_generator = build_prior_generator(anchor_generator)
+
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_conf = build_loss(loss_conf)
+ self.loss_xy = build_loss(loss_xy)
+ self.loss_wh = build_loss(loss_wh)
+
+ self.num_base_priors = self.prior_generator.num_base_priors[0]
+ assert len(
+ self.prior_generator.num_base_priors) == len(featmap_strides)
+ self._init_layers()
+
+ @property
+ def anchor_generator(self):
+
+ warnings.warn('DeprecationWarning: `anchor_generator` is deprecated, '
+ 'please use "prior_generator" instead')
+ return self.prior_generator
+
+ @property
+ def num_anchors(self):
+ """
+ Returns:
+ int: Number of anchors on each point of feature map.
+ """
+ warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
+ 'please use "num_base_priors" instead')
+ return self.num_base_priors
+
+ @property
+ def num_levels(self):
+ return len(self.featmap_strides)
+
+ @property
+ def num_attrib(self):
+ """int: number of attributes in pred_map, bboxes (4) +
+ objectness (1) + num_classes"""
+
+ return 5 + self.num_classes
+
+ def _init_layers(self):
+ self.convs_bridge = nn.ModuleList()
+ self.convs_pred = nn.ModuleList()
+ for i in range(self.num_levels):
+ conv_bridge = ConvModule(
+ self.in_channels[i],
+ self.out_channels[i],
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ conv_pred = nn.Conv2d(self.out_channels[i],
+ self.num_base_priors * self.num_attrib, 1)
+
+ self.convs_bridge.append(conv_bridge)
+ self.convs_pred.append(conv_pred)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, mean=0, std=0.01)
+ if is_norm(m):
+ constant_init(m, 1)
+
+ # Use prior in model initialization to improve stability
+ for conv_pred, stride in zip(self.convs_pred, self.featmap_strides):
+ bias = conv_pred.bias.reshape(self.num_base_priors, -1)
+ # init objectness with prior of 8 objects per feature map
+ # refer to https://github.com/ultralytics/yolov3
+ nn.init.constant_(bias.data[:, 4],
+ bias_init_with_prob(8 / (608 / stride)**2))
+ nn.init.constant_(bias.data[:, 5:], bias_init_with_prob(0.01))
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple[Tensor]: A tuple of multi-level predication map, each is a
+ 4D-tensor of shape (batch_size, 5+num_classes, height, width).
+ """
+
+ assert len(feats) == self.num_levels
+ pred_maps = []
+ for i in range(self.num_levels):
+ x = feats[i]
+ x = self.convs_bridge[i](x)
+ pred_map = self.convs_pred[i](x)
+ pred_maps.append(pred_map)
+
+ return tuple(pred_maps),
+
+ @force_fp32(apply_to=('pred_maps', ))
+ def get_bboxes(self,
+ pred_maps,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions. It has
+ been accelerated since PR #5991.
+
+ Args:
+ pred_maps (list[Tensor]): Raw predictions for a batch of images.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default: None.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+ """
+ assert len(pred_maps) == self.num_levels
+ cfg = self.test_cfg if cfg is None else cfg
+ scale_factors = np.array(
+ [img_meta['scale_factor'] for img_meta in img_metas])
+
+ num_imgs = len(img_metas)
+ featmap_sizes = [pred_map.shape[-2:] for pred_map in pred_maps]
+
+ mlvl_anchors = self.prior_generator.grid_priors(
+ featmap_sizes, device=pred_maps[0].device)
+ flatten_preds = []
+ flatten_strides = []
+ for pred, stride in zip(pred_maps, self.featmap_strides):
+ pred = pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
+ self.num_attrib)
+ pred[..., :2].sigmoid_()
+ flatten_preds.append(pred)
+ flatten_strides.append(
+ pred.new_tensor(stride).expand(pred.size(1)))
+
+ flatten_preds = torch.cat(flatten_preds, dim=1)
+ flatten_bbox_preds = flatten_preds[..., :4]
+ flatten_objectness = flatten_preds[..., 4].sigmoid()
+ flatten_cls_scores = flatten_preds[..., 5:].sigmoid()
+ flatten_anchors = torch.cat(mlvl_anchors)
+ flatten_strides = torch.cat(flatten_strides)
+ flatten_bboxes = self.bbox_coder.decode(flatten_anchors,
+ flatten_bbox_preds,
+ flatten_strides.unsqueeze(-1))
+
+ if with_nms and (flatten_objectness.size(0) == 0):
+ return torch.zeros((0, 5)), torch.zeros((0, ))
+
+ if rescale:
+ flatten_bboxes /= flatten_bboxes.new_tensor(
+ scale_factors).unsqueeze(1)
+
+ padding = flatten_bboxes.new_zeros(num_imgs, flatten_bboxes.shape[1],
+ 1)
+ flatten_cls_scores = torch.cat([flatten_cls_scores, padding], dim=-1)
+
+ det_results = []
+ for (bboxes, scores, objectness) in zip(flatten_bboxes,
+ flatten_cls_scores,
+ flatten_objectness):
+ # Filtering out all predictions with conf < conf_thr
+ conf_thr = cfg.get('conf_thr', -1)
+ if conf_thr > 0:
+ conf_inds = objectness >= conf_thr
+ bboxes = bboxes[conf_inds, :]
+ scores = scores[conf_inds, :]
+ objectness = objectness[conf_inds]
+
+ det_bboxes, det_labels = multiclass_nms(
+ bboxes,
+ scores,
+ cfg.score_thr,
+ cfg.nms,
+ cfg.max_per_img,
+ score_factors=objectness)
+ det_results.append(tuple([det_bboxes, det_labels]))
+ return det_results
+
+ @force_fp32(apply_to=('pred_maps', ))
+ def loss(self,
+ pred_maps,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ pred_maps (list[Tensor]): Prediction map for each scale level,
+ shape (N, num_anchors * num_attrib, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_imgs = len(img_metas)
+ device = pred_maps[0][0].device
+
+ featmap_sizes = [
+ pred_maps[i].shape[-2:] for i in range(self.num_levels)
+ ]
+ mlvl_anchors = self.prior_generator.grid_priors(
+ featmap_sizes, device=device)
+ anchor_list = [mlvl_anchors for _ in range(num_imgs)]
+
+ responsible_flag_list = []
+ for img_id in range(len(img_metas)):
+ responsible_flag_list.append(
+ self.prior_generator.responsible_flags(featmap_sizes,
+ gt_bboxes[img_id],
+ device))
+
+ target_maps_list, neg_maps_list = self.get_targets(
+ anchor_list, responsible_flag_list, gt_bboxes, gt_labels)
+
+ losses_cls, losses_conf, losses_xy, losses_wh = multi_apply(
+ self.loss_single, pred_maps, target_maps_list, neg_maps_list)
+
+ return dict(
+ loss_cls=losses_cls,
+ loss_conf=losses_conf,
+ loss_xy=losses_xy,
+ loss_wh=losses_wh)
+
+ def loss_single(self, pred_map, target_map, neg_map):
+ """Compute loss of a single image from a batch.
+
+ Args:
+ pred_map (Tensor): Raw predictions for a single level.
+ target_map (Tensor): The Ground-Truth target for a single level.
+ neg_map (Tensor): The negative masks for a single level.
+
+ Returns:
+ tuple:
+ loss_cls (Tensor): Classification loss.
+ loss_conf (Tensor): Confidence loss.
+ loss_xy (Tensor): Regression loss of x, y coordinate.
+ loss_wh (Tensor): Regression loss of w, h coordinate.
+ """
+
+ num_imgs = len(pred_map)
+ pred_map = pred_map.permute(0, 2, 3,
+ 1).reshape(num_imgs, -1, self.num_attrib)
+ neg_mask = neg_map.float()
+ pos_mask = target_map[..., 4]
+ pos_and_neg_mask = neg_mask + pos_mask
+ pos_mask = pos_mask.unsqueeze(dim=-1)
+ if torch.max(pos_and_neg_mask) > 1.:
+ warnings.warn('There is overlap between pos and neg sample.')
+ pos_and_neg_mask = pos_and_neg_mask.clamp(min=0., max=1.)
+
+ pred_xy = pred_map[..., :2]
+ pred_wh = pred_map[..., 2:4]
+ pred_conf = pred_map[..., 4]
+ pred_label = pred_map[..., 5:]
+
+ target_xy = target_map[..., :2]
+ target_wh = target_map[..., 2:4]
+ target_conf = target_map[..., 4]
+ target_label = target_map[..., 5:]
+
+ loss_cls = self.loss_cls(pred_label, target_label, weight=pos_mask)
+ loss_conf = self.loss_conf(
+ pred_conf, target_conf, weight=pos_and_neg_mask)
+ loss_xy = self.loss_xy(pred_xy, target_xy, weight=pos_mask)
+ loss_wh = self.loss_wh(pred_wh, target_wh, weight=pos_mask)
+
+ return loss_cls, loss_conf, loss_xy, loss_wh
+
+ def get_targets(self, anchor_list, responsible_flag_list, gt_bboxes_list,
+ gt_labels_list):
+ """Compute target maps for anchors in multiple images.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_total_anchors, 4).
+ responsible_flag_list (list[list[Tensor]]): Multi level responsible
+ flags of each image. Each element is a tensor of shape
+ (num_total_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+ - target_map_list (list[Tensor]): Target map of each level.
+ - neg_map_list (list[Tensor]): Negative map of each level.
+ """
+ num_imgs = len(anchor_list)
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+
+ results = multi_apply(self._get_targets_single, anchor_list,
+ responsible_flag_list, gt_bboxes_list,
+ gt_labels_list)
+
+ all_target_maps, all_neg_maps = results
+ assert num_imgs == len(all_target_maps) == len(all_neg_maps)
+ target_maps_list = images_to_levels(all_target_maps, num_level_anchors)
+ neg_maps_list = images_to_levels(all_neg_maps, num_level_anchors)
+
+ return target_maps_list, neg_maps_list
+
+ def _get_targets_single(self, anchors, responsible_flags, gt_bboxes,
+ gt_labels):
+ """Generate matching bounding box prior and converted GT.
+
+ Args:
+ anchors (list[Tensor]): Multi-level anchors of the image.
+ responsible_flags (list[Tensor]): Multi-level responsible flags of
+ anchors
+ gt_bboxes (Tensor): Ground truth bboxes of single image.
+ gt_labels (Tensor): Ground truth labels of single image.
+
+ Returns:
+ tuple:
+ target_map (Tensor): Predication target map of each
+ scale level, shape (num_total_anchors,
+ 5+num_classes)
+ neg_map (Tensor): Negative map of each scale level,
+ shape (num_total_anchors,)
+ """
+
+ anchor_strides = []
+ for i in range(len(anchors)):
+ anchor_strides.append(
+ torch.tensor(self.featmap_strides[i],
+ device=gt_bboxes.device).repeat(len(anchors[i])))
+ concat_anchors = torch.cat(anchors)
+ concat_responsible_flags = torch.cat(responsible_flags)
+
+ anchor_strides = torch.cat(anchor_strides)
+ assert len(anchor_strides) == len(concat_anchors) == \
+ len(concat_responsible_flags)
+ assign_result = self.assigner.assign(concat_anchors,
+ concat_responsible_flags,
+ gt_bboxes)
+ sampling_result = self.sampler.sample(assign_result, concat_anchors,
+ gt_bboxes)
+
+ target_map = concat_anchors.new_zeros(
+ concat_anchors.size(0), self.num_attrib)
+
+ target_map[sampling_result.pos_inds, :4] = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes,
+ anchor_strides[sampling_result.pos_inds])
+
+ target_map[sampling_result.pos_inds, 4] = 1
+
+ gt_labels_one_hot = F.one_hot(
+ gt_labels, num_classes=self.num_classes).float()
+ if self.one_hot_smoother != 0: # label smooth
+ gt_labels_one_hot = gt_labels_one_hot * (
+ 1 - self.one_hot_smoother
+ ) + self.one_hot_smoother / self.num_classes
+ target_map[sampling_result.pos_inds, 5:] = gt_labels_one_hot[
+ sampling_result.pos_assigned_gt_inds]
+
+ neg_map = concat_anchors.new_zeros(
+ concat_anchors.size(0), dtype=torch.uint8)
+ neg_map[sampling_result.neg_inds] = 1
+
+ return target_map, neg_map
+
+ def aug_test(self, feats, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ feats (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains features for all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[ndarray]: bbox results of each class
+ """
+ return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
+
+ @force_fp32(apply_to=('pred_maps'))
+ def onnx_export(self, pred_maps, img_metas, with_nms=True):
+ num_levels = len(pred_maps)
+ pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)]
+
+ cfg = self.test_cfg
+ assert len(pred_maps_list) == self.num_levels
+
+ device = pred_maps_list[0].device
+ batch_size = pred_maps_list[0].shape[0]
+
+ featmap_sizes = [
+ pred_maps_list[i].shape[-2:] for i in range(self.num_levels)
+ ]
+ mlvl_anchors = self.prior_generator.grid_priors(
+ featmap_sizes, device=device)
+ # convert to tensor to keep tracing
+ nms_pre_tensor = torch.tensor(
+ cfg.get('nms_pre', -1), device=device, dtype=torch.long)
+
+ multi_lvl_bboxes = []
+ multi_lvl_cls_scores = []
+ multi_lvl_conf_scores = []
+ for i in range(self.num_levels):
+ # get some key info for current scale
+ pred_map = pred_maps_list[i]
+ stride = self.featmap_strides[i]
+ # (b,h, w, num_anchors*num_attrib) ->
+ # (b,h*w*num_anchors, num_attrib)
+ pred_map = pred_map.permute(0, 2, 3,
+ 1).reshape(batch_size, -1,
+ self.num_attrib)
+ # Inplace operation like
+ # ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])```
+ # would create constant tensor when exporting to onnx
+ pred_map_conf = torch.sigmoid(pred_map[..., :2])
+ pred_map_rest = pred_map[..., 2:]
+ pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=-1)
+ pred_map_boxes = pred_map[..., :4]
+ multi_lvl_anchor = mlvl_anchors[i]
+ multi_lvl_anchor = multi_lvl_anchor.expand_as(pred_map_boxes)
+ bbox_pred = self.bbox_coder.decode(multi_lvl_anchor,
+ pred_map_boxes, stride)
+ # conf and cls
+ conf_pred = torch.sigmoid(pred_map[..., 4])
+ cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
+ batch_size, -1, self.num_classes) # Cls pred one-hot.
+
+ # Get top-k prediction
+ from mmdet.core.export import get_k_for_topk
+ nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1])
+ if nms_pre > 0:
+ _, topk_inds = conf_pred.topk(nms_pre)
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds).long()
+ # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501
+ transformed_inds = (
+ bbox_pred.shape[1] * batch_inds + topk_inds)
+ bbox_pred = bbox_pred.reshape(-1,
+ 4)[transformed_inds, :].reshape(
+ batch_size, -1, 4)
+ cls_pred = cls_pred.reshape(
+ -1, self.num_classes)[transformed_inds, :].reshape(
+ batch_size, -1, self.num_classes)
+ conf_pred = conf_pred.reshape(-1, 1)[transformed_inds].reshape(
+ batch_size, -1)
+
+ # Save the result of current scale
+ multi_lvl_bboxes.append(bbox_pred)
+ multi_lvl_cls_scores.append(cls_pred)
+ multi_lvl_conf_scores.append(conf_pred)
+
+ # Merge the results of different scales together
+ batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1)
+ batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1)
+ batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1)
+
+ # Replace multiclass_nms with ONNX::NonMaxSuppression in deployment
+ from mmdet.core.export import add_dummy_nms_for_onnx
+ conf_thr = cfg.get('conf_thr', -1)
+ score_thr = cfg.get('score_thr', -1)
+ # follow original pipeline of YOLOv3
+ if conf_thr > 0:
+ mask = (batch_mlvl_conf_scores >= conf_thr).float()
+ batch_mlvl_conf_scores *= mask
+ if score_thr > 0:
+ mask = (batch_mlvl_scores > score_thr).float()
+ batch_mlvl_scores *= mask
+ batch_mlvl_conf_scores = batch_mlvl_conf_scores.unsqueeze(2).expand_as(
+ batch_mlvl_scores)
+ batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_conf_scores
+ if with_nms:
+ max_output_boxes_per_class = cfg.nms.get(
+ 'max_output_boxes_per_class', 200)
+ iou_threshold = cfg.nms.get('iou_threshold', 0.5)
+ # keep aligned with original pipeline, improve
+ # mAP by 1% for YOLOv3 in ONNX
+ score_threshold = 0
+ nms_pre = cfg.get('deploy_nms_pre', -1)
+ return add_dummy_nms_for_onnx(
+ batch_mlvl_bboxes,
+ batch_mlvl_scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ nms_pre,
+ cfg.max_per_img,
+ )
+ else:
+ return batch_mlvl_bboxes, batch_mlvl_scores
diff --git a/mmdet/models/dense_heads/yolof_head.py b/mmdet/models/dense_heads/yolof_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1063524a7d17f2bb037ca64c35f5ce3e658771eb
--- /dev/null
+++ b/mmdet/models/dense_heads/yolof_head.py
@@ -0,0 +1,416 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import (ConvModule, bias_init_with_prob, constant_init, is_norm,
+ normal_init)
+from mmcv.runner import force_fp32
+
+from mmdet.core import anchor_inside_flags, multi_apply, reduce_mean, unmap
+from ..builder import HEADS
+from .anchor_head import AnchorHead
+
+INF = 1e8
+
+
+def levels_to_images(mlvl_tensor):
+ """Concat multi-level feature maps by image.
+
+ [feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
+ Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
+ (N, H*W , C), then split the element to N elements with shape (H*W, C), and
+ concat elements in same image of all level along first dimension.
+
+ Args:
+ mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from
+ corresponding level. Each element is of shape (N, C, H, W)
+
+ Returns:
+ list[torch.Tensor]: A list that contains N tensors and each tensor is
+ of shape (num_elements, C)
+ """
+ batch_size = mlvl_tensor[0].size(0)
+ batch_list = [[] for _ in range(batch_size)]
+ channels = mlvl_tensor[0].size(1)
+ for t in mlvl_tensor:
+ t = t.permute(0, 2, 3, 1)
+ t = t.view(batch_size, -1, channels).contiguous()
+ for img in range(batch_size):
+ batch_list[img].append(t[img])
+ return [torch.cat(item, 0) for item in batch_list]
+
+
+@HEADS.register_module()
+class YOLOFHead(AnchorHead):
+ """YOLOFHead Paper link: https://arxiv.org/abs/2103.09460.
+
+ Args:
+ num_classes (int): The number of object classes (w/o background)
+ in_channels (List[int]): The number of input channels per scale.
+ cls_num_convs (int): The number of convolutions of cls branch.
+ Default 2.
+ reg_num_convs (int): The number of convolutions of reg branch.
+ Default 4.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ num_cls_convs=2,
+ num_reg_convs=4,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ **kwargs):
+ self.num_cls_convs = num_cls_convs
+ self.num_reg_convs = num_reg_convs
+ self.norm_cfg = norm_cfg
+ super(YOLOFHead, self).__init__(num_classes, in_channels, **kwargs)
+
+ def _init_layers(self):
+ cls_subnet = []
+ bbox_subnet = []
+ for i in range(self.num_cls_convs):
+ cls_subnet.append(
+ ConvModule(
+ self.in_channels,
+ self.in_channels,
+ kernel_size=3,
+ padding=1,
+ norm_cfg=self.norm_cfg))
+ for i in range(self.num_reg_convs):
+ bbox_subnet.append(
+ ConvModule(
+ self.in_channels,
+ self.in_channels,
+ kernel_size=3,
+ padding=1,
+ norm_cfg=self.norm_cfg))
+ self.cls_subnet = nn.Sequential(*cls_subnet)
+ self.bbox_subnet = nn.Sequential(*bbox_subnet)
+ self.cls_score = nn.Conv2d(
+ self.in_channels,
+ self.num_base_priors * self.num_classes,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.bbox_pred = nn.Conv2d(
+ self.in_channels,
+ self.num_base_priors * 4,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.object_pred = nn.Conv2d(
+ self.in_channels,
+ self.num_base_priors,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, mean=0, std=0.01)
+ if is_norm(m):
+ constant_init(m, 1)
+
+ # Use prior in model initialization to improve stability
+ bias_cls = bias_init_with_prob(0.01)
+ torch.nn.init.constant_(self.cls_score.bias, bias_cls)
+
+ def forward_single(self, feature):
+ cls_score = self.cls_score(self.cls_subnet(feature))
+ N, _, H, W = cls_score.shape
+ cls_score = cls_score.view(N, -1, self.num_classes, H, W)
+
+ reg_feat = self.bbox_subnet(feature)
+ bbox_reg = self.bbox_pred(reg_feat)
+ objectness = self.object_pred(reg_feat)
+
+ # implicit objectness
+ objectness = objectness.view(N, -1, 1, H, W)
+ normalized_cls_score = cls_score + objectness - torch.log(
+ 1. + torch.clamp(cls_score.exp(), max=INF) +
+ torch.clamp(objectness.exp(), max=INF))
+ normalized_cls_score = normalized_cls_score.view(N, -1, H, W)
+ return normalized_cls_score, bbox_reg
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (batch, num_anchors * num_classes, h, w)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (batch, num_anchors * 4, h, w)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert len(cls_scores) == 1
+ assert self.prior_generator.num_levels == 1
+
+ device = cls_scores[0].device
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+
+ # The output level is always 1
+ anchor_list = [anchors[0] for anchors in anchor_list]
+ valid_flag_list = [valid_flags[0] for valid_flags in valid_flag_list]
+
+ cls_scores_list = levels_to_images(cls_scores)
+ bbox_preds_list = levels_to_images(bbox_preds)
+
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ cls_scores_list,
+ bbox_preds_list,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (batch_labels, batch_label_weights, num_total_pos, num_total_neg,
+ batch_bbox_weights, batch_pos_predicted_boxes,
+ batch_target_boxes) = cls_reg_targets
+
+ flatten_labels = batch_labels.reshape(-1)
+ batch_label_weights = batch_label_weights.reshape(-1)
+ cls_score = cls_scores[0].permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+
+ num_total_samples = (num_total_pos +
+ num_total_neg) if self.sampling else num_total_pos
+ num_total_samples = reduce_mean(
+ cls_score.new_tensor(num_total_samples)).clamp_(1.0).item()
+
+ # classification loss
+ loss_cls = self.loss_cls(
+ cls_score,
+ flatten_labels,
+ batch_label_weights,
+ avg_factor=num_total_samples)
+
+ # regression loss
+ if batch_pos_predicted_boxes.shape[0] == 0:
+ # no pos sample
+ loss_bbox = batch_pos_predicted_boxes.sum() * 0
+ else:
+ loss_bbox = self.loss_bbox(
+ batch_pos_predicted_boxes,
+ batch_target_boxes,
+ batch_bbox_weights.float(),
+ avg_factor=num_total_samples)
+
+ return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
+
+ def get_targets(self,
+ cls_scores_list,
+ bbox_preds_list,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in
+ multiple images.
+
+ Args:
+ cls_scores_list (list[Tensor]): Classification scores of
+ each image. each is a 4D-tensor, the shape is
+ (h * w, num_anchors * num_classes).
+ bbox_preds_list (list[Tensor]): Bbox preds of each image.
+ each is a 4D-tensor, the shape is (h * w, num_anchors * 4).
+ anchor_list (list[Tensor]): Anchors of each image. Each element of
+ is a tensor of shape (h * w * num_anchors, 4).
+ valid_flag_list (list[Tensor]): Valid flags of each image. Each
+ element of is a tensor of shape (h * w * num_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - batch_labels (Tensor): Label of all images. Each element \
+ of is a tensor of shape (batch, h * w * num_anchors)
+ - batch_label_weights (Tensor): Label weights of all images \
+ of is a tensor of shape (batch, h * w * num_anchors)
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+ additional_returns: This function enables user-defined returns from
+ `self._get_targets_single`. These returns are currently refined
+ to properties at each feature map (i.e. having HxW dimension).
+ The results will be concatenated after the end
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ results = multi_apply(
+ self._get_targets_single,
+ bbox_preds_list,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ (all_labels, all_label_weights, pos_inds_list, neg_inds_list,
+ sampling_results_list) = results[:5]
+ rest_results = list(results[5:]) # user-added return values
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+
+ batch_labels = torch.stack(all_labels, 0)
+ batch_label_weights = torch.stack(all_label_weights, 0)
+
+ res = (batch_labels, batch_label_weights, num_total_pos, num_total_neg)
+ for i, rests in enumerate(rest_results): # user-added return values
+ rest_results[i] = torch.cat(rests, 0)
+
+ return res + tuple(rest_results)
+
+ def _get_targets_single(self,
+ bbox_preds,
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ Args:
+ bbox_preds (Tensor): Bbox prediction of the image, which
+ shape is (h * w ,4)
+ flat_anchors (Tensor): Anchors of the image, which shape is
+ (h * w * num_anchors ,4)
+ valid_flags (Tensor): Valid flags of the image, which shape is
+ (h * w * num_anchors,).
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ img_meta (dict): Meta info of the image.
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple:
+ labels (Tensor): Labels of image, which shape is
+ (h * w * num_anchors, ).
+ label_weights (Tensor): Label weights of image, which shape is
+ (h * w * num_anchors, ).
+ pos_inds (Tensor): Pos index of image.
+ neg_inds (Tensor): Neg index of image.
+ sampling_result (obj:`SamplingResult`): Sampling result.
+ pos_bbox_weights (Tensor): The Weight of using to calculate
+ the bbox branch loss, which shape is (num, ).
+ pos_predicted_boxes (Tensor): boxes predicted value of
+ using to calculate the bbox branch loss, which shape is
+ (num, 4).
+ pos_target_boxes (Tensor): boxes target value of
+ using to calculate the bbox branch loss, which shape is
+ (num, 4).
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 8
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+ bbox_preds = bbox_preds.reshape(-1, 4)
+ bbox_preds = bbox_preds[inside_flags, :]
+
+ # decoded bbox
+ decoder_bbox_preds = self.bbox_coder.decode(anchors, bbox_preds)
+ assign_result = self.assigner.assign(
+ decoder_bbox_preds, anchors, gt_bboxes, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+
+ pos_bbox_weights = assign_result.get_extra_property('pos_idx')
+ pos_predicted_boxes = assign_result.get_extra_property(
+ 'pos_predicted_boxes')
+ pos_target_boxes = assign_result.get_extra_property('target_boxes')
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+ num_valid_anchors = anchors.shape[0]
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags,
+ fill=self.num_classes) # fill bg label
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+
+ return (labels, label_weights, pos_inds, neg_inds, sampling_result,
+ pos_bbox_weights, pos_predicted_boxes, pos_target_boxes)
diff --git a/mmdet/models/dense_heads/yolox_head.py b/mmdet/models/dense_heads/yolox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..f317e14760b2948609309016e6b4a87eae2e26a8
--- /dev/null
+++ b/mmdet/models/dense_heads/yolox_head.py
@@ -0,0 +1,493 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
+ bias_init_with_prob)
+from mmcv.ops.nms import batched_nms
+from mmcv.runner import force_fp32
+
+from mmdet.core import (MlvlPointGenerator, bbox_xyxy_to_cxcywh,
+ build_assigner, build_sampler, multi_apply,
+ reduce_mean)
+from ..builder import HEADS, build_loss
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+
+
+@HEADS.register_module()
+class YOLOXHead(BaseDenseHead, BBoxTestMixin):
+ """YOLOXHead head used in `YOLOX `_.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels in stacking convs.
+ Default: 256
+ stacked_convs (int): Number of stacking convs of the head.
+ Default: 2.
+ strides (tuple): Downsample factor of each feature map.
+ use_depthwise (bool): Whether to depthwise separable convolution in
+ blocks. Default: False
+ dcn_on_last_conv (bool): If true, use dcn in the last layer of
+ towers. Default: False.
+ conv_bias (bool | str): If specified as `auto`, it will be decided by
+ the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
+ None, otherwise False. Default: "auto".
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer. Default: None.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ loss_obj (dict): Config of objectness loss.
+ loss_l1 (dict): Config of L1 loss.
+ train_cfg (dict): Training config of anchor head.
+ test_cfg (dict): Testing config of anchor head.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ stacked_convs=2,
+ strides=[8, 16, 32],
+ use_depthwise=False,
+ dcn_on_last_conv=False,
+ conv_bias='auto',
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish'),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ reduction='sum',
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='IoULoss',
+ mode='square',
+ eps=1e-16,
+ reduction='sum',
+ loss_weight=5.0),
+ loss_obj=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ reduction='sum',
+ loss_weight=1.0),
+ loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0),
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=dict(
+ type='Kaiming',
+ layer='Conv2d',
+ a=math.sqrt(5),
+ distribution='uniform',
+ mode='fan_in',
+ nonlinearity='leaky_relu')):
+
+ super().__init__(init_cfg=init_cfg)
+ self.num_classes = num_classes
+ self.cls_out_channels = num_classes
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.stacked_convs = stacked_convs
+ self.strides = strides
+ self.use_depthwise = use_depthwise
+ self.dcn_on_last_conv = dcn_on_last_conv
+ assert conv_bias == 'auto' or isinstance(conv_bias, bool)
+ self.conv_bias = conv_bias
+ self.use_sigmoid_cls = True
+
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.loss_obj = build_loss(loss_obj)
+
+ self.use_l1 = False # This flag will be modified by hooks.
+ self.loss_l1 = build_loss(loss_l1)
+
+ self.prior_generator = MlvlPointGenerator(strides, offset=0)
+
+ self.test_cfg = test_cfg
+ self.train_cfg = train_cfg
+
+ self.sampling = False
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # sampling=False so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.fp16_enabled = False
+ self._init_layers()
+
+ def _init_layers(self):
+ self.multi_level_cls_convs = nn.ModuleList()
+ self.multi_level_reg_convs = nn.ModuleList()
+ self.multi_level_conv_cls = nn.ModuleList()
+ self.multi_level_conv_reg = nn.ModuleList()
+ self.multi_level_conv_obj = nn.ModuleList()
+ for _ in self.strides:
+ self.multi_level_cls_convs.append(self._build_stacked_convs())
+ self.multi_level_reg_convs.append(self._build_stacked_convs())
+ conv_cls, conv_reg, conv_obj = self._build_predictor()
+ self.multi_level_conv_cls.append(conv_cls)
+ self.multi_level_conv_reg.append(conv_reg)
+ self.multi_level_conv_obj.append(conv_obj)
+
+ def _build_stacked_convs(self):
+ """Initialize conv layers of a single level head."""
+ conv = DepthwiseSeparableConvModule \
+ if self.use_depthwise else ConvModule
+ stacked_convs = []
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ if self.dcn_on_last_conv and i == self.stacked_convs - 1:
+ conv_cfg = dict(type='DCNv2')
+ else:
+ conv_cfg = self.conv_cfg
+ stacked_convs.append(
+ conv(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ bias=self.conv_bias))
+ return nn.Sequential(*stacked_convs)
+
+ def _build_predictor(self):
+ """Initialize predictor layers of a single level head."""
+ conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1)
+ conv_reg = nn.Conv2d(self.feat_channels, 4, 1)
+ conv_obj = nn.Conv2d(self.feat_channels, 1, 1)
+ return conv_cls, conv_reg, conv_obj
+
+ def init_weights(self):
+ super(YOLOXHead, self).init_weights()
+ # Use prior in model initialization to improve stability
+ bias_init = bias_init_with_prob(0.01)
+ for conv_cls, conv_obj in zip(self.multi_level_conv_cls,
+ self.multi_level_conv_obj):
+ conv_cls.bias.data.fill_(bias_init)
+ conv_obj.bias.data.fill_(bias_init)
+
+ def forward_single(self, x, cls_convs, reg_convs, conv_cls, conv_reg,
+ conv_obj):
+ """Forward feature of a single scale level."""
+
+ cls_feat = cls_convs(x)
+ reg_feat = reg_convs(x)
+
+ cls_score = conv_cls(cls_feat)
+ bbox_pred = conv_reg(reg_feat)
+ objectness = conv_obj(reg_feat)
+
+ return cls_score, bbox_pred, objectness
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ Returns:
+ tuple[Tensor]: A tuple of multi-level predication map, each is a
+ 4D-tensor of shape (batch_size, 5+num_classes, height, width).
+ """
+
+ return multi_apply(self.forward_single, feats,
+ self.multi_level_cls_convs,
+ self.multi_level_reg_convs,
+ self.multi_level_conv_cls,
+ self.multi_level_conv_reg,
+ self.multi_level_conv_obj)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'objectnesses'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ objectnesses,
+ img_metas=None,
+ cfg=None,
+ rescale=False,
+ with_nms=True):
+ """Transform network outputs of a batch into bbox results.
+ Args:
+ cls_scores (list[Tensor]): Classification scores for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for all
+ scale levels, each is a 4D-tensor, has shape
+ (batch_size, num_priors * 4, H, W).
+ objectnesses (list[Tensor], Optional): Score factor for
+ all scale level, each is a 4D-tensor, has shape
+ (batch_size, 1, H, W).
+ img_metas (list[dict], Optional): Image meta info. Default None.
+ cfg (mmcv.Config, Optional): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default None.
+ rescale (bool): If True, return boxes in original image space.
+ Default False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default True.
+ Returns:
+ list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where the first 4 columns
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the
+ 5-th column is a score between 0 and 1. The second item is a
+ (n,) tensor where each item is the predicted class label of
+ the corresponding box.
+ """
+ assert len(cls_scores) == len(bbox_preds) == len(objectnesses)
+ cfg = self.test_cfg if cfg is None else cfg
+ scale_factors = np.array(
+ [img_meta['scale_factor'] for img_meta in img_metas])
+
+ num_imgs = len(img_metas)
+ featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
+ mlvl_priors = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=cls_scores[0].dtype,
+ device=cls_scores[0].device,
+ with_stride=True)
+
+ # flatten cls_scores, bbox_preds and objectness
+ flatten_cls_scores = [
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
+ self.cls_out_channels)
+ for cls_score in cls_scores
+ ]
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ flatten_objectness = [
+ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
+ for objectness in objectnesses
+ ]
+
+ flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
+ flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
+ flatten_priors = torch.cat(mlvl_priors)
+
+ flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)
+
+ if rescale:
+ flatten_bboxes[..., :4] /= flatten_bboxes.new_tensor(
+ scale_factors).unsqueeze(1)
+
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_scores = flatten_cls_scores[img_id]
+ score_factor = flatten_objectness[img_id]
+ bboxes = flatten_bboxes[img_id]
+
+ result_list.append(
+ self._bboxes_nms(cls_scores, bboxes, score_factor, cfg))
+
+ return result_list
+
+ def _bbox_decode(self, priors, bbox_preds):
+ xys = (bbox_preds[..., :2] * priors[:, 2:]) + priors[:, :2]
+ whs = bbox_preds[..., 2:].exp() * priors[:, 2:]
+
+ tl_x = (xys[..., 0] - whs[..., 0] / 2)
+ tl_y = (xys[..., 1] - whs[..., 1] / 2)
+ br_x = (xys[..., 0] + whs[..., 0] / 2)
+ br_y = (xys[..., 1] + whs[..., 1] / 2)
+
+ decoded_bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
+ return decoded_bboxes
+
+ def _bboxes_nms(self, cls_scores, bboxes, score_factor, cfg):
+ max_scores, labels = torch.max(cls_scores, 1)
+ valid_mask = score_factor * max_scores >= cfg.score_thr
+
+ bboxes = bboxes[valid_mask]
+ scores = max_scores[valid_mask] * score_factor[valid_mask]
+ labels = labels[valid_mask]
+
+ if labels.numel() == 0:
+ return bboxes, labels
+ else:
+ dets, keep = batched_nms(bboxes, scores, labels, cfg.nms)
+ return dets, labels[keep]
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'objectnesses'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ objectnesses,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level,
+ each is a 4D-tensor, the channel number is
+ num_priors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_priors * 4.
+ objectnesses (list[Tensor], Optional): Score factor for
+ all scale level, each is a 4D-tensor, has shape
+ (batch_size, 1, H, W).
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ """
+ num_imgs = len(img_metas)
+ featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
+ mlvl_priors = self.prior_generator.grid_priors(
+ featmap_sizes,
+ dtype=cls_scores[0].dtype,
+ device=cls_scores[0].device,
+ with_stride=True)
+
+ flatten_cls_preds = [
+ cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
+ self.cls_out_channels)
+ for cls_pred in cls_scores
+ ]
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ flatten_objectness = [
+ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
+ for objectness in objectnesses
+ ]
+
+ flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
+ flatten_objectness = torch.cat(flatten_objectness, dim=1)
+ flatten_priors = torch.cat(mlvl_priors)
+ flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)
+
+ (pos_masks, cls_targets, obj_targets, bbox_targets, l1_targets,
+ num_fg_imgs) = multi_apply(
+ self._get_target_single, flatten_cls_preds.detach(),
+ flatten_objectness.detach(),
+ flatten_priors.unsqueeze(0).repeat(num_imgs, 1, 1),
+ flatten_bboxes.detach(), gt_bboxes, gt_labels)
+
+ # The experimental results show that ‘reduce_mean’ can improve
+ # performance on the COCO dataset.
+ num_pos = torch.tensor(
+ sum(num_fg_imgs),
+ dtype=torch.float,
+ device=flatten_cls_preds.device)
+ num_total_samples = max(reduce_mean(num_pos), 1.0)
+
+ pos_masks = torch.cat(pos_masks, 0)
+ cls_targets = torch.cat(cls_targets, 0)
+ obj_targets = torch.cat(obj_targets, 0)
+ bbox_targets = torch.cat(bbox_targets, 0)
+ if self.use_l1:
+ l1_targets = torch.cat(l1_targets, 0)
+
+ loss_bbox = self.loss_bbox(
+ flatten_bboxes.view(-1, 4)[pos_masks],
+ bbox_targets) / num_total_samples
+ loss_obj = self.loss_obj(flatten_objectness.view(-1, 1),
+ obj_targets) / num_total_samples
+ loss_cls = self.loss_cls(
+ flatten_cls_preds.view(-1, self.num_classes)[pos_masks],
+ cls_targets) / num_total_samples
+
+ loss_dict = dict(
+ loss_cls=loss_cls, loss_bbox=loss_bbox, loss_obj=loss_obj)
+
+ if self.use_l1:
+ loss_l1 = self.loss_l1(
+ flatten_bbox_preds.view(-1, 4)[pos_masks],
+ l1_targets) / num_total_samples
+ loss_dict.update(loss_l1=loss_l1)
+
+ return loss_dict
+
+ @torch.no_grad()
+ def _get_target_single(self, cls_preds, objectness, priors, decoded_bboxes,
+ gt_bboxes, gt_labels):
+ """Compute classification, regression, and objectness targets for
+ priors in a single image.
+ Args:
+ cls_preds (Tensor): Classification predictions of one image,
+ a 2D-Tensor with shape [num_priors, num_classes]
+ objectness (Tensor): Objectness predictions of one image,
+ a 1D-Tensor with shape [num_priors]
+ priors (Tensor): All priors of one image, a 2D-Tensor with shape
+ [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
+ decoded_bboxes (Tensor): Decoded bboxes predictions of one image,
+ a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y,
+ br_x, br_y] format.
+ gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
+ with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (Tensor): Ground truth labels of one image, a Tensor
+ with shape [num_gts].
+ """
+
+ num_priors = priors.size(0)
+ num_gts = gt_labels.size(0)
+ gt_bboxes = gt_bboxes.to(decoded_bboxes.dtype)
+ # No target
+ if num_gts == 0:
+ cls_target = cls_preds.new_zeros((0, self.num_classes))
+ bbox_target = cls_preds.new_zeros((0, 4))
+ l1_target = cls_preds.new_zeros((0, 4))
+ obj_target = cls_preds.new_zeros((num_priors, 1))
+ foreground_mask = cls_preds.new_zeros(num_priors).bool()
+ return (foreground_mask, cls_target, obj_target, bbox_target,
+ l1_target, 0)
+
+ # YOLOX uses center priors with 0.5 offset to assign targets,
+ # but use center priors without offset to regress bboxes.
+ offset_priors = torch.cat(
+ [priors[:, :2] + priors[:, 2:] * 0.5, priors[:, 2:]], dim=-1)
+
+ assign_result = self.assigner.assign(
+ cls_preds.sigmoid() * objectness.unsqueeze(1).sigmoid(),
+ offset_priors, decoded_bboxes, gt_bboxes, gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, priors, gt_bboxes)
+ pos_inds = sampling_result.pos_inds
+ num_pos_per_img = pos_inds.size(0)
+
+ pos_ious = assign_result.max_overlaps[pos_inds]
+ # IOU aware classification score
+ cls_target = F.one_hot(sampling_result.pos_gt_labels,
+ self.num_classes) * pos_ious.unsqueeze(-1)
+ obj_target = torch.zeros_like(objectness).unsqueeze(-1)
+ obj_target[pos_inds] = 1
+ bbox_target = sampling_result.pos_gt_bboxes
+ l1_target = cls_preds.new_zeros((num_pos_per_img, 4))
+ if self.use_l1:
+ l1_target = self._get_l1_target(l1_target, bbox_target,
+ priors[pos_inds])
+ foreground_mask = torch.zeros_like(objectness).to(torch.bool)
+ foreground_mask[pos_inds] = 1
+ return (foreground_mask, cls_target, obj_target, bbox_target,
+ l1_target, num_pos_per_img)
+
+ def _get_l1_target(self, l1_target, gt_bboxes, priors, eps=1e-8):
+ """Convert gt bboxes to center offset and log width height."""
+ gt_cxcywh = bbox_xyxy_to_cxcywh(gt_bboxes)
+ l1_target[:, :2] = (gt_cxcywh[:, :2] - priors[:, :2]) / priors[:, 2:]
+ l1_target[:, 2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
+ return l1_target
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0a89b87ece2271f1d769413a2712a7bcf3c8620
--- /dev/null
+++ b/mmdet/models/detectors/__init__.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .atss import ATSS
+from .autoassign import AutoAssign
+from .base import BaseDetector
+from .cascade_rcnn import CascadeRCNN
+from .centernet import CenterNet
+from .cornernet import CornerNet
+from .ddod import DDOD
+from .deformable_detr import DeformableDETR
+from .detr import DETR
+from .fast_rcnn import FastRCNN
+from .faster_rcnn import FasterRCNN
+from .fcos import FCOS
+from .fovea import FOVEA
+from .fsaf import FSAF
+from .gfl import GFL
+from .grid_rcnn import GridRCNN
+from .htc import HybridTaskCascade
+from .kd_one_stage import KnowledgeDistillationSingleStageDetector
+from .lad import LAD
+from .mask2former import Mask2Former
+from .mask_rcnn import MaskRCNN
+from .mask_scoring_rcnn import MaskScoringRCNN
+from .maskformer import MaskFormer
+from .nasfcos import NASFCOS
+from .paa import PAA
+from .panoptic_fpn import PanopticFPN
+from .panoptic_two_stage_segmentor import TwoStagePanopticSegmentor
+from .point_rend import PointRend
+from .queryinst import QueryInst
+from .reppoints_detector import RepPointsDetector
+from .retinanet import RetinaNet
+from .rpn import RPN
+from .scnet import SCNet
+from .single_stage import SingleStageDetector
+from .solo import SOLO
+from .solov2 import SOLOv2
+from .sparse_rcnn import SparseRCNN
+from .tood import TOOD
+from .trident_faster_rcnn import TridentFasterRCNN
+from .two_stage import TwoStageDetector
+from .vfnet import VFNet
+from .yolact import YOLACT
+from .yolo import YOLOV3
+from .yolof import YOLOF
+from .yolox import YOLOX
+
+__all__ = [
+ 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN',
+ 'KnowledgeDistillationSingleStageDetector', 'FastRCNN', 'FasterRCNN',
+ 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'RetinaNet', 'FCOS',
+ 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 'FOVEA', 'FSAF',
+ 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT',
+ 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO',
+ 'SOLOv2', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX',
+ 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD',
+ 'MaskFormer', 'DDOD', 'Mask2Former'
+]
diff --git a/mmdet/models/detectors/atss.py b/mmdet/models/detectors/atss.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f1acd9a1595ecea0fd7a19ccd63cd991130657
--- /dev/null
+++ b/mmdet/models/detectors/atss.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class ATSS(SingleStageDetector):
+ """Implementation of `ATSS `_."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(ATSS, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/autoassign.py b/mmdet/models/detectors/autoassign.py
new file mode 100644
index 0000000000000000000000000000000000000000..30ab72075807fbe565ede7e15bbf5ad1ebbec001
--- /dev/null
+++ b/mmdet/models/detectors/autoassign.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class AutoAssign(SingleStageDetector):
+ """Implementation of `AutoAssign: Differentiable Label Assignment for Dense
+ Object Detection `_."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(AutoAssign, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f87097b1b86dee94f5cde10a31948593e93516a1
--- /dev/null
+++ b/mmdet/models/detectors/base.py
@@ -0,0 +1,365 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.runner import BaseModule, auto_fp16
+
+from mmdet.core.visualization import imshow_det_bboxes
+
+
+class BaseDetector(BaseModule, metaclass=ABCMeta):
+ """Base class for detectors."""
+
+ def __init__(self, init_cfg=None):
+ super(BaseDetector, self).__init__(init_cfg)
+ self.fp16_enabled = False
+
+ @property
+ def with_neck(self):
+ """bool: whether the detector has a neck"""
+ return hasattr(self, 'neck') and self.neck is not None
+
+ # TODO: these properties need to be carefully handled
+ # for both single stage & two stage detectors
+ @property
+ def with_shared_head(self):
+ """bool: whether the detector has a shared head in the RoI Head"""
+ return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
+
+ @property
+ def with_bbox(self):
+ """bool: whether the detector has a bbox head"""
+ return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
+ or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
+
+ @property
+ def with_mask(self):
+ """bool: whether the detector has a mask head"""
+ return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
+ or (hasattr(self, 'mask_head') and self.mask_head is not None))
+
+ @abstractmethod
+ def extract_feat(self, imgs):
+ """Extract features from images."""
+ pass
+
+ def extract_feats(self, imgs):
+ """Extract features from multiple images.
+
+ Args:
+ imgs (list[torch.Tensor]): A list of images. The images are
+ augmented from the same image but in different ways.
+
+ Returns:
+ list[torch.Tensor]: Features of different images
+ """
+ assert isinstance(imgs, list)
+ return [self.extract_feat(img) for img in imgs]
+
+ def forward_train(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys, see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ kwargs (keyword arguments): Specific to concrete implementation.
+ """
+ # NOTE the batched image size information may be useful, e.g.
+ # in DETR, this is needed for the construction of masks, which is
+ # then used for the transformer_head.
+ batch_input_shape = tuple(imgs[0].size()[-2:])
+ for img_meta in img_metas:
+ img_meta['batch_input_shape'] = batch_input_shape
+
+ async def async_simple_test(self, img, img_metas, **kwargs):
+ raise NotImplementedError
+
+ @abstractmethod
+ def simple_test(self, img, img_metas, **kwargs):
+ pass
+
+ @abstractmethod
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Test function with test time augmentation."""
+ pass
+
+ async def aforward_test(self, *, img, img_metas, **kwargs):
+ for var, name in [(img, 'img'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got {type(var)}')
+
+ num_augs = len(img)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(img)}) '
+ f'!= num of image metas ({len(img_metas)})')
+ # TODO: remove the restriction of samples_per_gpu == 1 when prepared
+ samples_per_gpu = img[0].size(0)
+ assert samples_per_gpu == 1
+
+ if num_augs == 1:
+ return await self.async_simple_test(img[0], img_metas[0], **kwargs)
+ else:
+ raise NotImplementedError
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got {type(var)}')
+
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) '
+ f'!= num of image meta ({len(img_metas)})')
+
+ # NOTE the batched image size information may be useful, e.g.
+ # in DETR, this is needed for the construction of masks, which is
+ # then used for the transformer_head.
+ for img, img_meta in zip(imgs, img_metas):
+ batch_size = len(img_meta)
+ for img_id in range(batch_size):
+ img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:])
+
+ if num_augs == 1:
+ # proposals (List[List[Tensor]]): the outer list indicates
+ # test-time augs (multiscale, flip, etc.) and the inner list
+ # indicates images in a batch.
+ # The Tensor should have a shape Px4, where P is the number of
+ # proposals.
+ if 'proposals' in kwargs:
+ kwargs['proposals'] = kwargs['proposals'][0]
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ assert imgs[0].size(0) == 1, 'aug test does not support ' \
+ 'inference with batch size ' \
+ f'{imgs[0].size(0)}'
+ # TODO: support test augmentation for predefined proposals
+ assert 'proposals' not in kwargs
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ @auto_fp16(apply_to=('img', ))
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if torch.onnx.is_in_onnx_export():
+ assert len(img_metas) == 1
+ return self.onnx_export(img[0], img_metas[0])
+
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def _parse_losses(self, losses):
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
+ which may be a weighted sum of all losses, log_vars contains \
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+
+ loss = sum(_value for _key, _value in log_vars.items()
+ if 'loss' in _key)
+
+ # If the loss_vars has different length, GPUs will wait infinitely
+ if dist.is_available() and dist.is_initialized():
+ log_var_length = torch.tensor(len(log_vars), device=loss.device)
+ dist.all_reduce(log_var_length)
+ message = (f'rank {dist.get_rank()}' +
+ f' len(log_vars): {len(log_vars)}' + ' keys: ' +
+ ','.join(log_vars.keys()))
+ assert log_var_length == len(log_vars) * dist.get_world_size(), \
+ 'loss log variables are different across GPUs!\n' + message
+
+ log_vars['loss'] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
+
+ def train_step(self, data, optimizer):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
+ ``num_samples``.
+
+ - ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ - ``log_vars`` contains all the variables to be sent to the
+ logger.
+ - ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
+
+ return outputs
+
+ def val_step(self, data, optimizer=None):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ losses = self(**data)
+ loss, log_vars = self._parse_losses(losses)
+
+ log_vars_ = dict()
+ for loss_name, loss_value in log_vars.items():
+ k = loss_name + '_val'
+ log_vars_[k] = loss_value
+
+ outputs = dict(
+ loss=loss, log_vars=log_vars_, num_samples=len(data['img_metas']))
+
+ return outputs
+
+ def show_result(self,
+ img,
+ result,
+ score_thr=0.3,
+ bbox_color=(72, 101, 241),
+ text_color=(72, 101, 241),
+ mask_color=None,
+ thickness=2,
+ font_size=13,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (Tensor or tuple): The results to draw over `img`
+ bbox_result or (bbox_result, segm_result).
+ score_thr (float, optional): Minimum score of bboxes to be shown.
+ Default: 0.3.
+ bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
+ The tuple of color should be in BGR order. Default: 'green'
+ text_color (str or tuple(int) or :obj:`Color`):Color of texts.
+ The tuple of color should be in BGR order. Default: 'green'
+ mask_color (None or str or tuple(int) or :obj:`Color`):
+ Color of masks. The tuple of color should be in BGR order.
+ Default: None
+ thickness (int): Thickness of lines. Default: 2
+ font_size (int): Font size of texts. Default: 13
+ win_name (str): The window name. Default: ''
+ wait_time (float): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+
+ Returns:
+ img (Tensor): Only if not `show` or `out_file`
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ if isinstance(result, tuple):
+ bbox_result, segm_result = result
+ if isinstance(segm_result, tuple):
+ segm_result = segm_result[0] # ms rcnn
+ else:
+ bbox_result, segm_result = result, None
+ bboxes = np.vstack(bbox_result)
+ labels = [
+ np.full(bbox.shape[0], i, dtype=np.int32)
+ for i, bbox in enumerate(bbox_result)
+ ]
+ labels = np.concatenate(labels)
+ # draw segmentation masks
+ segms = None
+ if segm_result is not None and len(labels) > 0: # non empty
+ segms = mmcv.concat_list(segm_result)
+ if isinstance(segms[0], torch.Tensor):
+ segms = torch.stack(segms, dim=0).detach().cpu().numpy()
+ else:
+ segms = np.stack(segms, axis=0)
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+ # draw bounding boxes
+ img = imshow_det_bboxes(
+ img,
+ bboxes,
+ labels,
+ segms,
+ class_names=self.CLASSES,
+ score_thr=score_thr,
+ bbox_color=bbox_color,
+ text_color=text_color,
+ mask_color=mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file)
+
+ if not (show or out_file):
+ return img
+
+ def onnx_export(self, img, img_metas):
+ raise NotImplementedError(f'{self.__class__.__name__} does '
+ f'not support ONNX EXPORT')
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8c738271d1c8bdc374a6deeab19902ad8d74b38
--- /dev/null
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -0,0 +1,49 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class CascadeRCNN(TwoStageDetector):
+ r"""Implementation of `Cascade R-CNN: Delving into High Quality Object
+ Detection `_"""
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ rpn_head=None,
+ roi_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(CascadeRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
+
+ def show_result(self, data, result, **kwargs):
+ """Show prediction results of the detector.
+
+ Args:
+ data (str or np.ndarray): Image filename or loaded image.
+ result (Tensor or tuple): The results to draw over `img`
+ bbox_result or (bbox_result, segm_result).
+
+ Returns:
+ np.ndarray: The image with bboxes drawn on it.
+ """
+ if self.with_mask:
+ ms_bbox_result, ms_segm_result = result
+ if isinstance(ms_bbox_result, dict):
+ result = (ms_bbox_result['ensemble'],
+ ms_segm_result['ensemble'])
+ else:
+ if isinstance(result, dict):
+ result = result['ensemble']
+ return super(CascadeRCNN, self).show_result(data, result, **kwargs)
diff --git a/mmdet/models/detectors/centernet.py b/mmdet/models/detectors/centernet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e3fd3ccd4d49832f7450ed359a0bbea13bf631
--- /dev/null
+++ b/mmdet/models/detectors/centernet.py
@@ -0,0 +1,111 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.core import bbox2result
+from mmdet.models.builder import DETECTORS
+from ...core.utils import flip_tensor
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class CenterNet(SingleStageDetector):
+ """Implementation of CenterNet(Objects as Points)
+
+ .
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(CenterNet, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
+
+ def merge_aug_results(self, aug_results, with_nms):
+ """Merge augmented detection bboxes and score.
+
+ Args:
+ aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each
+ image.
+ with_nms (bool): If True, do nms before return boxes.
+
+ Returns:
+ tuple: (out_bboxes, out_labels)
+ """
+ recovered_bboxes, aug_labels = [], []
+ for single_result in aug_results:
+ recovered_bboxes.append(single_result[0][0])
+ aug_labels.append(single_result[0][1])
+
+ bboxes = torch.cat(recovered_bboxes, dim=0).contiguous()
+ labels = torch.cat(aug_labels).contiguous()
+ if with_nms:
+ out_bboxes, out_labels = self.bbox_head._bboxes_nms(
+ bboxes, labels, self.bbox_head.test_cfg)
+ else:
+ out_bboxes, out_labels = bboxes, labels
+
+ return out_bboxes, out_labels
+
+ def aug_test(self, imgs, img_metas, rescale=True):
+ """Augment testing of CenterNet. Aug test must have flipped image pair,
+ and unlike CornerNet, it will perform an averaging operation on the
+ feature map instead of detecting bbox.
+
+ Args:
+ imgs (list[Tensor]): Augmented images.
+ img_metas (list[list[dict]]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: True.
+
+ Note:
+ ``imgs`` must including flipped image pairs.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ img_inds = list(range(len(imgs)))
+ assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], (
+ 'aug test must have flipped image pair')
+ aug_results = []
+ for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]):
+ flip_direction = img_metas[flip_ind][0]['flip_direction']
+ img_pair = torch.cat([imgs[ind], imgs[flip_ind]])
+ x = self.extract_feat(img_pair)
+ center_heatmap_preds, wh_preds, offset_preds = self.bbox_head(x)
+ assert len(center_heatmap_preds) == len(wh_preds) == len(
+ offset_preds) == 1
+
+ # Feature map averaging
+ center_heatmap_preds[0] = (
+ center_heatmap_preds[0][0:1] +
+ flip_tensor(center_heatmap_preds[0][1:2], flip_direction)) / 2
+ wh_preds[0] = (wh_preds[0][0:1] +
+ flip_tensor(wh_preds[0][1:2], flip_direction)) / 2
+
+ bbox_list = self.bbox_head.get_bboxes(
+ center_heatmap_preds,
+ wh_preds, [offset_preds[0][0:1]],
+ img_metas[ind],
+ rescale=rescale,
+ with_nms=False)
+ aug_results.append(bbox_list)
+
+ nms_cfg = self.bbox_head.test_cfg.get('nms_cfg', None)
+ if nms_cfg is None:
+ with_nms = False
+ else:
+ with_nms = True
+ bbox_list = [self.merge_aug_results(aug_results, with_nms)]
+ bbox_results = [
+ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
+ for det_bboxes, det_labels in bbox_list
+ ]
+ return bbox_results
diff --git a/mmdet/models/detectors/cornernet.py b/mmdet/models/detectors/cornernet.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce921cc3b38e81c1629abeea0cd4e3b317bf7a83
--- /dev/null
+++ b/mmdet/models/detectors/cornernet.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.core import bbox2result, bbox_mapping_back
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class CornerNet(SingleStageDetector):
+ """CornerNet.
+
+ This detector is the implementation of the paper `CornerNet: Detecting
+ Objects as Paired Keypoints `_ .
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(CornerNet, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
+
+ def merge_aug_results(self, aug_results, img_metas):
+ """Merge augmented detection bboxes and score.
+
+ Args:
+ aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each
+ image.
+ img_metas (list[list[dict]]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+
+ Returns:
+ tuple: (bboxes, labels)
+ """
+ recovered_bboxes, aug_labels = [], []
+ for bboxes_labels, img_info in zip(aug_results, img_metas):
+ img_shape = img_info[0]['img_shape'] # using shape before padding
+ scale_factor = img_info[0]['scale_factor']
+ flip = img_info[0]['flip']
+ bboxes, labels = bboxes_labels
+ bboxes, scores = bboxes[:, :4], bboxes[:, -1:]
+ bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip)
+ recovered_bboxes.append(torch.cat([bboxes, scores], dim=-1))
+ aug_labels.append(labels)
+
+ bboxes = torch.cat(recovered_bboxes, dim=0)
+ labels = torch.cat(aug_labels)
+
+ if bboxes.shape[0] > 0:
+ out_bboxes, out_labels = self.bbox_head._bboxes_nms(
+ bboxes, labels, self.bbox_head.test_cfg)
+ else:
+ out_bboxes, out_labels = bboxes, labels
+
+ return out_bboxes, out_labels
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Augment testing of CornerNet.
+
+ Args:
+ imgs (list[Tensor]): Augmented images.
+ img_metas (list[list[dict]]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+
+ Note:
+ ``imgs`` must including flipped image pairs.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ img_inds = list(range(len(imgs)))
+
+ assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], (
+ 'aug test must have flipped image pair')
+ aug_results = []
+ for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]):
+ img_pair = torch.cat([imgs[ind], imgs[flip_ind]])
+ x = self.extract_feat(img_pair)
+ outs = self.bbox_head(x)
+ bbox_list = self.bbox_head.get_bboxes(
+ *outs, [img_metas[ind], img_metas[flip_ind]], False, False)
+ aug_results.append(bbox_list[0])
+ aug_results.append(bbox_list[1])
+
+ bboxes, labels = self.merge_aug_results(aug_results, img_metas)
+ bbox_results = bbox2result(bboxes, labels, self.bbox_head.num_classes)
+
+ return [bbox_results]
diff --git a/mmdet/models/detectors/ddod.py b/mmdet/models/detectors/ddod.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ae0a74172ecca07aa8fad399425b19b4ce63eab
--- /dev/null
+++ b/mmdet/models/detectors/ddod.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class DDOD(SingleStageDetector):
+ """Implementation of `DDOD `_."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(DDOD, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/deformable_detr.py b/mmdet/models/detectors/deformable_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1f164221d2f6ac21448eeb04d685d93f7b86853
--- /dev/null
+++ b/mmdet/models/detectors/deformable_detr.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .detr import DETR
+
+
+@DETECTORS.register_module()
+class DeformableDETR(DETR):
+
+ def __init__(self, *args, **kwargs):
+ super(DETR, self).__init__(*args, **kwargs)
diff --git a/mmdet/models/detectors/detr.py b/mmdet/models/detectors/detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..06d76913be64b98e3a497c043cf71c7d2d4491ae
--- /dev/null
+++ b/mmdet/models/detectors/detr.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class DETR(SingleStageDetector):
+ r"""Implementation of `DETR: End-to-End Object Detection with
+ Transformers `_"""
+
+ def __init__(self,
+ backbone,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(DETR, self).__init__(backbone, None, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
+
+ # over-write `forward_dummy` because:
+ # the forward of bbox_head requires img_metas
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ warnings.warn('Warning! MultiheadAttention in DETR does not '
+ 'support flops computation! Do not use the '
+ 'results in your papers!')
+
+ batch_size, _, height, width = img.shape
+ dummy_img_metas = [
+ dict(
+ batch_input_shape=(height, width),
+ img_shape=(height, width, 3)) for _ in range(batch_size)
+ ]
+ x = self.extract_feat(img)
+ outs = self.bbox_head(x, dummy_img_metas)
+ return outs
+
+ # over-write `onnx_export` because:
+ # (1) the forward of bbox_head requires img_metas
+ # (2) the different behavior (e.g. construction of `masks`) between
+ # torch and ONNX model, during the forward of bbox_head
+ def onnx_export(self, img, img_metas):
+ """Test function for exporting to ONNX, without test time augmentation.
+
+ Args:
+ img (torch.Tensor): input images.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
+ and class labels of shape [N, num_det].
+ """
+ x = self.extract_feat(img)
+ # forward of this head requires img_metas
+ outs = self.bbox_head.forward_onnx(x, img_metas)
+ # get shape as tensor
+ img_shape = torch._shape_as_tensor(img)[2:]
+ img_metas[0]['img_shape_for_onnx'] = img_shape
+
+ det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas)
+
+ return det_bboxes, det_labels
diff --git a/mmdet/models/detectors/fast_rcnn.py b/mmdet/models/detectors/fast_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7aebe151feb22354573b7b06675e15be3f610fe6
--- /dev/null
+++ b/mmdet/models/detectors/fast_rcnn.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class FastRCNN(TwoStageDetector):
+ """Implementation of `Fast R-CNN `_"""
+
+ def __init__(self,
+ backbone,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None,
+ init_cfg=None):
+ super(FastRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
+
+ def forward_test(self, imgs, img_metas, proposals, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ proposals (List[List[Tensor]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. The Tensor should have a shape Px4, where
+ P is the number of proposals.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got {type(var)}')
+
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) '
+ f'!= num of image meta ({len(img_metas)})')
+
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], proposals[0],
+ **kwargs)
+ else:
+ # TODO: support test-time augmentation
+ assert NotImplementedError
diff --git a/mmdet/models/detectors/faster_rcnn.py b/mmdet/models/detectors/faster_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..70fb662f1705997be8d899f4760ab9a3aafec18d
--- /dev/null
+++ b/mmdet/models/detectors/faster_rcnn.py
@@ -0,0 +1,27 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class FasterRCNN(TwoStageDetector):
+ """Implementation of `Faster R-CNN `_"""
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None,
+ init_cfg=None):
+ super(FasterRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
diff --git a/mmdet/models/detectors/fcos.py b/mmdet/models/detectors/fcos.py
new file mode 100644
index 0000000000000000000000000000000000000000..d985bd02d7ca5c13e86dfdb9a7a5ed9b29d890cc
--- /dev/null
+++ b/mmdet/models/detectors/fcos.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class FCOS(SingleStageDetector):
+ """Implementation of `FCOS `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/fovea.py b/mmdet/models/detectors/fovea.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fd908c7e1795f3f216481d7a3f6975e710a33b5
--- /dev/null
+++ b/mmdet/models/detectors/fovea.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class FOVEA(SingleStageDetector):
+ """Implementation of `FoveaBox `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(FOVEA, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/fsaf.py b/mmdet/models/detectors/fsaf.py
new file mode 100644
index 0000000000000000000000000000000000000000..81ed1bdef1a8957077788397422725c83e3ffed2
--- /dev/null
+++ b/mmdet/models/detectors/fsaf.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class FSAF(SingleStageDetector):
+ """Implementation of `FSAF `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(FSAF, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/gfl.py b/mmdet/models/detectors/gfl.py
new file mode 100644
index 0000000000000000000000000000000000000000..4628e2e7c929bb7195ef51f741da9ca66bf9c3d8
--- /dev/null
+++ b/mmdet/models/detectors/gfl.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class GFL(SingleStageDetector):
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(GFL, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/grid_rcnn.py b/mmdet/models/detectors/grid_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..bba7873bcf3df1ca82f471a86cce5a3f15ccf724
--- /dev/null
+++ b/mmdet/models/detectors/grid_rcnn.py
@@ -0,0 +1,32 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class GridRCNN(TwoStageDetector):
+ """Grid R-CNN.
+
+ This detector is the implementation of:
+ - Grid R-CNN (https://arxiv.org/abs/1811.12030)
+ - Grid R-CNN Plus: Faster and Better (https://arxiv.org/abs/1906.05688)
+ """
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None,
+ init_cfg=None):
+ super(GridRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7c95338a78fad03ffa7db3a479865a416d0d70c
--- /dev/null
+++ b/mmdet/models/detectors/htc.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .cascade_rcnn import CascadeRCNN
+
+
+@DETECTORS.register_module()
+class HybridTaskCascade(CascadeRCNN):
+ """Implementation of `HTC `_"""
+
+ def __init__(self, **kwargs):
+ super(HybridTaskCascade, self).__init__(**kwargs)
+
+ @property
+ def with_semantic(self):
+ """bool: whether the detector has a semantic head"""
+ return self.roi_head.with_semantic
diff --git a/mmdet/models/detectors/kd_one_stage.py b/mmdet/models/detectors/kd_one_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb66b5152cdeb1dd9698cff011108de3f3f12ac2
--- /dev/null
+++ b/mmdet/models/detectors/kd_one_stage.py
@@ -0,0 +1,103 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from pathlib import Path
+
+import mmcv
+import torch
+from mmcv.runner import load_checkpoint
+
+from .. import build_detector
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class KnowledgeDistillationSingleStageDetector(SingleStageDetector):
+ r"""Implementation of `Distilling the Knowledge in a Neural Network.
+ `_.
+
+ Args:
+ teacher_config (str | dict): Config file path
+ or the config object of teacher model.
+ teacher_ckpt (str, optional): Checkpoint path of teacher model.
+ If left as None, the model will not load any weights.
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ teacher_config,
+ teacher_ckpt=None,
+ eval_teacher=True,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
+ pretrained)
+ self.eval_teacher = eval_teacher
+ # Build teacher model
+ if isinstance(teacher_config, (str, Path)):
+ teacher_config = mmcv.Config.fromfile(teacher_config)
+ self.teacher_model = build_detector(teacher_config['model'])
+ if teacher_ckpt is not None:
+ load_checkpoint(
+ self.teacher_model, teacher_ckpt, map_location='cpu')
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ x = self.extract_feat(img)
+ with torch.no_grad():
+ teacher_x = self.teacher_model.extract_feat(img)
+ out_teacher = self.teacher_model.bbox_head(teacher_x)
+ losses = self.bbox_head.forward_train(x, out_teacher, img_metas,
+ gt_bboxes, gt_labels,
+ gt_bboxes_ignore)
+ return losses
+
+ def cuda(self, device=None):
+ """Since teacher_model is registered as a plain object, it is necessary
+ to put the teacher model to cuda when calling cuda function."""
+ self.teacher_model.cuda(device=device)
+ return super().cuda(device=device)
+
+ def train(self, mode=True):
+ """Set the same train mode for teacher and student model."""
+ if self.eval_teacher:
+ self.teacher_model.train(False)
+ else:
+ self.teacher_model.train(mode)
+ super().train(mode)
+
+ def __setattr__(self, name, value):
+ """Set attribute, i.e. self.name = value
+
+ This reloading prevent the teacher model from being registered as a
+ nn.Module. The teacher module is registered as a plain object, so that
+ the teacher parameters will not show up when calling
+ ``self.parameters``, ``self.modules``, ``self.children`` methods.
+ """
+ if name == 'teacher_model':
+ object.__setattr__(self, name, value)
+ else:
+ super().__setattr__(name, value)
diff --git a/mmdet/models/detectors/lad.py b/mmdet/models/detectors/lad.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6cc1e0b2d9fd91dabc606da5192522e908ccebf
--- /dev/null
+++ b/mmdet/models/detectors/lad.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.runner import load_checkpoint
+
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .kd_one_stage import KnowledgeDistillationSingleStageDetector
+
+
+@DETECTORS.register_module()
+class LAD(KnowledgeDistillationSingleStageDetector):
+ """Implementation of `LAD `_."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ teacher_backbone,
+ teacher_neck,
+ teacher_bbox_head,
+ teacher_ckpt,
+ eval_teacher=True,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(KnowledgeDistillationSingleStageDetector,
+ self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
+ pretrained)
+ self.eval_teacher = eval_teacher
+ self.teacher_model = nn.Module()
+ self.teacher_model.backbone = build_backbone(teacher_backbone)
+ if teacher_neck is not None:
+ self.teacher_model.neck = build_neck(teacher_neck)
+ teacher_bbox_head.update(train_cfg=train_cfg)
+ teacher_bbox_head.update(test_cfg=test_cfg)
+ self.teacher_model.bbox_head = build_head(teacher_bbox_head)
+ if teacher_ckpt is not None:
+ load_checkpoint(
+ self.teacher_model, teacher_ckpt, map_location='cpu')
+
+ @property
+ def with_teacher_neck(self):
+ """bool: whether the detector has a teacher_neck"""
+ return hasattr(self.teacher_model, 'neck') and \
+ self.teacher_model.neck is not None
+
+ def extract_teacher_feat(self, img):
+ """Directly extract teacher features from the backbone+neck."""
+ x = self.teacher_model.backbone(img)
+ if self.with_teacher_neck:
+ x = self.teacher_model.neck(x)
+ return x
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # get label assignment from the teacher
+ with torch.no_grad():
+ x_teacher = self.extract_teacher_feat(img)
+ outs_teacher = self.teacher_model.bbox_head(x_teacher)
+ label_assignment_results = \
+ self.teacher_model.bbox_head.get_label_assignment(
+ *outs_teacher, gt_bboxes, gt_labels, img_metas,
+ gt_bboxes_ignore)
+
+ # the student use the label assignment from the teacher to learn
+ x = self.extract_feat(img)
+ losses = self.bbox_head.forward_train(x, label_assignment_results,
+ img_metas, gt_bboxes, gt_labels,
+ gt_bboxes_ignore)
+ return losses
diff --git a/mmdet/models/detectors/mask2former.py b/mmdet/models/detectors/mask2former.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9ad2ed25d30072aeb8ec99e4a865c9cad092444
--- /dev/null
+++ b/mmdet/models/detectors/mask2former.py
@@ -0,0 +1,27 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .maskformer import MaskFormer
+
+
+@DETECTORS.register_module()
+class Mask2Former(MaskFormer):
+ r"""Implementation of `Masked-attention Mask
+ Transformer for Universal Image Segmentation
+ `_."""
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ panoptic_head=None,
+ panoptic_fusion_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=None):
+ super().__init__(
+ backbone,
+ neck=neck,
+ panoptic_head=panoptic_head,
+ panoptic_fusion_head=panoptic_fusion_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ init_cfg=init_cfg)
diff --git a/mmdet/models/detectors/mask_rcnn.py b/mmdet/models/detectors/mask_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c68489f9c22e112ceae9c265e916cc3c1a6ae301
--- /dev/null
+++ b/mmdet/models/detectors/mask_rcnn.py
@@ -0,0 +1,27 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class MaskRCNN(TwoStageDetector):
+ """Implementation of `Mask R-CNN `_"""
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None,
+ init_cfg=None):
+ super(MaskRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
diff --git a/mmdet/models/detectors/mask_scoring_rcnn.py b/mmdet/models/detectors/mask_scoring_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f55656f3043564c7f974739c764180c9230738b
--- /dev/null
+++ b/mmdet/models/detectors/mask_scoring_rcnn.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class MaskScoringRCNN(TwoStageDetector):
+ """Mask Scoring RCNN.
+
+ https://arxiv.org/abs/1903.00241
+ """
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None,
+ init_cfg=None):
+ super(MaskScoringRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
diff --git a/mmdet/models/detectors/maskformer.py b/mmdet/models/detectors/maskformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d251adad139997d28827e3ad7ed79a48bcce8bb
--- /dev/null
+++ b/mmdet/models/detectors/maskformer.py
@@ -0,0 +1,258 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+import mmcv
+import numpy as np
+
+from mmdet.core import INSTANCE_OFFSET, bbox2result
+from mmdet.core.visualization import imshow_det_bboxes
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class MaskFormer(SingleStageDetector):
+ r"""Implementation of `Per-Pixel Classification is
+ NOT All You Need for Semantic Segmentation
+ `_."""
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ panoptic_head=None,
+ panoptic_fusion_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=None):
+ super(SingleStageDetector, self).__init__(init_cfg=init_cfg)
+ self.backbone = build_backbone(backbone)
+ if neck is not None:
+ self.neck = build_neck(neck)
+
+ panoptic_head_ = copy.deepcopy(panoptic_head)
+ panoptic_head_.update(train_cfg=train_cfg)
+ panoptic_head_.update(test_cfg=test_cfg)
+ self.panoptic_head = build_head(panoptic_head_)
+
+ panoptic_fusion_head_ = copy.deepcopy(panoptic_fusion_head)
+ panoptic_fusion_head_.update(test_cfg=test_cfg)
+ self.panoptic_fusion_head = build_head(panoptic_fusion_head_)
+
+ self.num_things_classes = self.panoptic_head.num_things_classes
+ self.num_stuff_classes = self.panoptic_head.num_stuff_classes
+ self.num_classes = self.panoptic_head.num_classes
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ # BaseDetector.show_result default for instance segmentation
+ if self.num_stuff_classes > 0:
+ self.show_result = self._show_pan_result
+
+ def forward_dummy(self, img, img_metas):
+ """Used for computing network flops. See
+ `mmdetection/tools/analysis_tools/get_flops.py`
+
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[Dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ """
+ super(SingleStageDetector, self).forward_train(img, img_metas)
+ x = self.extract_feat(img)
+ outs = self.panoptic_head(x, img_metas)
+ return outs
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_masks,
+ gt_semantic_seg=None,
+ gt_bboxes_ignore=None,
+ **kargs):
+ """
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[Dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box.
+ gt_masks (list[BitmapMasks]): true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+ gt_semantic_seg (list[tensor]): semantic segmentation mask for
+ images for panoptic segmentation.
+ Defaults to None for instance segmentation.
+ gt_bboxes_ignore (list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ Defaults to None.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # add batch_input_shape in img_metas
+ super(SingleStageDetector, self).forward_train(img, img_metas)
+ x = self.extract_feat(img)
+ losses = self.panoptic_head.forward_train(x, img_metas, gt_bboxes,
+ gt_labels, gt_masks,
+ gt_semantic_seg,
+ gt_bboxes_ignore)
+
+ return losses
+
+ def simple_test(self, imgs, img_metas, **kwargs):
+ """Test without augmentation.
+
+ Args:
+ imgs (Tensor): A batch of images.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ list[dict[str, np.array | tuple[list]] | tuple[list]]:
+ Semantic segmentation results and panoptic segmentation \
+ results of each image for panoptic segmentation, or formatted \
+ bbox and mask results of each image for instance segmentation.
+
+ .. code-block:: none
+
+ [
+ # panoptic segmentation
+ {
+ 'pan_results': np.array, # shape = [h, w]
+ 'ins_results': tuple[list],
+ # semantic segmentation results are not supported yet
+ 'sem_results': np.array
+ },
+ ...
+ ]
+
+ or
+
+ .. code-block:: none
+
+ [
+ # instance segmentation
+ (
+ bboxes, # list[np.array]
+ masks # list[list[np.array]]
+ ),
+ ...
+ ]
+ """
+ feats = self.extract_feat(imgs)
+ mask_cls_results, mask_pred_results = self.panoptic_head.simple_test(
+ feats, img_metas, **kwargs)
+ results = self.panoptic_fusion_head.simple_test(
+ mask_cls_results, mask_pred_results, img_metas, **kwargs)
+ for i in range(len(results)):
+ if 'pan_results' in results[i]:
+ results[i]['pan_results'] = results[i]['pan_results'].detach(
+ ).cpu().numpy()
+
+ if 'ins_results' in results[i]:
+ labels_per_image, bboxes, mask_pred_binary = results[i][
+ 'ins_results']
+ bbox_results = bbox2result(bboxes, labels_per_image,
+ self.num_things_classes)
+ mask_results = [[] for _ in range(self.num_things_classes)]
+ for j, label in enumerate(labels_per_image):
+ mask = mask_pred_binary[j].detach().cpu().numpy()
+ mask_results[label].append(mask)
+ results[i]['ins_results'] = bbox_results, mask_results
+
+ assert 'sem_results' not in results[i], 'segmantic segmentation '\
+ 'results are not supported yet.'
+
+ if self.num_stuff_classes == 0:
+ results = [res['ins_results'] for res in results]
+
+ return results
+
+ def aug_test(self, imgs, img_metas, **kwargs):
+ raise NotImplementedError
+
+ def onnx_export(self, img, img_metas):
+ raise NotImplementedError
+
+ def _show_pan_result(self,
+ img,
+ result,
+ score_thr=0.3,
+ bbox_color=(72, 101, 241),
+ text_color=(72, 101, 241),
+ mask_color=None,
+ thickness=2,
+ font_size=13,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None):
+ """Draw `panoptic result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (dict): The results.
+
+ score_thr (float, optional): Minimum score of bboxes to be shown.
+ Default: 0.3.
+ bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
+ The tuple of color should be in BGR order. Default: 'green'.
+ text_color (str or tuple(int) or :obj:`Color`):Color of texts.
+ The tuple of color should be in BGR order. Default: 'green'.
+ mask_color (None or str or tuple(int) or :obj:`Color`):
+ Color of masks. The tuple of color should be in BGR order.
+ Default: None.
+ thickness (int): Thickness of lines. Default: 2.
+ font_size (int): Font size of texts. Default: 13.
+ win_name (str): The window name. Default: ''.
+ wait_time (float): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+
+ Returns:
+ img (Tensor): Only if not `show` or `out_file`.
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ pan_results = result['pan_results']
+ # keep objects ahead
+ ids = np.unique(pan_results)[::-1]
+ legal_indices = ids != self.num_classes # for VOID label
+ ids = ids[legal_indices]
+ labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
+ segms = (pan_results[None] == ids[:, None, None])
+
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+ # draw bounding boxes
+ img = imshow_det_bboxes(
+ img,
+ segms=segms,
+ labels=labels,
+ class_names=self.CLASSES,
+ bbox_color=bbox_color,
+ text_color=text_color,
+ mask_color=mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file)
+
+ if not (show or out_file):
+ return img
diff --git a/mmdet/models/detectors/nasfcos.py b/mmdet/models/detectors/nasfcos.py
new file mode 100644
index 0000000000000000000000000000000000000000..a34c2280f59f93139e716b54ef1799fc0941149f
--- /dev/null
+++ b/mmdet/models/detectors/nasfcos.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class NASFCOS(SingleStageDetector):
+ """NAS-FCOS: Fast Neural Architecture Search for Object Detection.
+
+ https://arxiv.org/abs/1906.0442
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(NASFCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/paa.py b/mmdet/models/detectors/paa.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5cb8372a02e84fc1405c05cd814e8109bc19d20
--- /dev/null
+++ b/mmdet/models/detectors/paa.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class PAA(SingleStageDetector):
+ """Implementation of `PAA `_."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(PAA, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/panoptic_fpn.py b/mmdet/models/detectors/panoptic_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8ac751fad188a85a75a87678ee76693c5609df2
--- /dev/null
+++ b/mmdet/models/detectors/panoptic_fpn.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .panoptic_two_stage_segmentor import TwoStagePanopticSegmentor
+
+
+@DETECTORS.register_module()
+class PanopticFPN(TwoStagePanopticSegmentor):
+ r"""Implementation of `Panoptic feature pyramid
+ networks `_"""
+
+ def __init__(
+ self,
+ backbone,
+ neck=None,
+ rpn_head=None,
+ roi_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None,
+ # for panoptic segmentation
+ semantic_head=None,
+ panoptic_fusion_head=None):
+ super(PanopticFPN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg,
+ semantic_head=semantic_head,
+ panoptic_fusion_head=panoptic_fusion_head)
diff --git a/mmdet/models/detectors/panoptic_two_stage_segmentor.py b/mmdet/models/detectors/panoptic_two_stage_segmentor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ad49bac705a677d1656cf95d2686fd83d2b1b47
--- /dev/null
+++ b/mmdet/models/detectors/panoptic_two_stage_segmentor.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import numpy as np
+import torch
+
+from mmdet.core import INSTANCE_OFFSET, bbox2roi, multiclass_nms
+from mmdet.core.visualization import imshow_det_bboxes
+from ..builder import DETECTORS, build_head
+from ..roi_heads.mask_heads.fcn_mask_head import _do_paste_mask
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class TwoStagePanopticSegmentor(TwoStageDetector):
+ """Base class of Two-stage Panoptic Segmentor.
+
+ As well as the components in TwoStageDetector, Panoptic Segmentor has extra
+ semantic_head and panoptic_fusion_head.
+ """
+
+ def __init__(
+ self,
+ backbone,
+ neck=None,
+ rpn_head=None,
+ roi_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None,
+ # for panoptic segmentation
+ semantic_head=None,
+ panoptic_fusion_head=None):
+ super(TwoStagePanopticSegmentor,
+ self).__init__(backbone, neck, rpn_head, roi_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
+ if semantic_head is not None:
+ self.semantic_head = build_head(semantic_head)
+ if panoptic_fusion_head is not None:
+ panoptic_cfg = test_cfg.panoptic if test_cfg is not None else None
+ panoptic_fusion_head_ = panoptic_fusion_head.deepcopy()
+ panoptic_fusion_head_.update(test_cfg=panoptic_cfg)
+ self.panoptic_fusion_head = build_head(panoptic_fusion_head_)
+
+ self.num_things_classes = self.panoptic_fusion_head.\
+ num_things_classes
+ self.num_stuff_classes = self.panoptic_fusion_head.\
+ num_stuff_classes
+ self.num_classes = self.panoptic_fusion_head.num_classes
+
+ @property
+ def with_semantic_head(self):
+ return hasattr(self,
+ 'semantic_head') and self.semantic_head is not None
+
+ @property
+ def with_panoptic_fusion_head(self):
+ return hasattr(self, 'panoptic_fusion_heads') and \
+ self.panoptic_fusion_head is not None
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/get_flops.py`
+ """
+ raise NotImplementedError(
+ f'`forward_dummy` is not implemented in {self.__class__.__name__}')
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ gt_semantic_seg=None,
+ proposals=None,
+ **kwargs):
+ x = self.extract_feat(img)
+ losses = dict()
+
+ # RPN forward and loss
+ if self.with_rpn:
+ proposal_cfg = self.train_cfg.get('rpn_proposal',
+ self.test_cfg.rpn)
+ rpn_losses, proposal_list = self.rpn_head.forward_train(
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ proposal_cfg=proposal_cfg)
+ losses.update(rpn_losses)
+ else:
+ proposal_list = proposals
+
+ roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
+ gt_bboxes, gt_labels,
+ gt_bboxes_ignore, gt_masks,
+ **kwargs)
+ losses.update(roi_losses)
+
+ semantic_loss = self.semantic_head.forward_train(x, gt_semantic_seg)
+ losses.update(semantic_loss)
+
+ return losses
+
+ def simple_test_mask(self,
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=False):
+ """Simple test for mask head without augmentation."""
+ img_shapes = tuple(meta['ori_shape']
+ for meta in img_metas) if rescale else tuple(
+ meta['pad_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ masks = []
+ for img_shape in img_shapes:
+ out_shape = (0, self.roi_head.bbox_head.num_classes) \
+ + img_shape[:2]
+ masks.append(det_bboxes[0].new_zeros(out_shape))
+ mask_pred = det_bboxes[0].new_zeros((0, 80, 28, 28))
+ mask_results = dict(
+ masks=masks, mask_pred=mask_pred, mask_feats=None)
+ return mask_results
+
+ _bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
+ if rescale:
+ if not isinstance(scale_factors[0], float):
+ scale_factors = [
+ det_bboxes[0].new_tensor(scale_factor)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ _bboxes[i] * scale_factors[i] for i in range(len(_bboxes))
+ ]
+
+ mask_rois = bbox2roi(_bboxes)
+ mask_results = self.roi_head._mask_forward(x, mask_rois)
+ mask_pred = mask_results['mask_pred']
+ # split batch mask prediction back to each image
+ num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
+ mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
+
+ # resize the mask_preds to (K, H, W)
+ masks = []
+ for i in range(len(_bboxes)):
+ det_bbox = det_bboxes[i][:, :4]
+ det_label = det_labels[i]
+
+ mask_pred = mask_preds[i].sigmoid()
+
+ box_inds = torch.arange(mask_pred.shape[0])
+ mask_pred = mask_pred[box_inds, det_label][:, None]
+
+ img_h, img_w, _ = img_shapes[i]
+ mask_pred, _ = _do_paste_mask(
+ mask_pred, det_bbox, img_h, img_w, skip_empty=False)
+ masks.append(mask_pred)
+
+ mask_results['masks'] = masks
+
+ return mask_results
+
+ def simple_test(self, img, img_metas, proposals=None, rescale=False):
+ """Test without Augmentation."""
+ x = self.extract_feat(img)
+
+ if proposals is None:
+ proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
+ else:
+ proposal_list = proposals
+
+ bboxes, scores = self.roi_head.simple_test_bboxes(
+ x, img_metas, proposal_list, None, rescale=rescale)
+
+ pan_cfg = self.test_cfg.panoptic
+ # class-wise predictions
+ det_bboxes = []
+ det_labels = []
+ for bboxe, score in zip(bboxes, scores):
+ det_bbox, det_label = multiclass_nms(bboxe, score,
+ pan_cfg.score_thr,
+ pan_cfg.nms,
+ pan_cfg.max_per_img)
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+
+ mask_results = self.simple_test_mask(
+ x, img_metas, det_bboxes, det_labels, rescale=rescale)
+ masks = mask_results['masks']
+
+ seg_preds = self.semantic_head.simple_test(x, img_metas, rescale)
+
+ results = []
+ for i in range(len(det_bboxes)):
+ pan_results = self.panoptic_fusion_head.simple_test(
+ det_bboxes[i], det_labels[i], masks[i], seg_preds[i])
+ pan_results = pan_results.int().detach().cpu().numpy()
+ result = dict(pan_results=pan_results)
+ results.append(result)
+ return results
+
+ def show_result(self,
+ img,
+ result,
+ score_thr=0.3,
+ bbox_color=(72, 101, 241),
+ text_color=(72, 101, 241),
+ mask_color=None,
+ thickness=2,
+ font_size=13,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (dict): The results.
+
+ score_thr (float, optional): Minimum score of bboxes to be shown.
+ Default: 0.3.
+ bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
+ The tuple of color should be in BGR order. Default: 'green'.
+ text_color (str or tuple(int) or :obj:`Color`):Color of texts.
+ The tuple of color should be in BGR order. Default: 'green'.
+ mask_color (None or str or tuple(int) or :obj:`Color`):
+ Color of masks. The tuple of color should be in BGR order.
+ Default: None.
+ thickness (int): Thickness of lines. Default: 2.
+ font_size (int): Font size of texts. Default: 13.
+ win_name (str): The window name. Default: ''.
+ wait_time (float): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+
+ Returns:
+ img (Tensor): Only if not `show` or `out_file`.
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ pan_results = result['pan_results']
+ # keep objects ahead
+ ids = np.unique(pan_results)[::-1]
+ legal_indices = ids != self.num_classes # for VOID label
+ ids = ids[legal_indices]
+ labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
+ segms = (pan_results[None] == ids[:, None, None])
+
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+ # draw bounding boxes
+ img = imshow_det_bboxes(
+ img,
+ segms=segms,
+ labels=labels,
+ class_names=self.CLASSES,
+ bbox_color=bbox_color,
+ text_color=text_color,
+ mask_color=mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file)
+
+ if not (show or out_file):
+ return img
diff --git a/mmdet/models/detectors/point_rend.py b/mmdet/models/detectors/point_rend.py
new file mode 100644
index 0000000000000000000000000000000000000000..90eb4d40eb179e41cfe0cc2772c9120c093b3d93
--- /dev/null
+++ b/mmdet/models/detectors/point_rend.py
@@ -0,0 +1,32 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class PointRend(TwoStageDetector):
+ """PointRend: Image Segmentation as Rendering
+
+ This detector is the implementation of
+ `PointRend `_.
+
+ """
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None,
+ init_cfg=None):
+ super(PointRend, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
diff --git a/mmdet/models/detectors/queryinst.py b/mmdet/models/detectors/queryinst.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc216c47340fc79344c8eae908b1ec45da2b2b2
--- /dev/null
+++ b/mmdet/models/detectors/queryinst.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .sparse_rcnn import SparseRCNN
+
+
+@DETECTORS.register_module()
+class QueryInst(SparseRCNN):
+ r"""Implementation of
+ `Instances as Queries `_"""
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None,
+ init_cfg=None):
+ super(QueryInst, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
diff --git a/mmdet/models/detectors/reppoints_detector.py b/mmdet/models/detectors/reppoints_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1986cdccf3da96cd179f6bfe9f4f16ff54c411e
--- /dev/null
+++ b/mmdet/models/detectors/reppoints_detector.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class RepPointsDetector(SingleStageDetector):
+ """RepPoints: Point Set Representation for Object Detection.
+
+ This detector is the implementation of:
+ - RepPoints detector (https://arxiv.org/pdf/1904.11490)
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(RepPointsDetector,
+ self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
+ pretrained, init_cfg)
diff --git a/mmdet/models/detectors/retinanet.py b/mmdet/models/detectors/retinanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c28545abb011fa838c56d04fc2583428d61a42f8
--- /dev/null
+++ b/mmdet/models/detectors/retinanet.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class RetinaNet(SingleStageDetector):
+ """Implementation of `RetinaNet `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(RetinaNet, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/rpn.py b/mmdet/models/detectors/rpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..707e02b0ec94a55ac68fd8ee099a92a478e02184
--- /dev/null
+++ b/mmdet/models/detectors/rpn.py
@@ -0,0 +1,162 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from inspect import signature
+
+import mmcv
+import torch
+from mmcv.image import tensor2imgs
+
+from mmdet.core import bbox_mapping
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .base import BaseDetector
+
+
+@DETECTORS.register_module()
+class RPN(BaseDetector):
+ """Implementation of Region Proposal Network."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ rpn_head,
+ train_cfg,
+ test_cfg,
+ pretrained=None,
+ init_cfg=None):
+ super(RPN, self).__init__(init_cfg)
+ if pretrained:
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ backbone.pretrained = pretrained
+ self.backbone = build_backbone(backbone)
+ self.neck = build_neck(neck) if neck is not None else None
+ rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
+ rpn_head.update(train_cfg=rpn_train_cfg)
+ rpn_head.update(test_cfg=test_cfg.rpn)
+ self.rpn_head = build_head(rpn_head)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ def extract_feat(self, img):
+ """Extract features.
+
+ Args:
+ img (torch.Tensor): Image tensor with shape (n, c, h ,w).
+
+ Returns:
+ list[torch.Tensor]: Multi-level features that may have
+ different resolutions.
+ """
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ x = self.extract_feat(img)
+ rpn_outs = self.rpn_head(x)
+ return rpn_outs
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes=None,
+ gt_bboxes_ignore=None):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ if (isinstance(self.train_cfg.rpn, dict)
+ and self.train_cfg.rpn.get('debug', False)):
+ self.rpn_head.debug_imgs = tensor2imgs(img)
+
+ x = self.extract_feat(img)
+ losses = self.rpn_head.forward_train(x, img_metas, gt_bboxes, None,
+ gt_bboxes_ignore)
+ return losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test time augmentation.
+
+ Args:
+ imgs (list[torch.Tensor]): List of multiple images
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[np.ndarray]: proposals
+ """
+ x = self.extract_feat(img)
+ # get origin input shape to onnx dynamic input shape
+ if torch.onnx.is_in_onnx_export():
+ img_shape = torch._shape_as_tensor(img)[2:]
+ img_metas[0]['img_shape_for_onnx'] = img_shape
+ proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
+ if rescale:
+ for proposals, meta in zip(proposal_list, img_metas):
+ proposals[:, :4] /= proposals.new_tensor(meta['scale_factor'])
+ if torch.onnx.is_in_onnx_export():
+ return proposal_list
+
+ return [proposal.cpu().numpy() for proposal in proposal_list]
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ imgs (list[torch.Tensor]): List of multiple images
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[np.ndarray]: proposals
+ """
+ proposal_list = self.rpn_head.aug_test_rpn(
+ self.extract_feats(imgs), img_metas)
+ if not rescale:
+ for proposals, img_meta in zip(proposal_list, img_metas[0]):
+ img_shape = img_meta['img_shape']
+ scale_factor = img_meta['scale_factor']
+ flip = img_meta['flip']
+ flip_direction = img_meta['flip_direction']
+ proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
+ scale_factor, flip,
+ flip_direction)
+ return [proposal.cpu().numpy() for proposal in proposal_list]
+
+ def show_result(self, data, result, top_k=20, **kwargs):
+ """Show RPN proposals on the image.
+
+ Args:
+ data (str or np.ndarray): Image filename or loaded image.
+ result (Tensor or tuple): The results to draw over `img`
+ bbox_result or (bbox_result, segm_result).
+ top_k (int): Plot the first k bboxes only
+ if set positive. Default: 20
+
+ Returns:
+ np.ndarray: The image with bboxes drawn on it.
+ """
+ if kwargs is not None:
+ kwargs['colors'] = 'green'
+ sig = signature(mmcv.imshow_bboxes)
+ for k in list(kwargs.keys()):
+ if k not in sig.parameters:
+ kwargs.pop(k)
+ mmcv.imshow_bboxes(data, result, top_k=top_k, **kwargs)
diff --git a/mmdet/models/detectors/scnet.py b/mmdet/models/detectors/scnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a361d81c3aa62de0ff98b303cb5e0b838b8045fa
--- /dev/null
+++ b/mmdet/models/detectors/scnet.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .cascade_rcnn import CascadeRCNN
+
+
+@DETECTORS.register_module()
+class SCNet(CascadeRCNN):
+ """Implementation of `SCNet `_"""
+
+ def __init__(self, **kwargs):
+ super(SCNet, self).__init__(**kwargs)
diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..c375c72d69d21cade02f0b4bff8cb035e56f0d65
--- /dev/null
+++ b/mmdet/models/detectors/single_stage.py
@@ -0,0 +1,171 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+
+from mmdet.core import bbox2result
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .base import BaseDetector
+
+
+@DETECTORS.register_module()
+class SingleStageDetector(BaseDetector):
+ """Base class for single-stage detectors.
+
+ Single-stage detectors directly and densely predict bounding boxes on the
+ output features of the backbone+neck.
+ """
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ bbox_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(SingleStageDetector, self).__init__(init_cfg)
+ if pretrained:
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ backbone.pretrained = pretrained
+ self.backbone = build_backbone(backbone)
+ if neck is not None:
+ self.neck = build_neck(neck)
+ bbox_head.update(train_cfg=train_cfg)
+ bbox_head.update(test_cfg=test_cfg)
+ self.bbox_head = build_head(bbox_head)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ def extract_feat(self, img):
+ """Directly extract features from the backbone+neck."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ x = self.extract_feat(img)
+ outs = self.bbox_head(x)
+ return outs
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ super(SingleStageDetector, self).forward_train(img, img_metas)
+ x = self.extract_feat(img)
+ losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
+ gt_labels, gt_bboxes_ignore)
+ return losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test-time augmentation.
+
+ Args:
+ img (torch.Tensor): Images with shape (N, C, H, W).
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ feat = self.extract_feat(img)
+ results_list = self.bbox_head.simple_test(
+ feat, img_metas, rescale=rescale)
+ bbox_results = [
+ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
+ for det_bboxes, det_labels in results_list
+ ]
+ return bbox_results
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ imgs (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ assert hasattr(self.bbox_head, 'aug_test'), \
+ f'{self.bbox_head.__class__.__name__}' \
+ ' does not support test-time augmentation'
+
+ feats = self.extract_feats(imgs)
+ results_list = self.bbox_head.aug_test(
+ feats, img_metas, rescale=rescale)
+ bbox_results = [
+ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
+ for det_bboxes, det_labels in results_list
+ ]
+ return bbox_results
+
+ def onnx_export(self, img, img_metas, with_nms=True):
+ """Test function without test time augmentation.
+
+ Args:
+ img (torch.Tensor): input images.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
+ and class labels of shape [N, num_det].
+ """
+ x = self.extract_feat(img)
+ outs = self.bbox_head(x)
+ # get origin input shape to support onnx dynamic shape
+
+ # get shape as tensor
+ img_shape = torch._shape_as_tensor(img)[2:]
+ img_metas[0]['img_shape_for_onnx'] = img_shape
+ # get pad input shape to support onnx dynamic shape for exporting
+ # `CornerNet` and `CentripetalNet`, which 'pad_shape' is used
+ # for inference
+ img_metas[0]['pad_shape_for_onnx'] = img_shape
+
+ if len(outs) == 2:
+ # add dummy score_factor
+ outs = (*outs, None)
+ # TODO Can we change to `get_bboxes` when `onnx_export` fail
+ det_bboxes, det_labels = self.bbox_head.onnx_export(
+ *outs, img_metas, with_nms=with_nms)
+
+ return det_bboxes, det_labels
diff --git a/mmdet/models/detectors/single_stage_instance_seg.py b/mmdet/models/detectors/single_stage_instance_seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..239b669975239f9b1eebb6efa131db7978266704
--- /dev/null
+++ b/mmdet/models/detectors/single_stage_instance_seg.py
@@ -0,0 +1,363 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+import mmcv
+import numpy as np
+import torch
+
+from mmdet.core.visualization.image import imshow_det_bboxes
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .base import BaseDetector
+
+INF = 1e8
+
+
+@DETECTORS.register_module()
+class SingleStageInstanceSegmentor(BaseDetector):
+ """Base class for single-stage instance segmentors."""
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ bbox_head=None,
+ mask_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+
+ if pretrained:
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ backbone.pretrained = pretrained
+ super(SingleStageInstanceSegmentor, self).__init__(init_cfg=init_cfg)
+ self.backbone = build_backbone(backbone)
+ if neck is not None:
+ self.neck = build_neck(neck)
+ else:
+ self.neck = None
+ if bbox_head is not None:
+ bbox_head.update(train_cfg=copy.deepcopy(train_cfg))
+ bbox_head.update(test_cfg=copy.deepcopy(test_cfg))
+ self.bbox_head = build_head(bbox_head)
+ else:
+ self.bbox_head = None
+
+ assert mask_head, f'`mask_head` must ' \
+ f'be implemented in {self.__class__.__name__}'
+ mask_head.update(train_cfg=copy.deepcopy(train_cfg))
+ mask_head.update(test_cfg=copy.deepcopy(test_cfg))
+ self.mask_head = build_head(mask_head)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ def extract_feat(self, img):
+ """Directly extract features from the backbone and neck."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ raise NotImplementedError(
+ f'`forward_dummy` is not implemented in {self.__class__.__name__}')
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_masks,
+ gt_labels,
+ gt_bboxes=None,
+ gt_bboxes_ignore=None,
+ **kwargs):
+ """
+ Args:
+ img (Tensor): Input images of shape (B, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_masks (list[:obj:`BitmapMasks`] | None) : The segmentation
+ masks for each box.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes (list[Tensor]): Each item is the truth boxes
+ of each image in [tl_x, tl_y, br_x, br_y] format.
+ Default: None.
+ gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ gt_masks = [
+ gt_mask.to_tensor(dtype=torch.bool, device=img.device)
+ for gt_mask in gt_masks
+ ]
+ x = self.extract_feat(img)
+ losses = dict()
+
+ # CondInst and YOLACT have bbox_head
+ if self.bbox_head:
+ # bbox_head_preds is a tuple
+ bbox_head_preds = self.bbox_head(x)
+ # positive_infos is a list of obj:`InstanceData`
+ # It contains the information about the positive samples
+ # CondInst, YOLACT
+ det_losses, positive_infos = self.bbox_head.loss(
+ *bbox_head_preds,
+ gt_bboxes=gt_bboxes,
+ gt_labels=gt_labels,
+ gt_masks=gt_masks,
+ img_metas=img_metas,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ **kwargs)
+ losses.update(det_losses)
+ else:
+ positive_infos = None
+
+ mask_loss = self.mask_head.forward_train(
+ x,
+ gt_labels,
+ gt_masks,
+ img_metas,
+ positive_infos=positive_infos,
+ gt_bboxes=gt_bboxes,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ **kwargs)
+ # avoid loss override
+ assert not set(mask_loss.keys()) & set(losses.keys())
+
+ losses.update(mask_loss)
+ return losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test-time augmentation.
+
+ Args:
+ img (torch.Tensor): Images with shape (B, C, H, W).
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list(tuple): Formatted bbox and mask results of multiple \
+ images. The outer list corresponds to each image. \
+ Each tuple contains two type of results of single image:
+
+ - bbox_results (list[np.ndarray]): BBox results of
+ single image. The list corresponds to each class.
+ each ndarray has a shape (N, 5), N is the number of
+ bboxes with this category, and last dimension
+ 5 arrange as (x1, y1, x2, y2, scores).
+ - mask_results (list[np.ndarray]): Mask results of
+ single image. The list corresponds to each class.
+ each ndarray has a shape (N, img_h, img_w), N
+ is the number of masks with this category.
+ """
+ feat = self.extract_feat(img)
+ if self.bbox_head:
+ outs = self.bbox_head(feat)
+ # results_list is list[obj:`InstanceData`]
+ results_list = self.bbox_head.get_results(
+ *outs, img_metas=img_metas, cfg=self.test_cfg, rescale=rescale)
+ else:
+ results_list = None
+
+ results_list = self.mask_head.simple_test(
+ feat, img_metas, rescale=rescale, instances_list=results_list)
+
+ format_results_list = []
+ for results in results_list:
+ format_results_list.append(self.format_results(results))
+
+ return format_results_list
+
+ def format_results(self, results):
+ """Format the model predictions according to the interface with
+ dataset.
+
+ Args:
+ results (:obj:`InstanceData`): Processed
+ results of single images. Usually contains
+ following keys.
+
+ - scores (Tensor): Classification scores, has shape
+ (num_instance,)
+ - labels (Tensor): Has shape (num_instances,).
+ - masks (Tensor): Processed mask results, has
+ shape (num_instances, h, w).
+
+ Returns:
+ tuple: Formatted bbox and mask results.. It contains two items:
+
+ - bbox_results (list[np.ndarray]): BBox results of
+ single image. The list corresponds to each class.
+ each ndarray has a shape (N, 5), N is the number of
+ bboxes with this category, and last dimension
+ 5 arrange as (x1, y1, x2, y2, scores).
+ - mask_results (list[np.ndarray]): Mask results of
+ single image. The list corresponds to each class.
+ each ndarray has shape (N, img_h, img_w), N
+ is the number of masks with this category.
+ """
+ data_keys = results.keys()
+ assert 'scores' in data_keys
+ assert 'labels' in data_keys
+
+ assert 'masks' in data_keys, \
+ 'results should contain ' \
+ 'masks when format the results '
+ mask_results = [[] for _ in range(self.mask_head.num_classes)]
+
+ num_masks = len(results)
+
+ if num_masks == 0:
+ bbox_results = [
+ np.zeros((0, 5), dtype=np.float32)
+ for _ in range(self.mask_head.num_classes)
+ ]
+ return bbox_results, mask_results
+
+ labels = results.labels.detach().cpu().numpy()
+
+ if 'bboxes' not in results:
+ # create dummy bbox results to store the scores
+ results.bboxes = results.scores.new_zeros(len(results), 4)
+
+ det_bboxes = torch.cat([results.bboxes, results.scores[:, None]],
+ dim=-1)
+ det_bboxes = det_bboxes.detach().cpu().numpy()
+ bbox_results = [
+ det_bboxes[labels == i, :]
+ for i in range(self.mask_head.num_classes)
+ ]
+
+ masks = results.masks.detach().cpu().numpy()
+
+ for idx in range(num_masks):
+ mask = masks[idx]
+ mask_results[labels[idx]].append(mask)
+
+ return bbox_results, mask_results
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ raise NotImplementedError
+
+ def show_result(self,
+ img,
+ result,
+ score_thr=0.3,
+ bbox_color=(72, 101, 241),
+ text_color=(72, 101, 241),
+ mask_color=None,
+ thickness=2,
+ font_size=13,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (tuple): Format bbox and mask results.
+ It contains two items:
+
+ - bbox_results (list[np.ndarray]): BBox results of
+ single image. The list corresponds to each class.
+ each ndarray has a shape (N, 5), N is the number of
+ bboxes with this category, and last dimension
+ 5 arrange as (x1, y1, x2, y2, scores).
+ - mask_results (list[np.ndarray]): Mask results of
+ single image. The list corresponds to each class.
+ each ndarray has shape (N, img_h, img_w), N
+ is the number of masks with this category.
+
+ score_thr (float, optional): Minimum score of bboxes to be shown.
+ Default: 0.3.
+ bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
+ The tuple of color should be in BGR order. Default: 'green'
+ text_color (str or tuple(int) or :obj:`Color`):Color of texts.
+ The tuple of color should be in BGR order. Default: 'green'
+ mask_color (None or str or tuple(int) or :obj:`Color`):
+ Color of masks. The tuple of color should be in BGR order.
+ Default: None
+ thickness (int): Thickness of lines. Default: 2
+ font_size (int): Font size of texts. Default: 13
+ win_name (str): The window name. Default: ''
+ wait_time (float): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+
+ Returns:
+ img (Tensor): Only if not `show` or `out_file`
+ """
+
+ assert isinstance(result, tuple)
+ bbox_result, mask_result = result
+ bboxes = np.vstack(bbox_result)
+ img = mmcv.imread(img)
+ img = img.copy()
+ labels = [
+ np.full(bbox.shape[0], i, dtype=np.int32)
+ for i, bbox in enumerate(bbox_result)
+ ]
+ labels = np.concatenate(labels)
+ if len(labels) == 0:
+ bboxes = np.zeros([0, 5])
+ masks = np.zeros([0, 0, 0])
+ # draw segmentation masks
+ else:
+ masks = mmcv.concat_list(mask_result)
+
+ if isinstance(masks[0], torch.Tensor):
+ masks = torch.stack(masks, dim=0).detach().cpu().numpy()
+ else:
+ masks = np.stack(masks, axis=0)
+ # dummy bboxes
+ if bboxes[:, :4].sum() == 0:
+ num_masks = len(bboxes)
+ x_any = masks.any(axis=1)
+ y_any = masks.any(axis=2)
+ for idx in range(num_masks):
+ x = np.where(x_any[idx, :])[0]
+ y = np.where(y_any[idx, :])[0]
+ if len(x) > 0 and len(y) > 0:
+ bboxes[idx, :4] = np.array(
+ [x[0], y[0], x[-1] + 1, y[-1] + 1],
+ dtype=np.float32)
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+ # draw bounding boxes
+ img = imshow_det_bboxes(
+ img,
+ bboxes,
+ labels,
+ masks,
+ class_names=self.CLASSES,
+ score_thr=score_thr,
+ bbox_color=bbox_color,
+ text_color=text_color,
+ mask_color=mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file)
+
+ if not (show or out_file):
+ return img
diff --git a/mmdet/models/detectors/solo.py b/mmdet/models/detectors/solo.py
new file mode 100644
index 0000000000000000000000000000000000000000..df6f6de0162ef145ab36c645872337ec7ca4861b
--- /dev/null
+++ b/mmdet/models/detectors/solo.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage_instance_seg import SingleStageInstanceSegmentor
+
+
+@DETECTORS.register_module()
+class SOLO(SingleStageInstanceSegmentor):
+ """`SOLO: Segmenting Objects by Locations
+ `_
+
+ """
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ bbox_head=None,
+ mask_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=None,
+ pretrained=None):
+ super().__init__(
+ backbone=backbone,
+ neck=neck,
+ bbox_head=bbox_head,
+ mask_head=mask_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ init_cfg=init_cfg,
+ pretrained=pretrained)
diff --git a/mmdet/models/detectors/solov2.py b/mmdet/models/detectors/solov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..711fcb495da6738c27e4cbe018104c559e83b0ad
--- /dev/null
+++ b/mmdet/models/detectors/solov2.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage_instance_seg import SingleStageInstanceSegmentor
+
+
+@DETECTORS.register_module()
+class SOLOv2(SingleStageInstanceSegmentor):
+ """`SOLOv2: Dynamic and Fast Instance Segmentation
+ `_
+
+ """
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ bbox_head=None,
+ mask_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ init_cfg=None,
+ pretrained=None):
+ super().__init__(
+ backbone=backbone,
+ neck=neck,
+ bbox_head=bbox_head,
+ mask_head=mask_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ init_cfg=init_cfg,
+ pretrained=pretrained)
diff --git a/mmdet/models/detectors/sparse_rcnn.py b/mmdet/models/detectors/sparse_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e90c2a5aba5a538e024d27aa8d150b4a3982f6fe
--- /dev/null
+++ b/mmdet/models/detectors/sparse_rcnn.py
@@ -0,0 +1,111 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class SparseRCNN(TwoStageDetector):
+ r"""Implementation of `Sparse R-CNN: End-to-End Object Detection with
+ Learnable Proposals `_"""
+
+ def __init__(self, *args, **kwargs):
+ super(SparseRCNN, self).__init__(*args, **kwargs)
+ assert self.with_rpn, 'Sparse R-CNN and QueryInst ' \
+ 'do not support external proposals'
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ proposals=None,
+ **kwargs):
+ """Forward function of SparseR-CNN and QueryInst in train stage.
+
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (List[Tensor], optional) : Segmentation masks for
+ each box. This is required to train QueryInst.
+ proposals (List[Tensor], optional): override rpn proposals with
+ custom proposals. Use when `with_rpn` is False.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ assert proposals is None, 'Sparse R-CNN and QueryInst ' \
+ 'do not support external proposals'
+
+ x = self.extract_feat(img)
+ proposal_boxes, proposal_features, imgs_whwh = \
+ self.rpn_head.forward_train(x, img_metas)
+ roi_losses = self.roi_head.forward_train(
+ x,
+ proposal_boxes,
+ proposal_features,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ gt_masks=gt_masks,
+ imgs_whwh=imgs_whwh)
+ return roi_losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test time augmentation.
+
+ Args:
+ imgs (list[torch.Tensor]): List of multiple images
+ img_metas (list[dict]): List of image information.
+ rescale (bool): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ x = self.extract_feat(img)
+ proposal_boxes, proposal_features, imgs_whwh = \
+ self.rpn_head.simple_test_rpn(x, img_metas)
+ results = self.roi_head.simple_test(
+ x,
+ proposal_boxes,
+ proposal_features,
+ img_metas,
+ imgs_whwh=imgs_whwh,
+ rescale=rescale)
+ return results
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ # backbone
+ x = self.extract_feat(img)
+ # rpn
+ num_imgs = len(img)
+ dummy_img_metas = [
+ dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)
+ ]
+ proposal_boxes, proposal_features, imgs_whwh = \
+ self.rpn_head.simple_test_rpn(x, dummy_img_metas)
+ # roi_head
+ roi_outs = self.roi_head.forward_dummy(x, proposal_boxes,
+ proposal_features,
+ dummy_img_metas)
+ return roi_outs
diff --git a/mmdet/models/detectors/tood.py b/mmdet/models/detectors/tood.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dd18c3c96abd0fb4d4eac5a6fb708b242be0571
--- /dev/null
+++ b/mmdet/models/detectors/tood.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class TOOD(SingleStageDetector):
+ r"""Implementation of `TOOD: Task-aligned One-stage Object Detection.
+ `_."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(TOOD, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
+
+ def set_epoch(self, epoch):
+ self.bbox_head.epoch = epoch
diff --git a/mmdet/models/detectors/trident_faster_rcnn.py b/mmdet/models/detectors/trident_faster_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb26168ca382c2330fefe8065b654dc183d42a74
--- /dev/null
+++ b/mmdet/models/detectors/trident_faster_rcnn.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .faster_rcnn import FasterRCNN
+
+
+@DETECTORS.register_module()
+class TridentFasterRCNN(FasterRCNN):
+ """Implementation of `TridentNet `_"""
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None,
+ init_cfg=None):
+
+ super(TridentFasterRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
+ assert self.backbone.num_branch == self.roi_head.num_branch
+ assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx
+ self.num_branch = self.backbone.num_branch
+ self.test_branch_idx = self.backbone.test_branch_idx
+
+ def simple_test(self, img, img_metas, proposals=None, rescale=False):
+ """Test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ x = self.extract_feat(img)
+ if proposals is None:
+ num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
+ trident_img_metas = img_metas * num_branch
+ proposal_list = self.rpn_head.simple_test_rpn(x, trident_img_metas)
+ else:
+ proposal_list = proposals
+ # TODO: Fix trident_img_metas undefined errors
+ # when proposals is specified
+ return self.roi_head.simple_test(
+ x, proposal_list, trident_img_metas, rescale=rescale)
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ x = self.extract_feats(imgs)
+ num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
+ trident_img_metas = [img_metas * num_branch for img_metas in img_metas]
+ proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas)
+ return self.roi_head.aug_test(
+ x, proposal_list, img_metas, rescale=rescale)
+
+ def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
+ """make copies of img and gts to fit multi-branch."""
+ trident_gt_bboxes = tuple(gt_bboxes * self.num_branch)
+ trident_gt_labels = tuple(gt_labels * self.num_branch)
+ trident_img_metas = tuple(img_metas * self.num_branch)
+
+ return super(TridentFasterRCNN,
+ self).forward_train(img, trident_img_metas,
+ trident_gt_bboxes, trident_gt_labels)
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..870e2b8477f3c08de2029802a2a567592d9f7541
--- /dev/null
+++ b/mmdet/models/detectors/two_stage.py
@@ -0,0 +1,211 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .base import BaseDetector
+
+
+@DETECTORS.register_module()
+class TwoStageDetector(BaseDetector):
+ """Base class for two-stage detectors.
+
+ Two-stage detectors typically consisting of a region proposal network and a
+ task-specific regression head.
+ """
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ rpn_head=None,
+ roi_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(TwoStageDetector, self).__init__(init_cfg)
+ if pretrained:
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ backbone.pretrained = pretrained
+ self.backbone = build_backbone(backbone)
+
+ if neck is not None:
+ self.neck = build_neck(neck)
+
+ if rpn_head is not None:
+ rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
+ rpn_head_ = rpn_head.copy()
+ rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
+ self.rpn_head = build_head(rpn_head_)
+
+ if roi_head is not None:
+ # update train and test cfg here for now
+ # TODO: refactor assigner & sampler
+ rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
+ roi_head.update(train_cfg=rcnn_train_cfg)
+ roi_head.update(test_cfg=test_cfg.rcnn)
+ roi_head.pretrained = pretrained
+ self.roi_head = build_head(roi_head)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ @property
+ def with_rpn(self):
+ """bool: whether the detector has RPN"""
+ return hasattr(self, 'rpn_head') and self.rpn_head is not None
+
+ @property
+ def with_roi_head(self):
+ """bool: whether the detector has a RoI head"""
+ return hasattr(self, 'roi_head') and self.roi_head is not None
+
+ def extract_feat(self, img):
+ """Directly extract features from the backbone+neck."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ outs = ()
+ # backbone
+ x = self.extract_feat(img)
+ # rpn
+ if self.with_rpn:
+ rpn_outs = self.rpn_head(x)
+ outs = outs + (rpn_outs, )
+ proposals = torch.randn(1000, 4).to(img.device)
+ # roi_head
+ roi_outs = self.roi_head.forward_dummy(x, proposals)
+ outs = outs + (roi_outs, )
+ return outs
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ proposals=None,
+ **kwargs):
+ """
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+
+ gt_labels (list[Tensor]): class indices corresponding to each box
+
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ proposals : override rpn proposals with custom proposals. Use when
+ `with_rpn` is False.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ x = self.extract_feat(img)
+
+ losses = dict()
+
+ # RPN forward and loss
+ if self.with_rpn:
+ proposal_cfg = self.train_cfg.get('rpn_proposal',
+ self.test_cfg.rpn)
+ rpn_losses, proposal_list = self.rpn_head.forward_train(
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ proposal_cfg=proposal_cfg,
+ **kwargs)
+ losses.update(rpn_losses)
+ else:
+ proposal_list = proposals
+
+ roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
+ gt_bboxes, gt_labels,
+ gt_bboxes_ignore, gt_masks,
+ **kwargs)
+ losses.update(roi_losses)
+
+ return losses
+
+ async def async_simple_test(self,
+ img,
+ img_meta,
+ proposals=None,
+ rescale=False):
+ """Async test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ x = self.extract_feat(img)
+
+ if proposals is None:
+ proposal_list = await self.rpn_head.async_simple_test_rpn(
+ x, img_meta)
+ else:
+ proposal_list = proposals
+
+ return await self.roi_head.async_simple_test(
+ x, proposal_list, img_meta, rescale=rescale)
+
+ def simple_test(self, img, img_metas, proposals=None, rescale=False):
+ """Test without augmentation."""
+
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ x = self.extract_feat(img)
+ if proposals is None:
+ proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
+ else:
+ proposal_list = proposals
+
+ return self.roi_head.simple_test(
+ x, proposal_list, img_metas, rescale=rescale)
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ x = self.extract_feats(imgs)
+ proposal_list = self.rpn_head.aug_test_rpn(x, img_metas)
+ return self.roi_head.aug_test(
+ x, proposal_list, img_metas, rescale=rescale)
+
+ def onnx_export(self, img, img_metas):
+
+ img_shape = torch._shape_as_tensor(img)[2:]
+ img_metas[0]['img_shape_for_onnx'] = img_shape
+ x = self.extract_feat(img)
+ proposals = self.rpn_head.onnx_export(x, img_metas)
+ if hasattr(self.roi_head, 'onnx_export'):
+ return self.roi_head.onnx_export(x, proposals, img_metas)
+ else:
+ raise NotImplementedError(
+ f'{self.__class__.__name__} can not '
+ f'be exported to ONNX. Please refer to the '
+ f'list of supported models,'
+ f'https://mmdetection.readthedocs.io/en/latest/tutorials/pytorch2onnx.html#list-of-supported-models-exportable-to-onnx' # noqa E501
+ )
diff --git a/mmdet/models/detectors/vfnet.py b/mmdet/models/detectors/vfnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..38ddcdabd47d8ce886c31f89db7fcb0842a8c35f
--- /dev/null
+++ b/mmdet/models/detectors/vfnet.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class VFNet(SingleStageDetector):
+ """Implementation of `VarifocalNet
+ (VFNet).`_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(VFNet, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/yolact.py b/mmdet/models/detectors/yolact.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ddea0b229df9d661286257e41c37b9028a0fc8f
--- /dev/null
+++ b/mmdet/models/detectors/yolact.py
@@ -0,0 +1,120 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.core import bbox2result
+from ..builder import DETECTORS, build_head
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class YOLACT(SingleStageDetector):
+ """Implementation of `YOLACT `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ segm_head,
+ mask_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(YOLACT, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
+ self.segm_head = build_head(segm_head)
+ self.mask_head = build_head(mask_head)
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ feat = self.extract_feat(img)
+ bbox_outs = self.bbox_head(feat)
+ prototypes = self.mask_head.forward_dummy(feat[0])
+ return (bbox_outs, prototypes)
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None):
+ """
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # convert Bitmap mask or Polygon Mask to Tensor here
+ gt_masks = [
+ gt_mask.to_tensor(dtype=torch.uint8, device=img.device)
+ for gt_mask in gt_masks
+ ]
+
+ x = self.extract_feat(img)
+
+ cls_score, bbox_pred, coeff_pred = self.bbox_head(x)
+ bbox_head_loss_inputs = (cls_score, bbox_pred) + (gt_bboxes, gt_labels,
+ img_metas)
+ losses, sampling_results = self.bbox_head.loss(
+ *bbox_head_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+
+ segm_head_outs = self.segm_head(x[0])
+ loss_segm = self.segm_head.loss(segm_head_outs, gt_masks, gt_labels)
+ losses.update(loss_segm)
+
+ mask_pred = self.mask_head(x[0], coeff_pred, gt_bboxes, img_metas,
+ sampling_results)
+ loss_mask = self.mask_head.loss(mask_pred, gt_masks, gt_bboxes,
+ img_metas, sampling_results)
+ losses.update(loss_mask)
+
+ # check NaN and Inf
+ for loss_name in losses.keys():
+ assert torch.isfinite(torch.stack(losses[loss_name]))\
+ .all().item(), '{} becomes infinite or NaN!'\
+ .format(loss_name)
+
+ return losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test-time augmentation."""
+ feat = self.extract_feat(img)
+ det_bboxes, det_labels, det_coeffs = self.bbox_head.simple_test(
+ feat, img_metas, rescale=rescale)
+ bbox_results = [
+ bbox2result(det_bbox, det_label, self.bbox_head.num_classes)
+ for det_bbox, det_label in zip(det_bboxes, det_labels)
+ ]
+
+ segm_results = self.mask_head.simple_test(
+ feat,
+ det_bboxes,
+ det_labels,
+ det_coeffs,
+ img_metas,
+ rescale=rescale)
+
+ return list(zip(bbox_results, segm_results))
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test with augmentations."""
+ raise NotImplementedError(
+ 'YOLACT does not support test-time augmentation')
diff --git a/mmdet/models/detectors/yolo.py b/mmdet/models/detectors/yolo.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ccd41777a190e48308b390e9c96d60085096d13
--- /dev/null
+++ b/mmdet/models/detectors/yolo.py
@@ -0,0 +1,42 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Copyright (c) 2019 Western Digital Corporation or its affiliates.
+import torch
+
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class YOLOV3(SingleStageDetector):
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
+
+ def onnx_export(self, img, img_metas):
+ """Test function for exporting to ONNX, without test time augmentation.
+
+ Args:
+ img (torch.Tensor): input images.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
+ and class labels of shape [N, num_det].
+ """
+ x = self.extract_feat(img)
+ outs = self.bbox_head.forward(x)
+ # get shape as tensor
+ img_shape = torch._shape_as_tensor(img)[2:]
+ img_metas[0]['img_shape_for_onnx'] = img_shape
+
+ det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas)
+
+ return det_bboxes, det_labels
diff --git a/mmdet/models/detectors/yolof.py b/mmdet/models/detectors/yolof.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bc4f1abd21eb9ad439e5810dc8dce2c4d0d6329
--- /dev/null
+++ b/mmdet/models/detectors/yolof.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class YOLOF(SingleStageDetector):
+ r"""Implementation of `You Only Look One-level Feature
+ `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(YOLOF, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
diff --git a/mmdet/models/detectors/yolox.py b/mmdet/models/detectors/yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..34d51b1482fa55d39ec26e0bcbbe40a4efa661bb
--- /dev/null
+++ b/mmdet/models/detectors/yolox.py
@@ -0,0 +1,136 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import random
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from mmcv.runner import get_dist_info
+
+from ...utils import log_img_scale
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class YOLOX(SingleStageDetector):
+ r"""Implementation of `YOLOX: Exceeding YOLO Series in 2021
+ `_
+
+ Note: Considering the trade-off between training speed and accuracy,
+ multi-scale training is temporarily kept. More elegant implementation
+ will be adopted in the future.
+
+ Args:
+ backbone (nn.Module): The backbone module.
+ neck (nn.Module): The neck module.
+ bbox_head (nn.Module): The bbox head module.
+ train_cfg (obj:`ConfigDict`, optional): The training config
+ of YOLOX. Default: None.
+ test_cfg (obj:`ConfigDict`, optional): The testing config
+ of YOLOX. Default: None.
+ pretrained (str, optional): model pretrained path.
+ Default: None.
+ input_size (tuple): The model default input image size. The shape
+ order should be (height, width). Default: (640, 640).
+ size_multiplier (int): Image size multiplication factor.
+ Default: 32.
+ random_size_range (tuple): The multi-scale random range during
+ multi-scale training. The real training image size will
+ be multiplied by size_multiplier. Default: (15, 25).
+ random_size_interval (int): The iter interval of change
+ image size. Default: 10.
+ init_cfg (dict, optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ input_size=(640, 640),
+ size_multiplier=32,
+ random_size_range=(15, 25),
+ random_size_interval=10,
+ init_cfg=None):
+ super(YOLOX, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained, init_cfg)
+ log_img_scale(input_size, skip_square=True)
+ self.rank, self.world_size = get_dist_info()
+ self._default_input_size = input_size
+ self._input_size = input_size
+ self._random_size_range = random_size_range
+ self._random_size_interval = random_size_interval
+ self._size_multiplier = size_multiplier
+ self._progress_in_iter = 0
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # Multi-scale training
+ img, gt_bboxes = self._preprocess(img, gt_bboxes)
+
+ losses = super(YOLOX, self).forward_train(img, img_metas, gt_bboxes,
+ gt_labels, gt_bboxes_ignore)
+
+ # random resizing
+ if (self._progress_in_iter + 1) % self._random_size_interval == 0:
+ self._input_size = self._random_resize(device=img.device)
+ self._progress_in_iter += 1
+
+ return losses
+
+ def _preprocess(self, img, gt_bboxes):
+ scale_y = self._input_size[0] / self._default_input_size[0]
+ scale_x = self._input_size[1] / self._default_input_size[1]
+ if scale_x != 1 or scale_y != 1:
+ img = F.interpolate(
+ img,
+ size=self._input_size,
+ mode='bilinear',
+ align_corners=False)
+ for gt_bbox in gt_bboxes:
+ gt_bbox[..., 0::2] = gt_bbox[..., 0::2] * scale_x
+ gt_bbox[..., 1::2] = gt_bbox[..., 1::2] * scale_y
+ return img, gt_bboxes
+
+ def _random_resize(self, device):
+ tensor = torch.LongTensor(2).to(device)
+
+ if self.rank == 0:
+ size = random.randint(*self._random_size_range)
+ aspect_ratio = float(
+ self._default_input_size[1]) / self._default_input_size[0]
+ size = (self._size_multiplier * size,
+ self._size_multiplier * int(aspect_ratio * size))
+ tensor[0] = size[0]
+ tensor[1] = size[1]
+
+ if self.world_size > 1:
+ dist.barrier()
+ dist.broadcast(tensor, 0)
+
+ input_size = (tensor[0].item(), tensor[1].item())
+ return input_size
diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..068a54d651b0c8fd13380a67a216e1e7c3629bd7
--- /dev/null
+++ b/mmdet/models/losses/__init__.py
@@ -0,0 +1,32 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .accuracy import Accuracy, accuracy
+from .ae_loss import AssociativeEmbeddingLoss
+from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss
+from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
+ cross_entropy, mask_cross_entropy)
+from .dice_loss import DiceLoss
+from .focal_loss import FocalLoss, sigmoid_focal_loss
+from .gaussian_focal_loss import GaussianFocalLoss
+from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss
+from .ghm_loss import GHMC, GHMR
+from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss,
+ bounded_iou_loss, iou_loss)
+from .kd_loss import KnowledgeDistillationKLDivLoss
+from .mse_loss import MSELoss, mse_loss
+from .pisa_loss import carl_loss, isr_p
+from .seesaw_loss import SeesawLoss
+from .smooth_l1_loss import L1Loss, SmoothL1Loss, l1_loss, smooth_l1_loss
+from .utils import reduce_loss, weight_reduce_loss, weighted_loss
+from .varifocal_loss import VarifocalLoss
+
+__all__ = [
+ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
+ 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss',
+ 'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss',
+ 'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss',
+ 'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'DIoULoss', 'CIoULoss', 'GHMC',
+ 'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss',
+ 'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss',
+ 'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss',
+ 'VarifocalLoss', 'KnowledgeDistillationKLDivLoss', 'SeesawLoss', 'DiceLoss'
+]
diff --git a/mmdet/models/losses/accuracy.py b/mmdet/models/losses/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe765a39f2578bbe3387a087f9f9de9c78f6226f
--- /dev/null
+++ b/mmdet/models/losses/accuracy.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch.nn as nn
+
+
+@mmcv.jit(coderize=True)
+def accuracy(pred, target, topk=1, thresh=None):
+ """Calculate accuracy according to the prediction and target.
+
+ Args:
+ pred (torch.Tensor): The model prediction, shape (N, num_class)
+ target (torch.Tensor): The target of each prediction, shape (N, )
+ topk (int | tuple[int], optional): If the predictions in ``topk``
+ matches the target, the predictions will be regarded as
+ correct ones. Defaults to 1.
+ thresh (float, optional): If not None, predictions with scores under
+ this threshold are considered incorrect. Default to None.
+
+ Returns:
+ float | tuple[float]: If the input ``topk`` is a single integer,
+ the function will return a single float as accuracy. If
+ ``topk`` is a tuple containing multiple integers, the
+ function will return a tuple containing accuracies of
+ each ``topk`` number.
+ """
+ assert isinstance(topk, (int, tuple))
+ if isinstance(topk, int):
+ topk = (topk, )
+ return_single = True
+ else:
+ return_single = False
+
+ maxk = max(topk)
+ if pred.size(0) == 0:
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ return accu[0] if return_single else accu
+ assert pred.ndim == 2 and target.ndim == 1
+ assert pred.size(0) == target.size(0)
+ assert maxk <= pred.size(1), \
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
+ pred_value, pred_label = pred.topk(maxk, dim=1)
+ pred_label = pred_label.t() # transpose to shape (maxk, N)
+ correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
+ if thresh is not None:
+ # Only prediction values larger than thresh are counted as correct
+ correct = correct & (pred_value > thresh).t()
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / pred.size(0)))
+ return res[0] if return_single else res
+
+
+class Accuracy(nn.Module):
+
+ def __init__(self, topk=(1, ), thresh=None):
+ """Module to calculate the accuracy.
+
+ Args:
+ topk (tuple, optional): The criterion used to calculate the
+ accuracy. Defaults to (1,).
+ thresh (float, optional): If not None, predictions with scores
+ under this threshold are considered incorrect. Default to None.
+ """
+ super().__init__()
+ self.topk = topk
+ self.thresh = thresh
+
+ def forward(self, pred, target):
+ """Forward function to calculate accuracy.
+
+ Args:
+ pred (torch.Tensor): Prediction of models.
+ target (torch.Tensor): Target for each prediction.
+
+ Returns:
+ tuple[float]: The accuracies under different topk criterions.
+ """
+ return accuracy(pred, target, self.topk, self.thresh)
diff --git a/mmdet/models/losses/ae_loss.py b/mmdet/models/losses/ae_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c6da22a9ec6ca057359bfb9f1cee6e4bcecfdc1
--- /dev/null
+++ b/mmdet/models/losses/ae_loss.py
@@ -0,0 +1,103 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def ae_loss_per_image(tl_preds, br_preds, match):
+ """Associative Embedding Loss in one image.
+
+ Associative Embedding Loss including two parts: pull loss and push loss.
+ Pull loss makes embedding vectors from same object closer to each other.
+ Push loss distinguish embedding vector from different objects, and makes
+ the gap between them is large enough.
+
+ During computing, usually there are 3 cases:
+ - no object in image: both pull loss and push loss will be 0.
+ - one object in image: push loss will be 0 and pull loss is computed
+ by the two corner of the only object.
+ - more than one objects in image: pull loss is computed by corner pairs
+ from each object, push loss is computed by each object with all
+ other objects. We use confusion matrix with 0 in diagonal to
+ compute the push loss.
+
+ Args:
+ tl_preds (tensor): Embedding feature map of left-top corner.
+ br_preds (tensor): Embedding feature map of bottim-right corner.
+ match (list): Downsampled coordinates pair of each ground truth box.
+ """
+
+ tl_list, br_list, me_list = [], [], []
+ if len(match) == 0: # no object in image
+ pull_loss = tl_preds.sum() * 0.
+ push_loss = tl_preds.sum() * 0.
+ else:
+ for m in match:
+ [tl_y, tl_x], [br_y, br_x] = m
+ tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1)
+ br_e = br_preds[:, br_y, br_x].view(-1, 1)
+ tl_list.append(tl_e)
+ br_list.append(br_e)
+ me_list.append((tl_e + br_e) / 2.0)
+
+ tl_list = torch.cat(tl_list)
+ br_list = torch.cat(br_list)
+ me_list = torch.cat(me_list)
+
+ assert tl_list.size() == br_list.size()
+
+ # N is object number in image, M is dimension of embedding vector
+ N, M = tl_list.size()
+
+ pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2)
+ pull_loss = pull_loss.sum() / N
+
+ margin = 1 # exp setting of CornerNet, details in section 3.3 of paper
+
+ # confusion matrix of push loss
+ conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list
+ conf_weight = 1 - torch.eye(N).type_as(me_list)
+ conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs())
+
+ if N > 1: # more than one object in current image
+ push_loss = F.relu(conf_mat).sum() / (N * (N - 1))
+ else:
+ push_loss = tl_preds.sum() * 0.
+
+ return pull_loss, push_loss
+
+
+@LOSSES.register_module()
+class AssociativeEmbeddingLoss(nn.Module):
+ """Associative Embedding Loss.
+
+ More details can be found in
+ `Associative Embedding `_ and
+ `CornerNet `_ .
+ Code is modified from `kp_utils.py `_ # noqa: E501
+
+ Args:
+ pull_weight (float): Loss weight for corners from same object.
+ push_weight (float): Loss weight for corners from different object.
+ """
+
+ def __init__(self, pull_weight=0.25, push_weight=0.25):
+ super(AssociativeEmbeddingLoss, self).__init__()
+ self.pull_weight = pull_weight
+ self.push_weight = push_weight
+
+ def forward(self, pred, target, match):
+ """Forward function."""
+ batch = pred.size(0)
+ pull_all, push_all = 0.0, 0.0
+ for i in range(batch):
+ pull, push = ae_loss_per_image(pred[i], target[i], match[i])
+
+ pull_all += self.pull_weight * pull
+ push_all += self.push_weight * push
+
+ return pull_all, push_all
diff --git a/mmdet/models/losses/balanced_l1_loss.py b/mmdet/models/losses/balanced_l1_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8500345f0e41e8d98f75c4616c70eee8bce4473f
--- /dev/null
+++ b/mmdet/models/losses/balanced_l1_loss.py
@@ -0,0 +1,124 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def balanced_l1_loss(pred,
+ target,
+ beta=1.0,
+ alpha=0.5,
+ gamma=1.5,
+ reduction='mean'):
+ """Calculate balanced L1 loss.
+
+ Please see the `Libra R-CNN `_
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 4).
+ target (torch.Tensor): The learning target of the prediction with
+ shape (N, 4).
+ beta (float): The loss is a piecewise function of prediction and target
+ and ``beta`` serves as a threshold for the difference between the
+ prediction and target. Defaults to 1.0.
+ alpha (float): The denominator ``alpha`` in the balanced L1 loss.
+ Defaults to 0.5.
+ gamma (float): The ``gamma`` in the balanced L1 loss.
+ Defaults to 1.5.
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert beta > 0
+ if target.numel() == 0:
+ return pred.sum() * 0
+
+ assert pred.size() == target.size()
+
+ diff = torch.abs(pred - target)
+ b = np.e**(gamma / alpha) - 1
+ loss = torch.where(
+ diff < beta, alpha / b *
+ (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
+ gamma * diff + gamma / b - alpha * beta)
+
+ return loss
+
+
+@LOSSES.register_module()
+class BalancedL1Loss(nn.Module):
+ """Balanced L1 Loss.
+
+ arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
+
+ Args:
+ alpha (float): The denominator ``alpha`` in the balanced L1 loss.
+ Defaults to 0.5.
+ gamma (float): The ``gamma`` in the balanced L1 loss. Defaults to 1.5.
+ beta (float, optional): The loss is a piecewise function of prediction
+ and target. ``beta`` serves as a threshold for the difference
+ between the prediction and target. Defaults to 1.0.
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+
+ def __init__(self,
+ alpha=0.5,
+ gamma=1.5,
+ beta=1.0,
+ reduction='mean',
+ loss_weight=1.0):
+ super(BalancedL1Loss, self).__init__()
+ self.alpha = alpha
+ self.gamma = gamma
+ self.beta = beta
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function of loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 4).
+ target (torch.Tensor): The learning target of the prediction with
+ shape (N, 4).
+ weight (torch.Tensor, optional): Sample-wise loss weight with
+ shape (N, ).
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_bbox = self.loss_weight * balanced_l1_loss(
+ pred,
+ target,
+ weight,
+ alpha=self.alpha,
+ gamma=self.gamma,
+ beta=self.beta,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_bbox
diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..41411fc5456970d1aad9c11a58c2e4988a5a7440
--- /dev/null
+++ b/mmdet/models/losses/cross_entropy_loss.py
@@ -0,0 +1,301 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weight_reduce_loss
+
+
+def cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=-100,
+ avg_non_ignore=False):
+ """Calculate the CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int | None): The label index to be ignored.
+ If None, it will be set to default value. Default: -100.
+ avg_non_ignore (bool): The flag decides to whether the loss is
+ only averaged over non-ignored targets. Default: False.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ # The default value of ignore_index is the same as F.cross_entropy
+ ignore_index = -100 if ignore_index is None else ignore_index
+ # element-wise losses
+ loss = F.cross_entropy(
+ pred,
+ label,
+ weight=class_weight,
+ reduction='none',
+ ignore_index=ignore_index)
+
+ # average loss over non-ignored elements
+ # pytorch's official cross_entropy average loss over non-ignored elements
+ # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
+ if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
+ avg_factor = label.numel() - (label == ignore_index).sum().item()
+
+ # apply weights and do the reduction
+ if weight is not None:
+ weight = weight.float()
+ loss = weight_reduce_loss(
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
+ """Expand onehot labels to match the size of prediction."""
+ bin_labels = labels.new_full((labels.size(0), label_channels), 0)
+ valid_mask = (labels >= 0) & (labels != ignore_index)
+ inds = torch.nonzero(
+ valid_mask & (labels < label_channels), as_tuple=False)
+
+ if inds.numel() > 0:
+ bin_labels[inds, labels[inds]] = 1
+
+ valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
+ label_channels).float()
+ if label_weights is None:
+ bin_label_weights = valid_mask
+ else:
+ bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
+ bin_label_weights *= valid_mask
+
+ return bin_labels, bin_label_weights, valid_mask
+
+
+def binary_cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=-100,
+ avg_non_ignore=False):
+ """Calculate the binary CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
+ When the shape of pred is (N, 1), label will be expanded to
+ one-hot format, and when the shape of pred is (N, ), label
+ will not be expanded to one-hot format.
+ label (torch.Tensor): The learning label of the prediction,
+ with shape (N, ).
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int | None): The label index to be ignored.
+ If None, it will be set to default value. Default: -100.
+ avg_non_ignore (bool): The flag decides to whether the loss is
+ only averaged over non-ignored targets. Default: False.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ # The default value of ignore_index is the same as F.cross_entropy
+ ignore_index = -100 if ignore_index is None else ignore_index
+
+ if pred.dim() != label.dim():
+ label, weight, valid_mask = _expand_onehot_labels(
+ label, weight, pred.size(-1), ignore_index)
+ else:
+ # should mask out the ignored elements
+ valid_mask = ((label >= 0) & (label != ignore_index)).float()
+ if weight is not None:
+ # The inplace writing method will have a mismatched broadcast
+ # shape error if the weight and valid_mask dimensions
+ # are inconsistent such as (B,N,1) and (B,N,C).
+ weight = weight * valid_mask
+ else:
+ weight = valid_mask
+
+ # average loss over non-ignored elements
+ if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
+ avg_factor = valid_mask.sum().item()
+
+ # weighted element-wise losses
+ weight = weight.float()
+ loss = F.binary_cross_entropy_with_logits(
+ pred, label.float(), pos_weight=class_weight, reduction='none')
+ # do the reduction for the weighted loss
+ loss = weight_reduce_loss(
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def mask_cross_entropy(pred,
+ target,
+ label,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=None,
+ **kwargs):
+ """Calculate the CrossEntropy loss for masks.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C, *), C is the
+ number of classes. The trailing * indicates arbitrary shape.
+ target (torch.Tensor): The learning label of the prediction.
+ label (torch.Tensor): ``label`` indicates the class label of the mask
+ corresponding object. This will be used to select the mask in the
+ of the class which the object belongs to when the mask prediction
+ if not class-agnostic.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (None): Placeholder, to be consistent with other loss.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+
+ Example:
+ >>> N, C = 3, 11
+ >>> H, W = 2, 2
+ >>> pred = torch.randn(N, C, H, W) * 1000
+ >>> target = torch.rand(N, H, W)
+ >>> label = torch.randint(0, C, size=(N,))
+ >>> reduction = 'mean'
+ >>> avg_factor = None
+ >>> class_weights = None
+ >>> loss = mask_cross_entropy(pred, target, label, reduction,
+ >>> avg_factor, class_weights)
+ >>> assert loss.shape == (1,)
+ """
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
+ # TODO: handle these two reserved arguments
+ assert reduction == 'mean' and avg_factor is None
+ num_rois = pred.size()[0]
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+ pred_slice = pred[inds, label].squeeze(1)
+ return F.binary_cross_entropy_with_logits(
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
+
+
+@LOSSES.register_module()
+class CrossEntropyLoss(nn.Module):
+
+ def __init__(self,
+ use_sigmoid=False,
+ use_mask=False,
+ reduction='mean',
+ class_weight=None,
+ ignore_index=None,
+ loss_weight=1.0,
+ avg_non_ignore=False):
+ """CrossEntropyLoss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to False.
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
+ Defaults to False.
+ reduction (str, optional): . Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ class_weight (list[float], optional): Weight of each class.
+ Defaults to None.
+ ignore_index (int | None): The label index to be ignored.
+ Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ avg_non_ignore (bool): The flag decides to whether the loss is
+ only averaged over non-ignored targets. Default: False.
+ """
+ super(CrossEntropyLoss, self).__init__()
+ assert (use_sigmoid is False) or (use_mask is False)
+ self.use_sigmoid = use_sigmoid
+ self.use_mask = use_mask
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = class_weight
+ self.ignore_index = ignore_index
+ self.avg_non_ignore = avg_non_ignore
+ if ((ignore_index is not None) and not self.avg_non_ignore
+ and self.reduction == 'mean'):
+ warnings.warn(
+ 'Default ``avg_non_ignore`` is False, if you would like to '
+ 'ignore the certain label and average loss over non-ignore '
+ 'labels, which is the same with PyTorch official '
+ 'cross_entropy, set ``avg_non_ignore=True``.')
+
+ if self.use_sigmoid:
+ self.cls_criterion = binary_cross_entropy
+ elif self.use_mask:
+ self.cls_criterion = mask_cross_entropy
+ else:
+ self.cls_criterion = cross_entropy
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'avg_non_ignore={self.avg_non_ignore}'
+ return s
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ ignore_index=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ cls_score (torch.Tensor): The prediction.
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The method used to reduce the
+ loss. Options are "none", "mean" and "sum".
+ ignore_index (int | None): The label index to be ignored.
+ If not None, it will override the default value. Default: None.
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if ignore_index is None:
+ ignore_index = self.ignore_index
+
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(
+ self.class_weight, device=cls_score.device)
+ else:
+ class_weight = None
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ weight,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ ignore_index=ignore_index,
+ avg_non_ignore=self.avg_non_ignore,
+ **kwargs)
+ return loss_cls
diff --git a/mmdet/models/losses/dice_loss.py b/mmdet/models/losses/dice_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..585beeaf1c6bb86205f40c73a54e2826edc1fe5d
--- /dev/null
+++ b/mmdet/models/losses/dice_loss.py
@@ -0,0 +1,146 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from ..builder import LOSSES
+from .utils import weight_reduce_loss
+
+
+def dice_loss(pred,
+ target,
+ weight=None,
+ eps=1e-3,
+ reduction='mean',
+ naive_dice=False,
+ avg_factor=None):
+ """Calculate dice loss, there are two forms of dice loss is supported:
+
+ - the one proposed in `V-Net: Fully Convolutional Neural
+ Networks for Volumetric Medical Image Segmentation
+ `_.
+ - the dice loss in which the power of the number in the
+ denominator is the first power instead of the second
+ power.
+
+ Args:
+ pred (torch.Tensor): The prediction, has a shape (n, *)
+ target (torch.Tensor): The learning label of the prediction,
+ shape (n, *), same shape of pred.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction, has a shape (n,). Defaults to None.
+ eps (float): Avoid dividing by zero. Default: 1e-3.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ naive_dice (bool, optional): If false, use the dice
+ loss defined in the V-Net paper, otherwise, use the
+ naive dice loss in which the power of the number in the
+ denominator is the first power instead of the second
+ power.Defaults to False.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+
+ input = pred.flatten(1)
+ target = target.flatten(1).float()
+
+ a = torch.sum(input * target, 1)
+ if naive_dice:
+ b = torch.sum(input, 1)
+ c = torch.sum(target, 1)
+ d = (2 * a + eps) / (b + c + eps)
+ else:
+ b = torch.sum(input * input, 1) + eps
+ c = torch.sum(target * target, 1) + eps
+ d = (2 * a) / (b + c)
+
+ loss = 1 - d
+ if weight is not None:
+ assert weight.ndim == loss.ndim
+ assert len(weight) == len(pred)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+@LOSSES.register_module()
+class DiceLoss(nn.Module):
+
+ def __init__(self,
+ use_sigmoid=True,
+ activate=True,
+ reduction='mean',
+ naive_dice=False,
+ loss_weight=1.0,
+ eps=1e-3):
+ """Compute dice loss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether to the prediction is
+ used for sigmoid or softmax. Defaults to True.
+ activate (bool): Whether to activate the predictions inside,
+ this will disable the inside sigmoid operation.
+ Defaults to True.
+ reduction (str, optional): The method used
+ to reduce the loss. Options are "none",
+ "mean" and "sum". Defaults to 'mean'.
+ naive_dice (bool, optional): If false, use the dice
+ loss defined in the V-Net paper, otherwise, use the
+ naive dice loss in which the power of the number in the
+ denominator is the first power instead of the second
+ power. Defaults to False.
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+ eps (float): Avoid dividing by zero. Defaults to 1e-3.
+ """
+
+ super(DiceLoss, self).__init__()
+ self.use_sigmoid = use_sigmoid
+ self.reduction = reduction
+ self.naive_dice = naive_dice
+ self.loss_weight = loss_weight
+ self.eps = eps
+ self.activate = activate
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ reduction_override=None,
+ avg_factor=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction, has a shape (n, *).
+ target (torch.Tensor): The label of the prediction,
+ shape (n, *), same shape of pred.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction, has a shape (n,). Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+
+ if self.activate:
+ if self.use_sigmoid:
+ pred = pred.sigmoid()
+ else:
+ raise NotImplementedError
+
+ loss = self.loss_weight * dice_loss(
+ pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ naive_dice=self.naive_dice,
+ avg_factor=avg_factor)
+
+ return loss
diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..2858c198101c75942d6cc9d18e275dbd6ab359dd
--- /dev/null
+++ b/mmdet/models/losses/focal_loss.py
@@ -0,0 +1,244 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
+
+from ..builder import LOSSES
+from .utils import weight_reduce_loss
+
+
+# This method is only for debugging
+def py_sigmoid_focal_loss(pred,
+ target,
+ weight=None,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ avg_factor=None):
+ """PyTorch version of `Focal Loss `_.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the
+ number of classes
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ pred_sigmoid = pred.sigmoid()
+ target = target.type_as(pred)
+ pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
+ focal_weight = (alpha * target + (1 - alpha) *
+ (1 - target)) * pt.pow(gamma)
+ loss = F.binary_cross_entropy_with_logits(
+ pred, target, reduction='none') * focal_weight
+ if weight is not None:
+ if weight.shape != loss.shape:
+ if weight.size(0) == loss.size(0):
+ # For most cases, weight is of shape (num_priors, ),
+ # which means it does not have the second axis num_class
+ weight = weight.view(-1, 1)
+ else:
+ # Sometimes, weight per anchor per class is also needed. e.g.
+ # in FSAF. But it may be flattened of shape
+ # (num_priors x num_class, ), while loss is still of shape
+ # (num_priors, num_class).
+ assert weight.numel() == loss.numel()
+ weight = weight.view(loss.size(0), -1)
+ assert weight.ndim == loss.ndim
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+def py_focal_loss_with_prob(pred,
+ target,
+ weight=None,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ avg_factor=None):
+ """PyTorch version of `Focal Loss `_.
+ Different from `py_sigmoid_focal_loss`, this function accepts probability
+ as input.
+
+ Args:
+ pred (torch.Tensor): The prediction probability with shape (N, C),
+ C is the number of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ num_classes = pred.size(1)
+ target = F.one_hot(target, num_classes=num_classes + 1)
+ target = target[:, :num_classes]
+
+ target = target.type_as(pred)
+ pt = (1 - pred) * target + pred * (1 - target)
+ focal_weight = (alpha * target + (1 - alpha) *
+ (1 - target)) * pt.pow(gamma)
+ loss = F.binary_cross_entropy(
+ pred, target, reduction='none') * focal_weight
+ if weight is not None:
+ if weight.shape != loss.shape:
+ if weight.size(0) == loss.size(0):
+ # For most cases, weight is of shape (num_priors, ),
+ # which means it does not have the second axis num_class
+ weight = weight.view(-1, 1)
+ else:
+ # Sometimes, weight per anchor per class is also needed. e.g.
+ # in FSAF. But it may be flattened of shape
+ # (num_priors x num_class, ), while loss is still of shape
+ # (num_priors, num_class).
+ assert weight.numel() == loss.numel()
+ weight = weight.view(loss.size(0), -1)
+ assert weight.ndim == loss.ndim
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+def sigmoid_focal_loss(pred,
+ target,
+ weight=None,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ avg_factor=None):
+ r"""A wrapper of cuda version `Focal Loss
+ `_.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ # Function.apply does not accept keyword arguments, so the decorator
+ # "weighted_loss" is not applicable
+ loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
+ alpha, None, 'none')
+ if weight is not None:
+ if weight.shape != loss.shape:
+ if weight.size(0) == loss.size(0):
+ # For most cases, weight is of shape (num_priors, ),
+ # which means it does not have the second axis num_class
+ weight = weight.view(-1, 1)
+ else:
+ # Sometimes, weight per anchor per class is also needed. e.g.
+ # in FSAF. But it may be flattened of shape
+ # (num_priors x num_class, ), while loss is still of shape
+ # (num_priors, num_class).
+ assert weight.numel() == loss.numel()
+ weight = weight.view(loss.size(0), -1)
+ assert weight.ndim == loss.ndim
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+@LOSSES.register_module()
+class FocalLoss(nn.Module):
+
+ def __init__(self,
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ loss_weight=1.0,
+ activated=False):
+ """`Focal Loss `_
+
+ Args:
+ use_sigmoid (bool, optional): Whether to the prediction is
+ used for sigmoid or softmax. Defaults to True.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and
+ "sum".
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+ activated (bool, optional): Whether the input is activated.
+ If True, it means the input has been activated and can be
+ treated as probabilities. Else, it should be treated as logits.
+ Defaults to False.
+ """
+ super(FocalLoss, self).__init__()
+ assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
+ self.use_sigmoid = use_sigmoid
+ self.gamma = gamma
+ self.alpha = alpha
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.activated = activated
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.use_sigmoid:
+ if self.activated:
+ calculate_loss_func = py_focal_loss_with_prob
+ else:
+ if torch.cuda.is_available() and pred.is_cuda:
+ calculate_loss_func = sigmoid_focal_loss
+ else:
+ num_classes = pred.size(1)
+ target = F.one_hot(target, num_classes=num_classes + 1)
+ target = target[:, :num_classes]
+ calculate_loss_func = py_sigmoid_focal_loss
+
+ loss_cls = self.loss_weight * calculate_loss_func(
+ pred,
+ target,
+ weight,
+ gamma=self.gamma,
+ alpha=self.alpha,
+ reduction=reduction,
+ avg_factor=avg_factor)
+
+ else:
+ raise NotImplementedError
+ return loss_cls
diff --git a/mmdet/models/losses/gaussian_focal_loss.py b/mmdet/models/losses/gaussian_focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7abcb691acbfbb300597a72fcce67ca3b5e9f2f2
--- /dev/null
+++ b/mmdet/models/losses/gaussian_focal_loss.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch.nn as nn
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def gaussian_focal_loss(pred, gaussian_target, alpha=2.0, gamma=4.0):
+ """`Focal Loss `_ for targets in gaussian
+ distribution.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ gaussian_target (torch.Tensor): The learning target of the prediction
+ in gaussian distribution.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 2.0.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 4.0.
+ """
+ eps = 1e-12
+ pos_weights = gaussian_target.eq(1)
+ neg_weights = (1 - gaussian_target).pow(gamma)
+ pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights
+ neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
+ return pos_loss + neg_loss
+
+
+@LOSSES.register_module()
+class GaussianFocalLoss(nn.Module):
+ """GaussianFocalLoss is a variant of focal loss.
+
+ More details can be found in the `paper
+ `_
+ Code is modified from `kp_utils.py
+ `_ # noqa: E501
+ Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
+ not 0/1 binary target.
+
+ Args:
+ alpha (float): Power of prediction.
+ gamma (float): Power of target for negative samples.
+ reduction (str): Options are "none", "mean" and "sum".
+ loss_weight (float): Loss weight of current loss.
+ """
+
+ def __init__(self,
+ alpha=2.0,
+ gamma=4.0,
+ reduction='mean',
+ loss_weight=1.0):
+ super(GaussianFocalLoss, self).__init__()
+ self.alpha = alpha
+ self.gamma = gamma
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction
+ in gaussian distribution.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_reg = self.loss_weight * gaussian_focal_loss(
+ pred,
+ target,
+ weight,
+ alpha=self.alpha,
+ gamma=self.gamma,
+ reduction=reduction,
+ avg_factor=avg_factor)
+ return loss_reg
diff --git a/mmdet/models/losses/gfocal_loss.py b/mmdet/models/losses/gfocal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e8d26373f83f35ad032322d96cdbac995be2749
--- /dev/null
+++ b/mmdet/models/losses/gfocal_loss.py
@@ -0,0 +1,245 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def quality_focal_loss(pred, target, beta=2.0):
+ r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
+ Qualified and Distributed Bounding Boxes for Dense Object Detection
+ `_.
+
+ Args:
+ pred (torch.Tensor): Predicted joint representation of classification
+ and quality (IoU) estimation with shape (N, C), C is the number of
+ classes.
+ target (tuple([torch.Tensor])): Target category label with shape (N,)
+ and target quality label with shape (N,).
+ beta (float): The beta parameter for calculating the modulating factor.
+ Defaults to 2.0.
+
+ Returns:
+ torch.Tensor: Loss tensor with shape (N,).
+ """
+ assert len(target) == 2, """target for QFL must be a tuple of two elements,
+ including category label and quality label, respectively"""
+ # label denotes the category id, score denotes the quality score
+ label, score = target
+
+ # negatives are supervised by 0 quality score
+ pred_sigmoid = pred.sigmoid()
+ scale_factor = pred_sigmoid
+ zerolabel = scale_factor.new_zeros(pred.shape)
+ loss = F.binary_cross_entropy_with_logits(
+ pred, zerolabel, reduction='none') * scale_factor.pow(beta)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = pred.size(1)
+ pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+ pos_label = label[pos].long()
+ # positives are supervised by bbox quality (IoU) score
+ scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
+ loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
+ pred[pos, pos_label], score[pos],
+ reduction='none') * scale_factor.abs().pow(beta)
+
+ loss = loss.sum(dim=1, keepdim=False)
+ return loss
+
+
+@weighted_loss
+def quality_focal_loss_with_prob(pred, target, beta=2.0):
+ r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
+ Qualified and Distributed Bounding Boxes for Dense Object Detection
+ `_.
+ Different from `quality_focal_loss`, this function accepts probability
+ as input.
+
+ Args:
+ pred (torch.Tensor): Predicted joint representation of classification
+ and quality (IoU) estimation with shape (N, C), C is the number of
+ classes.
+ target (tuple([torch.Tensor])): Target category label with shape (N,)
+ and target quality label with shape (N,).
+ beta (float): The beta parameter for calculating the modulating factor.
+ Defaults to 2.0.
+
+ Returns:
+ torch.Tensor: Loss tensor with shape (N,).
+ """
+ assert len(target) == 2, """target for QFL must be a tuple of two elements,
+ including category label and quality label, respectively"""
+ # label denotes the category id, score denotes the quality score
+ label, score = target
+
+ # negatives are supervised by 0 quality score
+ pred_sigmoid = pred
+ scale_factor = pred_sigmoid
+ zerolabel = scale_factor.new_zeros(pred.shape)
+ loss = F.binary_cross_entropy(
+ pred, zerolabel, reduction='none') * scale_factor.pow(beta)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = pred.size(1)
+ pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+ pos_label = label[pos].long()
+ # positives are supervised by bbox quality (IoU) score
+ scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
+ loss[pos, pos_label] = F.binary_cross_entropy(
+ pred[pos, pos_label], score[pos],
+ reduction='none') * scale_factor.abs().pow(beta)
+
+ loss = loss.sum(dim=1, keepdim=False)
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def distribution_focal_loss(pred, label):
+ r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
+ Qualified and Distributed Bounding Boxes for Dense Object Detection
+ `_.
+
+ Args:
+ pred (torch.Tensor): Predicted general distribution of bounding boxes
+ (before softmax) with shape (N, n+1), n is the max value of the
+ integral set `{0, ..., n}` in paper.
+ label (torch.Tensor): Target distance label for bounding boxes with
+ shape (N,).
+
+ Returns:
+ torch.Tensor: Loss tensor with shape (N,).
+ """
+ dis_left = label.long()
+ dis_right = dis_left + 1
+ weight_left = dis_right.float() - label
+ weight_right = label - dis_left.float()
+ loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \
+ + F.cross_entropy(pred, dis_right, reduction='none') * weight_right
+ return loss
+
+
+@LOSSES.register_module()
+class QualityFocalLoss(nn.Module):
+ r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss:
+ Learning Qualified and Distributed Bounding Boxes for Dense Object
+ Detection `_.
+
+ Args:
+ use_sigmoid (bool): Whether sigmoid operation is conducted in QFL.
+ Defaults to True.
+ beta (float): The beta parameter for calculating the modulating factor.
+ Defaults to 2.0.
+ reduction (str): Options are "none", "mean" and "sum".
+ loss_weight (float): Loss weight of current loss.
+ activated (bool, optional): Whether the input is activated.
+ If True, it means the input has been activated and can be
+ treated as probabilities. Else, it should be treated as logits.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ use_sigmoid=True,
+ beta=2.0,
+ reduction='mean',
+ loss_weight=1.0,
+ activated=False):
+ super(QualityFocalLoss, self).__init__()
+ assert use_sigmoid is True, 'Only sigmoid in QFL supported now.'
+ self.use_sigmoid = use_sigmoid
+ self.beta = beta
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.activated = activated
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): Predicted joint representation of
+ classification and quality (IoU) estimation with shape (N, C),
+ C is the number of classes.
+ target (tuple([torch.Tensor])): Target category label with shape
+ (N,) and target quality label with shape (N,).
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.use_sigmoid:
+ if self.activated:
+ calculate_loss_func = quality_focal_loss_with_prob
+ else:
+ calculate_loss_func = quality_focal_loss
+ loss_cls = self.loss_weight * calculate_loss_func(
+ pred,
+ target,
+ weight,
+ beta=self.beta,
+ reduction=reduction,
+ avg_factor=avg_factor)
+ else:
+ raise NotImplementedError
+ return loss_cls
+
+
+@LOSSES.register_module()
+class DistributionFocalLoss(nn.Module):
+ r"""Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss:
+ Learning Qualified and Distributed Bounding Boxes for Dense Object
+ Detection `_.
+
+ Args:
+ reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
+ loss_weight (float): Loss weight of current loss.
+ """
+
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super(DistributionFocalLoss, self).__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): Predicted general distribution of bounding
+ boxes (before softmax) with shape (N, n+1), n is the max value
+ of the integral set `{0, ..., n}` in paper.
+ target (torch.Tensor): Target distance label for bounding boxes
+ with shape (N,).
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_cls = self.loss_weight * distribution_focal_loss(
+ pred, target, weight, reduction=reduction, avg_factor=avg_factor)
+ return loss_cls
diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4df9fe8e17c9f8aea75f4e995db491e929bd206
--- /dev/null
+++ b/mmdet/models/losses/ghm_loss.py
@@ -0,0 +1,213 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weight_reduce_loss
+
+
+def _expand_onehot_labels(labels, label_weights, label_channels):
+ bin_labels = labels.new_full((labels.size(0), label_channels), 0)
+ inds = torch.nonzero(
+ (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()
+ if inds.numel() > 0:
+ bin_labels[inds, labels[inds]] = 1
+ bin_label_weights = label_weights.view(-1, 1).expand(
+ label_weights.size(0), label_channels)
+ return bin_labels, bin_label_weights
+
+
+# TODO: code refactoring to make it consistent with other losses
+@LOSSES.register_module()
+class GHMC(nn.Module):
+ """GHM Classification Loss.
+
+ Details of the theorem can be viewed in the paper
+ `Gradient Harmonized Single-stage Detector
+ `_.
+
+ Args:
+ bins (int): Number of the unit regions for distribution calculation.
+ momentum (float): The parameter for moving average.
+ use_sigmoid (bool): Can only be true for BCE based loss now.
+ loss_weight (float): The weight of the total GHM-C loss.
+ reduction (str): Options are "none", "mean" and "sum".
+ Defaults to "mean"
+ """
+
+ def __init__(self,
+ bins=10,
+ momentum=0,
+ use_sigmoid=True,
+ loss_weight=1.0,
+ reduction='mean'):
+ super(GHMC, self).__init__()
+ self.bins = bins
+ self.momentum = momentum
+ edges = torch.arange(bins + 1).float() / bins
+ self.register_buffer('edges', edges)
+ self.edges[-1] += 1e-6
+ if momentum > 0:
+ acc_sum = torch.zeros(bins)
+ self.register_buffer('acc_sum', acc_sum)
+ self.use_sigmoid = use_sigmoid
+ if not self.use_sigmoid:
+ raise NotImplementedError
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self,
+ pred,
+ target,
+ label_weight,
+ reduction_override=None,
+ **kwargs):
+ """Calculate the GHM-C loss.
+
+ Args:
+ pred (float tensor of size [batch_num, class_num]):
+ The direct prediction of classification fc layer.
+ target (float tensor of size [batch_num, class_num]):
+ Binary class target for each sample.
+ label_weight (float tensor of size [batch_num, class_num]):
+ the value is 1 if the sample is valid and 0 if ignored.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ Returns:
+ The gradient harmonized loss.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ # the target should be binary class label
+ if pred.dim() != target.dim():
+ target, label_weight = _expand_onehot_labels(
+ target, label_weight, pred.size(-1))
+ target, label_weight = target.float(), label_weight.float()
+ edges = self.edges
+ mmt = self.momentum
+ weights = torch.zeros_like(pred)
+
+ # gradient length
+ g = torch.abs(pred.sigmoid().detach() - target)
+
+ valid = label_weight > 0
+ tot = max(valid.float().sum().item(), 1.0)
+ n = 0 # n valid bins
+ for i in range(self.bins):
+ inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
+ num_in_bin = inds.sum().item()
+ if num_in_bin > 0:
+ if mmt > 0:
+ self.acc_sum[i] = mmt * self.acc_sum[i] \
+ + (1 - mmt) * num_in_bin
+ weights[inds] = tot / self.acc_sum[i]
+ else:
+ weights[inds] = tot / num_in_bin
+ n += 1
+ if n > 0:
+ weights = weights / n
+
+ loss = F.binary_cross_entropy_with_logits(
+ pred, target, reduction='none')
+ loss = weight_reduce_loss(
+ loss, weights, reduction=reduction, avg_factor=tot)
+ return loss * self.loss_weight
+
+
+# TODO: code refactoring to make it consistent with other losses
+@LOSSES.register_module()
+class GHMR(nn.Module):
+ """GHM Regression Loss.
+
+ Details of the theorem can be viewed in the paper
+ `Gradient Harmonized Single-stage Detector
+ `_.
+
+ Args:
+ mu (float): The parameter for the Authentic Smooth L1 loss.
+ bins (int): Number of the unit regions for distribution calculation.
+ momentum (float): The parameter for moving average.
+ loss_weight (float): The weight of the total GHM-R loss.
+ reduction (str): Options are "none", "mean" and "sum".
+ Defaults to "mean"
+ """
+
+ def __init__(self,
+ mu=0.02,
+ bins=10,
+ momentum=0,
+ loss_weight=1.0,
+ reduction='mean'):
+ super(GHMR, self).__init__()
+ self.mu = mu
+ self.bins = bins
+ edges = torch.arange(bins + 1).float() / bins
+ self.register_buffer('edges', edges)
+ self.edges[-1] = 1e3
+ self.momentum = momentum
+ if momentum > 0:
+ acc_sum = torch.zeros(bins)
+ self.register_buffer('acc_sum', acc_sum)
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ # TODO: support reduction parameter
+ def forward(self,
+ pred,
+ target,
+ label_weight,
+ avg_factor=None,
+ reduction_override=None):
+ """Calculate the GHM-R loss.
+
+ Args:
+ pred (float tensor of size [batch_num, 4 (* class_num)]):
+ The prediction of box regression layer. Channel number can be 4
+ or 4 * class_num depending on whether it is class-agnostic.
+ target (float tensor of size [batch_num, 4 (* class_num)]):
+ The target regression values with the same size of pred.
+ label_weight (float tensor of size [batch_num, 4 (* class_num)]):
+ The weight of each sample, 0 if ignored.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ Returns:
+ The gradient harmonized loss.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ mu = self.mu
+ edges = self.edges
+ mmt = self.momentum
+
+ # ASL1 loss
+ diff = pred - target
+ loss = torch.sqrt(diff * diff + mu * mu) - mu
+
+ # gradient length
+ g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()
+ weights = torch.zeros_like(g)
+
+ valid = label_weight > 0
+ tot = max(label_weight.float().sum().item(), 1.0)
+ n = 0 # n: valid bins
+ for i in range(self.bins):
+ inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
+ num_in_bin = inds.sum().item()
+ if num_in_bin > 0:
+ n += 1
+ if mmt > 0:
+ self.acc_sum[i] = mmt * self.acc_sum[i] \
+ + (1 - mmt) * num_in_bin
+ weights[inds] = tot / self.acc_sum[i]
+ else:
+ weights[inds] = tot / num_in_bin
+ if n > 0:
+ weights /= n
+ loss = weight_reduce_loss(
+ loss, weights, reduction=reduction, avg_factor=tot)
+ return loss * self.loss_weight
diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf1ed04e1903d19ee339bd131b897df5b51d311a
--- /dev/null
+++ b/mmdet/models/losses/iou_loss.py
@@ -0,0 +1,474 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+
+import mmcv
+import torch
+import torch.nn as nn
+
+from mmdet.core import bbox_overlaps
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def iou_loss(pred, target, linear=False, mode='log', eps=1e-6):
+ """IoU loss.
+
+ Computing the IoU loss between a set of predicted bboxes and target bboxes.
+ The loss is calculated as negative log of IoU.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (torch.Tensor): Corresponding gt bboxes, shape (n, 4).
+ linear (bool, optional): If True, use linear scale of loss instead of
+ log scale. Default: False.
+ mode (str): Loss scaling mode, including "linear", "square", and "log".
+ Default: 'log'
+ eps (float): Eps to avoid log(0).
+
+ Return:
+ torch.Tensor: Loss tensor.
+ """
+ assert mode in ['linear', 'square', 'log']
+ if linear:
+ mode = 'linear'
+ warnings.warn('DeprecationWarning: Setting "linear=True" in '
+ 'iou_loss is deprecated, please use "mode=`linear`" '
+ 'instead.')
+ ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps)
+ if mode == 'linear':
+ loss = 1 - ious
+ elif mode == 'square':
+ loss = 1 - ious**2
+ elif mode == 'log':
+ loss = -ious.log()
+ else:
+ raise NotImplementedError
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3):
+ """BIoULoss.
+
+ This is an implementation of paper
+ `Improving Object Localization with Fitness NMS and Bounded IoU Loss.
+ `_.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes.
+ target (torch.Tensor): Target bboxes.
+ beta (float): beta parameter in smoothl1.
+ eps (float): eps to avoid NaN.
+ """
+ pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5
+ pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5
+ pred_w = pred[:, 2] - pred[:, 0]
+ pred_h = pred[:, 3] - pred[:, 1]
+ with torch.no_grad():
+ target_ctrx = (target[:, 0] + target[:, 2]) * 0.5
+ target_ctry = (target[:, 1] + target[:, 3]) * 0.5
+ target_w = target[:, 2] - target[:, 0]
+ target_h = target[:, 3] - target[:, 1]
+
+ dx = target_ctrx - pred_ctrx
+ dy = target_ctry - pred_ctry
+
+ loss_dx = 1 - torch.max(
+ (target_w - 2 * dx.abs()) /
+ (target_w + 2 * dx.abs() + eps), torch.zeros_like(dx))
+ loss_dy = 1 - torch.max(
+ (target_h - 2 * dy.abs()) /
+ (target_h + 2 * dy.abs() + eps), torch.zeros_like(dy))
+ loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w /
+ (target_w + eps))
+ loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h /
+ (target_h + eps))
+ # view(..., -1) does not work for empty tensor
+ loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh],
+ dim=-1).flatten(1)
+
+ loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta,
+ loss_comb - 0.5 * beta)
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def giou_loss(pred, target, eps=1e-7):
+ r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
+ Box Regression `_.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (torch.Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Eps to avoid log(0).
+
+ Return:
+ Tensor: Loss tensor.
+ """
+ gious = bbox_overlaps(pred, target, mode='giou', is_aligned=True, eps=eps)
+ loss = 1 - gious
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def diou_loss(pred, target, eps=1e-7):
+ r"""`Implementation of Distance-IoU Loss: Faster and Better
+ Learning for Bounding Box Regression, https://arxiv.org/abs/1911.08287`_.
+
+ Code is modified from https://github.com/Zzh-tju/DIoU.
+
+ Args:
+ pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Eps to avoid log(0).
+ Return:
+ Tensor: Loss tensor.
+ """
+ # overlap
+ lt = torch.max(pred[:, :2], target[:, :2])
+ rb = torch.min(pred[:, 2:], target[:, 2:])
+ wh = (rb - lt).clamp(min=0)
+ overlap = wh[:, 0] * wh[:, 1]
+
+ # union
+ ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
+ ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
+ union = ap + ag - overlap + eps
+
+ # IoU
+ ious = overlap / union
+
+ # enclose area
+ enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
+ enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
+ enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
+
+ cw = enclose_wh[:, 0]
+ ch = enclose_wh[:, 1]
+
+ c2 = cw**2 + ch**2 + eps
+
+ b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
+ b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
+ b2_x1, b2_y1 = target[:, 0], target[:, 1]
+ b2_x2, b2_y2 = target[:, 2], target[:, 3]
+
+ left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
+ right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
+ rho2 = left + right
+
+ # DIoU
+ dious = ious - rho2 / c2
+ loss = 1 - dious
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def ciou_loss(pred, target, eps=1e-7):
+ r"""`Implementation of paper `Enhancing Geometric Factors into
+ Model Learning and Inference for Object Detection and Instance
+ Segmentation `_.
+
+ Code is modified from https://github.com/Zzh-tju/CIoU.
+
+ Args:
+ pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Eps to avoid log(0).
+ Return:
+ Tensor: Loss tensor.
+ """
+ # overlap
+ lt = torch.max(pred[:, :2], target[:, :2])
+ rb = torch.min(pred[:, 2:], target[:, 2:])
+ wh = (rb - lt).clamp(min=0)
+ overlap = wh[:, 0] * wh[:, 1]
+
+ # union
+ ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
+ ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
+ union = ap + ag - overlap + eps
+
+ # IoU
+ ious = overlap / union
+
+ # enclose area
+ enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
+ enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
+ enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
+
+ cw = enclose_wh[:, 0]
+ ch = enclose_wh[:, 1]
+
+ c2 = cw**2 + ch**2 + eps
+
+ b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
+ b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
+ b2_x1, b2_y1 = target[:, 0], target[:, 1]
+ b2_x2, b2_y2 = target[:, 2], target[:, 3]
+
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
+
+ left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
+ right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
+ rho2 = left + right
+
+ factor = 4 / math.pi**2
+ v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
+
+ with torch.no_grad():
+ alpha = (ious > 0.5).float() * v / (1 - ious + v)
+
+ # CIoU
+ cious = ious - (rho2 / c2 + alpha * v)
+ loss = 1 - cious.clamp(min=-1.0, max=1.0)
+ return loss
+
+
+@LOSSES.register_module()
+class IoULoss(nn.Module):
+ """IoULoss.
+
+ Computing the IoU loss between a set of predicted bboxes and target bboxes.
+
+ Args:
+ linear (bool): If True, use linear scale of loss else determined
+ by mode. Default: False.
+ eps (float): Eps to avoid log(0).
+ reduction (str): Options are "none", "mean" and "sum".
+ loss_weight (float): Weight of loss.
+ mode (str): Loss scaling mode, including "linear", "square", and "log".
+ Default: 'log'
+ """
+
+ def __init__(self,
+ linear=False,
+ eps=1e-6,
+ reduction='mean',
+ loss_weight=1.0,
+ mode='log'):
+ super(IoULoss, self).__init__()
+ assert mode in ['linear', 'square', 'log']
+ if linear:
+ mode = 'linear'
+ warnings.warn('DeprecationWarning: Setting "linear=True" in '
+ 'IOULoss is deprecated, please use "mode=`linear`" '
+ 'instead.')
+ self.mode = mode
+ self.linear = linear
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None. Options are "none", "mean" and "sum".
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if (weight is not None) and (not torch.any(weight > 0)) and (
+ reduction != 'none'):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # iou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * iou_loss(
+ pred,
+ target,
+ weight,
+ mode=self.mode,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+@LOSSES.register_module()
+class BoundedIoULoss(nn.Module):
+
+ def __init__(self, beta=0.2, eps=1e-3, reduction='mean', loss_weight=1.0):
+ super(BoundedIoULoss, self).__init__()
+ self.beta = beta
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss = self.loss_weight * bounded_iou_loss(
+ pred,
+ target,
+ weight,
+ beta=self.beta,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+@LOSSES.register_module()
+class GIoULoss(nn.Module):
+
+ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+ super(GIoULoss, self).__init__()
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # giou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * giou_loss(
+ pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+@LOSSES.register_module()
+class DIoULoss(nn.Module):
+
+ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+ super(DIoULoss, self).__init__()
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # giou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * diou_loss(
+ pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+@LOSSES.register_module()
+class CIoULoss(nn.Module):
+
+ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+ super(CIoULoss, self).__init__()
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ if pred.dim() == weight.dim() + 1:
+ weight = weight.unsqueeze(1)
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # giou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * ciou_loss(
+ pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
diff --git a/mmdet/models/losses/kd_loss.py b/mmdet/models/losses/kd_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..75c19355fee4e20c03e553f2794e5d63446ad69b
--- /dev/null
+++ b/mmdet/models/losses/kd_loss.py
@@ -0,0 +1,88 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def knowledge_distillation_kl_div_loss(pred,
+ soft_label,
+ T,
+ detach_target=True):
+ r"""Loss function for knowledge distilling using KL divergence.
+
+ Args:
+ pred (Tensor): Predicted logits with shape (N, n + 1).
+ soft_label (Tensor): Target logits with shape (N, N + 1).
+ T (int): Temperature for distillation.
+ detach_target (bool): Remove soft_label from automatic differentiation
+
+ Returns:
+ torch.Tensor: Loss tensor with shape (N,).
+ """
+ assert pred.size() == soft_label.size()
+ target = F.softmax(soft_label / T, dim=1)
+ if detach_target:
+ target = target.detach()
+
+ kd_loss = F.kl_div(
+ F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * (
+ T * T)
+
+ return kd_loss
+
+
+@LOSSES.register_module()
+class KnowledgeDistillationKLDivLoss(nn.Module):
+ """Loss function for knowledge distilling using KL divergence.
+
+ Args:
+ reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
+ loss_weight (float): Loss weight of current loss.
+ T (int): Temperature for distillation.
+ """
+
+ def __init__(self, reduction='mean', loss_weight=1.0, T=10):
+ super(KnowledgeDistillationKLDivLoss, self).__init__()
+ assert T >= 1
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.T = T
+
+ def forward(self,
+ pred,
+ soft_label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (Tensor): Predicted logits with shape (N, n + 1).
+ soft_label (Tensor): Target logits with shape (N, N + 1).
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+
+ loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss(
+ pred,
+ soft_label,
+ weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ T=self.T)
+
+ return loss_kd
diff --git a/mmdet/models/losses/mse_loss.py b/mmdet/models/losses/mse_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ebd161f007a8cc6dea7b5cba1aac38ec342e3d2
--- /dev/null
+++ b/mmdet/models/losses/mse_loss.py
@@ -0,0 +1,57 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@weighted_loss
+def mse_loss(pred, target):
+ """Wrapper of mse loss."""
+ return F.mse_loss(pred, target, reduction='none')
+
+
+@LOSSES.register_module()
+class MSELoss(nn.Module):
+ """MSELoss.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super().__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function of loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): Weight of the loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss = self.loss_weight * mse_loss(
+ pred, target, weight, reduction=reduction, avg_factor=avg_factor)
+ return loss
diff --git a/mmdet/models/losses/pisa_loss.py b/mmdet/models/losses/pisa_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6afea0e5d27ad5ca122a4d16e3fb627a92460772
--- /dev/null
+++ b/mmdet/models/losses/pisa_loss.py
@@ -0,0 +1,184 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch
+
+from mmdet.core import bbox_overlaps
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def isr_p(cls_score,
+ bbox_pred,
+ bbox_targets,
+ rois,
+ sampling_results,
+ loss_cls,
+ bbox_coder,
+ k=2,
+ bias=0,
+ num_class=80):
+ """Importance-based Sample Reweighting (ISR_P), positive part.
+
+ Args:
+ cls_score (Tensor): Predicted classification scores.
+ bbox_pred (Tensor): Predicted bbox deltas.
+ bbox_targets (tuple[Tensor]): A tuple of bbox targets, the are
+ labels, label_weights, bbox_targets, bbox_weights, respectively.
+ rois (Tensor): Anchors (single_stage) in shape (n, 4) or RoIs
+ (two_stage) in shape (n, 5).
+ sampling_results (obj): Sampling results.
+ loss_cls (func): Classification loss func of the head.
+ bbox_coder (obj): BBox coder of the head.
+ k (float): Power of the non-linear mapping.
+ bias (float): Shift of the non-linear mapping.
+ num_class (int): Number of classes, default: 80.
+
+ Return:
+ tuple([Tensor]): labels, imp_based_label_weights, bbox_targets,
+ bbox_target_weights
+ """
+
+ labels, label_weights, bbox_targets, bbox_weights = bbox_targets
+ pos_label_inds = ((labels >= 0) &
+ (labels < num_class)).nonzero().reshape(-1)
+ pos_labels = labels[pos_label_inds]
+
+ # if no positive samples, return the original targets
+ num_pos = float(pos_label_inds.size(0))
+ if num_pos == 0:
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ # merge pos_assigned_gt_inds of per image to a single tensor
+ gts = list()
+ last_max_gt = 0
+ for i in range(len(sampling_results)):
+ gt_i = sampling_results[i].pos_assigned_gt_inds
+ gts.append(gt_i + last_max_gt)
+ if len(gt_i) != 0:
+ last_max_gt = gt_i.max() + 1
+ gts = torch.cat(gts)
+ assert len(gts) == num_pos
+
+ cls_score = cls_score.detach()
+ bbox_pred = bbox_pred.detach()
+
+ # For single stage detectors, rois here indicate anchors, in shape (N, 4)
+ # For two stage detectors, rois are in shape (N, 5)
+ if rois.size(-1) == 5:
+ pos_rois = rois[pos_label_inds][:, 1:]
+ else:
+ pos_rois = rois[pos_label_inds]
+
+ if bbox_pred.size(-1) > 4:
+ bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
+ pos_delta_pred = bbox_pred[pos_label_inds, pos_labels].view(-1, 4)
+ else:
+ pos_delta_pred = bbox_pred[pos_label_inds].view(-1, 4)
+
+ # compute iou of the predicted bbox and the corresponding GT
+ pos_delta_target = bbox_targets[pos_label_inds].view(-1, 4)
+ pos_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_pred)
+ target_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_target)
+ ious = bbox_overlaps(pos_bbox_pred, target_bbox_pred, is_aligned=True)
+
+ pos_imp_weights = label_weights[pos_label_inds]
+ # Two steps to compute IoU-HLR. Samples are first sorted by IoU locally,
+ # then sorted again within the same-rank group
+ max_l_num = pos_labels.bincount().max()
+ for label in pos_labels.unique():
+ l_inds = (pos_labels == label).nonzero().view(-1)
+ l_gts = gts[l_inds]
+ for t in l_gts.unique():
+ t_inds = l_inds[l_gts == t]
+ t_ious = ious[t_inds]
+ _, t_iou_rank_idx = t_ious.sort(descending=True)
+ _, t_iou_rank = t_iou_rank_idx.sort()
+ ious[t_inds] += max_l_num - t_iou_rank.float()
+ l_ious = ious[l_inds]
+ _, l_iou_rank_idx = l_ious.sort(descending=True)
+ _, l_iou_rank = l_iou_rank_idx.sort() # IoU-HLR
+ # linearly map HLR to label weights
+ pos_imp_weights[l_inds] *= (max_l_num - l_iou_rank.float()) / max_l_num
+
+ pos_imp_weights = (bias + pos_imp_weights * (1 - bias)).pow(k)
+
+ # normalize to make the new weighted loss value equal to the original loss
+ pos_loss_cls = loss_cls(
+ cls_score[pos_label_inds], pos_labels, reduction_override='none')
+ if pos_loss_cls.dim() > 1:
+ ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds][:,
+ None]
+ new_pos_loss_cls = pos_loss_cls * pos_imp_weights[:, None]
+ else:
+ ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds]
+ new_pos_loss_cls = pos_loss_cls * pos_imp_weights
+ pos_loss_cls_ratio = ori_pos_loss_cls.sum() / new_pos_loss_cls.sum()
+ pos_imp_weights = pos_imp_weights * pos_loss_cls_ratio
+ label_weights[pos_label_inds] = pos_imp_weights
+
+ bbox_targets = labels, label_weights, bbox_targets, bbox_weights
+ return bbox_targets
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def carl_loss(cls_score,
+ labels,
+ bbox_pred,
+ bbox_targets,
+ loss_bbox,
+ k=1,
+ bias=0.2,
+ avg_factor=None,
+ sigmoid=False,
+ num_class=80):
+ """Classification-Aware Regression Loss (CARL).
+
+ Args:
+ cls_score (Tensor): Predicted classification scores.
+ labels (Tensor): Targets of classification.
+ bbox_pred (Tensor): Predicted bbox deltas.
+ bbox_targets (Tensor): Target of bbox regression.
+ loss_bbox (func): Regression loss func of the head.
+ bbox_coder (obj): BBox coder of the head.
+ k (float): Power of the non-linear mapping.
+ bias (float): Shift of the non-linear mapping.
+ avg_factor (int): Average factor used in regression loss.
+ sigmoid (bool): Activation of the classification score.
+ num_class (int): Number of classes, default: 80.
+
+ Return:
+ dict: CARL loss dict.
+ """
+ pos_label_inds = ((labels >= 0) &
+ (labels < num_class)).nonzero().reshape(-1)
+ if pos_label_inds.numel() == 0:
+ return dict(loss_carl=cls_score.sum()[None] * 0.)
+ pos_labels = labels[pos_label_inds]
+
+ # multiply pos_cls_score with the corresponding bbox weight
+ # and remain gradient
+ if sigmoid:
+ pos_cls_score = cls_score.sigmoid()[pos_label_inds, pos_labels]
+ else:
+ pos_cls_score = cls_score.softmax(-1)[pos_label_inds, pos_labels]
+ carl_loss_weights = (bias + (1 - bias) * pos_cls_score).pow(k)
+
+ # normalize carl_loss_weight to make its sum equal to num positive
+ num_pos = float(pos_cls_score.size(0))
+ weight_ratio = num_pos / carl_loss_weights.sum()
+ carl_loss_weights *= weight_ratio
+
+ if avg_factor is None:
+ avg_factor = bbox_targets.size(0)
+ # if is class agnostic, bbox pred is in shape (N, 4)
+ # otherwise, bbox pred is in shape (N, #classes, 4)
+ if bbox_pred.size(-1) > 4:
+ bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
+ pos_bbox_preds = bbox_pred[pos_label_inds, pos_labels]
+ else:
+ pos_bbox_preds = bbox_pred[pos_label_inds]
+ ori_loss_reg = loss_bbox(
+ pos_bbox_preds,
+ bbox_targets[pos_label_inds],
+ reduction_override='none') / avg_factor
+ loss_carl = (ori_loss_reg * carl_loss_weights[:, None]).sum()
+ return dict(loss_carl=loss_carl[None])
diff --git a/mmdet/models/losses/seesaw_loss.py b/mmdet/models/losses/seesaw_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..01040472d85a79fbb1f78fecb403057c40703f0c
--- /dev/null
+++ b/mmdet/models/losses/seesaw_loss.py
@@ -0,0 +1,262 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .accuracy import accuracy
+from .cross_entropy_loss import cross_entropy
+from .utils import weight_reduce_loss
+
+
+def seesaw_ce_loss(cls_score,
+ labels,
+ label_weights,
+ cum_samples,
+ num_classes,
+ p,
+ q,
+ eps,
+ reduction='mean',
+ avg_factor=None):
+ """Calculate the Seesaw CrossEntropy loss.
+
+ Args:
+ cls_score (torch.Tensor): The prediction with shape (N, C),
+ C is the number of classes.
+ labels (torch.Tensor): The learning label of the prediction.
+ label_weights (torch.Tensor): Sample-wise loss weight.
+ cum_samples (torch.Tensor): Cumulative samples for each category.
+ num_classes (int): The number of classes.
+ p (float): The ``p`` in the mitigation factor.
+ q (float): The ``q`` in the compenstation factor.
+ eps (float): The minimal value of divisor to smooth
+ the computation of compensation factor
+ reduction (str, optional): The method used to reduce the loss.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert cls_score.size(-1) == num_classes
+ assert len(cum_samples) == num_classes
+
+ onehot_labels = F.one_hot(labels, num_classes)
+ seesaw_weights = cls_score.new_ones(onehot_labels.size())
+
+ # mitigation factor
+ if p > 0:
+ sample_ratio_matrix = cum_samples[None, :].clamp(
+ min=1) / cum_samples[:, None].clamp(min=1)
+ index = (sample_ratio_matrix < 1.0).float()
+ sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index)
+ mitigation_factor = sample_weights[labels.long(), :]
+ seesaw_weights = seesaw_weights * mitigation_factor
+
+ # compensation factor
+ if q > 0:
+ scores = F.softmax(cls_score.detach(), dim=1)
+ self_scores = scores[
+ torch.arange(0, len(scores)).to(scores.device).long(),
+ labels.long()]
+ score_matrix = scores / self_scores[:, None].clamp(min=eps)
+ index = (score_matrix > 1.0).float()
+ compensation_factor = score_matrix.pow(q) * index + (1 - index)
+ seesaw_weights = seesaw_weights * compensation_factor
+
+ cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))
+
+ loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')
+
+ if label_weights is not None:
+ label_weights = label_weights.float()
+ loss = weight_reduce_loss(
+ loss, weight=label_weights, reduction=reduction, avg_factor=avg_factor)
+ return loss
+
+
+@LOSSES.register_module()
+class SeesawLoss(nn.Module):
+ """
+ Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021)
+ arXiv: https://arxiv.org/abs/2008.10032
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Only False is supported.
+ p (float, optional): The ``p`` in the mitigation factor.
+ Defaults to 0.8.
+ q (float, optional): The ``q`` in the compenstation factor.
+ Defaults to 2.0.
+ num_classes (int, optional): The number of classes.
+ Default to 1203 for LVIS v1 dataset.
+ eps (float, optional): The minimal value of divisor to smooth
+ the computation of compensation factor
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ return_dict (bool, optional): Whether return the losses as a dict.
+ Default to True.
+ """
+
+ def __init__(self,
+ use_sigmoid=False,
+ p=0.8,
+ q=2.0,
+ num_classes=1203,
+ eps=1e-2,
+ reduction='mean',
+ loss_weight=1.0,
+ return_dict=True):
+ super(SeesawLoss, self).__init__()
+ assert not use_sigmoid
+ self.use_sigmoid = False
+ self.p = p
+ self.q = q
+ self.num_classes = num_classes
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.return_dict = return_dict
+
+ # 0 for pos, 1 for neg
+ self.cls_criterion = seesaw_ce_loss
+
+ # cumulative samples for each category
+ self.register_buffer(
+ 'cum_samples',
+ torch.zeros(self.num_classes + 1, dtype=torch.float))
+
+ # custom output channels of the classifier
+ self.custom_cls_channels = True
+ # custom activation of cls_score
+ self.custom_activation = True
+ # custom accuracy of the classsifier
+ self.custom_accuracy = True
+
+ def _split_cls_score(self, cls_score):
+ # split cls_score to cls_score_classes and cls_score_objectness
+ assert cls_score.size(-1) == self.num_classes + 2
+ cls_score_classes = cls_score[..., :-2]
+ cls_score_objectness = cls_score[..., -2:]
+ return cls_score_classes, cls_score_objectness
+
+ def get_cls_channels(self, num_classes):
+ """Get custom classification channels.
+
+ Args:
+ num_classes (int): The number of classes.
+
+ Returns:
+ int: The custom classification channels.
+ """
+ assert num_classes == self.num_classes
+ return num_classes + 2
+
+ def get_activation(self, cls_score):
+ """Get custom activation of cls_score.
+
+ Args:
+ cls_score (torch.Tensor): The prediction with shape (N, C + 2).
+
+ Returns:
+ torch.Tensor: The custom activation of cls_score with shape
+ (N, C + 1).
+ """
+ cls_score_classes, cls_score_objectness = self._split_cls_score(
+ cls_score)
+ score_classes = F.softmax(cls_score_classes, dim=-1)
+ score_objectness = F.softmax(cls_score_objectness, dim=-1)
+ score_pos = score_objectness[..., [0]]
+ score_neg = score_objectness[..., [1]]
+ score_classes = score_classes * score_pos
+ scores = torch.cat([score_classes, score_neg], dim=-1)
+ return scores
+
+ def get_accuracy(self, cls_score, labels):
+ """Get custom accuracy w.r.t. cls_score and labels.
+
+ Args:
+ cls_score (torch.Tensor): The prediction with shape (N, C + 2).
+ labels (torch.Tensor): The learning label of the prediction.
+
+ Returns:
+ Dict [str, torch.Tensor]: The accuracy for objectness and classes,
+ respectively.
+ """
+ pos_inds = labels < self.num_classes
+ obj_labels = (labels == self.num_classes).long()
+ cls_score_classes, cls_score_objectness = self._split_cls_score(
+ cls_score)
+ acc_objectness = accuracy(cls_score_objectness, obj_labels)
+ acc_classes = accuracy(cls_score_classes[pos_inds], labels[pos_inds])
+ acc = dict()
+ acc['acc_objectness'] = acc_objectness
+ acc['acc_classes'] = acc_classes
+ return acc
+
+ def forward(self,
+ cls_score,
+ labels,
+ label_weights=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ cls_score (torch.Tensor): The prediction with shape (N, C + 2).
+ labels (torch.Tensor): The learning label of the prediction.
+ label_weights (torch.Tensor, optional): Sample-wise loss weight.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ Returns:
+ torch.Tensor | Dict [str, torch.Tensor]:
+ if return_dict == False: The calculated loss |
+ if return_dict == True: The dict of calculated losses
+ for objectness and classes, respectively.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ assert cls_score.size(-1) == self.num_classes + 2
+ pos_inds = labels < self.num_classes
+ # 0 for pos, 1 for neg
+ obj_labels = (labels == self.num_classes).long()
+
+ # accumulate the samples for each category
+ unique_labels = labels.unique()
+ for u_l in unique_labels:
+ inds_ = labels == u_l.item()
+ self.cum_samples[u_l] += inds_.sum()
+
+ if label_weights is not None:
+ label_weights = label_weights.float()
+ else:
+ label_weights = labels.new_ones(labels.size(), dtype=torch.float)
+
+ cls_score_classes, cls_score_objectness = self._split_cls_score(
+ cls_score)
+ # calculate loss_cls_classes (only need pos samples)
+ if pos_inds.sum() > 0:
+ loss_cls_classes = self.loss_weight * self.cls_criterion(
+ cls_score_classes[pos_inds], labels[pos_inds],
+ label_weights[pos_inds], self.cum_samples[:self.num_classes],
+ self.num_classes, self.p, self.q, self.eps, reduction,
+ avg_factor)
+ else:
+ loss_cls_classes = cls_score_classes[pos_inds].sum()
+ # calculate loss_cls_objectness
+ loss_cls_objectness = self.loss_weight * cross_entropy(
+ cls_score_objectness, obj_labels, label_weights, reduction,
+ avg_factor)
+
+ if self.return_dict:
+ loss_cls = dict()
+ loss_cls['loss_cls_objectness'] = loss_cls_objectness
+ loss_cls['loss_cls_classes'] = loss_cls_classes
+ else:
+ loss_cls = loss_cls_classes + loss_cls_objectness
+ return loss_cls
diff --git a/mmdet/models/losses/smooth_l1_loss.py b/mmdet/models/losses/smooth_l1_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..551174672933cb0d23c93cbe22053e3910a9dcfb
--- /dev/null
+++ b/mmdet/models/losses/smooth_l1_loss.py
@@ -0,0 +1,146 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch
+import torch.nn as nn
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def smooth_l1_loss(pred, target, beta=1.0):
+ """Smooth L1 loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ beta (float, optional): The threshold in the piecewise function.
+ Defaults to 1.0.
+
+ Returns:
+ torch.Tensor: Calculated loss
+ """
+ assert beta > 0
+ if target.numel() == 0:
+ return pred.sum() * 0
+
+ assert pred.size() == target.size()
+ diff = torch.abs(pred - target)
+ loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
+ diff - 0.5 * beta)
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def l1_loss(pred, target):
+ """L1 loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+
+ Returns:
+ torch.Tensor: Calculated loss
+ """
+ if target.numel() == 0:
+ return pred.sum() * 0
+
+ assert pred.size() == target.size()
+ loss = torch.abs(pred - target)
+ return loss
+
+
+@LOSSES.register_module()
+class SmoothL1Loss(nn.Module):
+ """Smooth L1 loss.
+
+ Args:
+ beta (float, optional): The threshold in the piecewise function.
+ Defaults to 1.0.
+ reduction (str, optional): The method to reduce the loss.
+ Options are "none", "mean" and "sum". Defaults to "mean".
+ loss_weight (float, optional): The weight of loss.
+ """
+
+ def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0):
+ super(SmoothL1Loss, self).__init__()
+ self.beta = beta
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_bbox = self.loss_weight * smooth_l1_loss(
+ pred,
+ target,
+ weight,
+ beta=self.beta,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_bbox
+
+
+@LOSSES.register_module()
+class L1Loss(nn.Module):
+ """L1 loss.
+
+ Args:
+ reduction (str, optional): The method to reduce the loss.
+ Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of loss.
+ """
+
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super(L1Loss, self).__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_bbox = self.loss_weight * l1_loss(
+ pred, target, weight, reduction=reduction, avg_factor=avg_factor)
+ return loss_bbox
diff --git a/mmdet/models/losses/utils.py b/mmdet/models/losses/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..778237ebfd57160a3533d6d82b3d8fd7a36bf481
--- /dev/null
+++ b/mmdet/models/losses/utils.py
@@ -0,0 +1,105 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+
+import mmcv
+import torch
+import torch.nn.functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are "none", "mean" and "sum".
+
+ Return:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ elif reduction_enum == 2:
+ return loss.sum()
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights.
+ reduction (str): Same as built-in losses of PyTorch.
+ avg_factor (float): Average factor when computing the mean of losses.
+
+ Returns:
+ Tensor: Processed loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ loss = loss * weight
+
+ # if avg_factor is not specified, just reduce the loss
+ if avg_factor is None:
+ loss = reduce_loss(loss, reduction)
+ else:
+ # if reduction is mean, then average the loss by avg_factor
+ if reduction == 'mean':
+ # Avoid causing ZeroDivisionError when avg_factor is 0.0,
+ # i.e., all labels of an image belong to ignore index.
+ eps = torch.finfo(torch.float32).eps
+ loss = loss.sum() / (avg_factor + eps)
+ # if reduction is 'none', then do nothing, otherwise raise an error
+ elif reduction != 'none':
+ raise ValueError('avg_factor can not be used with reduction="sum"')
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ avg_factor=None, **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, avg_factor=2)
+ tensor(1.5000)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred,
+ target,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+ return wrapper
diff --git a/mmdet/models/losses/varifocal_loss.py b/mmdet/models/losses/varifocal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..42f0eef9c62e2a66b97914cf8b43a25112c4e79f
--- /dev/null
+++ b/mmdet/models/losses/varifocal_loss.py
@@ -0,0 +1,134 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weight_reduce_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def varifocal_loss(pred,
+ target,
+ weight=None,
+ alpha=0.75,
+ gamma=2.0,
+ iou_weighted=True,
+ reduction='mean',
+ avg_factor=None):
+ """`Varifocal Loss `_
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the
+ number of classes
+ target (torch.Tensor): The learning target of the iou-aware
+ classification score with shape (N, C), C is the number of classes.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ alpha (float, optional): A balance factor for the negative part of
+ Varifocal Loss, which is different from the alpha of Focal Loss.
+ Defaults to 0.75.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ iou_weighted (bool, optional): Whether to weight the loss of the
+ positive example with the iou target. Defaults to True.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and
+ "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ # pred and target should be of the same size
+ assert pred.size() == target.size()
+ pred_sigmoid = pred.sigmoid()
+ target = target.type_as(pred)
+ if iou_weighted:
+ focal_weight = target * (target > 0.0).float() + \
+ alpha * (pred_sigmoid - target).abs().pow(gamma) * \
+ (target <= 0.0).float()
+ else:
+ focal_weight = (target > 0.0).float() + \
+ alpha * (pred_sigmoid - target).abs().pow(gamma) * \
+ (target <= 0.0).float()
+ loss = F.binary_cross_entropy_with_logits(
+ pred, target, reduction='none') * focal_weight
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+@LOSSES.register_module()
+class VarifocalLoss(nn.Module):
+
+ def __init__(self,
+ use_sigmoid=True,
+ alpha=0.75,
+ gamma=2.0,
+ iou_weighted=True,
+ reduction='mean',
+ loss_weight=1.0):
+ """`Varifocal Loss `_
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction is
+ used for sigmoid or softmax. Defaults to True.
+ alpha (float, optional): A balance factor for the negative part of
+ Varifocal Loss, which is different from the alpha of Focal
+ Loss. Defaults to 0.75.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ iou_weighted (bool, optional): Whether to weight the loss of the
+ positive examples with the iou target. Defaults to True.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and
+ "sum".
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+ """
+ super(VarifocalLoss, self).__init__()
+ assert use_sigmoid is True, \
+ 'Only sigmoid varifocal loss supported now.'
+ assert alpha >= 0.0
+ self.use_sigmoid = use_sigmoid
+ self.alpha = alpha
+ self.gamma = gamma
+ self.iou_weighted = iou_weighted
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.use_sigmoid:
+ loss_cls = self.loss_weight * varifocal_loss(
+ pred,
+ target,
+ weight,
+ alpha=self.alpha,
+ gamma=self.gamma,
+ iou_weighted=self.iou_weighted,
+ reduction=reduction,
+ avg_factor=avg_factor)
+ else:
+ raise NotImplementedError
+ return loss_cls
diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2fa823fb35fdd90c07065cc93238d08385ce8b
--- /dev/null
+++ b/mmdet/models/necks/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .bfp import BFP
+from .channel_mapper import ChannelMapper
+from .ct_resnet_neck import CTResNetNeck
+from .dilated_encoder import DilatedEncoder
+from .dyhead import DyHead
+from .fpg import FPG
+from .fpn import FPN
+from .fpn_carafe import FPN_CARAFE
+from .hrfpn import HRFPN
+from .nas_fpn import NASFPN
+from .nasfcos_fpn import NASFCOS_FPN
+from .pafpn import PAFPN
+from .rfp import RFP
+from .ssd_neck import SSDNeck
+from .yolo_neck import YOLOV3Neck
+from .yolox_pafpn import YOLOXPAFPN
+
+__all__ = [
+ 'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
+ 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder',
+ 'CTResNetNeck', 'SSDNeck', 'YOLOXPAFPN', 'DyHead'
+]
diff --git a/mmdet/models/necks/bfp.py b/mmdet/models/necks/bfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fdfa036ddf693bbb7fbf77fe2089c2f98a2bb93
--- /dev/null
+++ b/mmdet/models/necks/bfp.py
@@ -0,0 +1,102 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.cnn.bricks import NonLocal2d
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class BFP(BaseModule):
+ """BFP (Balanced Feature Pyramids)
+
+ BFP takes multi-level features as inputs and gather them into a single one,
+ then refine the gathered feature and scatter the refined results to
+ multi-level features. This module is used in Libra R-CNN (CVPR 2019), see
+ the paper `Libra R-CNN: Towards Balanced Learning for Object Detection
+ `_ for details.
+
+ Args:
+ in_channels (int): Number of input channels (feature maps of all levels
+ should have the same channels).
+ num_levels (int): Number of input feature levels.
+ conv_cfg (dict): The config dict for convolution layers.
+ norm_cfg (dict): The config dict for normalization layers.
+ refine_level (int): Index of integration and refine level of BSF in
+ multi-level features from bottom to top.
+ refine_type (str): Type of the refine op, currently support
+ [None, 'conv', 'non_local'].
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ num_levels,
+ refine_level=2,
+ refine_type=None,
+ conv_cfg=None,
+ norm_cfg=None,
+ init_cfg=dict(
+ type='Xavier', layer='Conv2d', distribution='uniform')):
+ super(BFP, self).__init__(init_cfg)
+ assert refine_type in [None, 'conv', 'non_local']
+
+ self.in_channels = in_channels
+ self.num_levels = num_levels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.refine_level = refine_level
+ self.refine_type = refine_type
+ assert 0 <= self.refine_level < self.num_levels
+
+ if self.refine_type == 'conv':
+ self.refine = ConvModule(
+ self.in_channels,
+ self.in_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ elif self.refine_type == 'non_local':
+ self.refine = NonLocal2d(
+ self.in_channels,
+ reduction=1,
+ use_scale=False,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == self.num_levels
+
+ # step 1: gather multi-level features by resize and average
+ feats = []
+ gather_size = inputs[self.refine_level].size()[2:]
+ for i in range(self.num_levels):
+ if i < self.refine_level:
+ gathered = F.adaptive_max_pool2d(
+ inputs[i], output_size=gather_size)
+ else:
+ gathered = F.interpolate(
+ inputs[i], size=gather_size, mode='nearest')
+ feats.append(gathered)
+
+ bsf = sum(feats) / len(feats)
+
+ # step 2: refine gathered features
+ if self.refine_type is not None:
+ bsf = self.refine(bsf)
+
+ # step 3: scatter refined features to multi-levels by a residual path
+ outs = []
+ for i in range(self.num_levels):
+ out_size = inputs[i].size()[2:]
+ if i < self.refine_level:
+ residual = F.interpolate(bsf, size=out_size, mode='nearest')
+ else:
+ residual = F.adaptive_max_pool2d(bsf, output_size=out_size)
+ outs.append(residual + inputs[i])
+
+ return tuple(outs)
diff --git a/mmdet/models/necks/channel_mapper.py b/mmdet/models/necks/channel_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..774bdb1d7a522583df462fc09177a6a6ee899f17
--- /dev/null
+++ b/mmdet/models/necks/channel_mapper.py
@@ -0,0 +1,100 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class ChannelMapper(BaseModule):
+ r"""Channel Mapper to reduce/increase channels of backbone features.
+
+ This is used to reduce/increase channels of backbone features.
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ kernel_size (int, optional): kernel_size for reducing channels (used
+ at each scale). Default: 3.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ act_cfg (dict, optional): Config dict for activation layer in
+ ConvModule. Default: dict(type='ReLU').
+ num_outs (int, optional): Number of output feature maps. There
+ would be extra_convs when num_outs larger than the length
+ of in_channels.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = ChannelMapper(in_channels, 11, 3).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ num_outs=None,
+ init_cfg=dict(
+ type='Xavier', layer='Conv2d', distribution='uniform')):
+ super(ChannelMapper, self).__init__(init_cfg)
+ assert isinstance(in_channels, list)
+ self.extra_convs = None
+ if num_outs is None:
+ num_outs = len(in_channels)
+ self.convs = nn.ModuleList()
+ for in_channel in in_channels:
+ self.convs.append(
+ ConvModule(
+ in_channel,
+ out_channels,
+ kernel_size,
+ padding=(kernel_size - 1) // 2,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ if num_outs > len(in_channels):
+ self.extra_convs = nn.ModuleList()
+ for i in range(len(in_channels), num_outs):
+ if i == len(in_channels):
+ in_channel = in_channels[-1]
+ else:
+ in_channel = out_channels
+ self.extra_convs.append(
+ ConvModule(
+ in_channel,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.convs)
+ outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
+ if self.extra_convs:
+ for i in range(len(self.extra_convs)):
+ if i == 0:
+ outs.append(self.extra_convs[0](inputs[-1]))
+ else:
+ outs.append(self.extra_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/mmdet/models/necks/ct_resnet_neck.py b/mmdet/models/necks/ct_resnet_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..40eb2685767fbf0f365529eefc160e735608bab5
--- /dev/null
+++ b/mmdet/models/necks/ct_resnet_neck.py
@@ -0,0 +1,94 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule, auto_fp16
+
+from mmdet.models.builder import NECKS
+
+
+@NECKS.register_module()
+class CTResNetNeck(BaseModule):
+ """The neck used in `CenterNet `_ for
+ object classification and box regression.
+
+ Args:
+ in_channel (int): Number of input channels.
+ num_deconv_filters (tuple[int]): Number of filters per stage.
+ num_deconv_kernels (tuple[int]): Number of kernels per stage.
+ use_dcn (bool): If True, use DCNv2. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channel,
+ num_deconv_filters,
+ num_deconv_kernels,
+ use_dcn=True,
+ init_cfg=None):
+ super(CTResNetNeck, self).__init__(init_cfg)
+ assert len(num_deconv_filters) == len(num_deconv_kernels)
+ self.fp16_enabled = False
+ self.use_dcn = use_dcn
+ self.in_channel = in_channel
+ self.deconv_layers = self._make_deconv_layer(num_deconv_filters,
+ num_deconv_kernels)
+
+ def _make_deconv_layer(self, num_deconv_filters, num_deconv_kernels):
+ """use deconv layers to upsample backbone's output."""
+ layers = []
+ for i in range(len(num_deconv_filters)):
+ feat_channel = num_deconv_filters[i]
+ conv_module = ConvModule(
+ self.in_channel,
+ feat_channel,
+ 3,
+ padding=1,
+ conv_cfg=dict(type='DCNv2') if self.use_dcn else None,
+ norm_cfg=dict(type='BN'))
+ layers.append(conv_module)
+ upsample_module = ConvModule(
+ feat_channel,
+ feat_channel,
+ num_deconv_kernels[i],
+ stride=2,
+ padding=1,
+ conv_cfg=dict(type='deconv'),
+ norm_cfg=dict(type='BN'))
+ layers.append(upsample_module)
+ self.in_channel = feat_channel
+
+ return nn.Sequential(*layers)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.ConvTranspose2d):
+ # In order to be consistent with the source code,
+ # reset the ConvTranspose2d initialization parameters
+ m.reset_parameters()
+ # Simulated bilinear upsampling kernel
+ w = m.weight.data
+ f = math.ceil(w.size(2) / 2)
+ c = (2 * f - 1 - f % 2) / (2. * f)
+ for i in range(w.size(2)):
+ for j in range(w.size(3)):
+ w[0, 0, i, j] = \
+ (1 - math.fabs(i / f - c)) * (
+ 1 - math.fabs(j / f - c))
+ for c in range(1, w.size(0)):
+ w[c, 0, :, :] = w[0, 0, :, :]
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ # self.use_dcn is False
+ elif not self.use_dcn and isinstance(m, nn.Conv2d):
+ # In order to be consistent with the source code,
+ # reset the Conv2d initialization parameters
+ m.reset_parameters()
+
+ @auto_fp16()
+ def forward(self, inputs):
+ assert isinstance(inputs, (list, tuple))
+ outs = self.deconv_layers(inputs[-1])
+ return outs,
diff --git a/mmdet/models/necks/dilated_encoder.py b/mmdet/models/necks/dilated_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..79a8f4bb31b3387154a75c5c915df6bc59fc3638
--- /dev/null
+++ b/mmdet/models/necks/dilated_encoder.py
@@ -0,0 +1,109 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import (ConvModule, caffe2_xavier_init, constant_init, is_norm,
+ normal_init)
+from torch.nn import BatchNorm2d
+
+from ..builder import NECKS
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck block for DilatedEncoder used in `YOLOF.
+
+ `.
+
+ The Bottleneck contains three ConvLayers and one residual connection.
+
+ Args:
+ in_channels (int): The number of input channels.
+ mid_channels (int): The number of middle output channels.
+ dilation (int): Dilation rate.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ """
+
+ def __init__(self,
+ in_channels,
+ mid_channels,
+ dilation,
+ norm_cfg=dict(type='BN', requires_grad=True)):
+ super(Bottleneck, self).__init__()
+ self.conv1 = ConvModule(
+ in_channels, mid_channels, 1, norm_cfg=norm_cfg)
+ self.conv2 = ConvModule(
+ mid_channels,
+ mid_channels,
+ 3,
+ padding=dilation,
+ dilation=dilation,
+ norm_cfg=norm_cfg)
+ self.conv3 = ConvModule(
+ mid_channels, in_channels, 1, norm_cfg=norm_cfg)
+
+ def forward(self, x):
+ identity = x
+ out = self.conv1(x)
+ out = self.conv2(out)
+ out = self.conv3(out)
+ out = out + identity
+ return out
+
+
+@NECKS.register_module()
+class DilatedEncoder(nn.Module):
+ """Dilated Encoder for YOLOF `.
+
+ This module contains two types of components:
+ - the original FPN lateral convolution layer and fpn convolution layer,
+ which are 1x1 conv + 3x3 conv
+ - the dilated residual block
+
+ Args:
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ block_mid_channels (int): The number of middle block output channels
+ num_residual_blocks (int): The number of residual blocks.
+ block_dilations (list): The list of residual blocks dilation.
+ """
+
+ def __init__(self, in_channels, out_channels, block_mid_channels,
+ num_residual_blocks, block_dilations):
+ super(DilatedEncoder, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.block_mid_channels = block_mid_channels
+ self.num_residual_blocks = num_residual_blocks
+ self.block_dilations = block_dilations
+ self._init_layers()
+
+ def _init_layers(self):
+ self.lateral_conv = nn.Conv2d(
+ self.in_channels, self.out_channels, kernel_size=1)
+ self.lateral_norm = BatchNorm2d(self.out_channels)
+ self.fpn_conv = nn.Conv2d(
+ self.out_channels, self.out_channels, kernel_size=3, padding=1)
+ self.fpn_norm = BatchNorm2d(self.out_channels)
+ encoder_blocks = []
+ for i in range(self.num_residual_blocks):
+ dilation = self.block_dilations[i]
+ encoder_blocks.append(
+ Bottleneck(
+ self.out_channels,
+ self.block_mid_channels,
+ dilation=dilation))
+ self.dilated_encoder_blocks = nn.Sequential(*encoder_blocks)
+
+ def init_weights(self):
+ caffe2_xavier_init(self.lateral_conv)
+ caffe2_xavier_init(self.fpn_conv)
+ for m in [self.lateral_norm, self.fpn_norm]:
+ constant_init(m, 1)
+ for m in self.dilated_encoder_blocks.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, mean=0, std=0.01)
+ if is_norm(m):
+ constant_init(m, 1)
+
+ def forward(self, feature):
+ out = self.lateral_norm(self.lateral_conv(feature[-1]))
+ out = self.fpn_norm(self.fpn_conv(out))
+ return self.dilated_encoder_blocks(out),
diff --git a/mmdet/models/necks/dyhead.py b/mmdet/models/necks/dyhead.py
new file mode 100644
index 0000000000000000000000000000000000000000..649bb4ca2f46e1e7ec9324083d5f7e7d7ec1ab3f
--- /dev/null
+++ b/mmdet/models/necks/dyhead.py
@@ -0,0 +1,176 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (build_activation_layer, build_norm_layer, constant_init,
+ normal_init)
+from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+from ..utils import DyReLU
+
+# Reference:
+# https://github.com/microsoft/DynamicHead
+# https://github.com/jshilong/SEPC
+
+
+class DyDCNv2(nn.Module):
+ """ModulatedDeformConv2d with normalization layer used in DyHead.
+
+ This module cannot be configured with `conv_cfg=dict(type='DCNv2')`
+ because DyHead calculates offset and mask from middle-level feature.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ stride (int | tuple[int], optional): Stride of the convolution.
+ Default: 1.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: dict(type='GN', num_groups=16, requires_grad=True).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride=1,
+ norm_cfg=dict(type='GN', num_groups=16, requires_grad=True)):
+ super().__init__()
+ self.with_norm = norm_cfg is not None
+ bias = not self.with_norm
+ self.conv = ModulatedDeformConv2d(
+ in_channels, out_channels, 3, stride=stride, padding=1, bias=bias)
+ if self.with_norm:
+ self.norm = build_norm_layer(norm_cfg, out_channels)[1]
+
+ def forward(self, x, offset, mask):
+ """Forward function."""
+ x = self.conv(x.contiguous(), offset.contiguous(), mask)
+ if self.with_norm:
+ x = self.norm(x)
+ return x
+
+
+class DyHeadBlock(nn.Module):
+ """DyHead Block with three types of attention.
+
+ HSigmoid arguments in default act_cfg follow official code, not paper.
+ https://github.com/microsoft/DynamicHead/blob/master/dyhead/dyrelu.py
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ zero_init_offset (bool, optional): Whether to use zero init for
+ `spatial_conv_offset`. Default: True.
+ act_cfg (dict, optional): Config dict for the last activation layer of
+ scale-aware attention. Default: dict(type='HSigmoid', bias=3.0,
+ divisor=6.0).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ zero_init_offset=True,
+ act_cfg=dict(type='HSigmoid', bias=3.0, divisor=6.0)):
+ super().__init__()
+ self.zero_init_offset = zero_init_offset
+ # (offset_x, offset_y, mask) * kernel_size_y * kernel_size_x
+ self.offset_and_mask_dim = 3 * 3 * 3
+ self.offset_dim = 2 * 3 * 3
+
+ self.spatial_conv_high = DyDCNv2(in_channels, out_channels)
+ self.spatial_conv_mid = DyDCNv2(in_channels, out_channels)
+ self.spatial_conv_low = DyDCNv2(in_channels, out_channels, stride=2)
+ self.spatial_conv_offset = nn.Conv2d(
+ in_channels, self.offset_and_mask_dim, 3, padding=1)
+ self.scale_attn_module = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(out_channels, 1, 1),
+ nn.ReLU(inplace=True), build_activation_layer(act_cfg))
+ self.task_attn_module = DyReLU(out_channels)
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ normal_init(m, 0, 0.01)
+ if self.zero_init_offset:
+ constant_init(self.spatial_conv_offset, 0)
+
+ def forward(self, x):
+ """Forward function."""
+ outs = []
+ for level in range(len(x)):
+ # calculate offset and mask of DCNv2 from middle-level feature
+ offset_and_mask = self.spatial_conv_offset(x[level])
+ offset = offset_and_mask[:, :self.offset_dim, :, :]
+ mask = offset_and_mask[:, self.offset_dim:, :, :].sigmoid()
+
+ mid_feat = self.spatial_conv_mid(x[level], offset, mask)
+ sum_feat = mid_feat * self.scale_attn_module(mid_feat)
+ summed_levels = 1
+ if level > 0:
+ low_feat = self.spatial_conv_low(x[level - 1], offset, mask)
+ sum_feat = sum_feat + \
+ low_feat * self.scale_attn_module(low_feat)
+ summed_levels += 1
+ if level < len(x) - 1:
+ # this upsample order is weird, but faster than natural order
+ # https://github.com/microsoft/DynamicHead/issues/25
+ high_feat = F.interpolate(
+ self.spatial_conv_high(x[level + 1], offset, mask),
+ size=x[level].shape[-2:],
+ mode='bilinear',
+ align_corners=True)
+ sum_feat = sum_feat + high_feat * \
+ self.scale_attn_module(high_feat)
+ summed_levels += 1
+ outs.append(self.task_attn_module(sum_feat / summed_levels))
+
+ return outs
+
+
+@NECKS.register_module()
+class DyHead(BaseModule):
+ """DyHead neck consisting of multiple DyHead Blocks.
+
+ See `Dynamic Head: Unifying Object Detection Heads with Attentions
+ `_ for details.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_blocks (int, optional): Number of DyHead Blocks. Default: 6.
+ zero_init_offset (bool, optional): Whether to use zero init for
+ `spatial_conv_offset`. Default: True.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_blocks=6,
+ zero_init_offset=True,
+ init_cfg=None):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_blocks = num_blocks
+ self.zero_init_offset = zero_init_offset
+
+ dyhead_blocks = []
+ for i in range(num_blocks):
+ in_channels = self.in_channels if i == 0 else self.out_channels
+ dyhead_blocks.append(
+ DyHeadBlock(
+ in_channels,
+ self.out_channels,
+ zero_init_offset=zero_init_offset))
+ self.dyhead_blocks = nn.Sequential(*dyhead_blocks)
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert isinstance(inputs, (tuple, list))
+ outs = self.dyhead_blocks(inputs)
+ return tuple(outs)
diff --git a/mmdet/models/necks/fpg.py b/mmdet/models/necks/fpg.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6a2a12ed415bbb517b056d01172a83f6e30833d
--- /dev/null
+++ b/mmdet/models/necks/fpg.py
@@ -0,0 +1,406 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+
+
+class Transition(BaseModule):
+ """Base class for transition.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ """
+
+ def __init__(self, in_channels, out_channels, init_cfg=None):
+ super().__init__(init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ def forward(x):
+ pass
+
+
+class UpInterpolationConv(Transition):
+ """A transition used for up-sampling.
+
+ Up-sample the input by interpolation then refines the feature by
+ a convolution layer.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ scale_factor (int): Up-sampling factor. Default: 2.
+ mode (int): Interpolation mode. Default: nearest.
+ align_corners (bool): Whether align corners when interpolation.
+ Default: None.
+ kernel_size (int): Kernel size for the conv. Default: 3.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ scale_factor=2,
+ mode='nearest',
+ align_corners=None,
+ kernel_size=3,
+ init_cfg=None,
+ **kwargs):
+ super().__init__(in_channels, out_channels, init_cfg)
+ self.mode = mode
+ self.scale_factor = scale_factor
+ self.align_corners = align_corners
+ self.conv = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding=(kernel_size - 1) // 2,
+ **kwargs)
+
+ def forward(self, x):
+ x = F.interpolate(
+ x,
+ scale_factor=self.scale_factor,
+ mode=self.mode,
+ align_corners=self.align_corners)
+ x = self.conv(x)
+ return x
+
+
+class LastConv(Transition):
+ """A transition used for refining the output of the last stage.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_inputs (int): Number of inputs of the FPN features.
+ kernel_size (int): Kernel size for the conv. Default: 3.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_inputs,
+ kernel_size=3,
+ init_cfg=None,
+ **kwargs):
+ super().__init__(in_channels, out_channels, init_cfg)
+ self.num_inputs = num_inputs
+ self.conv_out = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding=(kernel_size - 1) // 2,
+ **kwargs)
+
+ def forward(self, inputs):
+ assert len(inputs) == self.num_inputs
+ return self.conv_out(inputs[-1])
+
+
+@NECKS.register_module()
+class FPG(BaseModule):
+ """FPG.
+
+ Implementation of `Feature Pyramid Grids (FPG)
+ `_.
+ This implementation only gives the basic structure stated in the paper.
+ But users can implement different type of transitions to fully explore the
+ the potential power of the structure of FPG.
+
+ Args:
+ in_channels (int): Number of input channels (feature maps of all levels
+ should have the same channels).
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ stack_times (int): The number of times the pyramid architecture will
+ be stacked.
+ paths (list[str]): Specify the path order of each stack level.
+ Each element in the list should be either 'bu' (bottom-up) or
+ 'td' (top-down).
+ inter_channels (int): Number of inter channels.
+ same_up_trans (dict): Transition that goes down at the same stage.
+ same_down_trans (dict): Transition that goes up at the same stage.
+ across_lateral_trans (dict): Across-pathway same-stage
+ across_down_trans (dict): Across-pathway bottom-up connection.
+ across_up_trans (dict): Across-pathway top-down connection.
+ across_skip_trans (dict): Across-pathway skip connection.
+ output_trans (dict): Transition that trans the output of the
+ last stage.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool): It decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ transition_types = {
+ 'conv': ConvModule,
+ 'interpolation_conv': UpInterpolationConv,
+ 'last_conv': LastConv,
+ }
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ stack_times,
+ paths,
+ inter_channels=None,
+ same_down_trans=None,
+ same_up_trans=dict(
+ type='conv', kernel_size=3, stride=2, padding=1),
+ across_lateral_trans=dict(type='conv', kernel_size=1),
+ across_down_trans=dict(type='conv', kernel_size=3),
+ across_up_trans=None,
+ across_skip_trans=dict(type='identity'),
+ output_trans=dict(type='last_conv', kernel_size=3),
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ norm_cfg=None,
+ skip_inds=None,
+ init_cfg=[
+ dict(type='Caffe2Xavier', layer='Conv2d'),
+ dict(
+ type='Constant',
+ layer=[
+ '_BatchNorm', '_InstanceNorm', 'GroupNorm',
+ 'LayerNorm'
+ ],
+ val=1.0)
+ ]):
+ super(FPG, self).__init__(init_cfg)
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ if inter_channels is None:
+ self.inter_channels = [out_channels for _ in range(num_outs)]
+ elif isinstance(inter_channels, int):
+ self.inter_channels = [inter_channels for _ in range(num_outs)]
+ else:
+ assert isinstance(inter_channels, list)
+ assert len(inter_channels) == num_outs
+ self.inter_channels = inter_channels
+ self.stack_times = stack_times
+ self.paths = paths
+ assert isinstance(paths, list) and len(paths) == stack_times
+ for d in paths:
+ assert d in ('bu', 'td')
+
+ self.same_down_trans = same_down_trans
+ self.same_up_trans = same_up_trans
+ self.across_lateral_trans = across_lateral_trans
+ self.across_down_trans = across_down_trans
+ self.across_up_trans = across_up_trans
+ self.output_trans = output_trans
+ self.across_skip_trans = across_skip_trans
+
+ self.with_bias = norm_cfg is None
+ # skip inds must be specified if across skip trans is not None
+ if self.across_skip_trans is not None:
+ skip_inds is not None
+ self.skip_inds = skip_inds
+ assert len(self.skip_inds[0]) <= self.stack_times
+
+ if end_level == -1 or end_level == self.num_ins - 1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level is not the last level, no extra level is allowed
+ self.backbone_end_level = end_level + 1
+ assert end_level < self.num_ins
+ assert num_outs == end_level - start_level + 1
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+
+ # build lateral 1x1 convs to reduce channels
+ self.lateral_convs = nn.ModuleList()
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = nn.Conv2d(self.in_channels[i],
+ self.inter_channels[i - self.start_level], 1)
+ self.lateral_convs.append(l_conv)
+
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ self.extra_downsamples = nn.ModuleList()
+ for i in range(extra_levels):
+ if self.add_extra_convs:
+ fpn_idx = self.backbone_end_level - self.start_level + i
+ extra_conv = nn.Conv2d(
+ self.inter_channels[fpn_idx - 1],
+ self.inter_channels[fpn_idx],
+ 3,
+ stride=2,
+ padding=1)
+ self.extra_downsamples.append(extra_conv)
+ else:
+ self.extra_downsamples.append(nn.MaxPool2d(1, stride=2))
+
+ self.fpn_transitions = nn.ModuleList() # stack times
+ for s in range(self.stack_times):
+ stage_trans = nn.ModuleList() # num of feature levels
+ for i in range(self.num_outs):
+ # same, across_lateral, across_down, across_up
+ trans = nn.ModuleDict()
+ if s in self.skip_inds[i]:
+ stage_trans.append(trans)
+ continue
+ # build same-stage down trans (used in bottom-up paths)
+ if i == 0 or self.same_up_trans is None:
+ same_up_trans = None
+ else:
+ same_up_trans = self.build_trans(
+ self.same_up_trans, self.inter_channels[i - 1],
+ self.inter_channels[i])
+ trans['same_up'] = same_up_trans
+ # build same-stage up trans (used in top-down paths)
+ if i == self.num_outs - 1 or self.same_down_trans is None:
+ same_down_trans = None
+ else:
+ same_down_trans = self.build_trans(
+ self.same_down_trans, self.inter_channels[i + 1],
+ self.inter_channels[i])
+ trans['same_down'] = same_down_trans
+ # build across lateral trans
+ across_lateral_trans = self.build_trans(
+ self.across_lateral_trans, self.inter_channels[i],
+ self.inter_channels[i])
+ trans['across_lateral'] = across_lateral_trans
+ # build across down trans
+ if i == self.num_outs - 1 or self.across_down_trans is None:
+ across_down_trans = None
+ else:
+ across_down_trans = self.build_trans(
+ self.across_down_trans, self.inter_channels[i + 1],
+ self.inter_channels[i])
+ trans['across_down'] = across_down_trans
+ # build across up trans
+ if i == 0 or self.across_up_trans is None:
+ across_up_trans = None
+ else:
+ across_up_trans = self.build_trans(
+ self.across_up_trans, self.inter_channels[i - 1],
+ self.inter_channels[i])
+ trans['across_up'] = across_up_trans
+ if self.across_skip_trans is None:
+ across_skip_trans = None
+ else:
+ across_skip_trans = self.build_trans(
+ self.across_skip_trans, self.inter_channels[i - 1],
+ self.inter_channels[i])
+ trans['across_skip'] = across_skip_trans
+ # build across_skip trans
+ stage_trans.append(trans)
+ self.fpn_transitions.append(stage_trans)
+
+ self.output_transition = nn.ModuleList() # output levels
+ for i in range(self.num_outs):
+ trans = self.build_trans(
+ self.output_trans,
+ self.inter_channels[i],
+ self.out_channels,
+ num_inputs=self.stack_times + 1)
+ self.output_transition.append(trans)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def build_trans(self, cfg, in_channels, out_channels, **extra_args):
+ cfg_ = cfg.copy()
+ trans_type = cfg_.pop('type')
+ trans_cls = self.transition_types[trans_type]
+ return trans_cls(in_channels, out_channels, **cfg_, **extra_args)
+
+ def fuse(self, fuse_dict):
+ out = None
+ for item in fuse_dict.values():
+ if item is not None:
+ if out is None:
+ out = item
+ else:
+ out = out + item
+ return out
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+
+ # build all levels from original feature maps
+ feats = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ for downsample in self.extra_downsamples:
+ feats.append(downsample(feats[-1]))
+
+ outs = [feats]
+
+ for i in range(self.stack_times):
+ current_outs = outs[-1]
+ next_outs = []
+ direction = self.paths[i]
+ for j in range(self.num_outs):
+ if i in self.skip_inds[j]:
+ next_outs.append(outs[-1][j])
+ continue
+ # feature level
+ if direction == 'td':
+ lvl = self.num_outs - j - 1
+ else:
+ lvl = j
+ # get transitions
+ if direction == 'td':
+ same_trans = self.fpn_transitions[i][lvl]['same_down']
+ else:
+ same_trans = self.fpn_transitions[i][lvl]['same_up']
+ across_lateral_trans = self.fpn_transitions[i][lvl][
+ 'across_lateral']
+ across_down_trans = self.fpn_transitions[i][lvl]['across_down']
+ across_up_trans = self.fpn_transitions[i][lvl]['across_up']
+ across_skip_trans = self.fpn_transitions[i][lvl]['across_skip']
+ # init output
+ to_fuse = dict(
+ same=None, lateral=None, across_up=None, across_down=None)
+ # same downsample/upsample
+ if same_trans is not None:
+ to_fuse['same'] = same_trans(next_outs[-1])
+ # across lateral
+ if across_lateral_trans is not None:
+ to_fuse['lateral'] = across_lateral_trans(
+ current_outs[lvl])
+ # across downsample
+ if lvl > 0 and across_up_trans is not None:
+ to_fuse['across_up'] = across_up_trans(current_outs[lvl -
+ 1])
+ # across upsample
+ if (lvl < self.num_outs - 1 and across_down_trans is not None):
+ to_fuse['across_down'] = across_down_trans(
+ current_outs[lvl + 1])
+ if across_skip_trans is not None:
+ to_fuse['across_skip'] = across_skip_trans(outs[0][lvl])
+ x = self.fuse(to_fuse)
+ next_outs.append(x)
+
+ if direction == 'td':
+ outs.append(next_outs[::-1])
+ else:
+ outs.append(next_outs)
+
+ # output trans
+ final_outs = []
+ for i in range(self.num_outs):
+ lvl_out_list = []
+ for s in range(len(outs)):
+ lvl_out_list.append(outs[s][i])
+ lvl_out = self.output_transition[i](lvl_out_list)
+ final_outs.append(lvl_out)
+
+ return final_outs
diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bdb5b22156b579dc262894fd0c4a141f4479854
--- /dev/null
+++ b/mmdet/models/necks/fpn.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule, auto_fp16
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class FPN(BaseModule):
+ r"""Feature Pyramid Network.
+
+ This is an implementation of paper `Feature Pyramid Networks for Object
+ Detection `_.
+
+ Args:
+ in_channels (list[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool | str): If bool, it decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, it is equivalent to `add_extra_convs='on_input'`.
+ If str, it specifies the source feature map of the extra convs.
+ Only the following options are allowed
+
+ - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+ - 'on_lateral': Last feature map after lateral convs.
+ - 'on_output': The last output feature map after fpn convs.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer in ConvModule.
+ Default: None.
+ upsample_cfg (dict): Config dict for interpolate layer.
+ Default: dict(mode='nearest').
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = FPN(in_channels, 11, len(in_channels)).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ upsample_cfg=dict(mode='nearest'),
+ init_cfg=dict(
+ type='Xavier', layer='Conv2d', distribution='uniform')):
+ super(FPN, self).__init__(init_cfg)
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.relu_before_extra_convs = relu_before_extra_convs
+ self.no_norm_on_lateral = no_norm_on_lateral
+ self.fp16_enabled = False
+ self.upsample_cfg = upsample_cfg.copy()
+
+ if end_level == -1 or end_level == self.num_ins - 1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level is not the last level, no extra level is allowed
+ self.backbone_end_level = end_level + 1
+ assert end_level < self.num_ins
+ assert num_outs == end_level - start_level + 1
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+ assert isinstance(add_extra_convs, (str, bool))
+ if isinstance(add_extra_convs, str):
+ # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+ assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+ elif add_extra_convs: # True
+ self.add_extra_convs = 'on_input'
+
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+ act_cfg=act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv layers (e.g., RetinaNet)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ if self.add_extra_convs and extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == 'on_input':
+ in_channels = self.in_channels[self.backbone_end_level - 1]
+ else:
+ in_channels = out_channels
+ extra_fpn_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.fpn_convs.append(extra_fpn_conv)
+
+ @auto_fp16()
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+ # it cannot co-exist with `size` in `F.interpolate`.
+ if 'scale_factor' in self.upsample_cfg:
+ # fix runtime error of "+=" inplace operation in PyTorch 1.10
+ laterals[i - 1] = laterals[i - 1] + F.interpolate(
+ laterals[i], **self.upsample_cfg)
+ else:
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] = laterals[i - 1] + F.interpolate(
+ laterals[i], size=prev_shape, **self.upsample_cfg)
+
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+ # part 2: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ extra_source = inputs[self.backbone_end_level - 1]
+ elif self.add_extra_convs == 'on_lateral':
+ extra_source = laterals[-1]
+ elif self.add_extra_convs == 'on_output':
+ extra_source = outs[-1]
+ else:
+ raise NotImplementedError
+ outs.append(self.fpn_convs[used_backbone_levels](extra_source))
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/mmdet/models/necks/fpn_carafe.py b/mmdet/models/necks/fpn_carafe.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdd91f34c94129eefb477451dd7c1f7a7854135e
--- /dev/null
+++ b/mmdet/models/necks/fpn_carafe.py
@@ -0,0 +1,275 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule, build_upsample_layer, xavier_init
+from mmcv.ops.carafe import CARAFEPack
+from mmcv.runner import BaseModule, ModuleList
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class FPN_CARAFE(BaseModule):
+ """FPN_CARAFE is a more flexible implementation of FPN. It allows more
+ choice for upsample methods during the top-down pathway.
+
+ It can reproduce the performance of ICCV 2019 paper
+ CARAFE: Content-Aware ReAssembly of FEatures
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+ Args:
+ in_channels (list[int]): Number of channels for each input feature map.
+ out_channels (int): Output channels of feature pyramids.
+ num_outs (int): Number of output stages.
+ start_level (int): Start level of feature pyramids.
+ (Default: 0)
+ end_level (int): End level of feature pyramids.
+ (Default: -1 indicates the last level).
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ activate (str): Type of activation function in ConvModule
+ (Default: None indicates w/o activation).
+ order (dict): Order of components in ConvModule.
+ upsample (str): Type of upsample layer.
+ upsample_cfg (dict): Dictionary to construct and config upsample layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ norm_cfg=None,
+ act_cfg=None,
+ order=('conv', 'norm', 'act'),
+ upsample_cfg=dict(
+ type='carafe',
+ up_kernel=5,
+ up_group=1,
+ encoder_kernel=3,
+ encoder_dilation=1),
+ init_cfg=None):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ super(FPN_CARAFE, self).__init__(init_cfg)
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.with_bias = norm_cfg is None
+ self.upsample_cfg = upsample_cfg.copy()
+ self.upsample = self.upsample_cfg.get('type')
+ self.relu = nn.ReLU(inplace=False)
+
+ self.order = order
+ assert order in [('conv', 'norm', 'act'), ('act', 'conv', 'norm')]
+
+ assert self.upsample in [
+ 'nearest', 'bilinear', 'deconv', 'pixel_shuffle', 'carafe', None
+ ]
+ if self.upsample in ['deconv', 'pixel_shuffle']:
+ assert hasattr(
+ self.upsample_cfg,
+ 'upsample_kernel') and self.upsample_cfg.upsample_kernel > 0
+ self.upsample_kernel = self.upsample_cfg.pop('upsample_kernel')
+
+ if end_level == -1 or end_level == self.num_ins - 1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level is not the last level, no extra level is allowed
+ self.backbone_end_level = end_level + 1
+ assert end_level < self.num_ins
+ assert num_outs == end_level - start_level + 1
+ self.start_level = start_level
+ self.end_level = end_level
+
+ self.lateral_convs = ModuleList()
+ self.fpn_convs = ModuleList()
+ self.upsample_modules = ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ norm_cfg=norm_cfg,
+ bias=self.with_bias,
+ act_cfg=act_cfg,
+ inplace=False,
+ order=self.order)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ bias=self.with_bias,
+ act_cfg=act_cfg,
+ inplace=False,
+ order=self.order)
+ if i != self.backbone_end_level - 1:
+ upsample_cfg_ = self.upsample_cfg.copy()
+ if self.upsample == 'deconv':
+ upsample_cfg_.update(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=self.upsample_kernel,
+ stride=2,
+ padding=(self.upsample_kernel - 1) // 2,
+ output_padding=(self.upsample_kernel - 1) // 2)
+ elif self.upsample == 'pixel_shuffle':
+ upsample_cfg_.update(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ scale_factor=2,
+ upsample_kernel=self.upsample_kernel)
+ elif self.upsample == 'carafe':
+ upsample_cfg_.update(channels=out_channels, scale_factor=2)
+ else:
+ # suppress warnings
+ align_corners = (None
+ if self.upsample == 'nearest' else False)
+ upsample_cfg_.update(
+ scale_factor=2,
+ mode=self.upsample,
+ align_corners=align_corners)
+ upsample_module = build_upsample_layer(upsample_cfg_)
+ self.upsample_modules.append(upsample_module)
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv layers (e.g., RetinaNet)
+ extra_out_levels = (
+ num_outs - self.backbone_end_level + self.start_level)
+ if extra_out_levels >= 1:
+ for i in range(extra_out_levels):
+ in_channels = (
+ self.in_channels[self.backbone_end_level -
+ 1] if i == 0 else out_channels)
+ extra_l_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ norm_cfg=norm_cfg,
+ bias=self.with_bias,
+ act_cfg=act_cfg,
+ inplace=False,
+ order=self.order)
+ if self.upsample == 'deconv':
+ upsampler_cfg_ = dict(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=self.upsample_kernel,
+ stride=2,
+ padding=(self.upsample_kernel - 1) // 2,
+ output_padding=(self.upsample_kernel - 1) // 2)
+ elif self.upsample == 'pixel_shuffle':
+ upsampler_cfg_ = dict(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ scale_factor=2,
+ upsample_kernel=self.upsample_kernel)
+ elif self.upsample == 'carafe':
+ upsampler_cfg_ = dict(
+ channels=out_channels,
+ scale_factor=2,
+ **self.upsample_cfg)
+ else:
+ # suppress warnings
+ align_corners = (None
+ if self.upsample == 'nearest' else False)
+ upsampler_cfg_ = dict(
+ scale_factor=2,
+ mode=self.upsample,
+ align_corners=align_corners)
+ upsampler_cfg_['type'] = self.upsample
+ upsample_module = build_upsample_layer(upsampler_cfg_)
+ extra_fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ bias=self.with_bias,
+ act_cfg=act_cfg,
+ inplace=False,
+ order=self.order)
+ self.upsample_modules.append(upsample_module)
+ self.fpn_convs.append(extra_fpn_conv)
+ self.lateral_convs.append(extra_l_conv)
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ """Initialize the weights of module."""
+ super(FPN_CARAFE, self).init_weights()
+ for m in self.modules():
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
+ xavier_init(m, distribution='uniform')
+ for m in self.modules():
+ if isinstance(m, CARAFEPack):
+ m.init_weights()
+
+ def slice_as(self, src, dst):
+ """Slice ``src`` as ``dst``
+
+ Note:
+ ``src`` should have the same or larger size than ``dst``.
+
+ Args:
+ src (torch.Tensor): Tensors to be sliced.
+ dst (torch.Tensor): ``src`` will be sliced to have the same
+ size as ``dst``.
+
+ Returns:
+ torch.Tensor: Sliced tensor.
+ """
+ assert (src.size(2) >= dst.size(2)) and (src.size(3) >= dst.size(3))
+ if src.size(2) == dst.size(2) and src.size(3) == dst.size(3):
+ return src
+ else:
+ return src[:, :, :dst.size(2), :dst.size(3)]
+
+ def tensor_add(self, a, b):
+ """Add tensors ``a`` and ``b`` that might have different sizes."""
+ if a.size() == b.size():
+ c = a + b
+ else:
+ c = a + self.slice_as(b, a)
+ return c
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = []
+ for i, lateral_conv in enumerate(self.lateral_convs):
+ if i <= self.backbone_end_level - self.start_level:
+ input = inputs[min(i + self.start_level, len(inputs) - 1)]
+ else:
+ input = laterals[-1]
+ lateral = lateral_conv(input)
+ laterals.append(lateral)
+
+ # build top-down path
+ for i in range(len(laterals) - 1, 0, -1):
+ if self.upsample is not None:
+ upsample_feat = self.upsample_modules[i - 1](laterals[i])
+ else:
+ upsample_feat = laterals[i]
+ laterals[i - 1] = self.tensor_add(laterals[i - 1], upsample_feat)
+
+ # build outputs
+ num_conv_outs = len(self.fpn_convs)
+ outs = []
+ for i in range(num_conv_outs):
+ out = self.fpn_convs[i](laterals[i])
+ outs.append(out)
+ return tuple(outs)
diff --git a/mmdet/models/necks/hrfpn.py b/mmdet/models/necks/hrfpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca15be6b29877b1023fdd9f93226690f816504bf
--- /dev/null
+++ b/mmdet/models/necks/hrfpn.py
@@ -0,0 +1,100 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+from torch.utils.checkpoint import checkpoint
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class HRFPN(BaseModule):
+ """HRFPN (High Resolution Feature Pyramids)
+
+ paper: `High-Resolution Representations for Labeling Pixels and Regions
+ `_.
+
+ Args:
+ in_channels (list): number of channels for each branch.
+ out_channels (int): output channels of feature pyramids.
+ num_outs (int): number of output stages.
+ pooling_type (str): pooling for generating feature pyramids
+ from {MAX, AVG}.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ stride (int): stride of 3x3 convolutional layers
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs=5,
+ pooling_type='AVG',
+ conv_cfg=None,
+ norm_cfg=None,
+ with_cp=False,
+ stride=1,
+ init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')):
+ super(HRFPN, self).__init__(init_cfg)
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.reduction_conv = ConvModule(
+ sum(in_channels),
+ out_channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ act_cfg=None)
+
+ self.fpn_convs = nn.ModuleList()
+ for i in range(self.num_outs):
+ self.fpn_convs.append(
+ ConvModule(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ conv_cfg=self.conv_cfg,
+ act_cfg=None))
+
+ if pooling_type == 'MAX':
+ self.pooling = F.max_pool2d
+ else:
+ self.pooling = F.avg_pool2d
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == self.num_ins
+ outs = [inputs[0]]
+ for i in range(1, self.num_ins):
+ outs.append(
+ F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
+ out = torch.cat(outs, dim=1)
+ if out.requires_grad and self.with_cp:
+ out = checkpoint(self.reduction_conv, out)
+ else:
+ out = self.reduction_conv(out)
+ outs = [out]
+ for i in range(1, self.num_outs):
+ outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
+ outputs = []
+
+ for i in range(self.num_outs):
+ if outs[i].requires_grad and self.with_cp:
+ tmp_out = checkpoint(self.fpn_convs[i], outs[i])
+ else:
+ tmp_out = self.fpn_convs[i](outs[i])
+ outputs.append(tmp_out)
+ return tuple(outputs)
diff --git a/mmdet/models/necks/nas_fpn.py b/mmdet/models/necks/nas_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..710592eccb4f483e64d4bc09b3d8669170dc8f0f
--- /dev/null
+++ b/mmdet/models/necks/nas_fpn.py
@@ -0,0 +1,158 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.ops.merge_cells import GlobalPoolingCell, SumCell
+from mmcv.runner import BaseModule, ModuleList
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class NASFPN(BaseModule):
+ """NAS-FPN.
+
+ Implementation of `NAS-FPN: Learning Scalable Feature Pyramid Architecture
+ for Object Detection `_
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ stack_times (int): The number of times the pyramid architecture will
+ be stacked.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool): It decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ stack_times,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ norm_cfg=None,
+ init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')):
+ super(NASFPN, self).__init__(init_cfg)
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels) # num of input feature levels
+ self.num_outs = num_outs # num of output feature levels
+ self.stack_times = stack_times
+ self.norm_cfg = norm_cfg
+
+ if end_level == -1 or end_level == self.num_ins - 1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level is not the last level, no extra level is allowed
+ self.backbone_end_level = end_level + 1
+ assert end_level < self.num_ins
+ assert num_outs == end_level - start_level + 1
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+
+ # add lateral connections
+ self.lateral_convs = nn.ModuleList()
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ self.lateral_convs.append(l_conv)
+
+ # add extra downsample layers (stride-2 pooling or conv)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ self.extra_downsamples = nn.ModuleList()
+ for i in range(extra_levels):
+ extra_conv = ConvModule(
+ out_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
+ self.extra_downsamples.append(
+ nn.Sequential(extra_conv, nn.MaxPool2d(2, 2)))
+
+ # add NAS FPN connections
+ self.fpn_stages = ModuleList()
+ for _ in range(self.stack_times):
+ stage = nn.ModuleDict()
+ # gp(p6, p4) -> p4_1
+ stage['gp_64_4'] = GlobalPoolingCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p4_1, p4) -> p4_2
+ stage['sum_44_4'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p4_2, p3) -> p3_out
+ stage['sum_43_3'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p3_out, p4_2) -> p4_out
+ stage['sum_34_4'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p5, gp(p4_out, p3_out)) -> p5_out
+ stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False)
+ stage['sum_55_5'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p7, gp(p5_out, p4_2)) -> p7_out
+ stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False)
+ stage['sum_77_7'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # gp(p7_out, p5_out) -> p6_out
+ stage['gp_75_6'] = GlobalPoolingCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ self.fpn_stages.append(stage)
+
+ def forward(self, inputs):
+ """Forward function."""
+ # build P3-P5
+ feats = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ # build P6-P7 on top of P5
+ for downsample in self.extra_downsamples:
+ feats.append(downsample(feats[-1]))
+
+ p3, p4, p5, p6, p7 = feats
+
+ for stage in self.fpn_stages:
+ # gp(p6, p4) -> p4_1
+ p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:])
+ # sum(p4_1, p4) -> p4_2
+ p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:])
+ # sum(p4_2, p3) -> p3_out
+ p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:])
+ # sum(p3_out, p4_2) -> p4_out
+ p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:])
+ # sum(p5, gp(p4_out, p3_out)) -> p5_out
+ p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:])
+ p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:])
+ # sum(p7, gp(p5_out, p4_2)) -> p7_out
+ p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:])
+ p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:])
+ # gp(p7_out, p5_out) -> p6_out
+ p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:])
+
+ return p3, p4, p5, p6, p7
diff --git a/mmdet/models/necks/nasfcos_fpn.py b/mmdet/models/necks/nasfcos_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4abfe7bde8a69c1219e7532669761c3e9e64e15
--- /dev/null
+++ b/mmdet/models/necks/nasfcos_fpn.py
@@ -0,0 +1,170 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, caffe2_xavier_init
+from mmcv.ops.merge_cells import ConcatCell
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class NASFCOS_FPN(BaseModule):
+ """FPN structure in NASFPN.
+
+ Implementation of paper `NAS-FCOS: Fast Neural Architecture Search for
+ Object Detection `_
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool): It decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=1,
+ end_level=-1,
+ add_extra_convs=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ init_cfg=None):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ super(NASFCOS_FPN, self).__init__(init_cfg)
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+
+ if end_level == -1 or end_level == self.num_ins - 1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level is not the last level, no extra level is allowed
+ self.backbone_end_level = end_level + 1
+ assert end_level < self.num_ins
+ assert num_outs == end_level - start_level + 1
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+
+ self.adapt_convs = nn.ModuleList()
+ for i in range(self.start_level, self.backbone_end_level):
+ adapt_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ stride=1,
+ padding=0,
+ bias=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU', inplace=False))
+ self.adapt_convs.append(adapt_conv)
+
+ # C2 is omitted according to the paper
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+
+ def build_concat_cell(with_input1_conv, with_input2_conv):
+ cell_conv_cfg = dict(
+ kernel_size=1, padding=0, bias=False, groups=out_channels)
+ return ConcatCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ with_out_conv=True,
+ out_conv_cfg=cell_conv_cfg,
+ out_norm_cfg=dict(type='BN'),
+ out_conv_order=('norm', 'act', 'conv'),
+ with_input1_conv=with_input1_conv,
+ with_input2_conv=with_input2_conv,
+ input_conv_cfg=conv_cfg,
+ input_norm_cfg=norm_cfg,
+ upsample_mode='nearest')
+
+ # Denote c3=f0, c4=f1, c5=f2 for convince
+ self.fpn = nn.ModuleDict()
+ self.fpn['c22_1'] = build_concat_cell(True, True)
+ self.fpn['c22_2'] = build_concat_cell(True, True)
+ self.fpn['c32'] = build_concat_cell(True, False)
+ self.fpn['c02'] = build_concat_cell(True, False)
+ self.fpn['c42'] = build_concat_cell(True, True)
+ self.fpn['c36'] = build_concat_cell(True, True)
+ self.fpn['c61'] = build_concat_cell(True, True) # f9
+ self.extra_downsamples = nn.ModuleList()
+ for i in range(extra_levels):
+ extra_act_cfg = None if i == 0 \
+ else dict(type='ReLU', inplace=False)
+ self.extra_downsamples.append(
+ ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ act_cfg=extra_act_cfg,
+ order=('act', 'norm', 'conv')))
+
+ def forward(self, inputs):
+ """Forward function."""
+ feats = [
+ adapt_conv(inputs[i + self.start_level])
+ for i, adapt_conv in enumerate(self.adapt_convs)
+ ]
+
+ for (i, module_name) in enumerate(self.fpn):
+ idx_1, idx_2 = int(module_name[1]), int(module_name[2])
+ res = self.fpn[module_name](feats[idx_1], feats[idx_2])
+ feats.append(res)
+
+ ret = []
+ for (idx, input_idx) in zip([9, 8, 7], [1, 2, 3]): # add P3, P4, P5
+ feats1, feats2 = feats[idx], feats[5]
+ feats2_resize = F.interpolate(
+ feats2,
+ size=feats1.size()[2:],
+ mode='bilinear',
+ align_corners=False)
+
+ feats_sum = feats1 + feats2_resize
+ ret.append(
+ F.interpolate(
+ feats_sum,
+ size=inputs[input_idx].size()[2:],
+ mode='bilinear',
+ align_corners=False))
+
+ for submodule in self.extra_downsamples:
+ ret.append(submodule(ret[-1]))
+
+ return tuple(ret)
+
+ def init_weights(self):
+ """Initialize the weights of module."""
+ super(NASFCOS_FPN, self).init_weights()
+ for module in self.fpn.values():
+ if hasattr(module, 'conv_out'):
+ caffe2_xavier_init(module.out_conv.conv)
+
+ for modules in [
+ self.adapt_convs.modules(),
+ self.extra_downsamples.modules()
+ ]:
+ for module in modules:
+ if isinstance(module, nn.Conv2d):
+ caffe2_xavier_init(module)
diff --git a/mmdet/models/necks/pafpn.py b/mmdet/models/necks/pafpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2edd34879425891a16b8e93b92fe2d653af07022
--- /dev/null
+++ b/mmdet/models/necks/pafpn.py
@@ -0,0 +1,159 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import auto_fp16
+
+from ..builder import NECKS
+from .fpn import FPN
+
+
+@NECKS.register_module()
+class PAFPN(FPN):
+ """Path Aggregation Network for Instance Segmentation.
+
+ This is an implementation of the `PAFPN in Path Aggregation Network
+ `_.
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool | str): If bool, it decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, it is equivalent to `add_extra_convs='on_input'`.
+ If str, it specifies the source feature map of the extra convs.
+ Only the following options are allowed
+
+ - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+ - 'on_lateral': Last feature map after lateral convs.
+ - 'on_output': The last output feature map after fpn convs.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (str): Config dict for activation layer in ConvModule.
+ Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ init_cfg=dict(
+ type='Xavier', layer='Conv2d', distribution='uniform')):
+ super(PAFPN, self).__init__(
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level,
+ end_level,
+ add_extra_convs,
+ relu_before_extra_convs,
+ no_norm_on_lateral,
+ conv_cfg,
+ norm_cfg,
+ act_cfg,
+ init_cfg=init_cfg)
+ # add extra bottom up pathway
+ self.downsample_convs = nn.ModuleList()
+ self.pafpn_convs = nn.ModuleList()
+ for i in range(self.start_level + 1, self.backbone_end_level):
+ d_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ pafpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.downsample_convs.append(d_conv)
+ self.pafpn_convs.append(pafpn_conv)
+
+ @auto_fp16()
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ # fix runtime error of "+=" inplace operation in PyTorch 1.10
+ laterals[i - 1] = laterals[i - 1] + F.interpolate(
+ laterals[i], size=prev_shape, mode='nearest')
+
+ # build outputs
+ # part 1: from original levels
+ inter_outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+
+ # part 2: add bottom-up path
+ for i in range(0, used_backbone_levels - 1):
+ inter_outs[i + 1] += self.downsample_convs[i](inter_outs[i])
+
+ outs = []
+ outs.append(inter_outs[0])
+ outs.extend([
+ self.pafpn_convs[i - 1](inter_outs[i])
+ for i in range(1, used_backbone_levels)
+ ])
+
+ # part 3: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ orig = inputs[self.backbone_end_level - 1]
+ outs.append(self.fpn_convs[used_backbone_levels](orig))
+ elif self.add_extra_convs == 'on_lateral':
+ outs.append(self.fpn_convs[used_backbone_levels](
+ laterals[-1]))
+ elif self.add_extra_convs == 'on_output':
+ outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
+ else:
+ raise NotImplementedError
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/mmdet/models/necks/rfp.py b/mmdet/models/necks/rfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..6976f4daf25a04f63f7570ec7ca7633c50fc725d
--- /dev/null
+++ b/mmdet/models/necks/rfp.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import constant_init, xavier_init
+from mmcv.runner import BaseModule, ModuleList
+
+from ..builder import NECKS, build_backbone
+from .fpn import FPN
+
+
+class ASPP(BaseModule):
+ """ASPP (Atrous Spatial Pyramid Pooling)
+
+ This is an implementation of the ASPP module used in DetectoRS
+ (https://arxiv.org/pdf/2006.02334.pdf)
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of channels produced by this module
+ dilations (tuple[int]): Dilations of the four branches.
+ Default: (1, 3, 6, 1)
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ dilations=(1, 3, 6, 1),
+ init_cfg=dict(type='Kaiming', layer='Conv2d')):
+ super().__init__(init_cfg)
+ assert dilations[-1] == 1
+ self.aspp = nn.ModuleList()
+ for dilation in dilations:
+ kernel_size = 3 if dilation > 1 else 1
+ padding = dilation if dilation > 1 else 0
+ conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ dilation=dilation,
+ padding=padding,
+ bias=True)
+ self.aspp.append(conv)
+ self.gap = nn.AdaptiveAvgPool2d(1)
+
+ def forward(self, x):
+ avg_x = self.gap(x)
+ out = []
+ for aspp_idx in range(len(self.aspp)):
+ inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x
+ out.append(F.relu_(self.aspp[aspp_idx](inp)))
+ out[-1] = out[-1].expand_as(out[-2])
+ out = torch.cat(out, dim=1)
+ return out
+
+
+@NECKS.register_module()
+class RFP(FPN):
+ """RFP (Recursive Feature Pyramid)
+
+ This is an implementation of RFP in `DetectoRS
+ `_. Different from standard FPN, the
+ input of RFP should be multi level features along with origin input image
+ of backbone.
+
+ Args:
+ rfp_steps (int): Number of unrolled steps of RFP.
+ rfp_backbone (dict): Configuration of the backbone for RFP.
+ aspp_out_channels (int): Number of output channels of ASPP module.
+ aspp_dilations (tuple[int]): Dilation rates of four branches.
+ Default: (1, 3, 6, 1)
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ rfp_steps,
+ rfp_backbone,
+ aspp_out_channels,
+ aspp_dilations=(1, 3, 6, 1),
+ init_cfg=None,
+ **kwargs):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ super().__init__(init_cfg=init_cfg, **kwargs)
+ self.rfp_steps = rfp_steps
+ # Be careful! Pretrained weights cannot be loaded when use
+ # nn.ModuleList
+ self.rfp_modules = ModuleList()
+ for rfp_idx in range(1, rfp_steps):
+ rfp_module = build_backbone(rfp_backbone)
+ self.rfp_modules.append(rfp_module)
+ self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels,
+ aspp_dilations)
+ self.rfp_weight = nn.Conv2d(
+ self.out_channels,
+ 1,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True)
+
+ def init_weights(self):
+ # Avoid using super().init_weights(), which may alter the default
+ # initialization of the modules in self.rfp_modules that have missing
+ # keys in the pretrained checkpoint.
+ for convs in [self.lateral_convs, self.fpn_convs]:
+ for m in convs.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+ for rfp_idx in range(self.rfp_steps - 1):
+ self.rfp_modules[rfp_idx].init_weights()
+ constant_init(self.rfp_weight, 0)
+
+ def forward(self, inputs):
+ inputs = list(inputs)
+ assert len(inputs) == len(self.in_channels) + 1 # +1 for input image
+ img = inputs.pop(0)
+ # FPN forward
+ x = super().forward(tuple(inputs))
+ for rfp_idx in range(self.rfp_steps - 1):
+ rfp_feats = [x[0]] + list(
+ self.rfp_aspp(x[i]) for i in range(1, len(x)))
+ x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats)
+ # FPN forward
+ x_idx = super().forward(x_idx)
+ x_new = []
+ for ft_idx in range(len(x_idx)):
+ add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx]))
+ x_new.append(add_weight * x_idx[ft_idx] +
+ (1 - add_weight) * x[ft_idx])
+ x = x_new
+ return x
diff --git a/mmdet/models/necks/ssd_neck.py b/mmdet/models/necks/ssd_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..179d575e172ef93dd42aecc9a55f216029db4aec
--- /dev/null
+++ b/mmdet/models/necks/ssd_neck.py
@@ -0,0 +1,129 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class SSDNeck(BaseModule):
+ """Extra layers of SSD backbone to generate multi-scale feature maps.
+
+ Args:
+ in_channels (Sequence[int]): Number of input channels per scale.
+ out_channels (Sequence[int]): Number of output channels per scale.
+ level_strides (Sequence[int]): Stride of 3x3 conv per level.
+ level_paddings (Sequence[int]): Padding size of 3x3 conv per level.
+ l2_norm_scale (float|None): L2 normalization layer init scale.
+ If None, not use L2 normalization on the first input feature.
+ last_kernel_size (int): Kernel size of the last conv layer.
+ Default: 3.
+ use_depthwise (bool): Whether to use DepthwiseSeparableConv.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ level_strides,
+ level_paddings,
+ l2_norm_scale=20.,
+ last_kernel_size=3,
+ use_depthwise=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ init_cfg=[
+ dict(
+ type='Xavier', distribution='uniform',
+ layer='Conv2d'),
+ dict(type='Constant', val=1, layer='BatchNorm2d'),
+ ]):
+ super(SSDNeck, self).__init__(init_cfg)
+ assert len(out_channels) > len(in_channels)
+ assert len(out_channels) - len(in_channels) == len(level_strides)
+ assert len(level_strides) == len(level_paddings)
+ assert in_channels == out_channels[:len(in_channels)]
+
+ if l2_norm_scale:
+ self.l2_norm = L2Norm(in_channels[0], l2_norm_scale)
+ self.init_cfg += [
+ dict(
+ type='Constant',
+ val=self.l2_norm.scale,
+ override=dict(name='l2_norm'))
+ ]
+
+ self.extra_layers = nn.ModuleList()
+ extra_layer_channels = out_channels[len(in_channels):]
+ second_conv = DepthwiseSeparableConvModule if \
+ use_depthwise else ConvModule
+
+ for i, (out_channel, stride, padding) in enumerate(
+ zip(extra_layer_channels, level_strides, level_paddings)):
+ kernel_size = last_kernel_size \
+ if i == len(extra_layer_channels) - 1 else 3
+ per_lvl_convs = nn.Sequential(
+ ConvModule(
+ out_channels[len(in_channels) - 1 + i],
+ out_channel // 2,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ second_conv(
+ out_channel // 2,
+ out_channel,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.extra_layers.append(per_lvl_convs)
+
+ def forward(self, inputs):
+ """Forward function."""
+ outs = [feat for feat in inputs]
+ if hasattr(self, 'l2_norm'):
+ outs[0] = self.l2_norm(outs[0])
+
+ feat = outs[-1]
+ for layer in self.extra_layers:
+ feat = layer(feat)
+ outs.append(feat)
+ return tuple(outs)
+
+
+class L2Norm(nn.Module):
+
+ def __init__(self, n_dims, scale=20., eps=1e-10):
+ """L2 normalization layer.
+
+ Args:
+ n_dims (int): Number of dimensions to be normalized
+ scale (float, optional): Defaults to 20..
+ eps (float, optional): Used to avoid division by zero.
+ Defaults to 1e-10.
+ """
+ super(L2Norm, self).__init__()
+ self.n_dims = n_dims
+ self.weight = nn.Parameter(torch.Tensor(self.n_dims))
+ self.eps = eps
+ self.scale = scale
+
+ def forward(self, x):
+ """Forward function."""
+ # normalization layer convert to FP32 in FP16 training
+ x_float = x.float()
+ norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
+ return (self.weight[None, :, None, None].float().expand_as(x_float) *
+ x_float / norm).type_as(x)
diff --git a/mmdet/models/necks/yolo_neck.py b/mmdet/models/necks/yolo_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8eeb5737cdf871fa415c1a207956ea7753c304e
--- /dev/null
+++ b/mmdet/models/necks/yolo_neck.py
@@ -0,0 +1,140 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Copyright (c) 2019 Western Digital Corporation or its affiliates.
+
+import torch
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+
+
+class DetectionBlock(BaseModule):
+ """Detection block in YOLO neck.
+
+ Let out_channels = n, the DetectionBlock contains:
+ Six ConvLayers, 1 Conv2D Layer and 1 YoloLayer.
+ The first 6 ConvLayers are formed the following way:
+ 1x1xn, 3x3x2n, 1x1xn, 3x3x2n, 1x1xn, 3x3x2n.
+ The Conv2D layer is 1x1x255.
+ Some block will have branch after the fifth ConvLayer.
+ The input channel is arbitrary (in_channels)
+
+ Args:
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
+ init_cfg=None):
+ super(DetectionBlock, self).__init__(init_cfg)
+ double_out_channels = out_channels * 2
+
+ # shortcut
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+ self.conv1 = ConvModule(in_channels, out_channels, 1, **cfg)
+ self.conv2 = ConvModule(
+ out_channels, double_out_channels, 3, padding=1, **cfg)
+ self.conv3 = ConvModule(double_out_channels, out_channels, 1, **cfg)
+ self.conv4 = ConvModule(
+ out_channels, double_out_channels, 3, padding=1, **cfg)
+ self.conv5 = ConvModule(double_out_channels, out_channels, 1, **cfg)
+
+ def forward(self, x):
+ tmp = self.conv1(x)
+ tmp = self.conv2(tmp)
+ tmp = self.conv3(tmp)
+ tmp = self.conv4(tmp)
+ out = self.conv5(tmp)
+ return out
+
+
+@NECKS.register_module()
+class YOLOV3Neck(BaseModule):
+ """The neck of YOLOV3.
+
+ It can be treated as a simplified version of FPN. It
+ will take the result from Darknet backbone and do some upsampling and
+ concatenation. It will finally output the detection result.
+
+ Note:
+ The input feats should be from top to bottom.
+ i.e., from high-lvl to low-lvl
+ But YOLOV3Neck will process them in reversed order.
+ i.e., from bottom (high-lvl) to top (low-lvl)
+
+ Args:
+ num_scales (int): The number of scales / stages.
+ in_channels (List[int]): The number of input channels per scale.
+ out_channels (List[int]): The number of output channels per scale.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict, optional): Dictionary to construct and config norm
+ layer. Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict, optional): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ num_scales,
+ in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
+ init_cfg=None):
+ super(YOLOV3Neck, self).__init__(init_cfg)
+ assert (num_scales == len(in_channels) == len(out_channels))
+ self.num_scales = num_scales
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ # shortcut
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+
+ # To support arbitrary scales, the code looks awful, but it works.
+ # Better solution is welcomed.
+ self.detect1 = DetectionBlock(in_channels[0], out_channels[0], **cfg)
+ for i in range(1, self.num_scales):
+ in_c, out_c = self.in_channels[i], self.out_channels[i]
+ inter_c = out_channels[i - 1]
+ self.add_module(f'conv{i}', ConvModule(inter_c, out_c, 1, **cfg))
+ # in_c + out_c : High-lvl feats will be cat with low-lvl feats
+ self.add_module(f'detect{i+1}',
+ DetectionBlock(in_c + out_c, out_c, **cfg))
+
+ def forward(self, feats):
+ assert len(feats) == self.num_scales
+
+ # processed from bottom (high-lvl) to top (low-lvl)
+ outs = []
+ out = self.detect1(feats[-1])
+ outs.append(out)
+
+ for i, x in enumerate(reversed(feats[:-1])):
+ conv = getattr(self, f'conv{i+1}')
+ tmp = conv(out)
+
+ # Cat with low-lvl feats
+ tmp = F.interpolate(tmp, scale_factor=2)
+ tmp = torch.cat((tmp, x), 1)
+
+ detect = getattr(self, f'detect{i+2}')
+ out = detect(tmp)
+ outs.append(out)
+
+ return tuple(outs)
diff --git a/mmdet/models/necks/yolox_pafpn.py b/mmdet/models/necks/yolox_pafpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0f6f7068645b6ce722556fa29c7e3d349934e74
--- /dev/null
+++ b/mmdet/models/necks/yolox_pafpn.py
@@ -0,0 +1,156 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+from ..utils import CSPLayer
+
+
+@NECKS.register_module()
+class YOLOXPAFPN(BaseModule):
+ """Path Aggregation Network used in YOLOX.
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 3
+ use_depthwise (bool): Whether to depthwise separable convolution in
+ blocks. Default: False
+ upsample_cfg (dict): Config dict for interpolate layer.
+ Default: `dict(scale_factor=2, mode='nearest')`
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN')
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='Swish')
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_csp_blocks=3,
+ use_depthwise=False,
+ upsample_cfg=dict(scale_factor=2, mode='nearest'),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish'),
+ init_cfg=dict(
+ type='Kaiming',
+ layer='Conv2d',
+ a=math.sqrt(5),
+ distribution='uniform',
+ mode='fan_in',
+ nonlinearity='leaky_relu')):
+ super(YOLOXPAFPN, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
+
+ # build top-down blocks
+ self.upsample = nn.Upsample(**upsample_cfg)
+ self.reduce_layers = nn.ModuleList()
+ self.top_down_blocks = nn.ModuleList()
+ for idx in range(len(in_channels) - 1, 0, -1):
+ self.reduce_layers.append(
+ ConvModule(
+ in_channels[idx],
+ in_channels[idx - 1],
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.top_down_blocks.append(
+ CSPLayer(
+ in_channels[idx - 1] * 2,
+ in_channels[idx - 1],
+ num_blocks=num_csp_blocks,
+ add_identity=False,
+ use_depthwise=use_depthwise,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ # build bottom-up blocks
+ self.downsamples = nn.ModuleList()
+ self.bottom_up_blocks = nn.ModuleList()
+ for idx in range(len(in_channels) - 1):
+ self.downsamples.append(
+ conv(
+ in_channels[idx],
+ in_channels[idx],
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottom_up_blocks.append(
+ CSPLayer(
+ in_channels[idx] * 2,
+ in_channels[idx + 1],
+ num_blocks=num_csp_blocks,
+ add_identity=False,
+ use_depthwise=use_depthwise,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ self.out_convs = nn.ModuleList()
+ for i in range(len(in_channels)):
+ self.out_convs.append(
+ ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs (tuple[Tensor]): input features.
+
+ Returns:
+ tuple[Tensor]: YOLOXPAFPN features.
+ """
+ assert len(inputs) == len(self.in_channels)
+
+ # top-down path
+ inner_outs = [inputs[-1]]
+ for idx in range(len(self.in_channels) - 1, 0, -1):
+ feat_heigh = inner_outs[0]
+ feat_low = inputs[idx - 1]
+ feat_heigh = self.reduce_layers[len(self.in_channels) - 1 - idx](
+ feat_heigh)
+ inner_outs[0] = feat_heigh
+
+ upsample_feat = self.upsample(feat_heigh)
+
+ inner_out = self.top_down_blocks[len(self.in_channels) - 1 - idx](
+ torch.cat([upsample_feat, feat_low], 1))
+ inner_outs.insert(0, inner_out)
+
+ # bottom-up path
+ outs = [inner_outs[0]]
+ for idx in range(len(self.in_channels) - 1):
+ feat_low = outs[-1]
+ feat_height = inner_outs[idx + 1]
+ downsample_feat = self.downsamples[idx](feat_low)
+ out = self.bottom_up_blocks[idx](
+ torch.cat([downsample_feat, feat_height], 1))
+ outs.append(out)
+
+ # out convs
+ for idx, conv in enumerate(self.out_convs):
+ outs[idx] = conv(outs[idx])
+
+ return tuple(outs)
diff --git a/mmdet/models/plugins/__init__.py b/mmdet/models/plugins/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a455c07bb99b9393e68b44d747cb5710b47c56fd
--- /dev/null
+++ b/mmdet/models/plugins/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .dropblock import DropBlock
+from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder
+from .pixel_decoder import PixelDecoder, TransformerEncoderPixelDecoder
+
+__all__ = [
+ 'DropBlock', 'PixelDecoder', 'TransformerEncoderPixelDecoder',
+ 'MSDeformAttnPixelDecoder'
+]
diff --git a/mmdet/models/plugins/dropblock.py b/mmdet/models/plugins/dropblock.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb00ade7384fadd086ad900d90c255b23f17a7da
--- /dev/null
+++ b/mmdet/models/plugins/dropblock.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import PLUGIN_LAYERS
+
+eps = 1e-6
+
+
+@PLUGIN_LAYERS.register_module()
+class DropBlock(nn.Module):
+ """Randomly drop some regions of feature maps.
+
+ Please refer to the method proposed in `DropBlock
+ `_ for details.
+
+ Args:
+ drop_prob (float): The probability of dropping each block.
+ block_size (int): The size of dropped blocks.
+ warmup_iters (int): The drop probability will linearly increase
+ from `0` to `drop_prob` during the first `warmup_iters` iterations.
+ Default: 2000.
+ """
+
+ def __init__(self, drop_prob, block_size, warmup_iters=2000, **kwargs):
+ super(DropBlock, self).__init__()
+ assert block_size % 2 == 1
+ assert 0 < drop_prob <= 1
+ assert warmup_iters >= 0
+ self.drop_prob = drop_prob
+ self.block_size = block_size
+ self.warmup_iters = warmup_iters
+ self.iter_cnt = 0
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Input feature map on which some areas will be randomly
+ dropped.
+
+ Returns:
+ Tensor: The tensor after DropBlock layer.
+ """
+ if not self.training:
+ return x
+ self.iter_cnt += 1
+ N, C, H, W = list(x.shape)
+ gamma = self._compute_gamma((H, W))
+ mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1)
+ mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device))
+
+ mask = F.pad(mask, [self.block_size // 2] * 4, value=0)
+ mask = F.max_pool2d(
+ input=mask,
+ stride=(1, 1),
+ kernel_size=(self.block_size, self.block_size),
+ padding=self.block_size // 2)
+ mask = 1 - mask
+ x = x * mask * mask.numel() / (eps + mask.sum())
+ return x
+
+ def _compute_gamma(self, feat_size):
+ """Compute the value of gamma according to paper. gamma is the
+ parameter of bernoulli distribution, which controls the number of
+ features to drop.
+
+ gamma = (drop_prob * fm_area) / (drop_area * keep_area)
+
+ Args:
+ feat_size (tuple[int, int]): The height and width of feature map.
+
+ Returns:
+ float: The value of gamma.
+ """
+ gamma = (self.drop_prob * feat_size[0] * feat_size[1])
+ gamma /= ((feat_size[0] - self.block_size + 1) *
+ (feat_size[1] - self.block_size + 1))
+ gamma /= (self.block_size**2)
+ factor = (1.0 if self.iter_cnt > self.warmup_iters else self.iter_cnt /
+ self.warmup_iters)
+ return gamma * factor
+
+ def extra_repr(self):
+ return (f'drop_prob={self.drop_prob}, block_size={self.block_size}, '
+ f'warmup_iters={self.warmup_iters}')
diff --git a/mmdet/models/plugins/msdeformattn_pixel_decoder.py b/mmdet/models/plugins/msdeformattn_pixel_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d553582baefc898da4f07089ba034d21dbbfb6d7
--- /dev/null
+++ b/mmdet/models/plugins/msdeformattn_pixel_decoder.py
@@ -0,0 +1,269 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init,
+ normal_init, xavier_init)
+from mmcv.cnn.bricks.transformer import (build_positional_encoding,
+ build_transformer_layer_sequence)
+from mmcv.runner import BaseModule, ModuleList
+
+from mmdet.core.anchor import MlvlPointGenerator
+from mmdet.models.utils.transformer import MultiScaleDeformableAttention
+
+
+@PLUGIN_LAYERS.register_module()
+class MSDeformAttnPixelDecoder(BaseModule):
+ """Pixel decoder with multi-scale deformable attention.
+
+ Args:
+ in_channels (list[int] | tuple[int]): Number of channels in the
+ input feature maps.
+ strides (list[int] | tuple[int]): Output strides of feature from
+ backbone.
+ feat_channels (int): Number of channels for feature.
+ out_channels (int): Number of channels for output.
+ num_outs (int): Number of output scales.
+ norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
+ Defaults to dict(type='GN', num_groups=32).
+ act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
+ Defaults to dict(type='ReLU').
+ encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer
+ encoder. Defaults to `DetrTransformerEncoder`.
+ positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer encoder position encoding. Defaults to
+ dict(type='SinePositionalEncoding', num_feats=128,
+ normalize=True).
+ init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels=[256, 512, 1024, 2048],
+ strides=[4, 8, 16, 32],
+ feat_channels=256,
+ out_channels=256,
+ num_outs=3,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='ReLU'),
+ encoder=dict(
+ type='DetrTransformerEncoder',
+ num_layers=6,
+ transformerlayers=dict(
+ type='BaseTransformerLayer',
+ attn_cfgs=dict(
+ type='MultiScaleDeformableAttention',
+ embed_dims=256,
+ num_heads=8,
+ num_levels=3,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.0,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None),
+ feedforward_channels=1024,
+ ffn_dropout=0.0,
+ operation_order=('self_attn', 'norm', 'ffn', 'norm')),
+ init_cfg=None),
+ positional_encoding=dict(
+ type='SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.strides = strides
+ self.num_input_levels = len(in_channels)
+ self.num_encoder_levels = \
+ encoder.transformerlayers.attn_cfgs.num_levels
+ assert self.num_encoder_levels >= 1, \
+ 'num_levels in attn_cfgs must be at least one'
+ input_conv_list = []
+ # from top to down (low to high resolution)
+ for i in range(self.num_input_levels - 1,
+ self.num_input_levels - self.num_encoder_levels - 1,
+ -1):
+ input_conv = ConvModule(
+ in_channels[i],
+ feat_channels,
+ kernel_size=1,
+ norm_cfg=norm_cfg,
+ act_cfg=None,
+ bias=True)
+ input_conv_list.append(input_conv)
+ self.input_convs = ModuleList(input_conv_list)
+
+ self.encoder = build_transformer_layer_sequence(encoder)
+ self.postional_encoding = build_positional_encoding(
+ positional_encoding)
+ # high resolution to low resolution
+ self.level_encoding = nn.Embedding(self.num_encoder_levels,
+ feat_channels)
+
+ # fpn-like structure
+ self.lateral_convs = ModuleList()
+ self.output_convs = ModuleList()
+ self.use_bias = norm_cfg is None
+ # from top to down (low to high resolution)
+ # fpn for the rest features that didn't pass in encoder
+ for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
+ -1):
+ lateral_conv = ConvModule(
+ in_channels[i],
+ feat_channels,
+ kernel_size=1,
+ bias=self.use_bias,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ output_conv = ConvModule(
+ feat_channels,
+ feat_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=self.use_bias,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.lateral_convs.append(lateral_conv)
+ self.output_convs.append(output_conv)
+
+ self.mask_feature = Conv2d(
+ feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ self.num_outs = num_outs
+ self.point_generator = MlvlPointGenerator(strides)
+
+ def init_weights(self):
+ """Initialize weights."""
+ for i in range(0, self.num_encoder_levels):
+ xavier_init(
+ self.input_convs[i].conv,
+ gain=1,
+ bias=0,
+ distribution='uniform')
+
+ for i in range(0, self.num_input_levels - self.num_encoder_levels):
+ caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
+ caffe2_xavier_init(self.output_convs[i].conv, bias=0)
+
+ caffe2_xavier_init(self.mask_feature, bias=0)
+
+ normal_init(self.level_encoding, mean=0, std=1)
+ for p in self.encoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_normal_(p)
+
+ # init_weights defined in MultiScaleDeformableAttention
+ for layer in self.encoder.layers:
+ for attn in layer.attentions:
+ if isinstance(attn, MultiScaleDeformableAttention):
+ attn.init_weights()
+
+ def forward(self, feats):
+ """
+ Args:
+ feats (list[Tensor]): Feature maps of each level. Each has
+ shape of (batch_size, c, h, w).
+
+ Returns:
+ tuple: A tuple containing the following:
+
+ - mask_feature (Tensor): shape (batch_size, c, h, w).
+ - multi_scale_features (list[Tensor]): Multi scale \
+ features, each in shape (batch_size, c, h, w).
+ """
+ # generate padding mask for each level, for each image
+ batch_size = feats[0].shape[0]
+ encoder_input_list = []
+ padding_mask_list = []
+ level_positional_encoding_list = []
+ spatial_shapes = []
+ reference_points_list = []
+ for i in range(self.num_encoder_levels):
+ level_idx = self.num_input_levels - i - 1
+ feat = feats[level_idx]
+ feat_projected = self.input_convs[i](feat)
+ h, w = feat.shape[-2:]
+
+ # no padding
+ padding_mask_resized = feat.new_zeros(
+ (batch_size, ) + feat.shape[-2:], dtype=torch.bool)
+ pos_embed = self.postional_encoding(padding_mask_resized)
+ level_embed = self.level_encoding.weight[i]
+ level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
+ # (h_i * w_i, 2)
+ reference_points = self.point_generator.single_level_grid_priors(
+ feat.shape[-2:], level_idx, device=feat.device)
+ # normalize
+ factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
+ reference_points = reference_points / factor
+
+ # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
+ feat_projected = feat_projected.flatten(2).permute(2, 0, 1)
+ level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)
+ padding_mask_resized = padding_mask_resized.flatten(1)
+
+ encoder_input_list.append(feat_projected)
+ padding_mask_list.append(padding_mask_resized)
+ level_positional_encoding_list.append(level_pos_embed)
+ spatial_shapes.append(feat.shape[-2:])
+ reference_points_list.append(reference_points)
+ # shape (batch_size, total_num_query),
+ # total_num_query=sum([., h_i * w_i,.])
+ padding_masks = torch.cat(padding_mask_list, dim=1)
+ # shape (total_num_query, batch_size, c)
+ encoder_inputs = torch.cat(encoder_input_list, dim=0)
+ level_positional_encodings = torch.cat(
+ level_positional_encoding_list, dim=0)
+ device = encoder_inputs.device
+ # shape (num_encoder_levels, 2), from low
+ # resolution to high resolution
+ spatial_shapes = torch.as_tensor(
+ spatial_shapes, dtype=torch.long, device=device)
+ # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
+ level_start_index = torch.cat((spatial_shapes.new_zeros(
+ (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ reference_points = torch.cat(reference_points_list, dim=0)
+ reference_points = reference_points[None, :, None].repeat(
+ batch_size, 1, self.num_encoder_levels, 1)
+ valid_radios = reference_points.new_ones(
+ (batch_size, self.num_encoder_levels, 2))
+ # shape (num_total_query, batch_size, c)
+ memory = self.encoder(
+ query=encoder_inputs,
+ key=None,
+ value=None,
+ query_pos=level_positional_encodings,
+ key_pos=None,
+ attn_masks=None,
+ key_padding_mask=None,
+ query_key_padding_mask=padding_masks,
+ spatial_shapes=spatial_shapes,
+ reference_points=reference_points,
+ level_start_index=level_start_index,
+ valid_radios=valid_radios)
+ # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)
+ memory = memory.permute(1, 2, 0)
+
+ # from low resolution to high resolution
+ num_query_per_level = [e[0] * e[1] for e in spatial_shapes]
+ outs = torch.split(memory, num_query_per_level, dim=-1)
+ outs = [
+ x.reshape(batch_size, -1, spatial_shapes[i][0],
+ spatial_shapes[i][1]) for i, x in enumerate(outs)
+ ]
+
+ for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
+ -1):
+ x = feats[i]
+ cur_feat = self.lateral_convs[i](x)
+ y = cur_feat + F.interpolate(
+ outs[-1],
+ size=cur_feat.shape[-2:],
+ mode='bilinear',
+ align_corners=False)
+ y = self.output_convs[i](y)
+ outs.append(y)
+ multi_scale_features = outs[:self.num_outs]
+
+ mask_feature = self.mask_feature(outs[-1])
+ return mask_feature, multi_scale_features
diff --git a/mmdet/models/plugins/pixel_decoder.py b/mmdet/models/plugins/pixel_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..537a187dc5c53279afff377c548e224ac092de69
--- /dev/null
+++ b/mmdet/models/plugins/pixel_decoder.py
@@ -0,0 +1,243 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init
+from mmcv.cnn.bricks.transformer import (build_positional_encoding,
+ build_transformer_layer_sequence)
+from mmcv.runner import BaseModule, ModuleList
+
+
+@PLUGIN_LAYERS.register_module()
+class PixelDecoder(BaseModule):
+ """Pixel decoder with a structure like fpn.
+
+ Args:
+ in_channels (list[int] | tuple[int]): Number of channels in the
+ input feature maps.
+ feat_channels (int): Number channels for feature.
+ out_channels (int): Number channels for output.
+ norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
+ Defaults to dict(type='GN', num_groups=32).
+ act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
+ Defaults to dict(type='ReLU').
+ encoder (:obj:`mmcv.ConfigDict` | dict): Config for transorformer
+ encoder.Defaults to None.
+ positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer encoder position encoding. Defaults to
+ dict(type='SinePositionalEncoding', num_feats=128,
+ normalize=True).
+ init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ feat_channels,
+ out_channels,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='ReLU'),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.num_inputs = len(in_channels)
+ self.lateral_convs = ModuleList()
+ self.output_convs = ModuleList()
+ self.use_bias = norm_cfg is None
+ for i in range(0, self.num_inputs - 1):
+ lateral_conv = ConvModule(
+ in_channels[i],
+ feat_channels,
+ kernel_size=1,
+ bias=self.use_bias,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ output_conv = ConvModule(
+ feat_channels,
+ feat_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=self.use_bias,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.lateral_convs.append(lateral_conv)
+ self.output_convs.append(output_conv)
+
+ self.last_feat_conv = ConvModule(
+ in_channels[-1],
+ feat_channels,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ bias=self.use_bias,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.mask_feature = Conv2d(
+ feat_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ def init_weights(self):
+ """Initialize weights."""
+ for i in range(0, self.num_inputs - 2):
+ caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
+ caffe2_xavier_init(self.output_convs[i].conv, bias=0)
+
+ caffe2_xavier_init(self.mask_feature, bias=0)
+ caffe2_xavier_init(self.last_feat_conv, bias=0)
+
+ def forward(self, feats, img_metas):
+ """
+ Args:
+ feats (list[Tensor]): Feature maps of each level. Each has
+ shape of (batch_size, c, h, w).
+ img_metas (list[dict]): List of image information. Pass in
+ for creating more accurate padding mask. Not used here.
+
+ Returns:
+ tuple: a tuple containing the following:
+ - mask_feature (Tensor): Shape (batch_size, c, h, w).
+ - memory (Tensor): Output of last stage of backbone.\
+ Shape (batch_size, c, h, w).
+ """
+ y = self.last_feat_conv(feats[-1])
+ for i in range(self.num_inputs - 2, -1, -1):
+ x = feats[i]
+ cur_feat = self.lateral_convs[i](x)
+ y = cur_feat + \
+ F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest')
+ y = self.output_convs[i](y)
+
+ mask_feature = self.mask_feature(y)
+ memory = feats[-1]
+ return mask_feature, memory
+
+
+@PLUGIN_LAYERS.register_module()
+class TransformerEncoderPixelDecoder(PixelDecoder):
+ """Pixel decoder with transormer encoder inside.
+
+ Args:
+ in_channels (list[int] | tuple[int]): Number of channels in the
+ input feature maps.
+ feat_channels (int): Number channels for feature.
+ out_channels (int): Number channels for output.
+ norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
+ Defaults to dict(type='GN', num_groups=32).
+ act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
+ Defaults to dict(type='ReLU').
+ encoder (:obj:`mmcv.ConfigDict` | dict): Config for transorformer
+ encoder.Defaults to None.
+ positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
+ transformer encoder position encoding. Defaults to
+ dict(type='SinePositionalEncoding', num_feats=128,
+ normalize=True).
+ init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ feat_channels,
+ out_channels,
+ norm_cfg=dict(type='GN', num_groups=32),
+ act_cfg=dict(type='ReLU'),
+ encoder=None,
+ positional_encoding=dict(
+ type='SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ init_cfg=None):
+ super(TransformerEncoderPixelDecoder, self).__init__(
+ in_channels,
+ feat_channels,
+ out_channels,
+ norm_cfg,
+ act_cfg,
+ init_cfg=init_cfg)
+ self.last_feat_conv = None
+
+ self.encoder = build_transformer_layer_sequence(encoder)
+ self.encoder_embed_dims = self.encoder.embed_dims
+ assert self.encoder_embed_dims == feat_channels, 'embed_dims({}) of ' \
+ 'tranformer encoder must equal to feat_channels({})'.format(
+ feat_channels, self.encoder_embed_dims)
+ self.positional_encoding = build_positional_encoding(
+ positional_encoding)
+ self.encoder_in_proj = Conv2d(
+ in_channels[-1], feat_channels, kernel_size=1)
+ self.encoder_out_proj = ConvModule(
+ feat_channels,
+ feat_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=self.use_bias,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def init_weights(self):
+ """Initialize weights."""
+ for i in range(0, self.num_inputs - 2):
+ caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
+ caffe2_xavier_init(self.output_convs[i].conv, bias=0)
+
+ caffe2_xavier_init(self.mask_feature, bias=0)
+ caffe2_xavier_init(self.encoder_in_proj, bias=0)
+ caffe2_xavier_init(self.encoder_out_proj.conv, bias=0)
+
+ for p in self.encoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, feats, img_metas):
+ """
+ Args:
+ feats (list[Tensor]): Feature maps of each level. Each has
+ shape of (batch_size, c, h, w).
+ img_metas (list[dict]): List of image information. Pass in
+ for creating more accurate padding mask.
+
+ Returns:
+ tuple: a tuple containing the following:
+ - mask_feature (Tensor): shape (batch_size, c, h, w).
+ - memory (Tensor): shape (batch_size, c, h, w).
+ """
+ feat_last = feats[-1]
+ bs, c, h, w = feat_last.shape
+ input_img_h, input_img_w = img_metas[0]['batch_input_shape']
+ padding_mask = feat_last.new_ones((bs, input_img_h, input_img_w),
+ dtype=torch.float32)
+ for i in range(bs):
+ img_h, img_w, _ = img_metas[i]['img_shape']
+ padding_mask[i, :img_h, :img_w] = 0
+ padding_mask = F.interpolate(
+ padding_mask.unsqueeze(1),
+ size=feat_last.shape[-2:],
+ mode='nearest').to(torch.bool).squeeze(1)
+
+ pos_embed = self.positional_encoding(padding_mask)
+ feat_last = self.encoder_in_proj(feat_last)
+ # (batch_size, c, h, w) -> (num_queries, batch_size, c)
+ feat_last = feat_last.flatten(2).permute(2, 0, 1)
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ # (batch_size, h, w) -> (batch_size, h*w)
+ padding_mask = padding_mask.flatten(1)
+ memory = self.encoder(
+ query=feat_last,
+ key=None,
+ value=None,
+ query_pos=pos_embed,
+ query_key_padding_mask=padding_mask)
+ # (num_queries, batch_size, c) -> (batch_size, c, h, w)
+ memory = memory.permute(1, 2, 0).view(bs, self.encoder_embed_dims, h,
+ w)
+ y = self.encoder_out_proj(memory)
+ for i in range(self.num_inputs - 2, -1, -1):
+ x = feats[i]
+ cur_feat = self.lateral_convs[i](x)
+ y = cur_feat + \
+ F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest')
+ y = self.output_convs[i](y)
+
+ mask_feature = self.mask_feature(y)
+ return mask_feature, memory
diff --git a/mmdet/models/roi_heads/__init__.py b/mmdet/models/roi_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..baae2a0535327ae3289398ff8e2df020a55aab93
--- /dev/null
+++ b/mmdet/models/roi_heads/__init__.py
@@ -0,0 +1,37 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_roi_head import BaseRoIHead
+from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DIIHead,
+ DoubleConvFCBBoxHead, SABLHead, SCNetBBoxHead,
+ Shared2FCBBoxHead, Shared4Conv1FCBBoxHead)
+from .cascade_roi_head import CascadeRoIHead
+from .double_roi_head import DoubleHeadRoIHead
+from .dynamic_roi_head import DynamicRoIHead
+from .grid_roi_head import GridRoIHead
+from .htc_roi_head import HybridTaskCascadeRoIHead
+from .mask_heads import (CoarseMaskHead, FCNMaskHead, FeatureRelayHead,
+ FusedSemanticHead, GlobalContextHead, GridHead,
+ HTCMaskHead, MaskIoUHead, MaskPointHead,
+ SCNetMaskHead, SCNetSemanticHead)
+from .mask_scoring_roi_head import MaskScoringRoIHead
+from .pisa_roi_head import PISARoIHead
+from .point_rend_roi_head import PointRendRoIHead
+from .roi_extractors import (BaseRoIExtractor, GenericRoIExtractor,
+ SingleRoIExtractor)
+from .scnet_roi_head import SCNetRoIHead
+from .shared_heads import ResLayer
+from .sparse_roi_head import SparseRoIHead
+from .standard_roi_head import StandardRoIHead
+from .trident_roi_head import TridentRoIHead
+
+__all__ = [
+ 'BaseRoIHead', 'CascadeRoIHead', 'DoubleHeadRoIHead', 'MaskScoringRoIHead',
+ 'HybridTaskCascadeRoIHead', 'GridRoIHead', 'ResLayer', 'BBoxHead',
+ 'ConvFCBBoxHead', 'DIIHead', 'SABLHead', 'Shared2FCBBoxHead',
+ 'StandardRoIHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead',
+ 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead',
+ 'MaskIoUHead', 'BaseRoIExtractor', 'GenericRoIExtractor',
+ 'SingleRoIExtractor', 'PISARoIHead', 'PointRendRoIHead', 'MaskPointHead',
+ 'CoarseMaskHead', 'DynamicRoIHead', 'SparseRoIHead', 'TridentRoIHead',
+ 'SCNetRoIHead', 'SCNetMaskHead', 'SCNetSemanticHead', 'SCNetBBoxHead',
+ 'FeatureRelayHead', 'GlobalContextHead'
+]
diff --git a/mmdet/models/roi_heads/base_roi_head.py b/mmdet/models/roi_heads/base_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4adbdef8f2f9ffb9a75c23c45481fc9bb3de9246
--- /dev/null
+++ b/mmdet/models/roi_heads/base_roi_head.py
@@ -0,0 +1,103 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+from mmcv.runner import BaseModule
+
+from ..builder import build_shared_head
+
+
+class BaseRoIHead(BaseModule, metaclass=ABCMeta):
+ """Base class for RoIHeads."""
+
+ def __init__(self,
+ bbox_roi_extractor=None,
+ bbox_head=None,
+ mask_roi_extractor=None,
+ mask_head=None,
+ shared_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ super(BaseRoIHead, self).__init__(init_cfg)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ if shared_head is not None:
+ shared_head.pretrained = pretrained
+ self.shared_head = build_shared_head(shared_head)
+
+ if bbox_head is not None:
+ self.init_bbox_head(bbox_roi_extractor, bbox_head)
+
+ if mask_head is not None:
+ self.init_mask_head(mask_roi_extractor, mask_head)
+
+ self.init_assigner_sampler()
+
+ @property
+ def with_bbox(self):
+ """bool: whether the RoI head contains a `bbox_head`"""
+ return hasattr(self, 'bbox_head') and self.bbox_head is not None
+
+ @property
+ def with_mask(self):
+ """bool: whether the RoI head contains a `mask_head`"""
+ return hasattr(self, 'mask_head') and self.mask_head is not None
+
+ @property
+ def with_shared_head(self):
+ """bool: whether the RoI head contains a `shared_head`"""
+ return hasattr(self, 'shared_head') and self.shared_head is not None
+
+ @abstractmethod
+ def init_bbox_head(self):
+ """Initialize ``bbox_head``"""
+ pass
+
+ @abstractmethod
+ def init_mask_head(self):
+ """Initialize ``mask_head``"""
+ pass
+
+ @abstractmethod
+ def init_assigner_sampler(self):
+ """Initialize assigner and sampler."""
+ pass
+
+ @abstractmethod
+ def forward_train(self,
+ x,
+ img_meta,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ **kwargs):
+ """Forward function during training."""
+
+ async def async_simple_test(self,
+ x,
+ proposal_list,
+ img_metas,
+ proposals=None,
+ rescale=False,
+ **kwargs):
+ """Asynchronized test function."""
+ raise NotImplementedError
+
+ def simple_test(self,
+ x,
+ proposal_list,
+ img_meta,
+ proposals=None,
+ rescale=False,
+ **kwargs):
+ """Test without augmentation."""
+
+ def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
diff --git a/mmdet/models/roi_heads/bbox_heads/__init__.py b/mmdet/models/roi_heads/bbox_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1207dbeead6fedc24e6b497fb98558998a14396
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .bbox_head import BBoxHead
+from .convfc_bbox_head import (ConvFCBBoxHead, Shared2FCBBoxHead,
+ Shared4Conv1FCBBoxHead)
+from .dii_head import DIIHead
+from .double_bbox_head import DoubleConvFCBBoxHead
+from .sabl_head import SABLHead
+from .scnet_bbox_head import SCNetBBoxHead
+
+__all__ = [
+ 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
+ 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'SABLHead', 'DIIHead',
+ 'SCNetBBoxHead'
+]
diff --git a/mmdet/models/roi_heads/bbox_heads/bbox_head.py b/mmdet/models/roi_heads/bbox_heads/bbox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..461b18b7fe4a408a2c01baf213ffa6170b7acc3a
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/bbox_head.py
@@ -0,0 +1,594 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.runner import BaseModule, auto_fp16, force_fp32
+from torch.nn.modules.utils import _pair
+
+from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.losses import accuracy
+from mmdet.models.utils import build_linear_layer
+
+
+@HEADS.register_module()
+class BBoxHead(BaseModule):
+ """Simplest RoI head, with only two fc layers for classification and
+ regression respectively."""
+
+ def __init__(self,
+ with_avg_pool=False,
+ with_cls=True,
+ with_reg=True,
+ roi_feat_size=7,
+ in_channels=256,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ clip_border=True,
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ reg_class_agnostic=False,
+ reg_decoded_bbox=False,
+ reg_predictor_cfg=dict(type='Linear'),
+ cls_predictor_cfg=dict(type='Linear'),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=1.0),
+ init_cfg=None):
+ super(BBoxHead, self).__init__(init_cfg)
+ assert with_cls or with_reg
+ self.with_avg_pool = with_avg_pool
+ self.with_cls = with_cls
+ self.with_reg = with_reg
+ self.roi_feat_size = _pair(roi_feat_size)
+ self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.reg_class_agnostic = reg_class_agnostic
+ self.reg_decoded_bbox = reg_decoded_bbox
+ self.reg_predictor_cfg = reg_predictor_cfg
+ self.cls_predictor_cfg = cls_predictor_cfg
+ self.fp16_enabled = False
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+
+ in_channels = self.in_channels
+ if self.with_avg_pool:
+ self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
+ else:
+ in_channels *= self.roi_feat_area
+ if self.with_cls:
+ # need to add background class
+ if self.custom_cls_channels:
+ cls_channels = self.loss_cls.get_cls_channels(self.num_classes)
+ else:
+ cls_channels = num_classes + 1
+ self.fc_cls = build_linear_layer(
+ self.cls_predictor_cfg,
+ in_features=in_channels,
+ out_features=cls_channels)
+ if self.with_reg:
+ out_dim_reg = 4 if reg_class_agnostic else 4 * num_classes
+ self.fc_reg = build_linear_layer(
+ self.reg_predictor_cfg,
+ in_features=in_channels,
+ out_features=out_dim_reg)
+ self.debug_imgs = None
+ if init_cfg is None:
+ self.init_cfg = []
+ if self.with_cls:
+ self.init_cfg += [
+ dict(
+ type='Normal', std=0.01, override=dict(name='fc_cls'))
+ ]
+ if self.with_reg:
+ self.init_cfg += [
+ dict(
+ type='Normal', std=0.001, override=dict(name='fc_reg'))
+ ]
+
+ @property
+ def custom_cls_channels(self):
+ return getattr(self.loss_cls, 'custom_cls_channels', False)
+
+ @property
+ def custom_activation(self):
+ return getattr(self.loss_cls, 'custom_activation', False)
+
+ @property
+ def custom_accuracy(self):
+ return getattr(self.loss_cls, 'custom_accuracy', False)
+
+ @auto_fp16()
+ def forward(self, x):
+ if self.with_avg_pool:
+ if x.numel() > 0:
+ x = self.avg_pool(x)
+ x = x.view(x.size(0), -1)
+ else:
+ # avg_pool does not support empty tensor,
+ # so use torch.mean instead it
+ x = torch.mean(x, dim=(-1, -2))
+ cls_score = self.fc_cls(x) if self.with_cls else None
+ bbox_pred = self.fc_reg(x) if self.with_reg else None
+ return cls_score, bbox_pred
+
+ def _get_target_single(self, pos_bboxes, neg_bboxes, pos_gt_bboxes,
+ pos_gt_labels, cfg):
+ """Calculate the ground truth for proposals in the single image
+ according to the sampling results.
+
+ Args:
+ pos_bboxes (Tensor): Contains all the positive boxes,
+ has shape (num_pos, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ neg_bboxes (Tensor): Contains all the negative boxes,
+ has shape (num_neg, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ pos_gt_bboxes (Tensor): Contains gt_boxes for
+ all positive samples, has shape (num_pos, 4),
+ the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ pos_gt_labels (Tensor): Contains gt_labels for
+ all positive samples, has shape (num_pos, ).
+ cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
+
+ Returns:
+ Tuple[Tensor]: Ground truth for proposals
+ in a single image. Containing the following Tensors:
+
+ - labels(Tensor): Gt_labels for all proposals, has
+ shape (num_proposals,).
+ - label_weights(Tensor): Labels_weights for all
+ proposals, has shape (num_proposals,).
+ - bbox_targets(Tensor):Regression target for all
+ proposals, has shape (num_proposals, 4), the
+ last dimension 4 represents [tl_x, tl_y, br_x, br_y].
+ - bbox_weights(Tensor):Regression weights for all
+ proposals, has shape (num_proposals, 4).
+ """
+ num_pos = pos_bboxes.size(0)
+ num_neg = neg_bboxes.size(0)
+ num_samples = num_pos + num_neg
+
+ # original implementation uses new_zeros since BG are set to be 0
+ # now use empty & fill because BG cat_id = num_classes,
+ # FG cat_id = [0, num_classes-1]
+ labels = pos_bboxes.new_full((num_samples, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = pos_bboxes.new_zeros(num_samples)
+ bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
+ bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
+ if num_pos > 0:
+ labels[:num_pos] = pos_gt_labels
+ pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
+ label_weights[:num_pos] = pos_weight
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ pos_bboxes, pos_gt_bboxes)
+ else:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, both
+ # the predicted boxes and regression targets should be with
+ # absolute coordinate format.
+ pos_bbox_targets = pos_gt_bboxes
+ bbox_targets[:num_pos, :] = pos_bbox_targets
+ bbox_weights[:num_pos, :] = 1
+ if num_neg > 0:
+ label_weights[-num_neg:] = 1.0
+
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ def get_targets(self,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ rcnn_train_cfg,
+ concat=True):
+ """Calculate the ground truth for all samples in a batch according to
+ the sampling_results.
+
+ Almost the same as the implementation in bbox_head, we passed
+ additional parameters pos_inds_list and neg_inds_list to
+ `_get_target_single` function.
+
+ Args:
+ sampling_results (List[obj:SamplingResults]): Assign results of
+ all images in a batch after sampling.
+ gt_bboxes (list[Tensor]): Gt_bboxes of all images in a batch,
+ each tensor has shape (num_gt, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ gt_labels (list[Tensor]): Gt_labels of all images in a batch,
+ each tensor has shape (num_gt,).
+ rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
+ concat (bool): Whether to concatenate the results of all
+ the images in a single batch.
+
+ Returns:
+ Tuple[Tensor]: Ground truth for proposals in a single image.
+ Containing the following list of Tensors:
+
+ - labels (list[Tensor],Tensor): Gt_labels for all
+ proposals in a batch, each tensor in list has
+ shape (num_proposals,) when `concat=False`, otherwise
+ just a single tensor has shape (num_all_proposals,).
+ - label_weights (list[Tensor]): Labels_weights for
+ all proposals in a batch, each tensor in list has
+ shape (num_proposals,) when `concat=False`, otherwise
+ just a single tensor has shape (num_all_proposals,).
+ - bbox_targets (list[Tensor],Tensor): Regression target
+ for all proposals in a batch, each tensor in list
+ has shape (num_proposals, 4) when `concat=False`,
+ otherwise just a single tensor has shape
+ (num_all_proposals, 4), the last dimension 4 represents
+ [tl_x, tl_y, br_x, br_y].
+ - bbox_weights (list[tensor],Tensor): Regression weights for
+ all proposals in a batch, each tensor in list has shape
+ (num_proposals, 4) when `concat=False`, otherwise just a
+ single tensor has shape (num_all_proposals, 4).
+ """
+ pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
+ neg_bboxes_list = [res.neg_bboxes for res in sampling_results]
+ pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
+ pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
+ labels, label_weights, bbox_targets, bbox_weights = multi_apply(
+ self._get_target_single,
+ pos_bboxes_list,
+ neg_bboxes_list,
+ pos_gt_bboxes_list,
+ pos_gt_labels_list,
+ cfg=rcnn_train_cfg)
+
+ if concat:
+ labels = torch.cat(labels, 0)
+ label_weights = torch.cat(label_weights, 0)
+ bbox_targets = torch.cat(bbox_targets, 0)
+ bbox_weights = torch.cat(bbox_weights, 0)
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ @force_fp32(apply_to=('cls_score', 'bbox_pred'))
+ def loss(self,
+ cls_score,
+ bbox_pred,
+ rois,
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ reduction_override=None):
+ losses = dict()
+ if cls_score is not None:
+ avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
+ if cls_score.numel() > 0:
+ loss_cls_ = self.loss_cls(
+ cls_score,
+ labels,
+ label_weights,
+ avg_factor=avg_factor,
+ reduction_override=reduction_override)
+ if isinstance(loss_cls_, dict):
+ losses.update(loss_cls_)
+ else:
+ losses['loss_cls'] = loss_cls_
+ if self.custom_activation:
+ acc_ = self.loss_cls.get_accuracy(cls_score, labels)
+ losses.update(acc_)
+ else:
+ losses['acc'] = accuracy(cls_score, labels)
+ if bbox_pred is not None:
+ bg_class_ind = self.num_classes
+ # 0~self.num_classes-1 are FG, self.num_classes is BG
+ pos_inds = (labels >= 0) & (labels < bg_class_ind)
+ # do not perform bounding box regression for BG anymore.
+ if pos_inds.any():
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`,
+ # `GIouLoss`, `DIouLoss`) is applied directly on
+ # the decoded bounding boxes, it decodes the
+ # already encoded coordinates to absolute format.
+ bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred)
+ if self.reg_class_agnostic:
+ pos_bbox_pred = bbox_pred.view(
+ bbox_pred.size(0), 4)[pos_inds.type(torch.bool)]
+ else:
+ pos_bbox_pred = bbox_pred.view(
+ bbox_pred.size(0), -1,
+ 4)[pos_inds.type(torch.bool),
+ labels[pos_inds.type(torch.bool)]]
+ losses['loss_bbox'] = self.loss_bbox(
+ pos_bbox_pred,
+ bbox_targets[pos_inds.type(torch.bool)],
+ bbox_weights[pos_inds.type(torch.bool)],
+ avg_factor=bbox_targets.size(0),
+ reduction_override=reduction_override)
+ else:
+ losses['loss_bbox'] = bbox_pred[pos_inds].sum()
+ return losses
+
+ @force_fp32(apply_to=('cls_score', 'bbox_pred'))
+ def get_bboxes(self,
+ rois,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ rois (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
+ last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
+ cls_score (Tensor): Box scores, has shape
+ (num_boxes, num_classes + 1).
+ bbox_pred (Tensor, optional): Box energies / deltas.
+ has shape (num_boxes, num_classes * 4).
+ img_shape (Sequence[int], optional): Maximum bounds for boxes,
+ specifies (H, W, C) or (H, W).
+ scale_factor (ndarray): Scale factor of the
+ image arrange as (w_scale, h_scale, w_scale, h_scale).
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
+
+ Returns:
+ tuple[Tensor, Tensor]:
+ First tensor is `det_bboxes`, has the shape
+ (num_boxes, 5) and last
+ dimension 5 represent (tl_x, tl_y, br_x, br_y, score).
+ Second tensor is the labels with shape (num_boxes, ).
+ """
+
+ # some loss (Seesaw loss..) may have custom activation
+ if self.custom_cls_channels:
+ scores = self.loss_cls.get_activation(cls_score)
+ else:
+ scores = F.softmax(
+ cls_score, dim=-1) if cls_score is not None else None
+ # bbox_pred would be None in some detector when with_reg is False,
+ # e.g. Grid R-CNN.
+ if bbox_pred is not None:
+ bboxes = self.bbox_coder.decode(
+ rois[..., 1:], bbox_pred, max_shape=img_shape)
+ else:
+ bboxes = rois[:, 1:].clone()
+ if img_shape is not None:
+ bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
+ bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
+
+ if rescale and bboxes.size(0) > 0:
+ scale_factor = bboxes.new_tensor(scale_factor)
+ bboxes = (bboxes.view(bboxes.size(0), -1, 4) / scale_factor).view(
+ bboxes.size()[0], -1)
+
+ if cfg is None:
+ return bboxes, scores
+ else:
+ det_bboxes, det_labels = multiclass_nms(bboxes, scores,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+
+ return det_bboxes, det_labels
+
+ @force_fp32(apply_to=('bbox_preds', ))
+ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
+ """Refine bboxes during training.
+
+ Args:
+ rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
+ and bs is the sampled RoIs per image. The first column is
+ the image id and the next 4 columns are x1, y1, x2, y2.
+ labels (Tensor): Shape (n*bs, ).
+ bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
+ pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
+ is a gt bbox.
+ img_metas (list[dict]): Meta info of each image.
+
+ Returns:
+ list[Tensor]: Refined bboxes of each image in a mini-batch.
+
+ Example:
+ >>> # xdoctest: +REQUIRES(module:kwarray)
+ >>> import kwarray
+ >>> import numpy as np
+ >>> from mmdet.core.bbox.demodata import random_boxes
+ >>> self = BBoxHead(reg_class_agnostic=True)
+ >>> n_roi = 2
+ >>> n_img = 4
+ >>> scale = 512
+ >>> rng = np.random.RandomState(0)
+ >>> img_metas = [{'img_shape': (scale, scale)}
+ ... for _ in range(n_img)]
+ >>> # Create rois in the expected format
+ >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
+ >>> img_ids = torch.randint(0, n_img, (n_roi,))
+ >>> img_ids = img_ids.float()
+ >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
+ >>> # Create other args
+ >>> labels = torch.randint(0, 2, (n_roi,)).long()
+ >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
+ >>> # For each image, pretend random positive boxes are gts
+ >>> is_label_pos = (labels.numpy() > 0).astype(np.int)
+ >>> lbl_per_img = kwarray.group_items(is_label_pos,
+ ... img_ids.numpy())
+ >>> pos_per_img = [sum(lbl_per_img.get(gid, []))
+ ... for gid in range(n_img)]
+ >>> pos_is_gts = [
+ >>> torch.randint(0, 2, (npos,)).byte().sort(
+ >>> descending=True)[0]
+ >>> for npos in pos_per_img
+ >>> ]
+ >>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
+ >>> pos_is_gts, img_metas)
+ >>> print(bboxes_list)
+ """
+ img_ids = rois[:, 0].long().unique(sorted=True)
+ assert img_ids.numel() <= len(img_metas)
+
+ bboxes_list = []
+ for i in range(len(img_metas)):
+ inds = torch.nonzero(
+ rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
+ num_rois = inds.numel()
+
+ bboxes_ = rois[inds, 1:]
+ label_ = labels[inds]
+ bbox_pred_ = bbox_preds[inds]
+ img_meta_ = img_metas[i]
+ pos_is_gts_ = pos_is_gts[i]
+
+ bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
+ img_meta_)
+
+ # filter gt bboxes
+ pos_keep = 1 - pos_is_gts_
+ keep_inds = pos_is_gts_.new_ones(num_rois)
+ keep_inds[:len(pos_is_gts_)] = pos_keep
+
+ bboxes_list.append(bboxes[keep_inds.type(torch.bool)])
+
+ return bboxes_list
+
+ @force_fp32(apply_to=('bbox_pred', ))
+ def regress_by_class(self, rois, label, bbox_pred, img_meta):
+ """Regress the bbox for the predicted class. Used in Cascade R-CNN.
+
+ Args:
+ rois (Tensor): Rois from `rpn_head` or last stage
+ `bbox_head`, has shape (num_proposals, 4) or
+ (num_proposals, 5).
+ label (Tensor): Only used when `self.reg_class_agnostic`
+ is False, has shape (num_proposals, ).
+ bbox_pred (Tensor): Regression prediction of
+ current stage `bbox_head`. When `self.reg_class_agnostic`
+ is False, it has shape (n, num_classes * 4), otherwise
+ it has shape (n, 4).
+ img_meta (dict): Image meta info.
+
+ Returns:
+ Tensor: Regressed bboxes, the same shape as input rois.
+ """
+
+ assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape)
+
+ if not self.reg_class_agnostic:
+ label = label * 4
+ inds = torch.stack((label, label + 1, label + 2, label + 3), 1)
+ bbox_pred = torch.gather(bbox_pred, 1, inds)
+ assert bbox_pred.size(1) == 4
+
+ max_shape = img_meta['img_shape']
+
+ if rois.size(1) == 4:
+ new_rois = self.bbox_coder.decode(
+ rois, bbox_pred, max_shape=max_shape)
+ else:
+ bboxes = self.bbox_coder.decode(
+ rois[:, 1:], bbox_pred, max_shape=max_shape)
+ new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
+
+ return new_rois
+
+ def onnx_export(self,
+ rois,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ cfg=None,
+ **kwargs):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ rois (Tensor): Boxes to be transformed.
+ Has shape (B, num_boxes, 5)
+ cls_score (Tensor): Box scores. has shape
+ (B, num_boxes, num_classes + 1), 1 represent the background.
+ bbox_pred (Tensor, optional): Box energies / deltas for,
+ has shape (B, num_boxes, num_classes * 4) when.
+ img_shape (torch.Tensor): Shape of image.
+ cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
+
+ Returns:
+ tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
+ and class labels of shape [N, num_det].
+ """
+
+ assert rois.ndim == 3, 'Only support export two stage ' \
+ 'model to ONNX ' \
+ 'with batch dimension. '
+ if self.custom_cls_channels:
+ scores = self.loss_cls.get_activation(cls_score)
+ else:
+ scores = F.softmax(
+ cls_score, dim=-1) if cls_score is not None else None
+
+ if bbox_pred is not None:
+ bboxes = self.bbox_coder.decode(
+ rois[..., 1:], bbox_pred, max_shape=img_shape)
+ else:
+ bboxes = rois[..., 1:].clone()
+ if img_shape is not None:
+ max_shape = bboxes.new_tensor(img_shape)[..., :2]
+ min_xy = bboxes.new_tensor(0)
+ max_xy = torch.cat(
+ [max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2)
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+ # Replace multiclass_nms with ONNX::NonMaxSuppression in deployment
+ from mmdet.core.export import add_dummy_nms_for_onnx
+ max_output_boxes_per_class = cfg.nms.get('max_output_boxes_per_class',
+ cfg.max_per_img)
+ iou_threshold = cfg.nms.get('iou_threshold', 0.5)
+ score_threshold = cfg.score_thr
+ nms_pre = cfg.get('deploy_nms_pre', -1)
+
+ scores = scores[..., :self.num_classes]
+ if self.reg_class_agnostic:
+ return add_dummy_nms_for_onnx(
+ bboxes,
+ scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ pre_top_k=nms_pre,
+ after_top_k=cfg.max_per_img)
+ else:
+ batch_size = scores.shape[0]
+ labels = torch.arange(
+ self.num_classes, dtype=torch.long).to(scores.device)
+ labels = labels.view(1, 1, -1).expand_as(scores)
+ labels = labels.reshape(batch_size, -1)
+ scores = scores.reshape(batch_size, -1)
+ bboxes = bboxes.reshape(batch_size, -1, 4)
+
+ max_size = torch.max(img_shape)
+ # Offset bboxes of each class so that bboxes of different labels
+ # do not overlap.
+ offsets = (labels * max_size + 1).unsqueeze(2)
+ bboxes_for_nms = bboxes + offsets
+
+ batch_dets, labels = add_dummy_nms_for_onnx(
+ bboxes_for_nms,
+ scores.unsqueeze(2),
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ pre_top_k=nms_pre,
+ after_top_k=cfg.max_per_img,
+ labels=labels)
+ # Offset the bboxes back after dummy nms.
+ offsets = (labels * max_size + 1).unsqueeze(2)
+ # Indexing + inplace operation fails with dynamic shape in ONNX
+ # original style: batch_dets[..., :4] -= offsets
+ bboxes, scores = batch_dets[..., 0:4], batch_dets[..., 4:5]
+ bboxes -= offsets
+ batch_dets = torch.cat([bboxes, scores], dim=2)
+ return batch_dets, labels
diff --git a/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..21124b9c9f266d404a8dbbcf72630601d1376beb
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py
@@ -0,0 +1,229 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+
+from mmdet.models.builder import HEADS
+from mmdet.models.utils import build_linear_layer
+from .bbox_head import BBoxHead
+
+
+@HEADS.register_module()
+class ConvFCBBoxHead(BBoxHead):
+ r"""More general bbox head, with shared conv and fc layers and two optional
+ separated branches.
+
+ .. code-block:: none
+
+ /-> cls convs -> cls fcs -> cls
+ shared convs -> shared fcs
+ \-> reg convs -> reg fcs -> reg
+ """ # noqa: W605
+
+ def __init__(self,
+ num_shared_convs=0,
+ num_shared_fcs=0,
+ num_cls_convs=0,
+ num_cls_fcs=0,
+ num_reg_convs=0,
+ num_reg_fcs=0,
+ conv_out_channels=256,
+ fc_out_channels=1024,
+ conv_cfg=None,
+ norm_cfg=None,
+ init_cfg=None,
+ *args,
+ **kwargs):
+ super(ConvFCBBoxHead, self).__init__(
+ *args, init_cfg=init_cfg, **kwargs)
+ assert (num_shared_convs + num_shared_fcs + num_cls_convs +
+ num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
+ if num_cls_convs > 0 or num_reg_convs > 0:
+ assert num_shared_fcs == 0
+ if not self.with_cls:
+ assert num_cls_convs == 0 and num_cls_fcs == 0
+ if not self.with_reg:
+ assert num_reg_convs == 0 and num_reg_fcs == 0
+ self.num_shared_convs = num_shared_convs
+ self.num_shared_fcs = num_shared_fcs
+ self.num_cls_convs = num_cls_convs
+ self.num_cls_fcs = num_cls_fcs
+ self.num_reg_convs = num_reg_convs
+ self.num_reg_fcs = num_reg_fcs
+ self.conv_out_channels = conv_out_channels
+ self.fc_out_channels = fc_out_channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ # add shared convs and fcs
+ self.shared_convs, self.shared_fcs, last_layer_dim = \
+ self._add_conv_fc_branch(
+ self.num_shared_convs, self.num_shared_fcs, self.in_channels,
+ True)
+ self.shared_out_channels = last_layer_dim
+
+ # add cls specific branch
+ self.cls_convs, self.cls_fcs, self.cls_last_dim = \
+ self._add_conv_fc_branch(
+ self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)
+
+ # add reg specific branch
+ self.reg_convs, self.reg_fcs, self.reg_last_dim = \
+ self._add_conv_fc_branch(
+ self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)
+
+ if self.num_shared_fcs == 0 and not self.with_avg_pool:
+ if self.num_cls_fcs == 0:
+ self.cls_last_dim *= self.roi_feat_area
+ if self.num_reg_fcs == 0:
+ self.reg_last_dim *= self.roi_feat_area
+
+ self.relu = nn.ReLU(inplace=True)
+ # reconstruct fc_cls and fc_reg since input channels are changed
+ if self.with_cls:
+ if self.custom_cls_channels:
+ cls_channels = self.loss_cls.get_cls_channels(self.num_classes)
+ else:
+ cls_channels = self.num_classes + 1
+ self.fc_cls = build_linear_layer(
+ self.cls_predictor_cfg,
+ in_features=self.cls_last_dim,
+ out_features=cls_channels)
+ if self.with_reg:
+ out_dim_reg = (4 if self.reg_class_agnostic else 4 *
+ self.num_classes)
+ self.fc_reg = build_linear_layer(
+ self.reg_predictor_cfg,
+ in_features=self.reg_last_dim,
+ out_features=out_dim_reg)
+
+ if init_cfg is None:
+ # when init_cfg is None,
+ # It has been set to
+ # [[dict(type='Normal', std=0.01, override=dict(name='fc_cls'))],
+ # [dict(type='Normal', std=0.001, override=dict(name='fc_reg'))]
+ # after `super(ConvFCBBoxHead, self).__init__()`
+ # we only need to append additional configuration
+ # for `shared_fcs`, `cls_fcs` and `reg_fcs`
+ self.init_cfg += [
+ dict(
+ type='Xavier',
+ distribution='uniform',
+ override=[
+ dict(name='shared_fcs'),
+ dict(name='cls_fcs'),
+ dict(name='reg_fcs')
+ ])
+ ]
+
+ def _add_conv_fc_branch(self,
+ num_branch_convs,
+ num_branch_fcs,
+ in_channels,
+ is_shared=False):
+ """Add shared or separable branch.
+
+ convs -> avg pool (optional) -> fcs
+ """
+ last_layer_dim = in_channels
+ # add branch specific conv layers
+ branch_convs = nn.ModuleList()
+ if num_branch_convs > 0:
+ for i in range(num_branch_convs):
+ conv_in_channels = (
+ last_layer_dim if i == 0 else self.conv_out_channels)
+ branch_convs.append(
+ ConvModule(
+ conv_in_channels,
+ self.conv_out_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ last_layer_dim = self.conv_out_channels
+ # add branch specific fc layers
+ branch_fcs = nn.ModuleList()
+ if num_branch_fcs > 0:
+ # for shared branch, only consider self.with_avg_pool
+ # for separated branches, also consider self.num_shared_fcs
+ if (is_shared
+ or self.num_shared_fcs == 0) and not self.with_avg_pool:
+ last_layer_dim *= self.roi_feat_area
+ for i in range(num_branch_fcs):
+ fc_in_channels = (
+ last_layer_dim if i == 0 else self.fc_out_channels)
+ branch_fcs.append(
+ nn.Linear(fc_in_channels, self.fc_out_channels))
+ last_layer_dim = self.fc_out_channels
+ return branch_convs, branch_fcs, last_layer_dim
+
+ def forward(self, x):
+ # shared part
+ if self.num_shared_convs > 0:
+ for conv in self.shared_convs:
+ x = conv(x)
+
+ if self.num_shared_fcs > 0:
+ if self.with_avg_pool:
+ x = self.avg_pool(x)
+
+ x = x.flatten(1)
+
+ for fc in self.shared_fcs:
+ x = self.relu(fc(x))
+ # separate branches
+ x_cls = x
+ x_reg = x
+
+ for conv in self.cls_convs:
+ x_cls = conv(x_cls)
+ if x_cls.dim() > 2:
+ if self.with_avg_pool:
+ x_cls = self.avg_pool(x_cls)
+ x_cls = x_cls.flatten(1)
+ for fc in self.cls_fcs:
+ x_cls = self.relu(fc(x_cls))
+
+ for conv in self.reg_convs:
+ x_reg = conv(x_reg)
+ if x_reg.dim() > 2:
+ if self.with_avg_pool:
+ x_reg = self.avg_pool(x_reg)
+ x_reg = x_reg.flatten(1)
+ for fc in self.reg_fcs:
+ x_reg = self.relu(fc(x_reg))
+
+ cls_score = self.fc_cls(x_cls) if self.with_cls else None
+ bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
+ return cls_score, bbox_pred
+
+
+@HEADS.register_module()
+class Shared2FCBBoxHead(ConvFCBBoxHead):
+
+ def __init__(self, fc_out_channels=1024, *args, **kwargs):
+ super(Shared2FCBBoxHead, self).__init__(
+ num_shared_convs=0,
+ num_shared_fcs=2,
+ num_cls_convs=0,
+ num_cls_fcs=0,
+ num_reg_convs=0,
+ num_reg_fcs=0,
+ fc_out_channels=fc_out_channels,
+ *args,
+ **kwargs)
+
+
+@HEADS.register_module()
+class Shared4Conv1FCBBoxHead(ConvFCBBoxHead):
+
+ def __init__(self, fc_out_channels=1024, *args, **kwargs):
+ super(Shared4Conv1FCBBoxHead, self).__init__(
+ num_shared_convs=4,
+ num_shared_fcs=1,
+ num_cls_convs=0,
+ num_cls_fcs=0,
+ num_reg_convs=0,
+ num_reg_fcs=0,
+ fc_out_channels=fc_out_channels,
+ *args,
+ **kwargs)
diff --git a/mmdet/models/roi_heads/bbox_heads/dii_head.py b/mmdet/models/roi_heads/bbox_heads/dii_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3777f52be4a9580662e6e7f5338229aedd310c7c
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/dii_head.py
@@ -0,0 +1,426 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import (bias_init_with_prob, build_activation_layer,
+ build_norm_layer)
+from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
+from mmcv.runner import auto_fp16, force_fp32
+
+from mmdet.core import multi_apply
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.dense_heads.atss_head import reduce_mean
+from mmdet.models.losses import accuracy
+from mmdet.models.utils import build_transformer
+from .bbox_head import BBoxHead
+
+
+@HEADS.register_module()
+class DIIHead(BBoxHead):
+ r"""Dynamic Instance Interactive Head for `Sparse R-CNN: End-to-End Object
+ Detection with Learnable Proposals `_
+
+ Args:
+ num_classes (int): Number of class in dataset.
+ Defaults to 80.
+ num_ffn_fcs (int): The number of fully-connected
+ layers in FFNs. Defaults to 2.
+ num_heads (int): The hidden dimension of FFNs.
+ Defaults to 8.
+ num_cls_fcs (int): The number of fully-connected
+ layers in classification subnet. Defaults to 1.
+ num_reg_fcs (int): The number of fully-connected
+ layers in regression subnet. Defaults to 3.
+ feedforward_channels (int): The hidden dimension
+ of FFNs. Defaults to 2048
+ in_channels (int): Hidden_channels of MultiheadAttention.
+ Defaults to 256.
+ dropout (float): Probability of drop the channel.
+ Defaults to 0.0
+ ffn_act_cfg (dict): The activation config for FFNs.
+ dynamic_conv_cfg (dict): The convolution config
+ for DynamicConv.
+ loss_iou (dict): The config for iou or giou loss.
+
+ """
+
+ def __init__(self,
+ num_classes=80,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_reg_fcs=3,
+ feedforward_channels=2048,
+ in_channels=256,
+ dropout=0.0,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ dynamic_conv_cfg=dict(
+ type='DynamicConv',
+ in_channels=256,
+ feat_channels=64,
+ out_channels=256,
+ input_feat_shape=7,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')),
+ loss_iou=dict(type='GIoULoss', loss_weight=2.0),
+ init_cfg=None,
+ **kwargs):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ super(DIIHead, self).__init__(
+ num_classes=num_classes,
+ reg_decoded_bbox=True,
+ reg_class_agnostic=True,
+ init_cfg=init_cfg,
+ **kwargs)
+ self.loss_iou = build_loss(loss_iou)
+ self.in_channels = in_channels
+ self.fp16_enabled = False
+ self.attention = MultiheadAttention(in_channels, num_heads, dropout)
+ self.attention_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
+
+ self.instance_interactive_conv = build_transformer(dynamic_conv_cfg)
+ self.instance_interactive_conv_dropout = nn.Dropout(dropout)
+ self.instance_interactive_conv_norm = build_norm_layer(
+ dict(type='LN'), in_channels)[1]
+
+ self.ffn = FFN(
+ in_channels,
+ feedforward_channels,
+ num_ffn_fcs,
+ act_cfg=ffn_act_cfg,
+ dropout=dropout)
+ self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
+
+ self.cls_fcs = nn.ModuleList()
+ for _ in range(num_cls_fcs):
+ self.cls_fcs.append(
+ nn.Linear(in_channels, in_channels, bias=False))
+ self.cls_fcs.append(
+ build_norm_layer(dict(type='LN'), in_channels)[1])
+ self.cls_fcs.append(
+ build_activation_layer(dict(type='ReLU', inplace=True)))
+
+ # over load the self.fc_cls in BBoxHead
+ if self.loss_cls.use_sigmoid:
+ self.fc_cls = nn.Linear(in_channels, self.num_classes)
+ else:
+ self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)
+
+ self.reg_fcs = nn.ModuleList()
+ for _ in range(num_reg_fcs):
+ self.reg_fcs.append(
+ nn.Linear(in_channels, in_channels, bias=False))
+ self.reg_fcs.append(
+ build_norm_layer(dict(type='LN'), in_channels)[1])
+ self.reg_fcs.append(
+ build_activation_layer(dict(type='ReLU', inplace=True)))
+ # over load the self.fc_cls in BBoxHead
+ self.fc_reg = nn.Linear(in_channels, 4)
+
+ assert self.reg_class_agnostic, 'DIIHead only ' \
+ 'suppport `reg_class_agnostic=True` '
+ assert self.reg_decoded_bbox, 'DIIHead only ' \
+ 'suppport `reg_decoded_bbox=True`'
+
+ def init_weights(self):
+ """Use xavier initialization for all weight parameter and set
+ classification head bias as a specific value when use focal loss."""
+ super(DIIHead, self).init_weights()
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ else:
+ # adopt the default initialization for
+ # the weight and bias of the layer norm
+ pass
+ if self.loss_cls.use_sigmoid:
+ bias_init = bias_init_with_prob(0.01)
+ nn.init.constant_(self.fc_cls.bias, bias_init)
+
+ @auto_fp16()
+ def forward(self, roi_feat, proposal_feat):
+ """Forward function of Dynamic Instance Interactive Head.
+
+ Args:
+ roi_feat (Tensor): Roi-pooling features with shape
+ (batch_size*num_proposals, feature_dimensions,
+ pooling_h , pooling_w).
+ proposal_feat (Tensor): Intermediate feature get from
+ diihead in last stage, has shape
+ (batch_size, num_proposals, feature_dimensions)
+
+ Returns:
+ tuple[Tensor]: Usually a tuple of classification scores
+ and bbox prediction and a intermediate feature.
+
+ - cls_scores (Tensor): Classification scores for
+ all proposals, has shape
+ (batch_size, num_proposals, num_classes).
+ - bbox_preds (Tensor): Box energies / deltas for
+ all proposals, has shape
+ (batch_size, num_proposals, 4).
+ - obj_feat (Tensor): Object feature before classification
+ and regression subnet, has shape
+ (batch_size, num_proposal, feature_dimensions).
+ """
+ N, num_proposals = proposal_feat.shape[:2]
+
+ # Self attention
+ proposal_feat = proposal_feat.permute(1, 0, 2)
+ proposal_feat = self.attention_norm(self.attention(proposal_feat))
+ attn_feats = proposal_feat.permute(1, 0, 2)
+
+ # instance interactive
+ proposal_feat = attn_feats.reshape(-1, self.in_channels)
+ proposal_feat_iic = self.instance_interactive_conv(
+ proposal_feat, roi_feat)
+ proposal_feat = proposal_feat + self.instance_interactive_conv_dropout(
+ proposal_feat_iic)
+ obj_feat = self.instance_interactive_conv_norm(proposal_feat)
+
+ # FFN
+ obj_feat = self.ffn_norm(self.ffn(obj_feat))
+
+ cls_feat = obj_feat
+ reg_feat = obj_feat
+
+ for cls_layer in self.cls_fcs:
+ cls_feat = cls_layer(cls_feat)
+ for reg_layer in self.reg_fcs:
+ reg_feat = reg_layer(reg_feat)
+
+ cls_score = self.fc_cls(cls_feat).view(
+ N, num_proposals, self.num_classes
+ if self.loss_cls.use_sigmoid else self.num_classes + 1)
+ bbox_delta = self.fc_reg(reg_feat).view(N, num_proposals, 4)
+
+ return cls_score, bbox_delta, obj_feat.view(
+ N, num_proposals, self.in_channels), attn_feats
+
+ @force_fp32(apply_to=('cls_score', 'bbox_pred'))
+ def loss(self,
+ cls_score,
+ bbox_pred,
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ imgs_whwh=None,
+ reduction_override=None,
+ **kwargs):
+ """"Loss function of DIIHead, get loss of all images.
+
+ Args:
+ cls_score (Tensor): Classification prediction
+ results of all class, has shape
+ (batch_size * num_proposals_single_image, num_classes)
+ bbox_pred (Tensor): Regression prediction results,
+ has shape
+ (batch_size * num_proposals_single_image, 4), the last
+ dimension 4 represents [tl_x, tl_y, br_x, br_y].
+ labels (Tensor): Label of each proposals, has shape
+ (batch_size * num_proposals_single_image
+ label_weights (Tensor): Classification loss
+ weight of each proposals, has shape
+ (batch_size * num_proposals_single_image
+ bbox_targets (Tensor): Regression targets of each
+ proposals, has shape
+ (batch_size * num_proposals_single_image, 4),
+ the last dimension 4 represents
+ [tl_x, tl_y, br_x, br_y].
+ bbox_weights (Tensor): Regression loss weight of each
+ proposals's coordinate, has shape
+ (batch_size * num_proposals_single_image, 4),
+ imgs_whwh (Tensor): imgs_whwh (Tensor): Tensor with\
+ shape (batch_size, num_proposals, 4), the last
+ dimension means
+ [img_width,img_height, img_width, img_height].
+ reduction_override (str, optional): The reduction
+ method used to override the original reduction
+ method of the loss. Options are "none",
+ "mean" and "sum". Defaults to None,
+
+ Returns:
+ dict[str, Tensor]: Dictionary of loss components
+ """
+ losses = dict()
+ bg_class_ind = self.num_classes
+ # note in spare rcnn num_gt == num_pos
+ pos_inds = (labels >= 0) & (labels < bg_class_ind)
+ num_pos = pos_inds.sum().float()
+ avg_factor = reduce_mean(num_pos)
+ if cls_score is not None:
+ if cls_score.numel() > 0:
+ losses['loss_cls'] = self.loss_cls(
+ cls_score,
+ labels,
+ label_weights,
+ avg_factor=avg_factor,
+ reduction_override=reduction_override)
+ losses['pos_acc'] = accuracy(cls_score[pos_inds],
+ labels[pos_inds])
+ if bbox_pred is not None:
+ # 0~self.num_classes-1 are FG, self.num_classes is BG
+ # do not perform bounding box regression for BG anymore.
+ if pos_inds.any():
+ pos_bbox_pred = bbox_pred.reshape(bbox_pred.size(0),
+ 4)[pos_inds.type(torch.bool)]
+ imgs_whwh = imgs_whwh.reshape(bbox_pred.size(0),
+ 4)[pos_inds.type(torch.bool)]
+ losses['loss_bbox'] = self.loss_bbox(
+ pos_bbox_pred / imgs_whwh,
+ bbox_targets[pos_inds.type(torch.bool)] / imgs_whwh,
+ bbox_weights[pos_inds.type(torch.bool)],
+ avg_factor=avg_factor)
+ losses['loss_iou'] = self.loss_iou(
+ pos_bbox_pred,
+ bbox_targets[pos_inds.type(torch.bool)],
+ bbox_weights[pos_inds.type(torch.bool)],
+ avg_factor=avg_factor)
+ else:
+ losses['loss_bbox'] = bbox_pred.sum() * 0
+ losses['loss_iou'] = bbox_pred.sum() * 0
+ return losses
+
+ def _get_target_single(self, pos_inds, neg_inds, pos_bboxes, neg_bboxes,
+ pos_gt_bboxes, pos_gt_labels, cfg):
+ """Calculate the ground truth for proposals in the single image
+ according to the sampling results.
+
+ Almost the same as the implementation in `bbox_head`,
+ we add pos_inds and neg_inds to select positive and
+ negative samples instead of selecting the first num_pos
+ as positive samples.
+
+ Args:
+ pos_inds (Tensor): The length is equal to the
+ positive sample numbers contain all index
+ of the positive sample in the origin proposal set.
+ neg_inds (Tensor): The length is equal to the
+ negative sample numbers contain all index
+ of the negative sample in the origin proposal set.
+ pos_bboxes (Tensor): Contains all the positive boxes,
+ has shape (num_pos, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ neg_bboxes (Tensor): Contains all the negative boxes,
+ has shape (num_neg, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ pos_gt_bboxes (Tensor): Contains gt_boxes for
+ all positive samples, has shape (num_pos, 4),
+ the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ pos_gt_labels (Tensor): Contains gt_labels for
+ all positive samples, has shape (num_pos, ).
+ cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
+
+ Returns:
+ Tuple[Tensor]: Ground truth for proposals in a single image.
+ Containing the following Tensors:
+
+ - labels(Tensor): Gt_labels for all proposals, has
+ shape (num_proposals,).
+ - label_weights(Tensor): Labels_weights for all proposals, has
+ shape (num_proposals,).
+ - bbox_targets(Tensor):Regression target for all proposals, has
+ shape (num_proposals, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ - bbox_weights(Tensor):Regression weights for all proposals,
+ has shape (num_proposals, 4).
+ """
+ num_pos = pos_bboxes.size(0)
+ num_neg = neg_bboxes.size(0)
+ num_samples = num_pos + num_neg
+
+ # original implementation uses new_zeros since BG are set to be 0
+ # now use empty & fill because BG cat_id = num_classes,
+ # FG cat_id = [0, num_classes-1]
+ labels = pos_bboxes.new_full((num_samples, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = pos_bboxes.new_zeros(num_samples)
+ bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
+ bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
+ if num_pos > 0:
+ labels[pos_inds] = pos_gt_labels
+ pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
+ label_weights[pos_inds] = pos_weight
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ pos_bboxes, pos_gt_bboxes)
+ else:
+ pos_bbox_targets = pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1
+ if num_neg > 0:
+ label_weights[neg_inds] = 1.0
+
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ def get_targets(self,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ rcnn_train_cfg,
+ concat=True):
+ """Calculate the ground truth for all samples in a batch according to
+ the sampling_results.
+
+ Almost the same as the implementation in bbox_head, we passed
+ additional parameters pos_inds_list and neg_inds_list to
+ `_get_target_single` function.
+
+ Args:
+ sampling_results (List[obj:SamplingResults]): Assign results of
+ all images in a batch after sampling.
+ gt_bboxes (list[Tensor]): Gt_bboxes of all images in a batch,
+ each tensor has shape (num_gt, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ gt_labels (list[Tensor]): Gt_labels of all images in a batch,
+ each tensor has shape (num_gt,).
+ rcnn_train_cfg (obj:`ConfigDict`): `train_cfg` of RCNN.
+ concat (bool): Whether to concatenate the results of all
+ the images in a single batch.
+
+ Returns:
+ Tuple[Tensor]: Ground truth for proposals in a single image.
+ Containing the following list of Tensors:
+
+ - labels (list[Tensor],Tensor): Gt_labels for all
+ proposals in a batch, each tensor in list has
+ shape (num_proposals,) when `concat=False`, otherwise just
+ a single tensor has shape (num_all_proposals,).
+ - label_weights (list[Tensor]): Labels_weights for
+ all proposals in a batch, each tensor in list has shape
+ (num_proposals,) when `concat=False`, otherwise just a
+ single tensor has shape (num_all_proposals,).
+ - bbox_targets (list[Tensor],Tensor): Regression target
+ for all proposals in a batch, each tensor in list has
+ shape (num_proposals, 4) when `concat=False`, otherwise
+ just a single tensor has shape (num_all_proposals, 4),
+ the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
+ - bbox_weights (list[tensor],Tensor): Regression weights for
+ all proposals in a batch, each tensor in list has shape
+ (num_proposals, 4) when `concat=False`, otherwise just a
+ single tensor has shape (num_all_proposals, 4).
+ """
+ pos_inds_list = [res.pos_inds for res in sampling_results]
+ neg_inds_list = [res.neg_inds for res in sampling_results]
+ pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
+ neg_bboxes_list = [res.neg_bboxes for res in sampling_results]
+ pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
+ pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
+ labels, label_weights, bbox_targets, bbox_weights = multi_apply(
+ self._get_target_single,
+ pos_inds_list,
+ neg_inds_list,
+ pos_bboxes_list,
+ neg_bboxes_list,
+ pos_gt_bboxes_list,
+ pos_gt_labels_list,
+ cfg=rcnn_train_cfg)
+ if concat:
+ labels = torch.cat(labels, 0)
+ label_weights = torch.cat(label_weights, 0)
+ bbox_targets = torch.cat(bbox_targets, 0)
+ bbox_weights = torch.cat(bbox_weights, 0)
+ return labels, label_weights, bbox_targets, bbox_weights
diff --git a/mmdet/models/roi_heads/bbox_heads/double_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/double_bbox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a38d591f8c8c44a93985762a8d7c7389f448ec1
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/double_bbox_head.py
@@ -0,0 +1,178 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule, ModuleList
+
+from mmdet.models.backbones.resnet import Bottleneck
+from mmdet.models.builder import HEADS
+from .bbox_head import BBoxHead
+
+
+class BasicResBlock(BaseModule):
+ """Basic residual block.
+
+ This block is a little different from the block in the ResNet backbone.
+ The kernel size of conv1 is 1 in this block while 3 in ResNet BasicBlock.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ out_channels (int): Channels of the output feature map.
+ conv_cfg (dict): The config dict for convolution layers.
+ norm_cfg (dict): The config dict for normalization layers.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ init_cfg=None):
+ super(BasicResBlock, self).__init__(init_cfg)
+
+ # main path
+ self.conv1 = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg)
+ self.conv2 = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ # identity path
+ self.conv_identity = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ identity = x
+
+ x = self.conv1(x)
+ x = self.conv2(x)
+
+ identity = self.conv_identity(identity)
+ out = x + identity
+
+ out = self.relu(out)
+ return out
+
+
+@HEADS.register_module()
+class DoubleConvFCBBoxHead(BBoxHead):
+ r"""Bbox head used in Double-Head R-CNN
+
+ .. code-block:: none
+
+ /-> cls
+ /-> shared convs ->
+ \-> reg
+ roi features
+ /-> cls
+ \-> shared fc ->
+ \-> reg
+ """ # noqa: W605
+
+ def __init__(self,
+ num_convs=0,
+ num_fcs=0,
+ conv_out_channels=1024,
+ fc_out_channels=1024,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ init_cfg=dict(
+ type='Normal',
+ override=[
+ dict(type='Normal', name='fc_cls', std=0.01),
+ dict(type='Normal', name='fc_reg', std=0.001),
+ dict(
+ type='Xavier',
+ name='fc_branch',
+ distribution='uniform')
+ ]),
+ **kwargs):
+ kwargs.setdefault('with_avg_pool', True)
+ super(DoubleConvFCBBoxHead, self).__init__(init_cfg=init_cfg, **kwargs)
+ assert self.with_avg_pool
+ assert num_convs > 0
+ assert num_fcs > 0
+ self.num_convs = num_convs
+ self.num_fcs = num_fcs
+ self.conv_out_channels = conv_out_channels
+ self.fc_out_channels = fc_out_channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ # increase the channel of input features
+ self.res_block = BasicResBlock(self.in_channels,
+ self.conv_out_channels)
+
+ # add conv heads
+ self.conv_branch = self._add_conv_branch()
+ # add fc heads
+ self.fc_branch = self._add_fc_branch()
+
+ out_dim_reg = 4 if self.reg_class_agnostic else 4 * self.num_classes
+ self.fc_reg = nn.Linear(self.conv_out_channels, out_dim_reg)
+
+ self.fc_cls = nn.Linear(self.fc_out_channels, self.num_classes + 1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def _add_conv_branch(self):
+ """Add the fc branch which consists of a sequential of conv layers."""
+ branch_convs = ModuleList()
+ for i in range(self.num_convs):
+ branch_convs.append(
+ Bottleneck(
+ inplanes=self.conv_out_channels,
+ planes=self.conv_out_channels // 4,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ return branch_convs
+
+ def _add_fc_branch(self):
+ """Add the fc branch which consists of a sequential of fc layers."""
+ branch_fcs = ModuleList()
+ for i in range(self.num_fcs):
+ fc_in_channels = (
+ self.in_channels *
+ self.roi_feat_area if i == 0 else self.fc_out_channels)
+ branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels))
+ return branch_fcs
+
+ def forward(self, x_cls, x_reg):
+ # conv head
+ x_conv = self.res_block(x_reg)
+
+ for conv in self.conv_branch:
+ x_conv = conv(x_conv)
+
+ if self.with_avg_pool:
+ x_conv = self.avg_pool(x_conv)
+
+ x_conv = x_conv.view(x_conv.size(0), -1)
+ bbox_pred = self.fc_reg(x_conv)
+
+ # fc head
+ x_fc = x_cls.view(x_cls.size(0), -1)
+ for fc in self.fc_branch:
+ x_fc = self.relu(fc(x_fc))
+
+ cls_score = self.fc_cls(x_fc)
+
+ return cls_score, bbox_pred
diff --git a/mmdet/models/roi_heads/bbox_heads/sabl_head.py b/mmdet/models/roi_heads/bbox_heads/sabl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ce986b9a29ed2264e48ac4df89b407dfc66eeca
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/sabl_head.py
@@ -0,0 +1,596 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule, force_fp32
+
+from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.losses import accuracy
+
+
+@HEADS.register_module()
+class SABLHead(BaseModule):
+ """Side-Aware Boundary Localization (SABL) for RoI-Head.
+
+ Side-Aware features are extracted by conv layers
+ with an attention mechanism.
+ Boundary Localization with Bucketing and Bucketing Guided Rescoring
+ are implemented in BucketingBBoxCoder.
+
+ Please refer to https://arxiv.org/abs/1912.04260 for more details.
+
+ Args:
+ cls_in_channels (int): Input channels of cls RoI feature. \
+ Defaults to 256.
+ reg_in_channels (int): Input channels of reg RoI feature. \
+ Defaults to 256.
+ roi_feat_size (int): Size of RoI features. Defaults to 7.
+ reg_feat_up_ratio (int): Upsample ratio of reg features. \
+ Defaults to 2.
+ reg_pre_kernel (int): Kernel of 2D conv layers before \
+ attention pooling. Defaults to 3.
+ reg_post_kernel (int): Kernel of 1D conv layers after \
+ attention pooling. Defaults to 3.
+ reg_pre_num (int): Number of pre convs. Defaults to 2.
+ reg_post_num (int): Number of post convs. Defaults to 1.
+ num_classes (int): Number of classes in dataset. Defaults to 80.
+ cls_out_channels (int): Hidden channels in cls fcs. Defaults to 1024.
+ reg_offset_out_channels (int): Hidden and output channel \
+ of reg offset branch. Defaults to 256.
+ reg_cls_out_channels (int): Hidden and output channel \
+ of reg cls branch. Defaults to 256.
+ num_cls_fcs (int): Number of fcs for cls branch. Defaults to 1.
+ num_reg_fcs (int): Number of fcs for reg branch.. Defaults to 0.
+ reg_class_agnostic (bool): Class agnostic regression or not. \
+ Defaults to True.
+ norm_cfg (dict): Config of norm layers. Defaults to None.
+ bbox_coder (dict): Config of bbox coder. Defaults 'BucketingBBoxCoder'.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox_cls (dict): Config of classification loss for bbox branch.
+ loss_bbox_reg (dict): Config of regression loss for bbox branch.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ num_classes,
+ cls_in_channels=256,
+ reg_in_channels=256,
+ roi_feat_size=7,
+ reg_feat_up_ratio=2,
+ reg_pre_kernel=3,
+ reg_post_kernel=3,
+ reg_pre_num=2,
+ reg_post_num=1,
+ cls_out_channels=1024,
+ reg_offset_out_channels=256,
+ reg_cls_out_channels=256,
+ num_cls_fcs=1,
+ num_reg_fcs=0,
+ reg_class_agnostic=True,
+ norm_cfg=None,
+ bbox_coder=dict(
+ type='BucketingBBoxCoder',
+ num_buckets=14,
+ scale_factor=1.7),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ loss_bbox_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_bbox_reg=dict(
+ type='SmoothL1Loss', beta=0.1, loss_weight=1.0),
+ init_cfg=None):
+ super(SABLHead, self).__init__(init_cfg)
+ self.cls_in_channels = cls_in_channels
+ self.reg_in_channels = reg_in_channels
+ self.roi_feat_size = roi_feat_size
+ self.reg_feat_up_ratio = int(reg_feat_up_ratio)
+ self.num_buckets = bbox_coder['num_buckets']
+ assert self.reg_feat_up_ratio // 2 >= 1
+ self.up_reg_feat_size = roi_feat_size * self.reg_feat_up_ratio
+ assert self.up_reg_feat_size == bbox_coder['num_buckets']
+ self.reg_pre_kernel = reg_pre_kernel
+ self.reg_post_kernel = reg_post_kernel
+ self.reg_pre_num = reg_pre_num
+ self.reg_post_num = reg_post_num
+ self.num_classes = num_classes
+ self.cls_out_channels = cls_out_channels
+ self.reg_offset_out_channels = reg_offset_out_channels
+ self.reg_cls_out_channels = reg_cls_out_channels
+ self.num_cls_fcs = num_cls_fcs
+ self.num_reg_fcs = num_reg_fcs
+ self.reg_class_agnostic = reg_class_agnostic
+ assert self.reg_class_agnostic
+ self.norm_cfg = norm_cfg
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox_cls = build_loss(loss_bbox_cls)
+ self.loss_bbox_reg = build_loss(loss_bbox_reg)
+
+ self.cls_fcs = self._add_fc_branch(self.num_cls_fcs,
+ self.cls_in_channels,
+ self.roi_feat_size,
+ self.cls_out_channels)
+
+ self.side_num = int(np.ceil(self.num_buckets / 2))
+
+ if self.reg_feat_up_ratio > 1:
+ self.upsample_x = nn.ConvTranspose1d(
+ reg_in_channels,
+ reg_in_channels,
+ self.reg_feat_up_ratio,
+ stride=self.reg_feat_up_ratio)
+ self.upsample_y = nn.ConvTranspose1d(
+ reg_in_channels,
+ reg_in_channels,
+ self.reg_feat_up_ratio,
+ stride=self.reg_feat_up_ratio)
+
+ self.reg_pre_convs = nn.ModuleList()
+ for i in range(self.reg_pre_num):
+ reg_pre_conv = ConvModule(
+ reg_in_channels,
+ reg_in_channels,
+ kernel_size=reg_pre_kernel,
+ padding=reg_pre_kernel // 2,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'))
+ self.reg_pre_convs.append(reg_pre_conv)
+
+ self.reg_post_conv_xs = nn.ModuleList()
+ for i in range(self.reg_post_num):
+ reg_post_conv_x = ConvModule(
+ reg_in_channels,
+ reg_in_channels,
+ kernel_size=(1, reg_post_kernel),
+ padding=(0, reg_post_kernel // 2),
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'))
+ self.reg_post_conv_xs.append(reg_post_conv_x)
+ self.reg_post_conv_ys = nn.ModuleList()
+ for i in range(self.reg_post_num):
+ reg_post_conv_y = ConvModule(
+ reg_in_channels,
+ reg_in_channels,
+ kernel_size=(reg_post_kernel, 1),
+ padding=(reg_post_kernel // 2, 0),
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'))
+ self.reg_post_conv_ys.append(reg_post_conv_y)
+
+ self.reg_conv_att_x = nn.Conv2d(reg_in_channels, 1, 1)
+ self.reg_conv_att_y = nn.Conv2d(reg_in_channels, 1, 1)
+
+ self.fc_cls = nn.Linear(self.cls_out_channels, self.num_classes + 1)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.reg_cls_fcs = self._add_fc_branch(self.num_reg_fcs,
+ self.reg_in_channels, 1,
+ self.reg_cls_out_channels)
+ self.reg_offset_fcs = self._add_fc_branch(self.num_reg_fcs,
+ self.reg_in_channels, 1,
+ self.reg_offset_out_channels)
+ self.fc_reg_cls = nn.Linear(self.reg_cls_out_channels, 1)
+ self.fc_reg_offset = nn.Linear(self.reg_offset_out_channels, 1)
+
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(
+ type='Xavier',
+ layer='Linear',
+ distribution='uniform',
+ override=[
+ dict(type='Normal', name='reg_conv_att_x', std=0.01),
+ dict(type='Normal', name='reg_conv_att_y', std=0.01),
+ dict(type='Normal', name='fc_reg_cls', std=0.01),
+ dict(type='Normal', name='fc_cls', std=0.01),
+ dict(type='Normal', name='fc_reg_offset', std=0.001)
+ ])
+ ]
+ if self.reg_feat_up_ratio > 1:
+ self.init_cfg += [
+ dict(
+ type='Kaiming',
+ distribution='normal',
+ override=[
+ dict(name='upsample_x'),
+ dict(name='upsample_y')
+ ])
+ ]
+
+ @property
+ def custom_cls_channels(self):
+ return getattr(self.loss_cls, 'custom_cls_channels', False)
+
+ @property
+ def custom_activation(self):
+ return getattr(self.loss_cls, 'custom_activation', False)
+
+ @property
+ def custom_accuracy(self):
+ return getattr(self.loss_cls, 'custom_accuracy', False)
+
+ def _add_fc_branch(self, num_branch_fcs, in_channels, roi_feat_size,
+ fc_out_channels):
+ in_channels = in_channels * roi_feat_size * roi_feat_size
+ branch_fcs = nn.ModuleList()
+ for i in range(num_branch_fcs):
+ fc_in_channels = (in_channels if i == 0 else fc_out_channels)
+ branch_fcs.append(nn.Linear(fc_in_channels, fc_out_channels))
+ return branch_fcs
+
+ def cls_forward(self, cls_x):
+ cls_x = cls_x.view(cls_x.size(0), -1)
+ for fc in self.cls_fcs:
+ cls_x = self.relu(fc(cls_x))
+ cls_score = self.fc_cls(cls_x)
+ return cls_score
+
+ def attention_pool(self, reg_x):
+ """Extract direction-specific features fx and fy with attention
+ methanism."""
+ reg_fx = reg_x
+ reg_fy = reg_x
+ reg_fx_att = self.reg_conv_att_x(reg_fx).sigmoid()
+ reg_fy_att = self.reg_conv_att_y(reg_fy).sigmoid()
+ reg_fx_att = reg_fx_att / reg_fx_att.sum(dim=2).unsqueeze(2)
+ reg_fy_att = reg_fy_att / reg_fy_att.sum(dim=3).unsqueeze(3)
+ reg_fx = (reg_fx * reg_fx_att).sum(dim=2)
+ reg_fy = (reg_fy * reg_fy_att).sum(dim=3)
+ return reg_fx, reg_fy
+
+ def side_aware_feature_extractor(self, reg_x):
+ """Refine and extract side-aware features without split them."""
+ for reg_pre_conv in self.reg_pre_convs:
+ reg_x = reg_pre_conv(reg_x)
+ reg_fx, reg_fy = self.attention_pool(reg_x)
+
+ if self.reg_post_num > 0:
+ reg_fx = reg_fx.unsqueeze(2)
+ reg_fy = reg_fy.unsqueeze(3)
+ for i in range(self.reg_post_num):
+ reg_fx = self.reg_post_conv_xs[i](reg_fx)
+ reg_fy = self.reg_post_conv_ys[i](reg_fy)
+ reg_fx = reg_fx.squeeze(2)
+ reg_fy = reg_fy.squeeze(3)
+ if self.reg_feat_up_ratio > 1:
+ reg_fx = self.relu(self.upsample_x(reg_fx))
+ reg_fy = self.relu(self.upsample_y(reg_fy))
+ reg_fx = torch.transpose(reg_fx, 1, 2)
+ reg_fy = torch.transpose(reg_fy, 1, 2)
+ return reg_fx.contiguous(), reg_fy.contiguous()
+
+ def reg_pred(self, x, offset_fcs, cls_fcs):
+ """Predict bucketing estimation (cls_pred) and fine regression (offset
+ pred) with side-aware features."""
+ x_offset = x.view(-1, self.reg_in_channels)
+ x_cls = x.view(-1, self.reg_in_channels)
+
+ for fc in offset_fcs:
+ x_offset = self.relu(fc(x_offset))
+ for fc in cls_fcs:
+ x_cls = self.relu(fc(x_cls))
+ offset_pred = self.fc_reg_offset(x_offset)
+ cls_pred = self.fc_reg_cls(x_cls)
+
+ offset_pred = offset_pred.view(x.size(0), -1)
+ cls_pred = cls_pred.view(x.size(0), -1)
+
+ return offset_pred, cls_pred
+
+ def side_aware_split(self, feat):
+ """Split side-aware features aligned with orders of bucketing
+ targets."""
+ l_end = int(np.ceil(self.up_reg_feat_size / 2))
+ r_start = int(np.floor(self.up_reg_feat_size / 2))
+ feat_fl = feat[:, :l_end]
+ feat_fr = feat[:, r_start:].flip(dims=(1, ))
+ feat_fl = feat_fl.contiguous()
+ feat_fr = feat_fr.contiguous()
+ feat = torch.cat([feat_fl, feat_fr], dim=-1)
+ return feat
+
+ def bbox_pred_split(self, bbox_pred, num_proposals_per_img):
+ """Split batch bbox prediction back to each image."""
+ bucket_cls_preds, bucket_offset_preds = bbox_pred
+ bucket_cls_preds = bucket_cls_preds.split(num_proposals_per_img, 0)
+ bucket_offset_preds = bucket_offset_preds.split(
+ num_proposals_per_img, 0)
+ bbox_pred = tuple(zip(bucket_cls_preds, bucket_offset_preds))
+ return bbox_pred
+
+ def reg_forward(self, reg_x):
+ outs = self.side_aware_feature_extractor(reg_x)
+ edge_offset_preds = []
+ edge_cls_preds = []
+ reg_fx = outs[0]
+ reg_fy = outs[1]
+ offset_pred_x, cls_pred_x = self.reg_pred(reg_fx, self.reg_offset_fcs,
+ self.reg_cls_fcs)
+ offset_pred_y, cls_pred_y = self.reg_pred(reg_fy, self.reg_offset_fcs,
+ self.reg_cls_fcs)
+ offset_pred_x = self.side_aware_split(offset_pred_x)
+ offset_pred_y = self.side_aware_split(offset_pred_y)
+ cls_pred_x = self.side_aware_split(cls_pred_x)
+ cls_pred_y = self.side_aware_split(cls_pred_y)
+ edge_offset_preds = torch.cat([offset_pred_x, offset_pred_y], dim=-1)
+ edge_cls_preds = torch.cat([cls_pred_x, cls_pred_y], dim=-1)
+
+ return (edge_cls_preds, edge_offset_preds)
+
+ def forward(self, x):
+
+ bbox_pred = self.reg_forward(x)
+ cls_score = self.cls_forward(x)
+
+ return cls_score, bbox_pred
+
+ def get_targets(self, sampling_results, gt_bboxes, gt_labels,
+ rcnn_train_cfg):
+ pos_proposals = [res.pos_bboxes for res in sampling_results]
+ neg_proposals = [res.neg_bboxes for res in sampling_results]
+ pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
+ pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
+ cls_reg_targets = self.bucket_target(pos_proposals, neg_proposals,
+ pos_gt_bboxes, pos_gt_labels,
+ rcnn_train_cfg)
+ (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
+ bucket_offset_targets, bucket_offset_weights) = cls_reg_targets
+ return (labels, label_weights, (bucket_cls_targets,
+ bucket_offset_targets),
+ (bucket_cls_weights, bucket_offset_weights))
+
+ def bucket_target(self,
+ pos_proposals_list,
+ neg_proposals_list,
+ pos_gt_bboxes_list,
+ pos_gt_labels_list,
+ rcnn_train_cfg,
+ concat=True):
+ (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
+ bucket_offset_targets, bucket_offset_weights) = multi_apply(
+ self._bucket_target_single,
+ pos_proposals_list,
+ neg_proposals_list,
+ pos_gt_bboxes_list,
+ pos_gt_labels_list,
+ cfg=rcnn_train_cfg)
+
+ if concat:
+ labels = torch.cat(labels, 0)
+ label_weights = torch.cat(label_weights, 0)
+ bucket_cls_targets = torch.cat(bucket_cls_targets, 0)
+ bucket_cls_weights = torch.cat(bucket_cls_weights, 0)
+ bucket_offset_targets = torch.cat(bucket_offset_targets, 0)
+ bucket_offset_weights = torch.cat(bucket_offset_weights, 0)
+ return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
+ bucket_offset_targets, bucket_offset_weights)
+
+ def _bucket_target_single(self, pos_proposals, neg_proposals,
+ pos_gt_bboxes, pos_gt_labels, cfg):
+ """Compute bucketing estimation targets and fine regression targets for
+ a single image.
+
+ Args:
+ pos_proposals (Tensor): positive proposals of a single image,
+ Shape (n_pos, 4)
+ neg_proposals (Tensor): negative proposals of a single image,
+ Shape (n_neg, 4).
+ pos_gt_bboxes (Tensor): gt bboxes assigned to positive proposals
+ of a single image, Shape (n_pos, 4).
+ pos_gt_labels (Tensor): gt labels assigned to positive proposals
+ of a single image, Shape (n_pos, ).
+ cfg (dict): Config of calculating targets
+
+ Returns:
+ tuple:
+
+ - labels (Tensor): Labels in a single image. \
+ Shape (n,).
+ - label_weights (Tensor): Label weights in a single image.\
+ Shape (n,)
+ - bucket_cls_targets (Tensor): Bucket cls targets in \
+ a single image. Shape (n, num_buckets*2).
+ - bucket_cls_weights (Tensor): Bucket cls weights in \
+ a single image. Shape (n, num_buckets*2).
+ - bucket_offset_targets (Tensor): Bucket offset targets \
+ in a single image. Shape (n, num_buckets*2).
+ - bucket_offset_targets (Tensor): Bucket offset weights \
+ in a single image. Shape (n, num_buckets*2).
+ """
+ num_pos = pos_proposals.size(0)
+ num_neg = neg_proposals.size(0)
+ num_samples = num_pos + num_neg
+ labels = pos_gt_bboxes.new_full((num_samples, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = pos_proposals.new_zeros(num_samples)
+ bucket_cls_targets = pos_proposals.new_zeros(num_samples,
+ 4 * self.side_num)
+ bucket_cls_weights = pos_proposals.new_zeros(num_samples,
+ 4 * self.side_num)
+ bucket_offset_targets = pos_proposals.new_zeros(
+ num_samples, 4 * self.side_num)
+ bucket_offset_weights = pos_proposals.new_zeros(
+ num_samples, 4 * self.side_num)
+ if num_pos > 0:
+ labels[:num_pos] = pos_gt_labels
+ label_weights[:num_pos] = 1.0
+ (pos_bucket_offset_targets, pos_bucket_offset_weights,
+ pos_bucket_cls_targets,
+ pos_bucket_cls_weights) = self.bbox_coder.encode(
+ pos_proposals, pos_gt_bboxes)
+ bucket_cls_targets[:num_pos, :] = pos_bucket_cls_targets
+ bucket_cls_weights[:num_pos, :] = pos_bucket_cls_weights
+ bucket_offset_targets[:num_pos, :] = pos_bucket_offset_targets
+ bucket_offset_weights[:num_pos, :] = pos_bucket_offset_weights
+ if num_neg > 0:
+ label_weights[-num_neg:] = 1.0
+ return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
+ bucket_offset_targets, bucket_offset_weights)
+
+ def loss(self,
+ cls_score,
+ bbox_pred,
+ rois,
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ reduction_override=None):
+ losses = dict()
+ if cls_score is not None:
+ avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
+ losses['loss_cls'] = self.loss_cls(
+ cls_score,
+ labels,
+ label_weights,
+ avg_factor=avg_factor,
+ reduction_override=reduction_override)
+ losses['acc'] = accuracy(cls_score, labels)
+
+ if bbox_pred is not None:
+ bucket_cls_preds, bucket_offset_preds = bbox_pred
+ bucket_cls_targets, bucket_offset_targets = bbox_targets
+ bucket_cls_weights, bucket_offset_weights = bbox_weights
+ # edge cls
+ bucket_cls_preds = bucket_cls_preds.view(-1, self.side_num)
+ bucket_cls_targets = bucket_cls_targets.view(-1, self.side_num)
+ bucket_cls_weights = bucket_cls_weights.view(-1, self.side_num)
+ losses['loss_bbox_cls'] = self.loss_bbox_cls(
+ bucket_cls_preds,
+ bucket_cls_targets,
+ bucket_cls_weights,
+ avg_factor=bucket_cls_targets.size(0),
+ reduction_override=reduction_override)
+
+ losses['loss_bbox_reg'] = self.loss_bbox_reg(
+ bucket_offset_preds,
+ bucket_offset_targets,
+ bucket_offset_weights,
+ avg_factor=bucket_offset_targets.size(0),
+ reduction_override=reduction_override)
+
+ return losses
+
+ @force_fp32(apply_to=('cls_score', 'bbox_pred'))
+ def get_bboxes(self,
+ rois,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None):
+ if isinstance(cls_score, list):
+ cls_score = sum(cls_score) / float(len(cls_score))
+ scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
+
+ if bbox_pred is not None:
+ bboxes, confidences = self.bbox_coder.decode(
+ rois[:, 1:], bbox_pred, img_shape)
+ else:
+ bboxes = rois[:, 1:].clone()
+ confidences = None
+ if img_shape is not None:
+ bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1)
+ bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)
+
+ if rescale and bboxes.size(0) > 0:
+ if isinstance(scale_factor, float):
+ bboxes /= scale_factor
+ else:
+ bboxes /= torch.from_numpy(scale_factor).to(bboxes.device)
+
+ if cfg is None:
+ return bboxes, scores
+ else:
+ det_bboxes, det_labels = multiclass_nms(
+ bboxes,
+ scores,
+ cfg.score_thr,
+ cfg.nms,
+ cfg.max_per_img,
+ score_factors=confidences)
+
+ return det_bboxes, det_labels
+
+ @force_fp32(apply_to=('bbox_preds', ))
+ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
+ """Refine bboxes during training.
+
+ Args:
+ rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
+ and bs is the sampled RoIs per image.
+ labels (Tensor): Shape (n*bs, ).
+ bbox_preds (list[Tensor]): Shape [(n*bs, num_buckets*2), \
+ (n*bs, num_buckets*2)].
+ pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
+ is a gt bbox.
+ img_metas (list[dict]): Meta info of each image.
+
+ Returns:
+ list[Tensor]: Refined bboxes of each image in a mini-batch.
+ """
+ img_ids = rois[:, 0].long().unique(sorted=True)
+ assert img_ids.numel() == len(img_metas)
+
+ bboxes_list = []
+ for i in range(len(img_metas)):
+ inds = torch.nonzero(
+ rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
+ num_rois = inds.numel()
+
+ bboxes_ = rois[inds, 1:]
+ label_ = labels[inds]
+ edge_cls_preds, edge_offset_preds = bbox_preds
+ edge_cls_preds_ = edge_cls_preds[inds]
+ edge_offset_preds_ = edge_offset_preds[inds]
+ bbox_pred_ = [edge_cls_preds_, edge_offset_preds_]
+ img_meta_ = img_metas[i]
+ pos_is_gts_ = pos_is_gts[i]
+
+ bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
+ img_meta_)
+ # filter gt bboxes
+ pos_keep = 1 - pos_is_gts_
+ keep_inds = pos_is_gts_.new_ones(num_rois)
+ keep_inds[:len(pos_is_gts_)] = pos_keep
+
+ bboxes_list.append(bboxes[keep_inds.type(torch.bool)])
+
+ return bboxes_list
+
+ @force_fp32(apply_to=('bbox_pred', ))
+ def regress_by_class(self, rois, label, bbox_pred, img_meta):
+ """Regress the bbox for the predicted class. Used in Cascade R-CNN.
+
+ Args:
+ rois (Tensor): shape (n, 4) or (n, 5)
+ label (Tensor): shape (n, )
+ bbox_pred (list[Tensor]): shape [(n, num_buckets *2), \
+ (n, num_buckets *2)]
+ img_meta (dict): Image meta info.
+
+ Returns:
+ Tensor: Regressed bboxes, the same shape as input rois.
+ """
+ assert rois.size(1) == 4 or rois.size(1) == 5
+
+ if rois.size(1) == 4:
+ new_rois, _ = self.bbox_coder.decode(rois, bbox_pred,
+ img_meta['img_shape'])
+ else:
+ bboxes, _ = self.bbox_coder.decode(rois[:, 1:], bbox_pred,
+ img_meta['img_shape'])
+ new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
+
+ return new_rois
diff --git a/mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf39ebef2fa26f69bb56e6d08384991975ad1cc2
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py
@@ -0,0 +1,77 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmdet.models.builder import HEADS
+from .convfc_bbox_head import ConvFCBBoxHead
+
+
+@HEADS.register_module()
+class SCNetBBoxHead(ConvFCBBoxHead):
+ """BBox head for `SCNet `_.
+
+ This inherits ``ConvFCBBoxHead`` with modified forward() function, allow us
+ to get intermediate shared feature.
+ """
+
+ def _forward_shared(self, x):
+ """Forward function for shared part."""
+ if self.num_shared_convs > 0:
+ for conv in self.shared_convs:
+ x = conv(x)
+
+ if self.num_shared_fcs > 0:
+ if self.with_avg_pool:
+ x = self.avg_pool(x)
+
+ x = x.flatten(1)
+
+ for fc in self.shared_fcs:
+ x = self.relu(fc(x))
+
+ return x
+
+ def _forward_cls_reg(self, x):
+ """Forward function for classification and regression parts."""
+ x_cls = x
+ x_reg = x
+
+ for conv in self.cls_convs:
+ x_cls = conv(x_cls)
+ if x_cls.dim() > 2:
+ if self.with_avg_pool:
+ x_cls = self.avg_pool(x_cls)
+ x_cls = x_cls.flatten(1)
+ for fc in self.cls_fcs:
+ x_cls = self.relu(fc(x_cls))
+
+ for conv in self.reg_convs:
+ x_reg = conv(x_reg)
+ if x_reg.dim() > 2:
+ if self.with_avg_pool:
+ x_reg = self.avg_pool(x_reg)
+ x_reg = x_reg.flatten(1)
+ for fc in self.reg_fcs:
+ x_reg = self.relu(fc(x_reg))
+
+ cls_score = self.fc_cls(x_cls) if self.with_cls else None
+ bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
+
+ return cls_score, bbox_pred
+
+ def forward(self, x, return_shared_feat=False):
+ """Forward function.
+
+ Args:
+ x (Tensor): input features
+ return_shared_feat (bool): If True, return cls-reg-shared feature.
+
+ Return:
+ out (tuple[Tensor]): contain ``cls_score`` and ``bbox_pred``,
+ if ``return_shared_feat`` is True, append ``x_shared`` to the
+ returned tuple.
+ """
+ x_shared = self._forward_shared(x)
+ out = self._forward_cls_reg(x_shared)
+
+ if return_shared_feat:
+ out += (x_shared, )
+
+ return out
diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e17313f20724263864cb8cf068e889ed71822b59
--- /dev/null
+++ b/mmdet/models/roi_heads/cascade_roi_head.py
@@ -0,0 +1,631 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.runner import ModuleList
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
+ build_sampler, merge_aug_bboxes, merge_aug_masks,
+ multiclass_nms)
+from ..builder import HEADS, build_head, build_roi_extractor
+from .base_roi_head import BaseRoIHead
+from .test_mixins import BBoxTestMixin, MaskTestMixin
+
+
+@HEADS.register_module()
+class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
+ """Cascade roi head including one bbox head and one mask head.
+
+ https://arxiv.org/abs/1712.00726
+ """
+
+ def __init__(self,
+ num_stages,
+ stage_loss_weights,
+ bbox_roi_extractor=None,
+ bbox_head=None,
+ mask_roi_extractor=None,
+ mask_head=None,
+ shared_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ assert bbox_roi_extractor is not None
+ assert bbox_head is not None
+ assert shared_head is None, \
+ 'Shared head is not supported in Cascade RCNN anymore'
+
+ self.num_stages = num_stages
+ self.stage_loss_weights = stage_loss_weights
+ super(CascadeRoIHead, self).__init__(
+ bbox_roi_extractor=bbox_roi_extractor,
+ bbox_head=bbox_head,
+ mask_roi_extractor=mask_roi_extractor,
+ mask_head=mask_head,
+ shared_head=shared_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
+
+ def init_bbox_head(self, bbox_roi_extractor, bbox_head):
+ """Initialize box head and box roi extractor.
+
+ Args:
+ bbox_roi_extractor (dict): Config of box roi extractor.
+ bbox_head (dict): Config of box in box head.
+ """
+ self.bbox_roi_extractor = ModuleList()
+ self.bbox_head = ModuleList()
+ if not isinstance(bbox_roi_extractor, list):
+ bbox_roi_extractor = [
+ bbox_roi_extractor for _ in range(self.num_stages)
+ ]
+ if not isinstance(bbox_head, list):
+ bbox_head = [bbox_head for _ in range(self.num_stages)]
+ assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages
+ for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
+ self.bbox_roi_extractor.append(build_roi_extractor(roi_extractor))
+ self.bbox_head.append(build_head(head))
+
+ def init_mask_head(self, mask_roi_extractor, mask_head):
+ """Initialize mask head and mask roi extractor.
+
+ Args:
+ mask_roi_extractor (dict): Config of mask roi extractor.
+ mask_head (dict): Config of mask in mask head.
+ """
+ self.mask_head = nn.ModuleList()
+ if not isinstance(mask_head, list):
+ mask_head = [mask_head for _ in range(self.num_stages)]
+ assert len(mask_head) == self.num_stages
+ for head in mask_head:
+ self.mask_head.append(build_head(head))
+ if mask_roi_extractor is not None:
+ self.share_roi_extractor = False
+ self.mask_roi_extractor = ModuleList()
+ if not isinstance(mask_roi_extractor, list):
+ mask_roi_extractor = [
+ mask_roi_extractor for _ in range(self.num_stages)
+ ]
+ assert len(mask_roi_extractor) == self.num_stages
+ for roi_extractor in mask_roi_extractor:
+ self.mask_roi_extractor.append(
+ build_roi_extractor(roi_extractor))
+ else:
+ self.share_roi_extractor = True
+ self.mask_roi_extractor = self.bbox_roi_extractor
+
+ def init_assigner_sampler(self):
+ """Initialize assigner and sampler for each stage."""
+ self.bbox_assigner = []
+ self.bbox_sampler = []
+ if self.train_cfg is not None:
+ for idx, rcnn_train_cfg in enumerate(self.train_cfg):
+ self.bbox_assigner.append(
+ build_assigner(rcnn_train_cfg.assigner))
+ self.current_stage = idx
+ self.bbox_sampler.append(
+ build_sampler(rcnn_train_cfg.sampler, context=self))
+
+ def forward_dummy(self, x, proposals):
+ """Dummy forward function."""
+ # bbox head
+ outs = ()
+ rois = bbox2roi([proposals])
+ if self.with_bbox:
+ for i in range(self.num_stages):
+ bbox_results = self._bbox_forward(i, x, rois)
+ outs = outs + (bbox_results['cls_score'],
+ bbox_results['bbox_pred'])
+ # mask heads
+ if self.with_mask:
+ mask_rois = rois[:100]
+ for i in range(self.num_stages):
+ mask_results = self._mask_forward(i, x, mask_rois)
+ outs = outs + (mask_results['mask_pred'], )
+ return outs
+
+ def _bbox_forward(self, stage, x, rois):
+ """Box head forward function used in both training and testing."""
+ bbox_roi_extractor = self.bbox_roi_extractor[stage]
+ bbox_head = self.bbox_head[stage]
+ bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
+ rois)
+ # do not support caffe_c4 model anymore
+ cls_score, bbox_pred = bbox_head(bbox_feats)
+
+ bbox_results = dict(
+ cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
+ return bbox_results
+
+ def _bbox_forward_train(self, stage, x, sampling_results, gt_bboxes,
+ gt_labels, rcnn_train_cfg):
+ """Run forward function and calculate loss for box head in training."""
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(stage, x, rois)
+ bbox_targets = self.bbox_head[stage].get_targets(
+ sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg)
+ loss_bbox = self.bbox_head[stage].loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(
+ loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets)
+ return bbox_results
+
+ def _mask_forward(self, stage, x, rois):
+ """Mask head forward function used in both training and testing."""
+ mask_roi_extractor = self.mask_roi_extractor[stage]
+ mask_head = self.mask_head[stage]
+ mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
+ rois)
+ # do not support caffe_c4 model anymore
+ mask_pred = mask_head(mask_feats)
+
+ mask_results = dict(mask_pred=mask_pred)
+ return mask_results
+
+ def _mask_forward_train(self,
+ stage,
+ x,
+ sampling_results,
+ gt_masks,
+ rcnn_train_cfg,
+ bbox_feats=None):
+ """Run forward function and calculate loss for mask head in
+ training."""
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+ mask_results = self._mask_forward(stage, x, pos_rois)
+
+ mask_targets = self.mask_head[stage].get_targets(
+ sampling_results, gt_masks, rcnn_train_cfg)
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ loss_mask = self.mask_head[stage].loss(mask_results['mask_pred'],
+ mask_targets, pos_labels)
+
+ mask_results.update(loss_mask=loss_mask)
+ return mask_results
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None):
+ """
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ proposals (list[Tensors]): list of region proposals.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ losses = dict()
+ for i in range(self.num_stages):
+ self.current_stage = i
+ rcnn_train_cfg = self.train_cfg[i]
+ lw = self.stage_loss_weights[i]
+
+ # assign gts and sample proposals
+ sampling_results = []
+ if self.with_bbox or self.with_mask:
+ bbox_assigner = self.bbox_assigner[i]
+ bbox_sampler = self.bbox_sampler[i]
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+
+ for j in range(num_imgs):
+ assign_result = bbox_assigner.assign(
+ proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j],
+ gt_labels[j])
+ sampling_result = bbox_sampler.sample(
+ assign_result,
+ proposal_list[j],
+ gt_bboxes[j],
+ gt_labels[j],
+ feats=[lvl_feat[j][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+
+ # bbox head forward and loss
+ bbox_results = self._bbox_forward_train(i, x, sampling_results,
+ gt_bboxes, gt_labels,
+ rcnn_train_cfg)
+
+ for name, value in bbox_results['loss_bbox'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # mask head forward and loss
+ if self.with_mask:
+ mask_results = self._mask_forward_train(
+ i, x, sampling_results, gt_masks, rcnn_train_cfg,
+ bbox_results['bbox_feats'])
+ for name, value in mask_results['loss_mask'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # refine bboxes
+ if i < self.num_stages - 1:
+ pos_is_gts = [res.pos_is_gt for res in sampling_results]
+ # bbox_targets is a tuple
+ roi_labels = bbox_results['bbox_targets'][0]
+ with torch.no_grad():
+ cls_score = bbox_results['cls_score']
+ if self.bbox_head[i].custom_activation:
+ cls_score = self.bbox_head[i].loss_cls.get_activation(
+ cls_score)
+
+ # Empty proposal.
+ if cls_score.numel() == 0:
+ break
+
+ roi_labels = torch.where(
+ roi_labels == self.bbox_head[i].num_classes,
+ cls_score[:, :-1].argmax(1), roi_labels)
+ proposal_list = self.bbox_head[i].refine_bboxes(
+ bbox_results['rois'], roi_labels,
+ bbox_results['bbox_pred'], pos_is_gts, img_metas)
+
+ return losses
+
+ def simple_test(self, x, proposal_list, img_metas, rescale=False):
+ """Test without augmentation.
+
+ Args:
+ x (tuple[Tensor]): Features from upstream network. Each
+ has shape (batch_size, c, h, w).
+ proposal_list (list(Tensor)): Proposals from rpn head.
+ Each has shape (num_proposals, 5), last dimension
+ 5 represent (x1, y1, x2, y2, score).
+ img_metas (list[dict]): Meta information of images.
+ rescale (bool): Whether to rescale the results to
+ the original image. Default: True.
+
+ Returns:
+ list[list[np.ndarray]] or list[tuple]: When no mask branch,
+ it is bbox results of each image and classes with type
+ `list[list[np.ndarray]]`. The outer list
+ corresponds to each image. The inner list
+ corresponds to each class. When the model has mask branch,
+ it contains bbox results and mask results.
+ The outer list corresponds to each image, and first element
+ of tuple is bbox results, second element is mask results.
+ """
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ num_imgs = len(proposal_list)
+ img_shapes = tuple(meta['img_shape'] for meta in img_metas)
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ # "ms" in variable names means multi-stage
+ ms_bbox_result = {}
+ ms_segm_result = {}
+ ms_scores = []
+ rcnn_test_cfg = self.test_cfg
+
+ rois = bbox2roi(proposal_list)
+
+ if rois.shape[0] == 0:
+ # There is no proposal in the whole batch
+ bbox_results = [[
+ np.zeros((0, 5), dtype=np.float32)
+ for _ in range(self.bbox_head[-1].num_classes)
+ ]] * num_imgs
+
+ if self.with_mask:
+ mask_classes = self.mask_head[-1].num_classes
+ segm_results = [[[] for _ in range(mask_classes)]
+ for _ in range(num_imgs)]
+ results = list(zip(bbox_results, segm_results))
+ else:
+ results = bbox_results
+
+ return results
+
+ for i in range(self.num_stages):
+ bbox_results = self._bbox_forward(i, x, rois)
+
+ # split batch bbox prediction back to each image
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+ num_proposals_per_img = tuple(
+ len(proposals) for proposals in proposal_list)
+ rois = rois.split(num_proposals_per_img, 0)
+ cls_score = cls_score.split(num_proposals_per_img, 0)
+ if isinstance(bbox_pred, torch.Tensor):
+ bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
+ else:
+ bbox_pred = self.bbox_head[i].bbox_pred_split(
+ bbox_pred, num_proposals_per_img)
+ ms_scores.append(cls_score)
+
+ if i < self.num_stages - 1:
+ if self.bbox_head[i].custom_activation:
+ cls_score = [
+ self.bbox_head[i].loss_cls.get_activation(s)
+ for s in cls_score
+ ]
+ refine_rois_list = []
+ for j in range(num_imgs):
+ if rois[j].shape[0] > 0:
+ bbox_label = cls_score[j][:, :-1].argmax(dim=1)
+ refined_rois = self.bbox_head[i].regress_by_class(
+ rois[j], bbox_label, bbox_pred[j], img_metas[j])
+ refine_rois_list.append(refined_rois)
+ rois = torch.cat(refine_rois_list)
+
+ # average scores of each image by stages
+ cls_score = [
+ sum([score[i] for score in ms_scores]) / float(len(ms_scores))
+ for i in range(num_imgs)
+ ]
+
+ # apply bbox post-processing to each image individually
+ det_bboxes = []
+ det_labels = []
+ for i in range(num_imgs):
+ det_bbox, det_label = self.bbox_head[-1].get_bboxes(
+ rois[i],
+ cls_score[i],
+ bbox_pred[i],
+ img_shapes[i],
+ scale_factors[i],
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head[-1].num_classes)
+ for i in range(num_imgs)
+ ]
+ ms_bbox_result['ensemble'] = bbox_results
+
+ if self.with_mask:
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ mask_classes = self.mask_head[-1].num_classes
+ segm_results = [[[] for _ in range(mask_classes)]
+ for _ in range(num_imgs)]
+ else:
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i][:, :4]
+ for i in range(len(det_bboxes))
+ ]
+ mask_rois = bbox2roi(_bboxes)
+ num_mask_rois_per_img = tuple(
+ _bbox.size(0) for _bbox in _bboxes)
+ aug_masks = []
+ for i in range(self.num_stages):
+ mask_results = self._mask_forward(i, x, mask_rois)
+ mask_pred = mask_results['mask_pred']
+ # split batch mask prediction back to each image
+ mask_pred = mask_pred.split(num_mask_rois_per_img, 0)
+ aug_masks.append([
+ m.sigmoid().cpu().detach().numpy() for m in mask_pred
+ ])
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[]
+ for _ in range(self.mask_head[-1].num_classes)])
+ else:
+ aug_mask = [mask[i] for mask in aug_masks]
+ merged_masks = merge_aug_masks(
+ aug_mask, [[img_metas[i]]] * self.num_stages,
+ rcnn_test_cfg)
+ segm_result = self.mask_head[-1].get_seg_masks(
+ merged_masks, _bboxes[i], det_labels[i],
+ rcnn_test_cfg, ori_shapes[i], scale_factors[i],
+ rescale)
+ segm_results.append(segm_result)
+ ms_segm_result['ensemble'] = segm_results
+
+ if self.with_mask:
+ results = list(
+ zip(ms_bbox_result['ensemble'], ms_segm_result['ensemble']))
+ else:
+ results = ms_bbox_result['ensemble']
+
+ return results
+
+ def aug_test(self, features, proposal_list, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ rcnn_test_cfg = self.test_cfg
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta in zip(features, img_metas):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ # "ms" in variable names means multi-stage
+ ms_scores = []
+
+ rois = bbox2roi([proposals])
+
+ if rois.shape[0] == 0:
+ # There is no proposal in the single image
+ aug_bboxes.append(rois.new_zeros(0, 4))
+ aug_scores.append(rois.new_zeros(0, 1))
+ continue
+
+ for i in range(self.num_stages):
+ bbox_results = self._bbox_forward(i, x, rois)
+ ms_scores.append(bbox_results['cls_score'])
+
+ if i < self.num_stages - 1:
+ cls_score = bbox_results['cls_score']
+ if self.bbox_head[i].custom_activation:
+ cls_score = self.bbox_head[i].loss_cls.get_activation(
+ cls_score)
+ bbox_label = cls_score[:, :-1].argmax(dim=1)
+ rois = self.bbox_head[i].regress_by_class(
+ rois, bbox_label, bbox_results['bbox_pred'],
+ img_meta[0])
+
+ cls_score = sum(ms_scores) / float(len(ms_scores))
+ bboxes, scores = self.bbox_head[-1].get_bboxes(
+ rois,
+ cls_score,
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ aug_bboxes.append(bboxes)
+ aug_scores.append(scores)
+
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+
+ bbox_result = bbox2result(det_bboxes, det_labels,
+ self.bbox_head[-1].num_classes)
+
+ if self.with_mask:
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[]
+ for _ in range(self.mask_head[-1].num_classes)]
+ else:
+ aug_masks = []
+ aug_img_metas = []
+ for x, img_meta in zip(features, img_metas):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ mask_rois = bbox2roi([_bboxes])
+ for i in range(self.num_stages):
+ mask_results = self._mask_forward(i, x, mask_rois)
+ aug_masks.append(
+ mask_results['mask_pred'].sigmoid().cpu().numpy())
+ aug_img_metas.append(img_meta)
+ merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
+ self.test_cfg)
+
+ ori_shape = img_metas[0][0]['ori_shape']
+ dummy_scale_factor = np.ones(4)
+ segm_result = self.mask_head[-1].get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ rcnn_test_cfg,
+ ori_shape,
+ scale_factor=dummy_scale_factor,
+ rescale=False)
+ return [(bbox_result, segm_result)]
+ else:
+ return [bbox_result]
+
+ def onnx_export(self, x, proposals, img_metas):
+
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ assert proposals.shape[0] == 1, 'Only support one input image ' \
+ 'while in exporting to ONNX'
+ # remove the scores
+ rois = proposals[..., :-1]
+ batch_size = rois.shape[0]
+ num_proposals_per_img = rois.shape[1]
+ # Eliminate the batch dimension
+ rois = rois.view(-1, 4)
+
+ # add dummy batch index
+ rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
+
+ max_shape = img_metas[0]['img_shape_for_onnx']
+ ms_scores = []
+ rcnn_test_cfg = self.test_cfg
+
+ for i in range(self.num_stages):
+ bbox_results = self._bbox_forward(i, x, rois)
+
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+ # Recover the batch dimension
+ rois = rois.reshape(batch_size, num_proposals_per_img,
+ rois.size(-1))
+ cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
+ cls_score.size(-1))
+ bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
+ ms_scores.append(cls_score)
+ if i < self.num_stages - 1:
+ assert self.bbox_head[i].reg_class_agnostic
+ new_rois = self.bbox_head[i].bbox_coder.decode(
+ rois[..., 1:], bbox_pred, max_shape=max_shape)
+ rois = new_rois.reshape(-1, new_rois.shape[-1])
+ # add dummy batch index
+ rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois],
+ dim=-1)
+
+ cls_score = sum(ms_scores) / float(len(ms_scores))
+ bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
+ rois = rois.reshape(batch_size, num_proposals_per_img, -1)
+ det_bboxes, det_labels = self.bbox_head[-1].onnx_export(
+ rois, cls_score, bbox_pred, max_shape, cfg=rcnn_test_cfg)
+
+ if not self.with_mask:
+ return det_bboxes, det_labels
+ else:
+ batch_index = torch.arange(
+ det_bboxes.size(0),
+ device=det_bboxes.device).float().view(-1, 1, 1).expand(
+ det_bboxes.size(0), det_bboxes.size(1), 1)
+ rois = det_bboxes[..., :4]
+ mask_rois = torch.cat([batch_index, rois], dim=-1)
+ mask_rois = mask_rois.view(-1, 5)
+ aug_masks = []
+ for i in range(self.num_stages):
+ mask_results = self._mask_forward(i, x, mask_rois)
+ mask_pred = mask_results['mask_pred']
+ aug_masks.append(mask_pred)
+ max_shape = img_metas[0]['img_shape_for_onnx']
+ # calculate the mean of masks from several stage
+ mask_pred = sum(aug_masks) / len(aug_masks)
+ segm_results = self.mask_head[-1].onnx_export(
+ mask_pred, rois.reshape(-1, 4), det_labels.reshape(-1),
+ self.test_cfg, max_shape)
+ segm_results = segm_results.reshape(batch_size,
+ det_bboxes.shape[1],
+ max_shape[0], max_shape[1])
+ return det_bboxes, det_labels, segm_results
diff --git a/mmdet/models/roi_heads/double_roi_head.py b/mmdet/models/roi_heads/double_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..895b5d3067846e023f21482fb1628e9bdb0035fd
--- /dev/null
+++ b/mmdet/models/roi_heads/double_roi_head.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..builder import HEADS
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class DoubleHeadRoIHead(StandardRoIHead):
+ """RoI head for Double Head RCNN.
+
+ https://arxiv.org/abs/1904.06493
+ """
+
+ def __init__(self, reg_roi_scale_factor, **kwargs):
+ super(DoubleHeadRoIHead, self).__init__(**kwargs)
+ self.reg_roi_scale_factor = reg_roi_scale_factor
+
+ def _bbox_forward(self, x, rois):
+ """Box head forward function used in both training and testing time."""
+ bbox_cls_feats = self.bbox_roi_extractor(
+ x[:self.bbox_roi_extractor.num_inputs], rois)
+ bbox_reg_feats = self.bbox_roi_extractor(
+ x[:self.bbox_roi_extractor.num_inputs],
+ rois,
+ roi_scale_factor=self.reg_roi_scale_factor)
+ if self.with_shared_head:
+ bbox_cls_feats = self.shared_head(bbox_cls_feats)
+ bbox_reg_feats = self.shared_head(bbox_reg_feats)
+ cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
+
+ bbox_results = dict(
+ cls_score=cls_score,
+ bbox_pred=bbox_pred,
+ bbox_feats=bbox_cls_feats)
+ return bbox_results
diff --git a/mmdet/models/roi_heads/dynamic_roi_head.py b/mmdet/models/roi_heads/dynamic_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c2b6cdac1e38a00a810be03275f66e5257fd6fb
--- /dev/null
+++ b/mmdet/models/roi_heads/dynamic_roi_head.py
@@ -0,0 +1,155 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from mmdet.core import bbox2roi
+from mmdet.models.losses import SmoothL1Loss
+from ..builder import HEADS
+from .standard_roi_head import StandardRoIHead
+
+EPS = 1e-15
+
+
+@HEADS.register_module()
+class DynamicRoIHead(StandardRoIHead):
+ """RoI head for `Dynamic R-CNN `_."""
+
+ def __init__(self, **kwargs):
+ super(DynamicRoIHead, self).__init__(**kwargs)
+ assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss)
+ # the IoU history of the past `update_iter_interval` iterations
+ self.iou_history = []
+ # the beta history of the past `update_iter_interval` iterations
+ self.beta_history = []
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None):
+ """Forward function for training.
+
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+ proposals (list[Tensors]): list of region proposals.
+
+ gt_bboxes (list[Tensor]): each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+
+ gt_labels (list[Tensor]): class indices corresponding to each box
+
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # assign gts and sample proposals
+ if self.with_bbox or self.with_mask:
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+ sampling_results = []
+ cur_iou = []
+ for i in range(num_imgs):
+ assign_result = self.bbox_assigner.assign(
+ proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
+ gt_labels[i])
+ sampling_result = self.bbox_sampler.sample(
+ assign_result,
+ proposal_list[i],
+ gt_bboxes[i],
+ gt_labels[i],
+ feats=[lvl_feat[i][None] for lvl_feat in x])
+ # record the `iou_topk`-th largest IoU in an image
+ iou_topk = min(self.train_cfg.dynamic_rcnn.iou_topk,
+ len(assign_result.max_overlaps))
+ ious, _ = torch.topk(assign_result.max_overlaps, iou_topk)
+ cur_iou.append(ious[-1].item())
+ sampling_results.append(sampling_result)
+ # average the current IoUs over images
+ cur_iou = np.mean(cur_iou)
+ self.iou_history.append(cur_iou)
+
+ losses = dict()
+ # bbox head forward and loss
+ if self.with_bbox:
+ bbox_results = self._bbox_forward_train(x, sampling_results,
+ gt_bboxes, gt_labels,
+ img_metas)
+ losses.update(bbox_results['loss_bbox'])
+
+ # mask head forward and loss
+ if self.with_mask:
+ mask_results = self._mask_forward_train(x, sampling_results,
+ bbox_results['bbox_feats'],
+ gt_masks, img_metas)
+ losses.update(mask_results['loss_mask'])
+
+ # update IoU threshold and SmoothL1 beta
+ update_iter_interval = self.train_cfg.dynamic_rcnn.update_iter_interval
+ if len(self.iou_history) % update_iter_interval == 0:
+ new_iou_thr, new_beta = self.update_hyperparameters()
+
+ return losses
+
+ def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
+ img_metas):
+ num_imgs = len(img_metas)
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(x, rois)
+
+ bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, self.train_cfg)
+ # record the `beta_topk`-th smallest target
+ # `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets
+ # and bbox_weights, respectively
+ pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1)
+ num_pos = len(pos_inds)
+ cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
+ beta_topk = min(self.train_cfg.dynamic_rcnn.beta_topk * num_imgs,
+ num_pos)
+ cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
+ self.beta_history.append(cur_target)
+ loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(loss_bbox=loss_bbox)
+ return bbox_results
+
+ def update_hyperparameters(self):
+ """Update hyperparameters like IoU thresholds for assigner and beta for
+ SmoothL1 loss based on the training statistics.
+
+ Returns:
+ tuple[float]: the updated ``iou_thr`` and ``beta``.
+ """
+ new_iou_thr = max(self.train_cfg.dynamic_rcnn.initial_iou,
+ np.mean(self.iou_history))
+ self.iou_history = []
+ self.bbox_assigner.pos_iou_thr = new_iou_thr
+ self.bbox_assigner.neg_iou_thr = new_iou_thr
+ self.bbox_assigner.min_pos_iou = new_iou_thr
+ if (np.median(self.beta_history) < EPS):
+ # avoid 0 or too small value for new_beta
+ new_beta = self.bbox_head.loss_bbox.beta
+ else:
+ new_beta = min(self.train_cfg.dynamic_rcnn.initial_beta,
+ np.median(self.beta_history))
+ self.beta_history = []
+ self.bbox_head.loss_bbox.beta = new_beta
+ return new_iou_thr, new_beta
diff --git a/mmdet/models/roi_heads/grid_roi_head.py b/mmdet/models/roi_heads/grid_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..333f62975c693fd00d2fa4605be7cef11aa404e1
--- /dev/null
+++ b/mmdet/models/roi_heads/grid_roi_head.py
@@ -0,0 +1,170 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from mmdet.core import bbox2result, bbox2roi
+from ..builder import HEADS, build_head, build_roi_extractor
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class GridRoIHead(StandardRoIHead):
+ """Grid roi head for Grid R-CNN.
+
+ https://arxiv.org/abs/1811.12030
+ """
+
+ def __init__(self, grid_roi_extractor, grid_head, **kwargs):
+ assert grid_head is not None
+ super(GridRoIHead, self).__init__(**kwargs)
+ if grid_roi_extractor is not None:
+ self.grid_roi_extractor = build_roi_extractor(grid_roi_extractor)
+ self.share_roi_extractor = False
+ else:
+ self.share_roi_extractor = True
+ self.grid_roi_extractor = self.bbox_roi_extractor
+ self.grid_head = build_head(grid_head)
+
+ def _random_jitter(self, sampling_results, img_metas, amplitude=0.15):
+ """Ramdom jitter positive proposals for training."""
+ for sampling_result, img_meta in zip(sampling_results, img_metas):
+ bboxes = sampling_result.pos_bboxes
+ random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_(
+ -amplitude, amplitude)
+ # before jittering
+ cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2
+ wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs()
+ # after jittering
+ new_cxcy = cxcy + wh * random_offsets[:, :2]
+ new_wh = wh * (1 + random_offsets[:, 2:])
+ # xywh to xyxy
+ new_x1y1 = (new_cxcy - new_wh / 2)
+ new_x2y2 = (new_cxcy + new_wh / 2)
+ new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1)
+ # clip bboxes
+ max_shape = img_meta['img_shape']
+ if max_shape is not None:
+ new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1)
+ new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1)
+
+ sampling_result.pos_bboxes = new_bboxes
+ return sampling_results
+
+ def forward_dummy(self, x, proposals):
+ """Dummy forward function."""
+ # bbox head
+ outs = ()
+ rois = bbox2roi([proposals])
+ if self.with_bbox:
+ bbox_results = self._bbox_forward(x, rois)
+ outs = outs + (bbox_results['cls_score'],
+ bbox_results['bbox_pred'])
+
+ # grid head
+ grid_rois = rois[:100]
+ grid_feats = self.grid_roi_extractor(
+ x[:self.grid_roi_extractor.num_inputs], grid_rois)
+ if self.with_shared_head:
+ grid_feats = self.shared_head(grid_feats)
+ grid_pred = self.grid_head(grid_feats)
+ outs = outs + (grid_pred, )
+
+ # mask head
+ if self.with_mask:
+ mask_rois = rois[:100]
+ mask_results = self._mask_forward(x, mask_rois)
+ outs = outs + (mask_results['mask_pred'], )
+ return outs
+
+ def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
+ img_metas):
+ """Run forward function and calculate loss for box head in training."""
+ bbox_results = super(GridRoIHead,
+ self)._bbox_forward_train(x, sampling_results,
+ gt_bboxes, gt_labels,
+ img_metas)
+
+ # Grid head forward and loss
+ sampling_results = self._random_jitter(sampling_results, img_metas)
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+
+ # GN in head does not support zero shape input
+ if pos_rois.shape[0] == 0:
+ return bbox_results
+
+ grid_feats = self.grid_roi_extractor(
+ x[:self.grid_roi_extractor.num_inputs], pos_rois)
+ if self.with_shared_head:
+ grid_feats = self.shared_head(grid_feats)
+ # Accelerate training
+ max_sample_num_grid = self.train_cfg.get('max_num_grid', 192)
+ sample_idx = torch.randperm(
+ grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid
+ )]
+ grid_feats = grid_feats[sample_idx]
+
+ grid_pred = self.grid_head(grid_feats)
+
+ grid_targets = self.grid_head.get_targets(sampling_results,
+ self.train_cfg)
+ grid_targets = grid_targets[sample_idx]
+
+ loss_grid = self.grid_head.loss(grid_pred, grid_targets)
+
+ bbox_results['loss_bbox'].update(loss_grid)
+ return bbox_results
+
+ def simple_test(self,
+ x,
+ proposal_list,
+ img_metas,
+ proposals=None,
+ rescale=False):
+ """Test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+
+ det_bboxes, det_labels = self.simple_test_bboxes(
+ x, img_metas, proposal_list, self.test_cfg, rescale=False)
+ # pack rois into bboxes
+ grid_rois = bbox2roi([det_bbox[:, :4] for det_bbox in det_bboxes])
+ if grid_rois.shape[0] != 0:
+ grid_feats = self.grid_roi_extractor(
+ x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois)
+ self.grid_head.test_mode = True
+ grid_pred = self.grid_head(grid_feats)
+ # split batch grid head prediction back to each image
+ num_roi_per_img = tuple(len(det_bbox) for det_bbox in det_bboxes)
+ grid_pred = {
+ k: v.split(num_roi_per_img, 0)
+ for k, v in grid_pred.items()
+ }
+
+ # apply bbox post-processing to each image individually
+ bbox_results = []
+ num_imgs = len(det_bboxes)
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ bbox_results.append([
+ np.zeros((0, 5), dtype=np.float32)
+ for _ in range(self.bbox_head.num_classes)
+ ])
+ else:
+ det_bbox = self.grid_head.get_bboxes(
+ det_bboxes[i], grid_pred['fused'][i], [img_metas[i]])
+ if rescale:
+ det_bbox[:, :4] /= img_metas[i]['scale_factor']
+ bbox_results.append(
+ bbox2result(det_bbox, det_labels[i],
+ self.bbox_head.num_classes))
+ else:
+ bbox_results = [[
+ np.zeros((0, 5), dtype=np.float32)
+ for _ in range(self.bbox_head.num_classes)
+ ] for _ in range(len(det_bboxes))]
+
+ if not self.with_mask:
+ return bbox_results
+ else:
+ segm_results = self.simple_test_mask(
+ x, img_metas, det_bboxes, det_labels, rescale=rescale)
+ return list(zip(bbox_results, segm_results))
diff --git a/mmdet/models/roi_heads/htc_roi_head.py b/mmdet/models/roi_heads/htc_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..86a6db10d4ac26901fbd44941de8107e67819d42
--- /dev/null
+++ b/mmdet/models/roi_heads/htc_roi_head.py
@@ -0,0 +1,628 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
+ merge_aug_masks, multiclass_nms)
+from ..builder import HEADS, build_head, build_roi_extractor
+from ..utils.brick_wrappers import adaptive_avg_pool2d
+from .cascade_roi_head import CascadeRoIHead
+
+
+@HEADS.register_module()
+class HybridTaskCascadeRoIHead(CascadeRoIHead):
+ """Hybrid task cascade roi head including one bbox head and one mask head.
+
+ https://arxiv.org/abs/1901.07518
+ """
+
+ def __init__(self,
+ num_stages,
+ stage_loss_weights,
+ semantic_roi_extractor=None,
+ semantic_head=None,
+ semantic_fusion=('bbox', 'mask'),
+ interleaved=True,
+ mask_info_flow=True,
+ **kwargs):
+ super(HybridTaskCascadeRoIHead,
+ self).__init__(num_stages, stage_loss_weights, **kwargs)
+ assert self.with_bbox
+ assert not self.with_shared_head # shared head is not supported
+
+ if semantic_head is not None:
+ self.semantic_roi_extractor = build_roi_extractor(
+ semantic_roi_extractor)
+ self.semantic_head = build_head(semantic_head)
+
+ self.semantic_fusion = semantic_fusion
+ self.interleaved = interleaved
+ self.mask_info_flow = mask_info_flow
+
+ @property
+ def with_semantic(self):
+ """bool: whether the head has semantic head"""
+ if hasattr(self, 'semantic_head') and self.semantic_head is not None:
+ return True
+ else:
+ return False
+
+ def forward_dummy(self, x, proposals):
+ """Dummy forward function."""
+ outs = ()
+ # semantic head
+ if self.with_semantic:
+ _, semantic_feat = self.semantic_head(x)
+ else:
+ semantic_feat = None
+ # bbox heads
+ rois = bbox2roi([proposals])
+ for i in range(self.num_stages):
+ bbox_results = self._bbox_forward(
+ i, x, rois, semantic_feat=semantic_feat)
+ outs = outs + (bbox_results['cls_score'],
+ bbox_results['bbox_pred'])
+ # mask heads
+ if self.with_mask:
+ mask_rois = rois[:100]
+ mask_roi_extractor = self.mask_roi_extractor[-1]
+ mask_feats = mask_roi_extractor(
+ x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+ if self.with_semantic and 'mask' in self.semantic_fusion:
+ mask_semantic_feat = self.semantic_roi_extractor(
+ [semantic_feat], mask_rois)
+ mask_feats = mask_feats + mask_semantic_feat
+ last_feat = None
+ for i in range(self.num_stages):
+ mask_head = self.mask_head[i]
+ if self.mask_info_flow:
+ mask_pred, last_feat = mask_head(mask_feats, last_feat)
+ else:
+ mask_pred = mask_head(mask_feats)
+ outs = outs + (mask_pred, )
+ return outs
+
+ def _bbox_forward_train(self,
+ stage,
+ x,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ rcnn_train_cfg,
+ semantic_feat=None):
+ """Run forward function and calculate loss for box head in training."""
+ bbox_head = self.bbox_head[stage]
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(
+ stage, x, rois, semantic_feat=semantic_feat)
+
+ bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, rcnn_train_cfg)
+ loss_bbox = bbox_head.loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(
+ loss_bbox=loss_bbox,
+ rois=rois,
+ bbox_targets=bbox_targets,
+ )
+ return bbox_results
+
+ def _mask_forward_train(self,
+ stage,
+ x,
+ sampling_results,
+ gt_masks,
+ rcnn_train_cfg,
+ semantic_feat=None):
+ """Run forward function and calculate loss for mask head in
+ training."""
+ mask_roi_extractor = self.mask_roi_extractor[stage]
+ mask_head = self.mask_head[stage]
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+ mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
+ pos_rois)
+
+ # semantic feature fusion
+ # element-wise sum for original features and pooled semantic features
+ if self.with_semantic and 'mask' in self.semantic_fusion:
+ mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ pos_rois)
+ if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
+ mask_semantic_feat = F.adaptive_avg_pool2d(
+ mask_semantic_feat, mask_feats.shape[-2:])
+ mask_feats = mask_feats + mask_semantic_feat
+
+ # mask information flow
+ # forward all previous mask heads to obtain last_feat, and fuse it
+ # with the normal mask feature
+ if self.mask_info_flow:
+ last_feat = None
+ for i in range(stage):
+ last_feat = self.mask_head[i](
+ mask_feats, last_feat, return_logits=False)
+ mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
+ else:
+ mask_pred = mask_head(mask_feats, return_feat=False)
+
+ mask_targets = mask_head.get_targets(sampling_results, gt_masks,
+ rcnn_train_cfg)
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels)
+
+ mask_results = dict(loss_mask=loss_mask)
+ return mask_results
+
+ def _bbox_forward(self, stage, x, rois, semantic_feat=None):
+ """Box head forward function used in both training and testing."""
+ bbox_roi_extractor = self.bbox_roi_extractor[stage]
+ bbox_head = self.bbox_head[stage]
+ bbox_feats = bbox_roi_extractor(
+ x[:len(bbox_roi_extractor.featmap_strides)], rois)
+ if self.with_semantic and 'bbox' in self.semantic_fusion:
+ bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ rois)
+ if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
+ bbox_semantic_feat = adaptive_avg_pool2d(
+ bbox_semantic_feat, bbox_feats.shape[-2:])
+ bbox_feats = bbox_feats + bbox_semantic_feat
+ cls_score, bbox_pred = bbox_head(bbox_feats)
+
+ bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred)
+ return bbox_results
+
+ def _mask_forward_test(self, stage, x, bboxes, semantic_feat=None):
+ """Mask head forward function for testing."""
+ mask_roi_extractor = self.mask_roi_extractor[stage]
+ mask_head = self.mask_head[stage]
+ mask_rois = bbox2roi([bboxes])
+ mask_feats = mask_roi_extractor(
+ x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+ if self.with_semantic and 'mask' in self.semantic_fusion:
+ mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ mask_rois)
+ if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
+ mask_semantic_feat = F.adaptive_avg_pool2d(
+ mask_semantic_feat, mask_feats.shape[-2:])
+ mask_feats = mask_feats + mask_semantic_feat
+ if self.mask_info_flow:
+ last_feat = None
+ last_pred = None
+ for i in range(stage):
+ mask_pred, last_feat = self.mask_head[i](mask_feats, last_feat)
+ if last_pred is not None:
+ mask_pred = mask_pred + last_pred
+ last_pred = mask_pred
+ mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
+ if last_pred is not None:
+ mask_pred = mask_pred + last_pred
+ else:
+ mask_pred = mask_head(mask_feats)
+ return mask_pred
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ gt_semantic_seg=None):
+ """
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+ proposal_list (list[Tensors]): list of region proposals.
+
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+
+ gt_labels (list[Tensor]): class indices corresponding to each box
+
+ gt_bboxes_ignore (None, list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ gt_masks (None, Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ gt_semantic_seg (None, list[Tensor]): semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # semantic segmentation part
+ # 2 outputs: segmentation prediction and embedded features
+ losses = dict()
+ if self.with_semantic:
+ semantic_pred, semantic_feat = self.semantic_head(x)
+ loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg)
+ losses['loss_semantic_seg'] = loss_seg
+ else:
+ semantic_feat = None
+
+ for i in range(self.num_stages):
+ self.current_stage = i
+ rcnn_train_cfg = self.train_cfg[i]
+ lw = self.stage_loss_weights[i]
+
+ # assign gts and sample proposals
+ sampling_results = []
+ bbox_assigner = self.bbox_assigner[i]
+ bbox_sampler = self.bbox_sampler[i]
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+
+ for j in range(num_imgs):
+ assign_result = bbox_assigner.assign(proposal_list[j],
+ gt_bboxes[j],
+ gt_bboxes_ignore[j],
+ gt_labels[j])
+ sampling_result = bbox_sampler.sample(
+ assign_result,
+ proposal_list[j],
+ gt_bboxes[j],
+ gt_labels[j],
+ feats=[lvl_feat[j][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+
+ # bbox head forward and loss
+ bbox_results = \
+ self._bbox_forward_train(
+ i, x, sampling_results, gt_bboxes, gt_labels,
+ rcnn_train_cfg, semantic_feat)
+ roi_labels = bbox_results['bbox_targets'][0]
+
+ for name, value in bbox_results['loss_bbox'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # mask head forward and loss
+ if self.with_mask:
+ # interleaved execution: use regressed bboxes by the box branch
+ # to train the mask branch
+ if self.interleaved:
+ pos_is_gts = [res.pos_is_gt for res in sampling_results]
+ with torch.no_grad():
+ proposal_list = self.bbox_head[i].refine_bboxes(
+ bbox_results['rois'], roi_labels,
+ bbox_results['bbox_pred'], pos_is_gts, img_metas)
+ # re-assign and sample 512 RoIs from 512 RoIs
+ sampling_results = []
+ for j in range(num_imgs):
+ assign_result = bbox_assigner.assign(
+ proposal_list[j], gt_bboxes[j],
+ gt_bboxes_ignore[j], gt_labels[j])
+ sampling_result = bbox_sampler.sample(
+ assign_result,
+ proposal_list[j],
+ gt_bboxes[j],
+ gt_labels[j],
+ feats=[lvl_feat[j][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+ mask_results = self._mask_forward_train(
+ i, x, sampling_results, gt_masks, rcnn_train_cfg,
+ semantic_feat)
+ for name, value in mask_results['loss_mask'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # refine bboxes (same as Cascade R-CNN)
+ if i < self.num_stages - 1 and not self.interleaved:
+ pos_is_gts = [res.pos_is_gt for res in sampling_results]
+ with torch.no_grad():
+ proposal_list = self.bbox_head[i].refine_bboxes(
+ bbox_results['rois'], roi_labels,
+ bbox_results['bbox_pred'], pos_is_gts, img_metas)
+
+ return losses
+
+ def simple_test(self, x, proposal_list, img_metas, rescale=False):
+ """Test without augmentation.
+
+ Args:
+ x (tuple[Tensor]): Features from upstream network. Each
+ has shape (batch_size, c, h, w).
+ proposal_list (list(Tensor)): Proposals from rpn head.
+ Each has shape (num_proposals, 5), last dimension
+ 5 represent (x1, y1, x2, y2, score).
+ img_metas (list[dict]): Meta information of images.
+ rescale (bool): Whether to rescale the results to
+ the original image. Default: True.
+
+ Returns:
+ list[list[np.ndarray]] or list[tuple]: When no mask branch,
+ it is bbox results of each image and classes with type
+ `list[list[np.ndarray]]`. The outer list
+ corresponds to each image. The inner list
+ corresponds to each class. When the model has mask branch,
+ it contains bbox results and mask results.
+ The outer list corresponds to each image, and first element
+ of tuple is bbox results, second element is mask results.
+ """
+ if self.with_semantic:
+ _, semantic_feat = self.semantic_head(x)
+ else:
+ semantic_feat = None
+
+ num_imgs = len(proposal_list)
+ img_shapes = tuple(meta['img_shape'] for meta in img_metas)
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ # "ms" in variable names means multi-stage
+ ms_bbox_result = {}
+ ms_segm_result = {}
+ ms_scores = []
+ rcnn_test_cfg = self.test_cfg
+
+ rois = bbox2roi(proposal_list)
+
+ if rois.shape[0] == 0:
+ # There is no proposal in the whole batch
+ bbox_results = [[
+ np.zeros((0, 5), dtype=np.float32)
+ for _ in range(self.bbox_head[-1].num_classes)
+ ]] * num_imgs
+
+ if self.with_mask:
+ mask_classes = self.mask_head[-1].num_classes
+ segm_results = [[[] for _ in range(mask_classes)]
+ for _ in range(num_imgs)]
+ results = list(zip(bbox_results, segm_results))
+ else:
+ results = bbox_results
+
+ return results
+
+ for i in range(self.num_stages):
+ bbox_head = self.bbox_head[i]
+ bbox_results = self._bbox_forward(
+ i, x, rois, semantic_feat=semantic_feat)
+ # split batch bbox prediction back to each image
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+ num_proposals_per_img = tuple(len(p) for p in proposal_list)
+ rois = rois.split(num_proposals_per_img, 0)
+ cls_score = cls_score.split(num_proposals_per_img, 0)
+ bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
+ ms_scores.append(cls_score)
+
+ if i < self.num_stages - 1:
+ refine_rois_list = []
+ for j in range(num_imgs):
+ if rois[j].shape[0] > 0:
+ bbox_label = cls_score[j][:, :-1].argmax(dim=1)
+ refine_rois = bbox_head.regress_by_class(
+ rois[j], bbox_label, bbox_pred[j], img_metas[j])
+ refine_rois_list.append(refine_rois)
+ rois = torch.cat(refine_rois_list)
+
+ # average scores of each image by stages
+ cls_score = [
+ sum([score[i] for score in ms_scores]) / float(len(ms_scores))
+ for i in range(num_imgs)
+ ]
+
+ # apply bbox post-processing to each image individually
+ det_bboxes = []
+ det_labels = []
+ for i in range(num_imgs):
+ det_bbox, det_label = self.bbox_head[-1].get_bboxes(
+ rois[i],
+ cls_score[i],
+ bbox_pred[i],
+ img_shapes[i],
+ scale_factors[i],
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+ bbox_result = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head[-1].num_classes)
+ for i in range(num_imgs)
+ ]
+ ms_bbox_result['ensemble'] = bbox_result
+
+ if self.with_mask:
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ mask_classes = self.mask_head[-1].num_classes
+ segm_results = [[[] for _ in range(mask_classes)]
+ for _ in range(num_imgs)]
+ else:
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i]
+ for i in range(num_imgs)
+ ]
+ mask_rois = bbox2roi(_bboxes)
+ aug_masks = []
+ mask_roi_extractor = self.mask_roi_extractor[-1]
+ mask_feats = mask_roi_extractor(
+ x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+ if self.with_semantic and 'mask' in self.semantic_fusion:
+ mask_semantic_feat = self.semantic_roi_extractor(
+ [semantic_feat], mask_rois)
+ mask_feats = mask_feats + mask_semantic_feat
+ last_feat = None
+
+ num_bbox_per_img = tuple(len(_bbox) for _bbox in _bboxes)
+ for i in range(self.num_stages):
+ mask_head = self.mask_head[i]
+ if self.mask_info_flow:
+ mask_pred, last_feat = mask_head(mask_feats, last_feat)
+ else:
+ mask_pred = mask_head(mask_feats)
+
+ # split batch mask prediction back to each image
+ mask_pred = mask_pred.split(num_bbox_per_img, 0)
+ aug_masks.append(
+ [mask.sigmoid().cpu().numpy() for mask in mask_pred])
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[]
+ for _ in range(self.mask_head[-1].num_classes)])
+ else:
+ aug_mask = [mask[i] for mask in aug_masks]
+ merged_mask = merge_aug_masks(
+ aug_mask, [[img_metas[i]]] * self.num_stages,
+ rcnn_test_cfg)
+ segm_result = self.mask_head[-1].get_seg_masks(
+ merged_mask, _bboxes[i], det_labels[i],
+ rcnn_test_cfg, ori_shapes[i], scale_factors[i],
+ rescale)
+ segm_results.append(segm_result)
+ ms_segm_result['ensemble'] = segm_results
+
+ if self.with_mask:
+ results = list(
+ zip(ms_bbox_result['ensemble'], ms_segm_result['ensemble']))
+ else:
+ results = ms_bbox_result['ensemble']
+
+ return results
+
+ def aug_test(self, img_feats, proposal_list, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ if self.with_semantic:
+ semantic_feats = [
+ self.semantic_head(feat)[1] for feat in img_feats
+ ]
+ else:
+ semantic_feats = [None] * len(img_metas)
+
+ rcnn_test_cfg = self.test_cfg
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta, semantic in zip(img_feats, img_metas, semantic_feats):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ # "ms" in variable names means multi-stage
+ ms_scores = []
+
+ rois = bbox2roi([proposals])
+
+ if rois.shape[0] == 0:
+ # There is no proposal in the single image
+ aug_bboxes.append(rois.new_zeros(0, 4))
+ aug_scores.append(rois.new_zeros(0, 1))
+ continue
+
+ for i in range(self.num_stages):
+ bbox_head = self.bbox_head[i]
+ bbox_results = self._bbox_forward(
+ i, x, rois, semantic_feat=semantic)
+ ms_scores.append(bbox_results['cls_score'])
+
+ if i < self.num_stages - 1:
+ bbox_label = bbox_results['cls_score'].argmax(dim=1)
+ rois = bbox_head.regress_by_class(
+ rois, bbox_label, bbox_results['bbox_pred'],
+ img_meta[0])
+
+ cls_score = sum(ms_scores) / float(len(ms_scores))
+ bboxes, scores = self.bbox_head[-1].get_bboxes(
+ rois,
+ cls_score,
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ aug_bboxes.append(bboxes)
+ aug_scores.append(scores)
+
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+
+ bbox_result = bbox2result(det_bboxes, det_labels,
+ self.bbox_head[-1].num_classes)
+
+ if self.with_mask:
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[]
+ for _ in range(self.mask_head[-1].num_classes)]
+ else:
+ aug_masks = []
+ aug_img_metas = []
+ for x, img_meta, semantic in zip(img_feats, img_metas,
+ semantic_feats):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ mask_rois = bbox2roi([_bboxes])
+ mask_feats = self.mask_roi_extractor[-1](
+ x[:len(self.mask_roi_extractor[-1].featmap_strides)],
+ mask_rois)
+ if self.with_semantic:
+ semantic_feat = semantic
+ mask_semantic_feat = self.semantic_roi_extractor(
+ [semantic_feat], mask_rois)
+ if mask_semantic_feat.shape[-2:] != mask_feats.shape[
+ -2:]:
+ mask_semantic_feat = F.adaptive_avg_pool2d(
+ mask_semantic_feat, mask_feats.shape[-2:])
+ mask_feats = mask_feats + mask_semantic_feat
+ last_feat = None
+ for i in range(self.num_stages):
+ mask_head = self.mask_head[i]
+ if self.mask_info_flow:
+ mask_pred, last_feat = mask_head(
+ mask_feats, last_feat)
+ else:
+ mask_pred = mask_head(mask_feats)
+ aug_masks.append(mask_pred.sigmoid().cpu().numpy())
+ aug_img_metas.append(img_meta)
+ merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
+ self.test_cfg)
+
+ ori_shape = img_metas[0][0]['ori_shape']
+ segm_result = self.mask_head[-1].get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ rcnn_test_cfg,
+ ori_shape,
+ scale_factor=1.0,
+ rescale=False)
+ return [(bbox_result, segm_result)]
+ else:
+ return [bbox_result]
diff --git a/mmdet/models/roi_heads/mask_heads/__init__.py b/mmdet/models/roi_heads/mask_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..48a5d4227be41b8985403251e1803f78cf500636
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .coarse_mask_head import CoarseMaskHead
+from .dynamic_mask_head import DynamicMaskHead
+from .fcn_mask_head import FCNMaskHead
+from .feature_relay_head import FeatureRelayHead
+from .fused_semantic_head import FusedSemanticHead
+from .global_context_head import GlobalContextHead
+from .grid_head import GridHead
+from .htc_mask_head import HTCMaskHead
+from .mask_point_head import MaskPointHead
+from .maskiou_head import MaskIoUHead
+from .scnet_mask_head import SCNetMaskHead
+from .scnet_semantic_head import SCNetSemanticHead
+
+__all__ = [
+ 'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead',
+ 'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead', 'SCNetMaskHead',
+ 'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead',
+ 'DynamicMaskHead'
+]
diff --git a/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py b/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..946254cb4fe2544a0c6d390afbf40e2c50720f9e
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py
@@ -0,0 +1,100 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.cnn import ConvModule, Linear
+from mmcv.runner import ModuleList, auto_fp16
+
+from mmdet.models.builder import HEADS
+from .fcn_mask_head import FCNMaskHead
+
+
+@HEADS.register_module()
+class CoarseMaskHead(FCNMaskHead):
+ """Coarse mask head used in PointRend.
+
+ Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample
+ the input feature map instead of upsample it.
+
+ Args:
+ num_convs (int): Number of conv layers in the head. Default: 0.
+ num_fcs (int): Number of fc layers in the head. Default: 2.
+ fc_out_channels (int): Number of output channels of fc layer.
+ Default: 1024.
+ downsample_factor (int): The factor that feature map is downsampled by.
+ Default: 2.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_convs=0,
+ num_fcs=2,
+ fc_out_channels=1024,
+ downsample_factor=2,
+ init_cfg=dict(
+ type='Xavier',
+ override=[
+ dict(name='fcs'),
+ dict(type='Constant', val=0.001, name='fc_logits')
+ ]),
+ *arg,
+ **kwarg):
+ super(CoarseMaskHead, self).__init__(
+ *arg,
+ num_convs=num_convs,
+ upsample_cfg=dict(type=None),
+ init_cfg=None,
+ **kwarg)
+ self.init_cfg = init_cfg
+ self.num_fcs = num_fcs
+ assert self.num_fcs > 0
+ self.fc_out_channels = fc_out_channels
+ self.downsample_factor = downsample_factor
+ assert self.downsample_factor >= 1
+ # remove conv_logit
+ delattr(self, 'conv_logits')
+
+ if downsample_factor > 1:
+ downsample_in_channels = (
+ self.conv_out_channels
+ if self.num_convs > 0 else self.in_channels)
+ self.downsample_conv = ConvModule(
+ downsample_in_channels,
+ self.conv_out_channels,
+ kernel_size=downsample_factor,
+ stride=downsample_factor,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ else:
+ self.downsample_conv = None
+
+ self.output_size = (self.roi_feat_size[0] // downsample_factor,
+ self.roi_feat_size[1] // downsample_factor)
+ self.output_area = self.output_size[0] * self.output_size[1]
+
+ last_layer_dim = self.conv_out_channels * self.output_area
+
+ self.fcs = ModuleList()
+ for i in range(num_fcs):
+ fc_in_channels = (
+ last_layer_dim if i == 0 else self.fc_out_channels)
+ self.fcs.append(Linear(fc_in_channels, self.fc_out_channels))
+ last_layer_dim = self.fc_out_channels
+ output_channels = self.num_classes * self.output_area
+ self.fc_logits = Linear(last_layer_dim, output_channels)
+
+ def init_weights(self):
+ super(FCNMaskHead, self).init_weights()
+
+ @auto_fp16()
+ def forward(self, x):
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.downsample_conv is not None:
+ x = self.downsample_conv(x)
+
+ x = x.flatten(1)
+ for fc in self.fcs:
+ x = self.relu(fc(x))
+ mask_pred = self.fc_logits(x).view(
+ x.size(0), self.num_classes, *self.output_size)
+ return mask_pred
diff --git a/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py b/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bbe7eea49cae55ef3c4bdbb17e41f5788e45c79
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py
@@ -0,0 +1,147 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.runner import auto_fp16, force_fp32
+
+from mmdet.core import mask_target
+from mmdet.models.builder import HEADS
+from mmdet.models.dense_heads.atss_head import reduce_mean
+from mmdet.models.utils import build_transformer
+from .fcn_mask_head import FCNMaskHead
+
+
+@HEADS.register_module()
+class DynamicMaskHead(FCNMaskHead):
+ r"""Dynamic Mask Head for
+ `Instances as Queries `_
+
+ Args:
+ num_convs (int): Number of convolution layer.
+ Defaults to 4.
+ roi_feat_size (int): The output size of RoI extractor,
+ Defaults to 14.
+ in_channels (int): Input feature channels.
+ Defaults to 256.
+ conv_kernel_size (int): Kernel size of convolution layers.
+ Defaults to 3.
+ conv_out_channels (int): Output channels of convolution layers.
+ Defaults to 256.
+ num_classes (int): Number of classes.
+ Defaults to 80
+ class_agnostic (int): Whether generate class agnostic prediction.
+ Defaults to False.
+ dropout (float): Probability of drop the channel.
+ Defaults to 0.0
+ upsample_cfg (dict): The config for upsample layer.
+ conv_cfg (dict): The convolution layer config.
+ norm_cfg (dict): The norm layer config.
+ dynamic_conv_cfg (dict): The dynamic convolution layer config.
+ loss_mask (dict): The config for mask loss.
+ """
+
+ def __init__(self,
+ num_convs=4,
+ roi_feat_size=14,
+ in_channels=256,
+ conv_kernel_size=3,
+ conv_out_channels=256,
+ num_classes=80,
+ class_agnostic=False,
+ upsample_cfg=dict(type='deconv', scale_factor=2),
+ conv_cfg=None,
+ norm_cfg=None,
+ dynamic_conv_cfg=dict(
+ type='DynamicConv',
+ in_channels=256,
+ feat_channels=64,
+ out_channels=256,
+ input_feat_shape=14,
+ with_proj=False,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')),
+ loss_mask=dict(type='DiceLoss', loss_weight=8.0),
+ **kwargs):
+ super(DynamicMaskHead, self).__init__(
+ num_convs=num_convs,
+ roi_feat_size=roi_feat_size,
+ in_channels=in_channels,
+ conv_kernel_size=conv_kernel_size,
+ conv_out_channels=conv_out_channels,
+ num_classes=num_classes,
+ class_agnostic=class_agnostic,
+ upsample_cfg=upsample_cfg,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ loss_mask=loss_mask,
+ **kwargs)
+ assert class_agnostic is False, \
+ 'DynamicMaskHead only support class_agnostic=False'
+ self.fp16_enabled = False
+
+ self.instance_interactive_conv = build_transformer(dynamic_conv_cfg)
+
+ def init_weights(self):
+ """Use xavier initialization for all weight parameter and set
+ classification head bias as a specific value when use focal loss."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ nn.init.constant_(self.conv_logits.bias, 0.)
+
+ @auto_fp16()
+ def forward(self, roi_feat, proposal_feat):
+ """Forward function of DynamicMaskHead.
+
+ Args:
+ roi_feat (Tensor): Roi-pooling features with shape
+ (batch_size*num_proposals, feature_dimensions,
+ pooling_h , pooling_w).
+ proposal_feat (Tensor): Intermediate feature get from
+ diihead in last stage, has shape
+ (batch_size*num_proposals, feature_dimensions)
+
+ Returns:
+ mask_pred (Tensor): Predicted foreground masks with shape
+ (batch_size*num_proposals, num_classes,
+ pooling_h*2, pooling_w*2).
+ """
+
+ proposal_feat = proposal_feat.reshape(-1, self.in_channels)
+ proposal_feat_iic = self.instance_interactive_conv(
+ proposal_feat, roi_feat)
+
+ x = proposal_feat_iic.permute(0, 2, 1).reshape(roi_feat.size())
+
+ for conv in self.convs:
+ x = conv(x)
+ if self.upsample is not None:
+ x = self.upsample(x)
+ if self.upsample_method == 'deconv':
+ x = self.relu(x)
+ mask_pred = self.conv_logits(x)
+ return mask_pred
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def loss(self, mask_pred, mask_targets, labels):
+ num_pos = labels.new_ones(labels.size()).float().sum()
+ avg_factor = torch.clamp(reduce_mean(num_pos), min=1.).item()
+ loss = dict()
+ if mask_pred.size(0) == 0:
+ loss_mask = mask_pred.sum()
+ else:
+ loss_mask = self.loss_mask(
+ mask_pred[torch.arange(num_pos).long(), labels, ...].sigmoid(),
+ mask_targets,
+ avg_factor=avg_factor)
+ loss['loss_mask'] = loss_mask
+ return loss
+
+ def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg):
+
+ pos_proposals = [res.pos_bboxes for res in sampling_results]
+ pos_assigned_gt_inds = [
+ res.pos_assigned_gt_inds for res in sampling_results
+ ]
+ mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
+ gt_masks, rcnn_train_cfg)
+ return mask_targets
diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..355d88221403f01a36a9e99d1a12d877a877790a
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from warnings import warn
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, build_conv_layer, build_upsample_layer
+from mmcv.ops.carafe import CARAFEPack
+from mmcv.runner import BaseModule, ModuleList, auto_fp16, force_fp32
+from torch.nn.modules.utils import _pair
+
+from mmdet.core import mask_target
+from mmdet.models.builder import HEADS, build_loss
+
+BYTES_PER_FLOAT = 4
+# TODO: This memory limit may be too much or too little. It would be better to
+# determine it based on available resources.
+GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit
+
+
+@HEADS.register_module()
+class FCNMaskHead(BaseModule):
+
+ def __init__(self,
+ num_convs=4,
+ roi_feat_size=14,
+ in_channels=256,
+ conv_kernel_size=3,
+ conv_out_channels=256,
+ num_classes=80,
+ class_agnostic=False,
+ upsample_cfg=dict(type='deconv', scale_factor=2),
+ conv_cfg=None,
+ norm_cfg=None,
+ predictor_cfg=dict(type='Conv'),
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
+ init_cfg=None):
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
+ 'behavior, init_cfg is not allowed to be set'
+ super(FCNMaskHead, self).__init__(init_cfg)
+ self.upsample_cfg = upsample_cfg.copy()
+ if self.upsample_cfg['type'] not in [
+ None, 'deconv', 'nearest', 'bilinear', 'carafe'
+ ]:
+ raise ValueError(
+ f'Invalid upsample method {self.upsample_cfg["type"]}, '
+ 'accepted methods are "deconv", "nearest", "bilinear", '
+ '"carafe"')
+ self.num_convs = num_convs
+ # WARN: roi_feat_size is reserved and not used
+ self.roi_feat_size = _pair(roi_feat_size)
+ self.in_channels = in_channels
+ self.conv_kernel_size = conv_kernel_size
+ self.conv_out_channels = conv_out_channels
+ self.upsample_method = self.upsample_cfg.get('type')
+ self.scale_factor = self.upsample_cfg.pop('scale_factor', None)
+ self.num_classes = num_classes
+ self.class_agnostic = class_agnostic
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.predictor_cfg = predictor_cfg
+ self.fp16_enabled = False
+ self.loss_mask = build_loss(loss_mask)
+
+ self.convs = ModuleList()
+ for i in range(self.num_convs):
+ in_channels = (
+ self.in_channels if i == 0 else self.conv_out_channels)
+ padding = (self.conv_kernel_size - 1) // 2
+ self.convs.append(
+ ConvModule(
+ in_channels,
+ self.conv_out_channels,
+ self.conv_kernel_size,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg))
+ upsample_in_channels = (
+ self.conv_out_channels if self.num_convs > 0 else in_channels)
+ upsample_cfg_ = self.upsample_cfg.copy()
+ if self.upsample_method is None:
+ self.upsample = None
+ elif self.upsample_method == 'deconv':
+ upsample_cfg_.update(
+ in_channels=upsample_in_channels,
+ out_channels=self.conv_out_channels,
+ kernel_size=self.scale_factor,
+ stride=self.scale_factor)
+ self.upsample = build_upsample_layer(upsample_cfg_)
+ elif self.upsample_method == 'carafe':
+ upsample_cfg_.update(
+ channels=upsample_in_channels, scale_factor=self.scale_factor)
+ self.upsample = build_upsample_layer(upsample_cfg_)
+ else:
+ # suppress warnings
+ align_corners = (None
+ if self.upsample_method == 'nearest' else False)
+ upsample_cfg_.update(
+ scale_factor=self.scale_factor,
+ mode=self.upsample_method,
+ align_corners=align_corners)
+ self.upsample = build_upsample_layer(upsample_cfg_)
+
+ out_channels = 1 if self.class_agnostic else self.num_classes
+ logits_in_channel = (
+ self.conv_out_channels
+ if self.upsample_method == 'deconv' else upsample_in_channels)
+ self.conv_logits = build_conv_layer(self.predictor_cfg,
+ logits_in_channel, out_channels, 1)
+ self.relu = nn.ReLU(inplace=True)
+ self.debug_imgs = None
+
+ def init_weights(self):
+ super(FCNMaskHead, self).init_weights()
+ for m in [self.upsample, self.conv_logits]:
+ if m is None:
+ continue
+ elif isinstance(m, CARAFEPack):
+ m.init_weights()
+ elif hasattr(m, 'weight') and hasattr(m, 'bias'):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ nn.init.constant_(m.bias, 0)
+
+ @auto_fp16()
+ def forward(self, x):
+ for conv in self.convs:
+ x = conv(x)
+ if self.upsample is not None:
+ x = self.upsample(x)
+ if self.upsample_method == 'deconv':
+ x = self.relu(x)
+ mask_pred = self.conv_logits(x)
+ return mask_pred
+
+ def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg):
+ pos_proposals = [res.pos_bboxes for res in sampling_results]
+ pos_assigned_gt_inds = [
+ res.pos_assigned_gt_inds for res in sampling_results
+ ]
+ mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
+ gt_masks, rcnn_train_cfg)
+ return mask_targets
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def loss(self, mask_pred, mask_targets, labels):
+ """
+ Example:
+ >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
+ >>> N = 7 # N = number of extracted ROIs
+ >>> C, H, W = 11, 32, 32
+ >>> # Create example instance of FCN Mask Head.
+ >>> # There are lots of variations depending on the configuration
+ >>> self = FCNMaskHead(num_classes=C, num_convs=1)
+ >>> inputs = torch.rand(N, self.in_channels, H, W)
+ >>> mask_pred = self.forward(inputs)
+ >>> sf = self.scale_factor
+ >>> labels = torch.randint(0, C, size=(N,))
+ >>> # With the default properties the mask targets should indicate
+ >>> # a (potentially soft) single-class label
+ >>> mask_targets = torch.rand(N, H * sf, W * sf)
+ >>> loss = self.loss(mask_pred, mask_targets, labels)
+ >>> print('loss = {!r}'.format(loss))
+ """
+ loss = dict()
+ if mask_pred.size(0) == 0:
+ loss_mask = mask_pred.sum()
+ else:
+ if self.class_agnostic:
+ loss_mask = self.loss_mask(mask_pred, mask_targets,
+ torch.zeros_like(labels))
+ else:
+ loss_mask = self.loss_mask(mask_pred, mask_targets, labels)
+ loss['loss_mask'] = loss_mask
+ return loss
+
+ def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
+ ori_shape, scale_factor, rescale):
+ """Get segmentation masks from mask_pred and bboxes.
+
+ Args:
+ mask_pred (Tensor or ndarray): shape (n, #class, h, w).
+ For single-scale testing, mask_pred is the direct output of
+ model, whose type is Tensor, while for multi-scale testing,
+ it will be converted to numpy array outside of this method.
+ det_bboxes (Tensor): shape (n, 4/5)
+ det_labels (Tensor): shape (n, )
+ rcnn_test_cfg (dict): rcnn testing config
+ ori_shape (Tuple): original image height and width, shape (2,)
+ scale_factor(ndarray | Tensor): If ``rescale is True``, box
+ coordinates are divided by this scale factor to fit
+ ``ori_shape``.
+ rescale (bool): If True, the resulting masks will be rescaled to
+ ``ori_shape``.
+
+ Returns:
+ list[list]: encoded masks. The c-th item in the outer list
+ corresponds to the c-th class. Given the c-th outer list, the
+ i-th item in that inner list is the mask for the i-th box with
+ class label c.
+
+ Example:
+ >>> import mmcv
+ >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
+ >>> N = 7 # N = number of extracted ROIs
+ >>> C, H, W = 11, 32, 32
+ >>> # Create example instance of FCN Mask Head.
+ >>> self = FCNMaskHead(num_classes=C, num_convs=0)
+ >>> inputs = torch.rand(N, self.in_channels, H, W)
+ >>> mask_pred = self.forward(inputs)
+ >>> # Each input is associated with some bounding box
+ >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N)
+ >>> det_labels = torch.randint(0, C, size=(N,))
+ >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, })
+ >>> ori_shape = (H * 4, W * 4)
+ >>> scale_factor = torch.FloatTensor((1, 1))
+ >>> rescale = False
+ >>> # Encoded masks are a list for each category.
+ >>> encoded_masks = self.get_seg_masks(
+ >>> mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape,
+ >>> scale_factor, rescale
+ >>> )
+ >>> assert len(encoded_masks) == C
+ >>> assert sum(list(map(len, encoded_masks))) == N
+ """
+ if isinstance(mask_pred, torch.Tensor):
+ mask_pred = mask_pred.sigmoid()
+ else:
+ # In AugTest, has been activated before
+ mask_pred = det_bboxes.new_tensor(mask_pred)
+
+ device = mask_pred.device
+ cls_segms = [[] for _ in range(self.num_classes)
+ ] # BG is not included in num_classes
+ bboxes = det_bboxes[:, :4]
+ labels = det_labels
+
+ # In most cases, scale_factor should have been
+ # converted to Tensor when rescale the bbox
+ if not isinstance(scale_factor, torch.Tensor):
+ if isinstance(scale_factor, float):
+ scale_factor = np.array([scale_factor] * 4)
+ warn('Scale_factor should be a Tensor or ndarray '
+ 'with shape (4,), float would be deprecated. ')
+ assert isinstance(scale_factor, np.ndarray)
+ scale_factor = torch.Tensor(scale_factor)
+
+ if rescale:
+ img_h, img_w = ori_shape[:2]
+ bboxes = bboxes / scale_factor.to(bboxes)
+ else:
+ w_scale, h_scale = scale_factor[0], scale_factor[1]
+ img_h = np.round(ori_shape[0] * h_scale.item()).astype(np.int32)
+ img_w = np.round(ori_shape[1] * w_scale.item()).astype(np.int32)
+
+ N = len(mask_pred)
+ # The actual implementation split the input into chunks,
+ # and paste them chunk by chunk.
+ if device.type == 'cpu':
+ # CPU is most efficient when they are pasted one by one with
+ # skip_empty=True, so that it performs minimal number of
+ # operations.
+ num_chunks = N
+ else:
+ # GPU benefits from parallelism for larger chunks,
+ # but may have memory issue
+ # the types of img_w and img_h are np.int32,
+ # when the image resolution is large,
+ # the calculation of num_chunks will overflow.
+ # so we need to change the types of img_w and img_h to int.
+ # See https://github.com/open-mmlab/mmdetection/pull/5191
+ num_chunks = int(
+ np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT /
+ GPU_MEM_LIMIT))
+ assert (num_chunks <=
+ N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
+ chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
+
+ threshold = rcnn_test_cfg.mask_thr_binary
+ im_mask = torch.zeros(
+ N,
+ img_h,
+ img_w,
+ device=device,
+ dtype=torch.bool if threshold >= 0 else torch.uint8)
+
+ if not self.class_agnostic:
+ mask_pred = mask_pred[range(N), labels][:, None]
+
+ for inds in chunks:
+ masks_chunk, spatial_inds = _do_paste_mask(
+ mask_pred[inds],
+ bboxes[inds],
+ img_h,
+ img_w,
+ skip_empty=device.type == 'cpu')
+
+ if threshold >= 0:
+ masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
+ else:
+ # for visualization and debugging
+ masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
+
+ im_mask[(inds, ) + spatial_inds] = masks_chunk
+
+ for i in range(N):
+ cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
+ return cls_segms
+
+ def onnx_export(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
+ ori_shape, **kwargs):
+ """Get segmentation masks from mask_pred and bboxes.
+
+ Args:
+ mask_pred (Tensor): shape (n, #class, h, w).
+ det_bboxes (Tensor): shape (n, 4/5)
+ det_labels (Tensor): shape (n, )
+ rcnn_test_cfg (dict): rcnn testing config
+ ori_shape (Tuple): original image height and width, shape (2,)
+
+ Returns:
+ Tensor: a mask of shape (N, img_h, img_w).
+ """
+
+ mask_pred = mask_pred.sigmoid()
+ bboxes = det_bboxes[:, :4]
+ labels = det_labels
+ # No need to consider rescale and scale_factor while exporting to ONNX
+ img_h, img_w = ori_shape[:2]
+ threshold = rcnn_test_cfg.mask_thr_binary
+ if not self.class_agnostic:
+ box_inds = torch.arange(mask_pred.shape[0])
+ mask_pred = mask_pred[box_inds, labels][:, None]
+ masks, _ = _do_paste_mask(
+ mask_pred, bboxes, img_h, img_w, skip_empty=False)
+ if threshold >= 0:
+ # should convert to float to avoid problems in TRT
+ masks = (masks >= threshold).to(dtype=torch.float)
+ return masks
+
+
+def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
+ """Paste instance masks according to boxes.
+
+ This implementation is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ Args:
+ masks (Tensor): N, 1, H, W
+ boxes (Tensor): N, 4
+ img_h (int): Height of the image to be pasted.
+ img_w (int): Width of the image to be pasted.
+ skip_empty (bool): Only paste masks within the region that
+ tightly bound all boxes, and returns the results this region only.
+ An important optimization for CPU.
+
+ Returns:
+ tuple: (Tensor, tuple). The first item is mask tensor, the second one
+ is the slice object.
+ If skip_empty == False, the whole image will be pasted. It will
+ return a mask of shape (N, img_h, img_w) and an empty tuple.
+ If skip_empty == True, only area around the mask will be pasted.
+ A mask of shape (N, h', w') and its start and end coordinates
+ in the original image will be returned.
+ """
+ # On GPU, paste all masks together (up to chunk size)
+ # by using the entire image to sample the masks
+ # Compared to pasting them one by one,
+ # this has more operations but is faster on COCO-scale dataset.
+ device = masks.device
+ if skip_empty:
+ x0_int, y0_int = torch.clamp(
+ boxes.min(dim=0).values.floor()[:2] - 1,
+ min=0).to(dtype=torch.int32)
+ x1_int = torch.clamp(
+ boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
+ y1_int = torch.clamp(
+ boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
+ else:
+ x0_int, y0_int = 0, 0
+ x1_int, y1_int = img_w, img_h
+ x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
+
+ N = masks.shape[0]
+
+ img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5
+ img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5
+ img_y = (img_y - y0) / (y1 - y0) * 2 - 1
+ img_x = (img_x - x0) / (x1 - x0) * 2 - 1
+ # img_x, img_y have shapes (N, w), (N, h)
+ # IsInf op is not supported with ONNX<=1.7.0
+ if not torch.onnx.is_in_onnx_export():
+ if torch.isinf(img_x).any():
+ inds = torch.where(torch.isinf(img_x))
+ img_x[inds] = 0
+ if torch.isinf(img_y).any():
+ inds = torch.where(torch.isinf(img_y))
+ img_y[inds] = 0
+
+ gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
+ gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
+ grid = torch.stack([gx, gy], dim=3)
+
+ img_masks = F.grid_sample(
+ masks.to(dtype=torch.float32), grid, align_corners=False)
+
+ if skip_empty:
+ return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
+ else:
+ return img_masks[:, 0], ()
diff --git a/mmdet/models/roi_heads/mask_heads/feature_relay_head.py b/mmdet/models/roi_heads/mask_heads/feature_relay_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..452f37afdb6c8232aac0a68dcb7ccbd256d788b6
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/feature_relay_head.py
@@ -0,0 +1,53 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.runner import BaseModule, auto_fp16
+
+from mmdet.models.builder import HEADS
+
+
+@HEADS.register_module()
+class FeatureRelayHead(BaseModule):
+ """Feature Relay Head used in `SCNet `_.
+
+ Args:
+ in_channels (int, optional): number of input channels. Default: 256.
+ conv_out_channels (int, optional): number of output channels before
+ classification layer. Default: 256.
+ roi_feat_size (int, optional): roi feat size at box head. Default: 7.
+ scale_factor (int, optional): scale factor to match roi feat size
+ at mask head. Default: 2.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ in_channels=1024,
+ out_conv_channels=256,
+ roi_feat_size=7,
+ scale_factor=2,
+ init_cfg=dict(type='Kaiming', layer='Linear')):
+ super(FeatureRelayHead, self).__init__(init_cfg)
+ assert isinstance(roi_feat_size, int)
+
+ self.in_channels = in_channels
+ self.out_conv_channels = out_conv_channels
+ self.roi_feat_size = roi_feat_size
+ self.out_channels = (roi_feat_size**2) * out_conv_channels
+ self.scale_factor = scale_factor
+ self.fp16_enabled = False
+
+ self.fc = nn.Linear(self.in_channels, self.out_channels)
+ self.upsample = nn.Upsample(
+ scale_factor=scale_factor, mode='bilinear', align_corners=True)
+
+ @auto_fp16()
+ def forward(self, x):
+ """Forward function."""
+ N, in_C = x.shape
+ if N > 0:
+ out_C = self.out_conv_channels
+ out_HW = self.roi_feat_size
+ x = self.fc(x)
+ x = x.reshape(N, out_C, out_HW, out_HW)
+ x = self.upsample(x)
+ return x
+ return None
diff --git a/mmdet/models/roi_heads/mask_heads/fused_semantic_head.py b/mmdet/models/roi_heads/mask_heads/fused_semantic_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6eaa54ae8c90e305e5ec498a8af7c05db4a831f
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/fused_semantic_head.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule, auto_fp16, force_fp32
+
+from mmdet.models.builder import HEADS, build_loss
+
+
+@HEADS.register_module()
+class FusedSemanticHead(BaseModule):
+ r"""Multi-level fused semantic segmentation head.
+
+ .. code-block:: none
+
+ in_1 -> 1x1 conv ---
+ |
+ in_2 -> 1x1 conv -- |
+ ||
+ in_3 -> 1x1 conv - ||
+ ||| /-> 1x1 conv (mask prediction)
+ in_4 -> 1x1 conv -----> 3x3 convs (*4)
+ | \-> 1x1 conv (feature)
+ in_5 -> 1x1 conv ---
+ """ # noqa: W605
+
+ def __init__(self,
+ num_ins,
+ fusion_level,
+ num_convs=4,
+ in_channels=256,
+ conv_out_channels=256,
+ num_classes=183,
+ conv_cfg=None,
+ norm_cfg=None,
+ ignore_label=None,
+ loss_weight=None,
+ loss_seg=dict(
+ type='CrossEntropyLoss',
+ ignore_index=255,
+ loss_weight=0.2),
+ init_cfg=dict(
+ type='Kaiming', override=dict(name='conv_logits'))):
+ super(FusedSemanticHead, self).__init__(init_cfg)
+ self.num_ins = num_ins
+ self.fusion_level = fusion_level
+ self.num_convs = num_convs
+ self.in_channels = in_channels
+ self.conv_out_channels = conv_out_channels
+ self.num_classes = num_classes
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.fp16_enabled = False
+
+ self.lateral_convs = nn.ModuleList()
+ for i in range(self.num_ins):
+ self.lateral_convs.append(
+ ConvModule(
+ self.in_channels,
+ self.in_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=False))
+
+ self.convs = nn.ModuleList()
+ for i in range(self.num_convs):
+ in_channels = self.in_channels if i == 0 else conv_out_channels
+ self.convs.append(
+ ConvModule(
+ in_channels,
+ conv_out_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.conv_embedding = ConvModule(
+ conv_out_channels,
+ conv_out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)
+ if ignore_label:
+ loss_seg['ignore_index'] = ignore_label
+ if loss_weight:
+ loss_seg['loss_weight'] = loss_weight
+ if ignore_label or loss_weight:
+ warnings.warn('``ignore_label`` and ``loss_weight`` would be '
+ 'deprecated soon. Please set ``ingore_index`` and '
+ '``loss_weight`` in ``loss_seg`` instead.')
+ self.criterion = build_loss(loss_seg)
+
+ @auto_fp16()
+ def forward(self, feats):
+ x = self.lateral_convs[self.fusion_level](feats[self.fusion_level])
+ fused_size = tuple(x.shape[-2:])
+ for i, feat in enumerate(feats):
+ if i != self.fusion_level:
+ feat = F.interpolate(
+ feat, size=fused_size, mode='bilinear', align_corners=True)
+ # fix runtime error of "+=" inplace operation in PyTorch 1.10
+ x = x + self.lateral_convs[i](feat)
+
+ for i in range(self.num_convs):
+ x = self.convs[i](x)
+
+ mask_pred = self.conv_logits(x)
+ x = self.conv_embedding(x)
+ return mask_pred, x
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def loss(self, mask_pred, labels):
+ labels = labels.squeeze(1).long()
+ loss_semantic_seg = self.criterion(mask_pred, labels)
+ return loss_semantic_seg
diff --git a/mmdet/models/roi_heads/mask_heads/global_context_head.py b/mmdet/models/roi_heads/mask_heads/global_context_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..af76a174be4dadf603f82b44a64ce487c9c64ca7
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/global_context_head.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule, auto_fp16, force_fp32
+
+from mmdet.models.builder import HEADS
+from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
+
+
+@HEADS.register_module()
+class GlobalContextHead(BaseModule):
+ """Global context head used in `SCNet `_.
+
+ Args:
+ num_convs (int, optional): number of convolutional layer in GlbCtxHead.
+ Default: 4.
+ in_channels (int, optional): number of input channels. Default: 256.
+ conv_out_channels (int, optional): number of output channels before
+ classification layer. Default: 256.
+ num_classes (int, optional): number of classes. Default: 80.
+ loss_weight (float, optional): global context loss weight. Default: 1.
+ conv_cfg (dict, optional): config to init conv layer. Default: None.
+ norm_cfg (dict, optional): config to init norm layer. Default: None.
+ conv_to_res (bool, optional): if True, 2 convs will be grouped into
+ 1 `SimplifiedBasicBlock` using a skip connection. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_convs=4,
+ in_channels=256,
+ conv_out_channels=256,
+ num_classes=80,
+ loss_weight=1.0,
+ conv_cfg=None,
+ norm_cfg=None,
+ conv_to_res=False,
+ init_cfg=dict(
+ type='Normal', std=0.01, override=dict(name='fc'))):
+ super(GlobalContextHead, self).__init__(init_cfg)
+ self.num_convs = num_convs
+ self.in_channels = in_channels
+ self.conv_out_channels = conv_out_channels
+ self.num_classes = num_classes
+ self.loss_weight = loss_weight
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.conv_to_res = conv_to_res
+ self.fp16_enabled = False
+
+ if self.conv_to_res:
+ num_res_blocks = num_convs // 2
+ self.convs = ResLayer(
+ SimplifiedBasicBlock,
+ in_channels,
+ self.conv_out_channels,
+ num_res_blocks,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ self.num_convs = num_res_blocks
+ else:
+ self.convs = nn.ModuleList()
+ for i in range(self.num_convs):
+ in_channels = self.in_channels if i == 0 else conv_out_channels
+ self.convs.append(
+ ConvModule(
+ in_channels,
+ conv_out_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+
+ self.pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Linear(conv_out_channels, num_classes)
+
+ self.criterion = nn.BCEWithLogitsLoss()
+
+ @auto_fp16()
+ def forward(self, feats):
+ """Forward function."""
+ x = feats[-1]
+ for i in range(self.num_convs):
+ x = self.convs[i](x)
+ x = self.pool(x)
+
+ # multi-class prediction
+ mc_pred = x.reshape(x.size(0), -1)
+ mc_pred = self.fc(mc_pred)
+
+ return mc_pred, x
+
+ @force_fp32(apply_to=('pred', ))
+ def loss(self, pred, labels):
+ """Loss function."""
+ labels = [lbl.unique() for lbl in labels]
+ targets = pred.new_zeros(pred.size())
+ for i, label in enumerate(labels):
+ targets[i, label] = 1.0
+ loss = self.loss_weight * self.criterion(pred, targets)
+ return loss
diff --git a/mmdet/models/roi_heads/mask_heads/grid_head.py b/mmdet/models/roi_heads/mask_heads/grid_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c0702d2a3f8bb7f2292307b907260bdecf1a164
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/grid_head.py
@@ -0,0 +1,363 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+
+from mmdet.models.builder import HEADS, build_loss
+
+
+@HEADS.register_module()
+class GridHead(BaseModule):
+
+ def __init__(self,
+ grid_points=9,
+ num_convs=8,
+ roi_feat_size=14,
+ in_channels=256,
+ conv_kernel_size=3,
+ point_feat_channels=64,
+ deconv_kernel_size=4,
+ class_agnostic=False,
+ loss_grid=dict(
+ type='CrossEntropyLoss', use_sigmoid=True,
+ loss_weight=15),
+ conv_cfg=None,
+ norm_cfg=dict(type='GN', num_groups=36),
+ init_cfg=[
+ dict(type='Kaiming', layer=['Conv2d', 'Linear']),
+ dict(
+ type='Normal',
+ layer='ConvTranspose2d',
+ std=0.001,
+ override=dict(
+ type='Normal',
+ name='deconv2',
+ std=0.001,
+ bias=-np.log(0.99 / 0.01)))
+ ]):
+ super(GridHead, self).__init__(init_cfg)
+ self.grid_points = grid_points
+ self.num_convs = num_convs
+ self.roi_feat_size = roi_feat_size
+ self.in_channels = in_channels
+ self.conv_kernel_size = conv_kernel_size
+ self.point_feat_channels = point_feat_channels
+ self.conv_out_channels = self.point_feat_channels * self.grid_points
+ self.class_agnostic = class_agnostic
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ if isinstance(norm_cfg, dict) and norm_cfg['type'] == 'GN':
+ assert self.conv_out_channels % norm_cfg['num_groups'] == 0
+
+ assert self.grid_points >= 4
+ self.grid_size = int(np.sqrt(self.grid_points))
+ if self.grid_size * self.grid_size != self.grid_points:
+ raise ValueError('grid_points must be a square number')
+
+ # the predicted heatmap is half of whole_map_size
+ if not isinstance(self.roi_feat_size, int):
+ raise ValueError('Only square RoIs are supporeted in Grid R-CNN')
+ self.whole_map_size = self.roi_feat_size * 4
+
+ # compute point-wise sub-regions
+ self.sub_regions = self.calc_sub_regions()
+
+ self.convs = []
+ for i in range(self.num_convs):
+ in_channels = (
+ self.in_channels if i == 0 else self.conv_out_channels)
+ stride = 2 if i == 0 else 1
+ padding = (self.conv_kernel_size - 1) // 2
+ self.convs.append(
+ ConvModule(
+ in_channels,
+ self.conv_out_channels,
+ self.conv_kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=True))
+ self.convs = nn.Sequential(*self.convs)
+
+ self.deconv1 = nn.ConvTranspose2d(
+ self.conv_out_channels,
+ self.conv_out_channels,
+ kernel_size=deconv_kernel_size,
+ stride=2,
+ padding=(deconv_kernel_size - 2) // 2,
+ groups=grid_points)
+ self.norm1 = nn.GroupNorm(grid_points, self.conv_out_channels)
+ self.deconv2 = nn.ConvTranspose2d(
+ self.conv_out_channels,
+ grid_points,
+ kernel_size=deconv_kernel_size,
+ stride=2,
+ padding=(deconv_kernel_size - 2) // 2,
+ groups=grid_points)
+
+ # find the 4-neighbor of each grid point
+ self.neighbor_points = []
+ grid_size = self.grid_size
+ for i in range(grid_size): # i-th column
+ for j in range(grid_size): # j-th row
+ neighbors = []
+ if i > 0: # left: (i - 1, j)
+ neighbors.append((i - 1) * grid_size + j)
+ if j > 0: # up: (i, j - 1)
+ neighbors.append(i * grid_size + j - 1)
+ if j < grid_size - 1: # down: (i, j + 1)
+ neighbors.append(i * grid_size + j + 1)
+ if i < grid_size - 1: # right: (i + 1, j)
+ neighbors.append((i + 1) * grid_size + j)
+ self.neighbor_points.append(tuple(neighbors))
+ # total edges in the grid
+ self.num_edges = sum([len(p) for p in self.neighbor_points])
+
+ self.forder_trans = nn.ModuleList() # first-order feature transition
+ self.sorder_trans = nn.ModuleList() # second-order feature transition
+ for neighbors in self.neighbor_points:
+ fo_trans = nn.ModuleList()
+ so_trans = nn.ModuleList()
+ for _ in range(len(neighbors)):
+ # each transition module consists of a 5x5 depth-wise conv and
+ # 1x1 conv.
+ fo_trans.append(
+ nn.Sequential(
+ nn.Conv2d(
+ self.point_feat_channels,
+ self.point_feat_channels,
+ 5,
+ stride=1,
+ padding=2,
+ groups=self.point_feat_channels),
+ nn.Conv2d(self.point_feat_channels,
+ self.point_feat_channels, 1)))
+ so_trans.append(
+ nn.Sequential(
+ nn.Conv2d(
+ self.point_feat_channels,
+ self.point_feat_channels,
+ 5,
+ 1,
+ 2,
+ groups=self.point_feat_channels),
+ nn.Conv2d(self.point_feat_channels,
+ self.point_feat_channels, 1)))
+ self.forder_trans.append(fo_trans)
+ self.sorder_trans.append(so_trans)
+
+ self.loss_grid = build_loss(loss_grid)
+
+ def forward(self, x):
+ assert x.shape[-1] == x.shape[-2] == self.roi_feat_size
+ # RoI feature transformation, downsample 2x
+ x = self.convs(x)
+
+ c = self.point_feat_channels
+ # first-order fusion
+ x_fo = [None for _ in range(self.grid_points)]
+ for i, points in enumerate(self.neighbor_points):
+ x_fo[i] = x[:, i * c:(i + 1) * c]
+ for j, point_idx in enumerate(points):
+ x_fo[i] = x_fo[i] + self.forder_trans[i][j](
+ x[:, point_idx * c:(point_idx + 1) * c])
+
+ # second-order fusion
+ x_so = [None for _ in range(self.grid_points)]
+ for i, points in enumerate(self.neighbor_points):
+ x_so[i] = x[:, i * c:(i + 1) * c]
+ for j, point_idx in enumerate(points):
+ x_so[i] = x_so[i] + self.sorder_trans[i][j](x_fo[point_idx])
+
+ # predicted heatmap with fused features
+ x2 = torch.cat(x_so, dim=1)
+ x2 = self.deconv1(x2)
+ x2 = F.relu(self.norm1(x2), inplace=True)
+ heatmap = self.deconv2(x2)
+
+ # predicted heatmap with original features (applicable during training)
+ if self.training:
+ x1 = x
+ x1 = self.deconv1(x1)
+ x1 = F.relu(self.norm1(x1), inplace=True)
+ heatmap_unfused = self.deconv2(x1)
+ else:
+ heatmap_unfused = heatmap
+
+ return dict(fused=heatmap, unfused=heatmap_unfused)
+
+ def calc_sub_regions(self):
+ """Compute point specific representation regions.
+
+ See Grid R-CNN Plus (https://arxiv.org/abs/1906.05688) for details.
+ """
+ # to make it consistent with the original implementation, half_size
+ # is computed as 2 * quarter_size, which is smaller
+ half_size = self.whole_map_size // 4 * 2
+ sub_regions = []
+ for i in range(self.grid_points):
+ x_idx = i // self.grid_size
+ y_idx = i % self.grid_size
+ if x_idx == 0:
+ sub_x1 = 0
+ elif x_idx == self.grid_size - 1:
+ sub_x1 = half_size
+ else:
+ ratio = x_idx / (self.grid_size - 1) - 0.25
+ sub_x1 = max(int(ratio * self.whole_map_size), 0)
+
+ if y_idx == 0:
+ sub_y1 = 0
+ elif y_idx == self.grid_size - 1:
+ sub_y1 = half_size
+ else:
+ ratio = y_idx / (self.grid_size - 1) - 0.25
+ sub_y1 = max(int(ratio * self.whole_map_size), 0)
+ sub_regions.append(
+ (sub_x1, sub_y1, sub_x1 + half_size, sub_y1 + half_size))
+ return sub_regions
+
+ def get_targets(self, sampling_results, rcnn_train_cfg):
+ # mix all samples (across images) together.
+ pos_bboxes = torch.cat([res.pos_bboxes for res in sampling_results],
+ dim=0).cpu()
+ pos_gt_bboxes = torch.cat(
+ [res.pos_gt_bboxes for res in sampling_results], dim=0).cpu()
+ assert pos_bboxes.shape == pos_gt_bboxes.shape
+
+ # expand pos_bboxes to 2x of original size
+ x1 = pos_bboxes[:, 0] - (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
+ y1 = pos_bboxes[:, 1] - (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
+ x2 = pos_bboxes[:, 2] + (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
+ y2 = pos_bboxes[:, 3] + (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
+ pos_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
+ pos_bbox_ws = (pos_bboxes[:, 2] - pos_bboxes[:, 0]).unsqueeze(-1)
+ pos_bbox_hs = (pos_bboxes[:, 3] - pos_bboxes[:, 1]).unsqueeze(-1)
+
+ num_rois = pos_bboxes.shape[0]
+ map_size = self.whole_map_size
+ # this is not the final target shape
+ targets = torch.zeros((num_rois, self.grid_points, map_size, map_size),
+ dtype=torch.float)
+
+ # pre-compute interpolation factors for all grid points.
+ # the first item is the factor of x-dim, and the second is y-dim.
+ # for a 9-point grid, factors are like (1, 0), (0.5, 0.5), (0, 1)
+ factors = []
+ for j in range(self.grid_points):
+ x_idx = j // self.grid_size
+ y_idx = j % self.grid_size
+ factors.append((1 - x_idx / (self.grid_size - 1),
+ 1 - y_idx / (self.grid_size - 1)))
+
+ radius = rcnn_train_cfg.pos_radius
+ radius2 = radius**2
+ for i in range(num_rois):
+ # ignore small bboxes
+ if (pos_bbox_ws[i] <= self.grid_size
+ or pos_bbox_hs[i] <= self.grid_size):
+ continue
+ # for each grid point, mark a small circle as positive
+ for j in range(self.grid_points):
+ factor_x, factor_y = factors[j]
+ gridpoint_x = factor_x * pos_gt_bboxes[i, 0] + (
+ 1 - factor_x) * pos_gt_bboxes[i, 2]
+ gridpoint_y = factor_y * pos_gt_bboxes[i, 1] + (
+ 1 - factor_y) * pos_gt_bboxes[i, 3]
+
+ cx = int((gridpoint_x - pos_bboxes[i, 0]) / pos_bbox_ws[i] *
+ map_size)
+ cy = int((gridpoint_y - pos_bboxes[i, 1]) / pos_bbox_hs[i] *
+ map_size)
+
+ for x in range(cx - radius, cx + radius + 1):
+ for y in range(cy - radius, cy + radius + 1):
+ if x >= 0 and x < map_size and y >= 0 and y < map_size:
+ if (x - cx)**2 + (y - cy)**2 <= radius2:
+ targets[i, j, y, x] = 1
+ # reduce the target heatmap size by a half
+ # proposed in Grid R-CNN Plus (https://arxiv.org/abs/1906.05688).
+ sub_targets = []
+ for i in range(self.grid_points):
+ sub_x1, sub_y1, sub_x2, sub_y2 = self.sub_regions[i]
+ sub_targets.append(targets[:, [i], sub_y1:sub_y2, sub_x1:sub_x2])
+ sub_targets = torch.cat(sub_targets, dim=1)
+ sub_targets = sub_targets.to(sampling_results[0].pos_bboxes.device)
+ return sub_targets
+
+ def loss(self, grid_pred, grid_targets):
+ loss_fused = self.loss_grid(grid_pred['fused'], grid_targets)
+ loss_unfused = self.loss_grid(grid_pred['unfused'], grid_targets)
+ loss_grid = loss_fused + loss_unfused
+ return dict(loss_grid=loss_grid)
+
+ def get_bboxes(self, det_bboxes, grid_pred, img_metas):
+ # TODO: refactoring
+ assert det_bboxes.shape[0] == grid_pred.shape[0]
+ det_bboxes = det_bboxes.cpu()
+ cls_scores = det_bboxes[:, [4]]
+ det_bboxes = det_bboxes[:, :4]
+ grid_pred = grid_pred.sigmoid().cpu()
+
+ R, c, h, w = grid_pred.shape
+ half_size = self.whole_map_size // 4 * 2
+ assert h == w == half_size
+ assert c == self.grid_points
+
+ # find the point with max scores in the half-sized heatmap
+ grid_pred = grid_pred.view(R * c, h * w)
+ pred_scores, pred_position = grid_pred.max(dim=1)
+ xs = pred_position % w
+ ys = pred_position // w
+
+ # get the position in the whole heatmap instead of half-sized heatmap
+ for i in range(self.grid_points):
+ xs[i::self.grid_points] += self.sub_regions[i][0]
+ ys[i::self.grid_points] += self.sub_regions[i][1]
+
+ # reshape to (num_rois, grid_points)
+ pred_scores, xs, ys = tuple(
+ map(lambda x: x.view(R, c), [pred_scores, xs, ys]))
+
+ # get expanded pos_bboxes
+ widths = (det_bboxes[:, 2] - det_bboxes[:, 0]).unsqueeze(-1)
+ heights = (det_bboxes[:, 3] - det_bboxes[:, 1]).unsqueeze(-1)
+ x1 = (det_bboxes[:, 0, None] - widths / 2)
+ y1 = (det_bboxes[:, 1, None] - heights / 2)
+ # map the grid point to the absolute coordinates
+ abs_xs = (xs.float() + 0.5) / w * widths + x1
+ abs_ys = (ys.float() + 0.5) / h * heights + y1
+
+ # get the grid points indices that fall on the bbox boundaries
+ x1_inds = [i for i in range(self.grid_size)]
+ y1_inds = [i * self.grid_size for i in range(self.grid_size)]
+ x2_inds = [
+ self.grid_points - self.grid_size + i
+ for i in range(self.grid_size)
+ ]
+ y2_inds = [(i + 1) * self.grid_size - 1 for i in range(self.grid_size)]
+
+ # voting of all grid points on some boundary
+ bboxes_x1 = (abs_xs[:, x1_inds] * pred_scores[:, x1_inds]).sum(
+ dim=1, keepdim=True) / (
+ pred_scores[:, x1_inds].sum(dim=1, keepdim=True))
+ bboxes_y1 = (abs_ys[:, y1_inds] * pred_scores[:, y1_inds]).sum(
+ dim=1, keepdim=True) / (
+ pred_scores[:, y1_inds].sum(dim=1, keepdim=True))
+ bboxes_x2 = (abs_xs[:, x2_inds] * pred_scores[:, x2_inds]).sum(
+ dim=1, keepdim=True) / (
+ pred_scores[:, x2_inds].sum(dim=1, keepdim=True))
+ bboxes_y2 = (abs_ys[:, y2_inds] * pred_scores[:, y2_inds]).sum(
+ dim=1, keepdim=True) / (
+ pred_scores[:, y2_inds].sum(dim=1, keepdim=True))
+
+ bbox_res = torch.cat(
+ [bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2, cls_scores], dim=1)
+ bbox_res[:, [0, 2]].clamp_(min=0, max=img_metas[0]['img_shape'][1])
+ bbox_res[:, [1, 3]].clamp_(min=0, max=img_metas[0]['img_shape'][0])
+
+ return bbox_res
diff --git a/mmdet/models/roi_heads/mask_heads/htc_mask_head.py b/mmdet/models/roi_heads/mask_heads/htc_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ad8592b4c35e4d1c483fe6bc372ee1facb8fde2
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/htc_mask_head.py
@@ -0,0 +1,39 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.cnn import ConvModule
+
+from mmdet.models.builder import HEADS
+from .fcn_mask_head import FCNMaskHead
+
+
+@HEADS.register_module()
+class HTCMaskHead(FCNMaskHead):
+
+ def __init__(self, with_conv_res=True, *args, **kwargs):
+ super(HTCMaskHead, self).__init__(*args, **kwargs)
+ self.with_conv_res = with_conv_res
+ if self.with_conv_res:
+ self.conv_res = ConvModule(
+ self.conv_out_channels,
+ self.conv_out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+
+ def forward(self, x, res_feat=None, return_logits=True, return_feat=True):
+ if res_feat is not None:
+ assert self.with_conv_res
+ res_feat = self.conv_res(res_feat)
+ x = x + res_feat
+ for conv in self.convs:
+ x = conv(x)
+ res_feat = x
+ outs = []
+ if return_logits:
+ x = self.upsample(x)
+ if self.upsample_method == 'deconv':
+ x = self.relu(x)
+ mask_pred = self.conv_logits(x)
+ outs.append(mask_pred)
+ if return_feat:
+ outs.append(res_feat)
+ return outs if len(outs) > 1 else outs[0]
diff --git a/mmdet/models/roi_heads/mask_heads/mask_point_head.py b/mmdet/models/roi_heads/mask_heads/mask_point_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c77c46d2c6fc73872535597068441cdb608e481c
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/mask_point_head.py
@@ -0,0 +1,253 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point
+from mmcv.runner import BaseModule
+
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.utils import (get_uncertain_point_coords_with_randomness,
+ get_uncertainty)
+
+
+@HEADS.register_module()
+class MaskPointHead(BaseModule):
+ """A mask point head use in PointRend.
+
+ ``MaskPointHead`` use shared multi-layer perceptron (equivalent to
+ nn.Conv1d) to predict the logit of input points. The fine-grained feature
+ and coarse feature will be concatenate together for predication.
+
+ Args:
+ num_fcs (int): Number of fc layers in the head. Default: 3.
+ in_channels (int): Number of input channels. Default: 256.
+ fc_channels (int): Number of fc channels. Default: 256.
+ num_classes (int): Number of classes for logits. Default: 80.
+ class_agnostic (bool): Whether use class agnostic classification.
+ If so, the output channels of logits will be 1. Default: False.
+ coarse_pred_each_layer (bool): Whether concatenate coarse feature with
+ the output of each fc layer. Default: True.
+ conv_cfg (dict | None): Dictionary to construct and config conv layer.
+ Default: dict(type='Conv1d'))
+ norm_cfg (dict | None): Dictionary to construct and config norm layer.
+ Default: None.
+ loss_point (dict): Dictionary to construct and config loss layer of
+ point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
+ loss_weight=1.0).
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_classes,
+ num_fcs=3,
+ in_channels=256,
+ fc_channels=256,
+ class_agnostic=False,
+ coarse_pred_each_layer=True,
+ conv_cfg=dict(type='Conv1d'),
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ loss_point=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0),
+ init_cfg=dict(
+ type='Normal', std=0.001,
+ override=dict(name='fc_logits'))):
+ super().__init__(init_cfg)
+ self.num_fcs = num_fcs
+ self.in_channels = in_channels
+ self.fc_channels = fc_channels
+ self.num_classes = num_classes
+ self.class_agnostic = class_agnostic
+ self.coarse_pred_each_layer = coarse_pred_each_layer
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.loss_point = build_loss(loss_point)
+
+ fc_in_channels = in_channels + num_classes
+ self.fcs = nn.ModuleList()
+ for _ in range(num_fcs):
+ fc = ConvModule(
+ fc_in_channels,
+ fc_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.fcs.append(fc)
+ fc_in_channels = fc_channels
+ fc_in_channels += num_classes if self.coarse_pred_each_layer else 0
+
+ out_channels = 1 if self.class_agnostic else self.num_classes
+ self.fc_logits = nn.Conv1d(
+ fc_in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, fine_grained_feats, coarse_feats):
+ """Classify each point base on fine grained and coarse feats.
+
+ Args:
+ fine_grained_feats (Tensor): Fine grained feature sampled from FPN,
+ shape (num_rois, in_channels, num_points).
+ coarse_feats (Tensor): Coarse feature sampled from CoarseMaskHead,
+ shape (num_rois, num_classes, num_points).
+
+ Returns:
+ Tensor: Point classification results,
+ shape (num_rois, num_class, num_points).
+ """
+
+ x = torch.cat([fine_grained_feats, coarse_feats], dim=1)
+ for fc in self.fcs:
+ x = fc(x)
+ if self.coarse_pred_each_layer:
+ x = torch.cat((x, coarse_feats), dim=1)
+ return self.fc_logits(x)
+
+ def get_targets(self, rois, rel_roi_points, sampling_results, gt_masks,
+ cfg):
+ """Get training targets of MaskPointHead for all images.
+
+ Args:
+ rois (Tensor): Region of Interest, shape (num_rois, 5).
+ rel_roi_points: Points coordinates relative to RoI, shape
+ (num_rois, num_points, 2).
+ sampling_results (:obj:`SamplingResult`): Sampling result after
+ sampling and assignment.
+ gt_masks (Tensor) : Ground truth segmentation masks of
+ corresponding boxes, shape (num_rois, height, width).
+ cfg (dict): Training cfg.
+
+ Returns:
+ Tensor: Point target, shape (num_rois, num_points).
+ """
+
+ num_imgs = len(sampling_results)
+ rois_list = []
+ rel_roi_points_list = []
+ for batch_ind in range(num_imgs):
+ inds = (rois[:, 0] == batch_ind)
+ rois_list.append(rois[inds])
+ rel_roi_points_list.append(rel_roi_points[inds])
+ pos_assigned_gt_inds_list = [
+ res.pos_assigned_gt_inds for res in sampling_results
+ ]
+ cfg_list = [cfg for _ in range(num_imgs)]
+
+ point_targets = map(self._get_target_single, rois_list,
+ rel_roi_points_list, pos_assigned_gt_inds_list,
+ gt_masks, cfg_list)
+ point_targets = list(point_targets)
+
+ if len(point_targets) > 0:
+ point_targets = torch.cat(point_targets)
+
+ return point_targets
+
+ def _get_target_single(self, rois, rel_roi_points, pos_assigned_gt_inds,
+ gt_masks, cfg):
+ """Get training target of MaskPointHead for each image."""
+ num_pos = rois.size(0)
+ num_points = cfg.num_points
+ if num_pos > 0:
+ gt_masks_th = (
+ gt_masks.to_tensor(rois.dtype, rois.device).index_select(
+ 0, pos_assigned_gt_inds))
+ gt_masks_th = gt_masks_th.unsqueeze(1)
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois, rel_roi_points, gt_masks_th)
+ point_targets = point_sample(gt_masks_th,
+ rel_img_points).squeeze(1)
+ else:
+ point_targets = rois.new_zeros((0, num_points))
+ return point_targets
+
+ def loss(self, point_pred, point_targets, labels):
+ """Calculate loss for MaskPointHead.
+
+ Args:
+ point_pred (Tensor): Point predication result, shape
+ (num_rois, num_classes, num_points).
+ point_targets (Tensor): Point targets, shape (num_roi, num_points).
+ labels (Tensor): Class label of corresponding boxes,
+ shape (num_rois, )
+
+ Returns:
+ dict[str, Tensor]: a dictionary of point loss components
+ """
+
+ loss = dict()
+ if self.class_agnostic:
+ loss_point = self.loss_point(point_pred, point_targets,
+ torch.zeros_like(labels))
+ else:
+ loss_point = self.loss_point(point_pred, point_targets, labels)
+ loss['loss_point'] = loss_point
+ return loss
+
+ def get_roi_rel_points_train(self, mask_pred, labels, cfg):
+ """Get ``num_points`` most uncertain points with random points during
+ train.
+
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
+ uncertainty. The uncertainties are calculated for each point using
+ '_get_uncertainty()' function that takes point's logit prediction as
+ input.
+
+ Args:
+ mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
+ mask_height, mask_width) for class-specific or class-agnostic
+ prediction.
+ labels (list): The ground truth class for each instance.
+ cfg (dict): Training config of point head.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
+ that contains the coordinates sampled points.
+ """
+ point_coords = get_uncertain_point_coords_with_randomness(
+ mask_pred, labels, cfg.num_points, cfg.oversample_ratio,
+ cfg.importance_sample_ratio)
+ return point_coords
+
+ def get_roi_rel_points_test(self, mask_pred, pred_label, cfg):
+ """Get ``num_points`` most uncertain points during test.
+
+ Args:
+ mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
+ mask_height, mask_width) for class-specific or class-agnostic
+ prediction.
+ pred_label (list): The predication class for each instance.
+ cfg (dict): Testing config of point head.
+
+ Returns:
+ point_indices (Tensor): A tensor of shape (num_rois, num_points)
+ that contains indices from [0, mask_height x mask_width) of the
+ most uncertain points.
+ point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
+ that contains [0, 1] x [0, 1] normalized coordinates of the
+ most uncertain points from the [mask_height, mask_width] grid .
+ """
+ num_points = cfg.subdivision_num_points
+ uncertainty_map = get_uncertainty(mask_pred, pred_label)
+ num_rois, _, mask_height, mask_width = uncertainty_map.shape
+
+ # During ONNX exporting, the type of each elements of 'shape' is
+ # `Tensor(float)`, while it is `float` during PyTorch inference.
+ if isinstance(mask_height, torch.Tensor):
+ h_step = 1.0 / mask_height.float()
+ w_step = 1.0 / mask_width.float()
+ else:
+ h_step = 1.0 / mask_height
+ w_step = 1.0 / mask_width
+ # cast to int to avoid dynamic K for TopK op in ONNX
+ mask_size = int(mask_height * mask_width)
+ uncertainty_map = uncertainty_map.view(num_rois, mask_size)
+ num_points = min(mask_size, num_points)
+ point_indices = uncertainty_map.topk(num_points, dim=1)[1]
+ xs = w_step / 2.0 + (point_indices % mask_width).float() * w_step
+ ys = h_step / 2.0 + (point_indices // mask_width).float() * h_step
+ point_coords = torch.stack([xs, ys], dim=2)
+ return point_indices, point_coords
diff --git a/mmdet/models/roi_heads/mask_heads/maskiou_head.py b/mmdet/models/roi_heads/mask_heads/maskiou_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7ff7c7c4e70bd3c033731f9bc0bf40ca74a4bba
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/maskiou_head.py
@@ -0,0 +1,183 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import Conv2d, Linear, MaxPool2d
+from mmcv.runner import BaseModule, force_fp32
+from torch.nn.modules.utils import _pair
+
+from mmdet.models.builder import HEADS, build_loss
+
+
+@HEADS.register_module()
+class MaskIoUHead(BaseModule):
+ """Mask IoU Head.
+
+ This head predicts the IoU of predicted masks and corresponding gt masks.
+ """
+
+ def __init__(self,
+ num_convs=4,
+ num_fcs=2,
+ roi_feat_size=14,
+ in_channels=256,
+ conv_out_channels=256,
+ fc_out_channels=1024,
+ num_classes=80,
+ loss_iou=dict(type='MSELoss', loss_weight=0.5),
+ init_cfg=[
+ dict(type='Kaiming', override=dict(name='convs')),
+ dict(type='Caffe2Xavier', override=dict(name='fcs')),
+ dict(
+ type='Normal',
+ std=0.01,
+ override=dict(name='fc_mask_iou'))
+ ]):
+ super(MaskIoUHead, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.conv_out_channels = conv_out_channels
+ self.fc_out_channels = fc_out_channels
+ self.num_classes = num_classes
+ self.fp16_enabled = False
+
+ self.convs = nn.ModuleList()
+ for i in range(num_convs):
+ if i == 0:
+ # concatenation of mask feature and mask prediction
+ in_channels = self.in_channels + 1
+ else:
+ in_channels = self.conv_out_channels
+ stride = 2 if i == num_convs - 1 else 1
+ self.convs.append(
+ Conv2d(
+ in_channels,
+ self.conv_out_channels,
+ 3,
+ stride=stride,
+ padding=1))
+
+ roi_feat_size = _pair(roi_feat_size)
+ pooled_area = (roi_feat_size[0] // 2) * (roi_feat_size[1] // 2)
+ self.fcs = nn.ModuleList()
+ for i in range(num_fcs):
+ in_channels = (
+ self.conv_out_channels *
+ pooled_area if i == 0 else self.fc_out_channels)
+ self.fcs.append(Linear(in_channels, self.fc_out_channels))
+
+ self.fc_mask_iou = Linear(self.fc_out_channels, self.num_classes)
+ self.relu = nn.ReLU()
+ self.max_pool = MaxPool2d(2, 2)
+ self.loss_iou = build_loss(loss_iou)
+
+ def forward(self, mask_feat, mask_pred):
+ mask_pred = mask_pred.sigmoid()
+ mask_pred_pooled = self.max_pool(mask_pred.unsqueeze(1))
+
+ x = torch.cat((mask_feat, mask_pred_pooled), 1)
+
+ for conv in self.convs:
+ x = self.relu(conv(x))
+ x = x.flatten(1)
+ for fc in self.fcs:
+ x = self.relu(fc(x))
+ mask_iou = self.fc_mask_iou(x)
+ return mask_iou
+
+ @force_fp32(apply_to=('mask_iou_pred', ))
+ def loss(self, mask_iou_pred, mask_iou_targets):
+ pos_inds = mask_iou_targets > 0
+ if pos_inds.sum() > 0:
+ loss_mask_iou = self.loss_iou(mask_iou_pred[pos_inds],
+ mask_iou_targets[pos_inds])
+ else:
+ loss_mask_iou = mask_iou_pred.sum() * 0
+ return dict(loss_mask_iou=loss_mask_iou)
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def get_targets(self, sampling_results, gt_masks, mask_pred, mask_targets,
+ rcnn_train_cfg):
+ """Compute target of mask IoU.
+
+ Mask IoU target is the IoU of the predicted mask (inside a bbox) and
+ the gt mask of corresponding gt mask (the whole instance).
+ The intersection area is computed inside the bbox, and the gt mask area
+ is computed with two steps, firstly we compute the gt area inside the
+ bbox, then divide it by the area ratio of gt area inside the bbox and
+ the gt area of the whole instance.
+
+ Args:
+ sampling_results (list[:obj:`SamplingResult`]): sampling results.
+ gt_masks (BitmapMask | PolygonMask): Gt masks (the whole instance)
+ of each image, with the same shape of the input image.
+ mask_pred (Tensor): Predicted masks of each positive proposal,
+ shape (num_pos, h, w).
+ mask_targets (Tensor): Gt mask of each positive proposal,
+ binary map of the shape (num_pos, h, w).
+ rcnn_train_cfg (dict): Training config for R-CNN part.
+
+ Returns:
+ Tensor: mask iou target (length == num positive).
+ """
+ pos_proposals = [res.pos_bboxes for res in sampling_results]
+ pos_assigned_gt_inds = [
+ res.pos_assigned_gt_inds for res in sampling_results
+ ]
+
+ # compute the area ratio of gt areas inside the proposals and
+ # the whole instance
+ area_ratios = map(self._get_area_ratio, pos_proposals,
+ pos_assigned_gt_inds, gt_masks)
+ area_ratios = torch.cat(list(area_ratios))
+ assert mask_targets.size(0) == area_ratios.size(0)
+
+ mask_pred = (mask_pred > rcnn_train_cfg.mask_thr_binary).float()
+ mask_pred_areas = mask_pred.sum((-1, -2))
+
+ # mask_pred and mask_targets are binary maps
+ overlap_areas = (mask_pred * mask_targets).sum((-1, -2))
+
+ # compute the mask area of the whole instance
+ gt_full_areas = mask_targets.sum((-1, -2)) / (area_ratios + 1e-7)
+
+ mask_iou_targets = overlap_areas / (
+ mask_pred_areas + gt_full_areas - overlap_areas)
+ return mask_iou_targets
+
+ def _get_area_ratio(self, pos_proposals, pos_assigned_gt_inds, gt_masks):
+ """Compute area ratio of the gt mask inside the proposal and the gt
+ mask of the corresponding instance."""
+ num_pos = pos_proposals.size(0)
+ if num_pos > 0:
+ area_ratios = []
+ proposals_np = pos_proposals.cpu().numpy()
+ pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
+ # compute mask areas of gt instances (batch processing for speedup)
+ gt_instance_mask_area = gt_masks.areas
+ for i in range(num_pos):
+ gt_mask = gt_masks[pos_assigned_gt_inds[i]]
+
+ # crop the gt mask inside the proposal
+ bbox = proposals_np[i, :].astype(np.int32)
+ gt_mask_in_proposal = gt_mask.crop(bbox)
+
+ ratio = gt_mask_in_proposal.areas[0] / (
+ gt_instance_mask_area[pos_assigned_gt_inds[i]] + 1e-7)
+ area_ratios.append(ratio)
+ area_ratios = torch.from_numpy(np.stack(area_ratios)).float().to(
+ pos_proposals.device)
+ else:
+ area_ratios = pos_proposals.new_zeros((0, ))
+ return area_ratios
+
+ @force_fp32(apply_to=('mask_iou_pred', ))
+ def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels):
+ """Get the mask scores.
+
+ mask_score = bbox_score * mask_iou
+ """
+ inds = range(det_labels.size(0))
+ mask_scores = mask_iou_pred[inds, det_labels] * det_bboxes[inds, -1]
+ mask_scores = mask_scores.cpu().numpy()
+ det_labels = det_labels.cpu().numpy()
+ return [mask_scores[det_labels == i] for i in range(self.num_classes)]
diff --git a/mmdet/models/roi_heads/mask_heads/scnet_mask_head.py b/mmdet/models/roi_heads/mask_heads/scnet_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca62486615a3c99fe09f4c71758b4dd01dc2fc3a
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/scnet_mask_head.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmdet.models.builder import HEADS
+from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
+from .fcn_mask_head import FCNMaskHead
+
+
+@HEADS.register_module()
+class SCNetMaskHead(FCNMaskHead):
+ """Mask head for `SCNet `_.
+
+ Args:
+ conv_to_res (bool, optional): if True, change the conv layers to
+ ``SimplifiedBasicBlock``.
+ """
+
+ def __init__(self, conv_to_res=True, **kwargs):
+ super(SCNetMaskHead, self).__init__(**kwargs)
+ self.conv_to_res = conv_to_res
+ if conv_to_res:
+ assert self.conv_kernel_size == 3
+ self.num_res_blocks = self.num_convs // 2
+ self.convs = ResLayer(
+ SimplifiedBasicBlock,
+ self.in_channels,
+ self.conv_out_channels,
+ self.num_res_blocks,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
diff --git a/mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py b/mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b8c5c32bbb7604426d774674eb9fecb51e1d789
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmdet.models.builder import HEADS
+from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
+from .fused_semantic_head import FusedSemanticHead
+
+
+@HEADS.register_module()
+class SCNetSemanticHead(FusedSemanticHead):
+ """Mask head for `SCNet `_.
+
+ Args:
+ conv_to_res (bool, optional): if True, change the conv layers to
+ ``SimplifiedBasicBlock``.
+ """
+
+ def __init__(self, conv_to_res=True, **kwargs):
+ super(SCNetSemanticHead, self).__init__(**kwargs)
+ self.conv_to_res = conv_to_res
+ if self.conv_to_res:
+ num_res_blocks = self.num_convs // 2
+ self.convs = ResLayer(
+ SimplifiedBasicBlock,
+ self.in_channels,
+ self.conv_out_channels,
+ num_res_blocks,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ self.num_convs = num_res_blocks
diff --git a/mmdet/models/roi_heads/mask_scoring_roi_head.py b/mmdet/models/roi_heads/mask_scoring_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4617988e30abebe9ede13e04dda72632724ce159
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_scoring_roi_head.py
@@ -0,0 +1,113 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.core import bbox2roi
+from ..builder import HEADS, build_head
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class MaskScoringRoIHead(StandardRoIHead):
+ """Mask Scoring RoIHead for Mask Scoring RCNN.
+
+ https://arxiv.org/abs/1903.00241
+ """
+
+ def __init__(self, mask_iou_head, **kwargs):
+ assert mask_iou_head is not None
+ super(MaskScoringRoIHead, self).__init__(**kwargs)
+ self.mask_iou_head = build_head(mask_iou_head)
+
+ def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
+ img_metas):
+ """Run forward function and calculate loss for Mask head in
+ training."""
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ mask_results = super(MaskScoringRoIHead,
+ self)._mask_forward_train(x, sampling_results,
+ bbox_feats, gt_masks,
+ img_metas)
+ if mask_results['loss_mask'] is None:
+ return mask_results
+
+ # mask iou head forward and loss
+ pos_mask_pred = mask_results['mask_pred'][
+ range(mask_results['mask_pred'].size(0)), pos_labels]
+ mask_iou_pred = self.mask_iou_head(mask_results['mask_feats'],
+ pos_mask_pred)
+ pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)),
+ pos_labels]
+
+ mask_iou_targets = self.mask_iou_head.get_targets(
+ sampling_results, gt_masks, pos_mask_pred,
+ mask_results['mask_targets'], self.train_cfg)
+ loss_mask_iou = self.mask_iou_head.loss(pos_mask_iou_pred,
+ mask_iou_targets)
+ mask_results['loss_mask'].update(loss_mask_iou)
+ return mask_results
+
+ def simple_test_mask(self,
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=False):
+ """Obtain mask prediction without augmentation."""
+ # image shapes of images in the batch
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ num_imgs = len(det_bboxes)
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ num_classes = self.mask_head.num_classes
+ segm_results = [[[] for _ in range(num_classes)]
+ for _ in range(num_imgs)]
+ mask_scores = [[[] for _ in range(num_classes)]
+ for _ in range(num_imgs)]
+ else:
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i]
+ for i in range(num_imgs)
+ ]
+ mask_rois = bbox2roi(_bboxes)
+ mask_results = self._mask_forward(x, mask_rois)
+ concat_det_labels = torch.cat(det_labels)
+ # get mask scores with mask iou head
+ mask_feats = mask_results['mask_feats']
+ mask_pred = mask_results['mask_pred']
+ mask_iou_pred = self.mask_iou_head(
+ mask_feats, mask_pred[range(concat_det_labels.size(0)),
+ concat_det_labels])
+ # split batch mask prediction back to each image
+ num_bboxes_per_img = tuple(len(_bbox) for _bbox in _bboxes)
+ mask_preds = mask_pred.split(num_bboxes_per_img, 0)
+ mask_iou_preds = mask_iou_pred.split(num_bboxes_per_img, 0)
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ mask_scores = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ mask_scores.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ else:
+ segm_result = self.mask_head.get_seg_masks(
+ mask_preds[i], _bboxes[i], det_labels[i],
+ self.test_cfg, ori_shapes[i], scale_factors[i],
+ rescale)
+ # get mask scores with mask iou head
+ mask_score = self.mask_iou_head.get_mask_scores(
+ mask_iou_preds[i], det_bboxes[i], det_labels[i])
+ segm_results.append(segm_result)
+ mask_scores.append(mask_score)
+ return list(zip(segm_results, mask_scores))
diff --git a/mmdet/models/roi_heads/pisa_roi_head.py b/mmdet/models/roi_heads/pisa_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..92a51186e28bf25ba71474536fc211037999d0f8
--- /dev/null
+++ b/mmdet/models/roi_heads/pisa_roi_head.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmdet.core import bbox2roi
+from ..builder import HEADS
+from ..losses.pisa_loss import carl_loss, isr_p
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class PISARoIHead(StandardRoIHead):
+ r"""The RoI head for `Prime Sample Attention in Object Detection
+ `_."""
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None):
+ """Forward function for training.
+
+ Args:
+ x (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ proposals (list[Tensors]): List of region proposals.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (list[Tensor], optional): Specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (None | Tensor) : True segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # assign gts and sample proposals
+ if self.with_bbox or self.with_mask:
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+ sampling_results = []
+ neg_label_weights = []
+ for i in range(num_imgs):
+ assign_result = self.bbox_assigner.assign(
+ proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
+ gt_labels[i])
+ sampling_result = self.bbox_sampler.sample(
+ assign_result,
+ proposal_list[i],
+ gt_bboxes[i],
+ gt_labels[i],
+ feats=[lvl_feat[i][None] for lvl_feat in x])
+ # neg label weight is obtained by sampling when using ISR-N
+ neg_label_weight = None
+ if isinstance(sampling_result, tuple):
+ sampling_result, neg_label_weight = sampling_result
+ sampling_results.append(sampling_result)
+ neg_label_weights.append(neg_label_weight)
+
+ losses = dict()
+ # bbox head forward and loss
+ if self.with_bbox:
+ bbox_results = self._bbox_forward_train(
+ x,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ neg_label_weights=neg_label_weights)
+ losses.update(bbox_results['loss_bbox'])
+
+ # mask head forward and loss
+ if self.with_mask:
+ mask_results = self._mask_forward_train(x, sampling_results,
+ bbox_results['bbox_feats'],
+ gt_masks, img_metas)
+ losses.update(mask_results['loss_mask'])
+
+ return losses
+
+ def _bbox_forward(self, x, rois):
+ """Box forward function used in both training and testing."""
+ # TODO: a more flexible way to decide which feature maps to use
+ bbox_feats = self.bbox_roi_extractor(
+ x[:self.bbox_roi_extractor.num_inputs], rois)
+ if self.with_shared_head:
+ bbox_feats = self.shared_head(bbox_feats)
+ cls_score, bbox_pred = self.bbox_head(bbox_feats)
+
+ bbox_results = dict(
+ cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
+ return bbox_results
+
+ def _bbox_forward_train(self,
+ x,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ neg_label_weights=None):
+ """Run forward function and calculate loss for box head in training."""
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+
+ bbox_results = self._bbox_forward(x, rois)
+
+ bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, self.train_cfg)
+
+ # neg_label_weights obtained by sampler is image-wise, mapping back to
+ # the corresponding location in label weights
+ if neg_label_weights[0] is not None:
+ label_weights = bbox_targets[1]
+ cur_num_rois = 0
+ for i in range(len(sampling_results)):
+ num_pos = sampling_results[i].pos_inds.size(0)
+ num_neg = sampling_results[i].neg_inds.size(0)
+ label_weights[cur_num_rois + num_pos:cur_num_rois + num_pos +
+ num_neg] = neg_label_weights[i]
+ cur_num_rois += num_pos + num_neg
+
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+
+ # Apply ISR-P
+ isr_cfg = self.train_cfg.get('isr', None)
+ if isr_cfg is not None:
+ bbox_targets = isr_p(
+ cls_score,
+ bbox_pred,
+ bbox_targets,
+ rois,
+ sampling_results,
+ self.bbox_head.loss_cls,
+ self.bbox_head.bbox_coder,
+ **isr_cfg,
+ num_class=self.bbox_head.num_classes)
+ loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, rois,
+ *bbox_targets)
+
+ # Add CARL Loss
+ carl_cfg = self.train_cfg.get('carl', None)
+ if carl_cfg is not None:
+ loss_carl = carl_loss(
+ cls_score,
+ bbox_targets[0],
+ bbox_pred,
+ bbox_targets[2],
+ self.bbox_head.loss_bbox,
+ **carl_cfg,
+ num_class=self.bbox_head.num_classes)
+ loss_bbox.update(loss_carl)
+
+ bbox_results.update(loss_bbox=loss_bbox)
+ return bbox_results
diff --git a/mmdet/models/roi_heads/point_rend_roi_head.py b/mmdet/models/roi_heads/point_rend_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f667793f48abd948592d1c0f50f8975ae2c4b89
--- /dev/null
+++ b/mmdet/models/roi_heads/point_rend_roi_head.py
@@ -0,0 +1,393 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
+import os
+import warnings
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point
+
+from mmdet.core import bbox2roi, bbox_mapping, merge_aug_masks
+from .. import builder
+from ..builder import HEADS
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class PointRendRoIHead(StandardRoIHead):
+ """`PointRend `_."""
+
+ def __init__(self, point_head, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert self.with_bbox and self.with_mask
+ self.init_point_head(point_head)
+
+ def init_point_head(self, point_head):
+ """Initialize ``point_head``"""
+ self.point_head = builder.build_head(point_head)
+
+ def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
+ img_metas):
+ """Run forward function and calculate loss for mask head and point head
+ in training."""
+ mask_results = super()._mask_forward_train(x, sampling_results,
+ bbox_feats, gt_masks,
+ img_metas)
+ if mask_results['loss_mask'] is not None:
+ loss_point = self._mask_point_forward_train(
+ x, sampling_results, mask_results['mask_pred'], gt_masks,
+ img_metas)
+ mask_results['loss_mask'].update(loss_point)
+
+ return mask_results
+
+ def _mask_point_forward_train(self, x, sampling_results, mask_pred,
+ gt_masks, img_metas):
+ """Run forward function and calculate loss for point head in
+ training."""
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ rel_roi_points = self.point_head.get_roi_rel_points_train(
+ mask_pred, pos_labels, cfg=self.train_cfg)
+ rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, rois, rel_roi_points, img_metas)
+ coarse_point_feats = point_sample(mask_pred, rel_roi_points)
+ mask_point_pred = self.point_head(fine_grained_point_feats,
+ coarse_point_feats)
+ mask_point_target = self.point_head.get_targets(
+ rois, rel_roi_points, sampling_results, gt_masks, self.train_cfg)
+ loss_mask_point = self.point_head.loss(mask_point_pred,
+ mask_point_target, pos_labels)
+
+ return loss_mask_point
+
+ def _get_fine_grained_point_feats(self, x, rois, rel_roi_points,
+ img_metas):
+ """Sample fine grained feats from each level feature map and
+ concatenate them together.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ rois (Tensor): shape (num_rois, 5).
+ rel_roi_points (Tensor): A tensor of shape (num_rois, num_points,
+ 2) that contains [0, 1] x [0, 1] normalized coordinates of the
+ most uncertain points from the [mask_height, mask_width] grid.
+ img_metas (list[dict]): Image meta info.
+
+ Returns:
+ Tensor: The fine grained features for each points,
+ has shape (num_rois, feats_channels, num_points).
+ """
+ num_imgs = len(img_metas)
+ fine_grained_feats = []
+ for idx in range(self.mask_roi_extractor.num_inputs):
+ feats = x[idx]
+ spatial_scale = 1. / float(
+ self.mask_roi_extractor.featmap_strides[idx])
+ point_feats = []
+ for batch_ind in range(num_imgs):
+ # unravel batch dim
+ feat = feats[batch_ind].unsqueeze(0)
+ inds = (rois[:, 0].long() == batch_ind)
+ if inds.any():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois[inds], rel_roi_points[inds], feat.shape[2:],
+ spatial_scale).unsqueeze(0)
+ point_feat = point_sample(feat, rel_img_points)
+ point_feat = point_feat.squeeze(0).transpose(0, 1)
+ point_feats.append(point_feat)
+ fine_grained_feats.append(torch.cat(point_feats, dim=0))
+ return torch.cat(fine_grained_feats, dim=1)
+
+ def _mask_point_forward_test(self, x, rois, label_pred, mask_pred,
+ img_metas):
+ """Mask refining process with point head in testing.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ rois (Tensor): shape (num_rois, 5).
+ label_pred (Tensor): The predication class for each rois.
+ mask_pred (Tensor): The predication coarse masks of
+ shape (num_rois, num_classes, small_size, small_size).
+ img_metas (list[dict]): Image meta info.
+
+ Returns:
+ Tensor: The refined masks of shape (num_rois, num_classes,
+ large_size, large_size).
+ """
+ refined_mask_pred = mask_pred.clone()
+ for subdivision_step in range(self.test_cfg.subdivision_steps):
+ refined_mask_pred = F.interpolate(
+ refined_mask_pred,
+ scale_factor=self.test_cfg.scale_factor,
+ mode='bilinear',
+ align_corners=False)
+ # If `subdivision_num_points` is larger or equal to the
+ # resolution of the next step, then we can skip this step
+ num_rois, channels, mask_height, mask_width = \
+ refined_mask_pred.shape
+ if (self.test_cfg.subdivision_num_points >=
+ self.test_cfg.scale_factor**2 * mask_height * mask_width
+ and
+ subdivision_step < self.test_cfg.subdivision_steps - 1):
+ continue
+ point_indices, rel_roi_points = \
+ self.point_head.get_roi_rel_points_test(
+ refined_mask_pred, label_pred, cfg=self.test_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, rois, rel_roi_points, img_metas)
+ coarse_point_feats = point_sample(mask_pred, rel_roi_points)
+ mask_point_pred = self.point_head(fine_grained_point_feats,
+ coarse_point_feats)
+
+ point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
+ refined_mask_pred = refined_mask_pred.reshape(
+ num_rois, channels, mask_height * mask_width)
+ refined_mask_pred = refined_mask_pred.scatter_(
+ 2, point_indices, mask_point_pred)
+ refined_mask_pred = refined_mask_pred.view(num_rois, channels,
+ mask_height, mask_width)
+
+ return refined_mask_pred
+
+ def simple_test_mask(self,
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=False):
+ """Obtain mask prediction without augmentation."""
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ if isinstance(scale_factors[0], float):
+ warnings.warn(
+ 'Scale factor in img_metas should be a '
+ 'ndarray with shape (4,) '
+ 'arrange as (factor_w, factor_h, factor_w, factor_h), '
+ 'The scale_factor with float type has been deprecated. ')
+ scale_factors = np.array([scale_factors] * 4, dtype=np.float32)
+
+ num_imgs = len(det_bboxes)
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ segm_results = [[[] for _ in range(self.mask_head.num_classes)]
+ for _ in range(num_imgs)]
+ else:
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ _bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
+ if rescale:
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ _bboxes[i] * scale_factors[i] for i in range(len(_bboxes))
+ ]
+
+ mask_rois = bbox2roi(_bboxes)
+ mask_results = self._mask_forward(x, mask_rois)
+ # split batch mask prediction back to each image
+ mask_pred = mask_results['mask_pred']
+ num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
+ mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
+ mask_rois = mask_rois.split(num_mask_roi_per_img, 0)
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ else:
+ x_i = [xx[[i]] for xx in x]
+ mask_rois_i = mask_rois[i]
+ mask_rois_i[:, 0] = 0 # TODO: remove this hack
+ mask_pred_i = self._mask_point_forward_test(
+ x_i, mask_rois_i, det_labels[i], mask_preds[i],
+ [img_metas])
+ segm_result = self.mask_head.get_seg_masks(
+ mask_pred_i, _bboxes[i], det_labels[i], self.test_cfg,
+ ori_shapes[i], scale_factors[i], rescale)
+ segm_results.append(segm_result)
+ return segm_results
+
+ def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
+ """Test for mask head with test time augmentation."""
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[] for _ in range(self.mask_head.num_classes)]
+ else:
+ aug_masks = []
+ for x, img_meta in zip(feats, img_metas):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip)
+ mask_rois = bbox2roi([_bboxes])
+ mask_results = self._mask_forward(x, mask_rois)
+ mask_results['mask_pred'] = self._mask_point_forward_test(
+ x, mask_rois, det_labels, mask_results['mask_pred'],
+ img_meta)
+ # convert to numpy array to save memory
+ aug_masks.append(
+ mask_results['mask_pred'].sigmoid().cpu().numpy())
+ merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg)
+
+ ori_shape = img_metas[0][0]['ori_shape']
+ segm_result = self.mask_head.get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ self.test_cfg,
+ ori_shape,
+ scale_factor=1.0,
+ rescale=False)
+ return segm_result
+
+ def _onnx_get_fine_grained_point_feats(self, x, rois, rel_roi_points):
+ """Export the process of sampling fine grained feats to onnx.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ rois (Tensor): shape (num_rois, 5).
+ rel_roi_points (Tensor): A tensor of shape (num_rois, num_points,
+ 2) that contains [0, 1] x [0, 1] normalized coordinates of the
+ most uncertain points from the [mask_height, mask_width] grid.
+
+ Returns:
+ Tensor: The fine grained features for each points,
+ has shape (num_rois, feats_channels, num_points).
+ """
+ batch_size = x[0].shape[0]
+ num_rois = rois.shape[0]
+ fine_grained_feats = []
+ for idx in range(self.mask_roi_extractor.num_inputs):
+ feats = x[idx]
+ spatial_scale = 1. / float(
+ self.mask_roi_extractor.featmap_strides[idx])
+
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois, rel_roi_points, feats, spatial_scale)
+ channels = feats.shape[1]
+ num_points = rel_img_points.shape[1]
+ rel_img_points = rel_img_points.reshape(batch_size, -1, num_points,
+ 2)
+ point_feats = point_sample(feats, rel_img_points)
+ point_feats = point_feats.transpose(1, 2).reshape(
+ num_rois, channels, num_points)
+ fine_grained_feats.append(point_feats)
+ return torch.cat(fine_grained_feats, dim=1)
+
+ def _mask_point_onnx_export(self, x, rois, label_pred, mask_pred):
+ """Export mask refining process with point head to onnx.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ rois (Tensor): shape (num_rois, 5).
+ label_pred (Tensor): The predication class for each rois.
+ mask_pred (Tensor): The predication coarse masks of
+ shape (num_rois, num_classes, small_size, small_size).
+
+ Returns:
+ Tensor: The refined masks of shape (num_rois, num_classes,
+ large_size, large_size).
+ """
+ refined_mask_pred = mask_pred.clone()
+ for subdivision_step in range(self.test_cfg.subdivision_steps):
+ refined_mask_pred = F.interpolate(
+ refined_mask_pred,
+ scale_factor=self.test_cfg.scale_factor,
+ mode='bilinear',
+ align_corners=False)
+ # If `subdivision_num_points` is larger or equal to the
+ # resolution of the next step, then we can skip this step
+ num_rois, channels, mask_height, mask_width = \
+ refined_mask_pred.shape
+ if (self.test_cfg.subdivision_num_points >=
+ self.test_cfg.scale_factor**2 * mask_height * mask_width
+ and
+ subdivision_step < self.test_cfg.subdivision_steps - 1):
+ continue
+ point_indices, rel_roi_points = \
+ self.point_head.get_roi_rel_points_test(
+ refined_mask_pred, label_pred, cfg=self.test_cfg)
+ fine_grained_point_feats = self._onnx_get_fine_grained_point_feats(
+ x, rois, rel_roi_points)
+ coarse_point_feats = point_sample(mask_pred, rel_roi_points)
+ mask_point_pred = self.point_head(fine_grained_point_feats,
+ coarse_point_feats)
+
+ point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
+ refined_mask_pred = refined_mask_pred.reshape(
+ num_rois, channels, mask_height * mask_width)
+
+ is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
+ # avoid ScatterElements op in ONNX for TensorRT
+ if is_trt_backend:
+ mask_shape = refined_mask_pred.shape
+ point_shape = point_indices.shape
+ inds_dim0 = torch.arange(point_shape[0]).reshape(
+ point_shape[0], 1, 1).expand_as(point_indices)
+ inds_dim1 = torch.arange(point_shape[1]).reshape(
+ 1, point_shape[1], 1).expand_as(point_indices)
+ inds_1d = inds_dim0.reshape(
+ -1) * mask_shape[1] * mask_shape[2] + inds_dim1.reshape(
+ -1) * mask_shape[2] + point_indices.reshape(-1)
+ refined_mask_pred = refined_mask_pred.reshape(-1)
+ refined_mask_pred[inds_1d] = mask_point_pred.reshape(-1)
+ refined_mask_pred = refined_mask_pred.reshape(*mask_shape)
+ else:
+ refined_mask_pred = refined_mask_pred.scatter_(
+ 2, point_indices, mask_point_pred)
+
+ refined_mask_pred = refined_mask_pred.view(num_rois, channels,
+ mask_height, mask_width)
+
+ return refined_mask_pred
+
+ def mask_onnx_export(self, x, img_metas, det_bboxes, det_labels, **kwargs):
+ """Export mask branch to onnx which supports batch inference.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ img_metas (list[dict]): Image meta info.
+ det_bboxes (Tensor): Bboxes and corresponding scores.
+ has shape [N, num_bboxes, 5].
+ det_labels (Tensor): class labels of
+ shape [N, num_bboxes].
+
+ Returns:
+ Tensor: The segmentation results of shape [N, num_bboxes,
+ image_height, image_width].
+ """
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ raise RuntimeError('[ONNX Error] Can not record MaskHead '
+ 'as it has not been executed this time')
+ batch_size = det_bboxes.size(0)
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ det_bboxes = det_bboxes[..., :4]
+ batch_index = torch.arange(
+ det_bboxes.size(0), device=det_bboxes.device).float().view(
+ -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
+ mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
+ mask_rois = mask_rois.view(-1, 5)
+ mask_results = self._mask_forward(x, mask_rois)
+ mask_pred = mask_results['mask_pred']
+ max_shape = img_metas[0]['img_shape_for_onnx']
+ num_det = det_bboxes.shape[1]
+ det_bboxes = det_bboxes.reshape(-1, 4)
+ det_labels = det_labels.reshape(-1)
+
+ mask_pred = self._mask_point_onnx_export(x, mask_rois, det_labels,
+ mask_pred)
+
+ segm_results = self.mask_head.onnx_export(mask_pred, det_bboxes,
+ det_labels, self.test_cfg,
+ max_shape)
+ segm_results = segm_results.reshape(batch_size, num_det, max_shape[0],
+ max_shape[1])
+ return segm_results
diff --git a/mmdet/models/roi_heads/roi_extractors/__init__.py b/mmdet/models/roi_heads/roi_extractors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f60214991b0ed14cdbc3964aee15356c6aaf2aa
--- /dev/null
+++ b/mmdet/models/roi_heads/roi_extractors/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_roi_extractor import BaseRoIExtractor
+from .generic_roi_extractor import GenericRoIExtractor
+from .single_level_roi_extractor import SingleRoIExtractor
+
+__all__ = ['BaseRoIExtractor', 'SingleRoIExtractor', 'GenericRoIExtractor']
diff --git a/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..82629757decc4bc4c374369641f4b742abd47c12
--- /dev/null
+++ b/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py
@@ -0,0 +1,88 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+import torch
+import torch.nn as nn
+from mmcv import ops
+from mmcv.runner import BaseModule
+
+
+class BaseRoIExtractor(BaseModule, metaclass=ABCMeta):
+ """Base class for RoI extractor.
+
+ Args:
+ roi_layer (dict): Specify RoI layer type and arguments.
+ out_channels (int): Output channels of RoI layers.
+ featmap_strides (int): Strides of input feature maps.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ roi_layer,
+ out_channels,
+ featmap_strides,
+ init_cfg=None):
+ super(BaseRoIExtractor, self).__init__(init_cfg)
+ self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
+ self.out_channels = out_channels
+ self.featmap_strides = featmap_strides
+ self.fp16_enabled = False
+
+ @property
+ def num_inputs(self):
+ """int: Number of input feature maps."""
+ return len(self.featmap_strides)
+
+ def build_roi_layers(self, layer_cfg, featmap_strides):
+ """Build RoI operator to extract feature from each level feature map.
+
+ Args:
+ layer_cfg (dict): Dictionary to construct and config RoI layer
+ operation. Options are modules under ``mmcv/ops`` such as
+ ``RoIAlign``.
+ featmap_strides (List[int]): The stride of input feature map w.r.t
+ to the original image size, which would be used to scale RoI
+ coordinate (original image coordinate system) to feature
+ coordinate system.
+
+ Returns:
+ nn.ModuleList: The RoI extractor modules for each level feature
+ map.
+ """
+
+ cfg = layer_cfg.copy()
+ layer_type = cfg.pop('type')
+ assert hasattr(ops, layer_type)
+ layer_cls = getattr(ops, layer_type)
+ roi_layers = nn.ModuleList(
+ [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
+ return roi_layers
+
+ def roi_rescale(self, rois, scale_factor):
+ """Scale RoI coordinates by scale factor.
+
+ Args:
+ rois (torch.Tensor): RoI (Region of Interest), shape (n, 5)
+ scale_factor (float): Scale factor that RoI will be multiplied by.
+
+ Returns:
+ torch.Tensor: Scaled RoI.
+ """
+
+ cx = (rois[:, 1] + rois[:, 3]) * 0.5
+ cy = (rois[:, 2] + rois[:, 4]) * 0.5
+ w = rois[:, 3] - rois[:, 1]
+ h = rois[:, 4] - rois[:, 2]
+ new_w = w * scale_factor
+ new_h = h * scale_factor
+ x1 = cx - new_w * 0.5
+ x2 = cx + new_w * 0.5
+ y1 = cy - new_h * 0.5
+ y2 = cy + new_h * 0.5
+ new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
+ return new_rois
+
+ @abstractmethod
+ def forward(self, feats, rois, roi_scale_factor=None):
+ pass
diff --git a/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..89a9f891e1e5aa52d85531dc62e7f518124df2f4
--- /dev/null
+++ b/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.cnn.bricks import build_plugin_layer
+from mmcv.runner import force_fp32
+
+from mmdet.models.builder import ROI_EXTRACTORS
+from .base_roi_extractor import BaseRoIExtractor
+
+
+@ROI_EXTRACTORS.register_module()
+class GenericRoIExtractor(BaseRoIExtractor):
+ """Extract RoI features from all level feature maps levels.
+
+ This is the implementation of `A novel Region of Interest Extraction Layer
+ for Instance Segmentation `_.
+
+ Args:
+ aggregation (str): The method to aggregate multiple feature maps.
+ Options are 'sum', 'concat'. Default: 'sum'.
+ pre_cfg (dict | None): Specify pre-processing modules. Default: None.
+ post_cfg (dict | None): Specify post-processing modules. Default: None.
+ kwargs (keyword arguments): Arguments that are the same
+ as :class:`BaseRoIExtractor`.
+ """
+
+ def __init__(self,
+ aggregation='sum',
+ pre_cfg=None,
+ post_cfg=None,
+ **kwargs):
+ super(GenericRoIExtractor, self).__init__(**kwargs)
+
+ assert aggregation in ['sum', 'concat']
+
+ self.aggregation = aggregation
+ self.with_post = post_cfg is not None
+ self.with_pre = pre_cfg is not None
+ # build pre/post processing modules
+ if self.with_post:
+ self.post_module = build_plugin_layer(post_cfg, '_post_module')[1]
+ if self.with_pre:
+ self.pre_module = build_plugin_layer(pre_cfg, '_pre_module')[1]
+
+ @force_fp32(apply_to=('feats', ), out_fp16=True)
+ def forward(self, feats, rois, roi_scale_factor=None):
+ """Forward function."""
+ if len(feats) == 1:
+ return self.roi_layers[0](feats[0], rois)
+
+ out_size = self.roi_layers[0].output_size
+ num_levels = len(feats)
+ roi_feats = feats[0].new_zeros(
+ rois.size(0), self.out_channels, *out_size)
+
+ # some times rois is an empty tensor
+ if roi_feats.shape[0] == 0:
+ return roi_feats
+
+ if roi_scale_factor is not None:
+ rois = self.roi_rescale(rois, roi_scale_factor)
+
+ # mark the starting channels for concat mode
+ start_channels = 0
+ for i in range(num_levels):
+ roi_feats_t = self.roi_layers[i](feats[i], rois)
+ end_channels = start_channels + roi_feats_t.size(1)
+ if self.with_pre:
+ # apply pre-processing to a RoI extracted from each layer
+ roi_feats_t = self.pre_module(roi_feats_t)
+ if self.aggregation == 'sum':
+ # and sum them all
+ roi_feats = roi_feats + roi_feats_t
+ else:
+ # and concat them along channel dimension
+ roi_feats[:, start_channels:end_channels] = roi_feats_t
+ # update channels starting position
+ start_channels = end_channels
+ # check if concat channels match at the end
+ if self.aggregation == 'concat':
+ assert start_channels == self.out_channels
+
+ if self.with_post:
+ # apply post-processing before return the result
+ roi_feats = self.post_module(roi_feats)
+ return roi_feats
diff --git a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc5aef15b94d8606d60a6c4d0a1c8e6f6e60a41
--- /dev/null
+++ b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.models.builder import ROI_EXTRACTORS
+from .base_roi_extractor import BaseRoIExtractor
+
+
+@ROI_EXTRACTORS.register_module()
+class SingleRoIExtractor(BaseRoIExtractor):
+ """Extract RoI features from a single level feature map.
+
+ If there are multiple input feature levels, each RoI is mapped to a level
+ according to its scale. The mapping rule is proposed in
+ `FPN `_.
+
+ Args:
+ roi_layer (dict): Specify RoI layer type and arguments.
+ out_channels (int): Output channels of RoI layers.
+ featmap_strides (List[int]): Strides of input feature maps.
+ finest_scale (int): Scale threshold of mapping to level 0. Default: 56.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ roi_layer,
+ out_channels,
+ featmap_strides,
+ finest_scale=56,
+ init_cfg=None):
+ super(SingleRoIExtractor, self).__init__(roi_layer, out_channels,
+ featmap_strides, init_cfg)
+ self.finest_scale = finest_scale
+
+ def map_roi_levels(self, rois, num_levels):
+ """Map rois to corresponding feature levels by scales.
+
+ - scale < finest_scale * 2: level 0
+ - finest_scale * 2 <= scale < finest_scale * 4: level 1
+ - finest_scale * 4 <= scale < finest_scale * 8: level 2
+ - scale >= finest_scale * 8: level 3
+
+ Args:
+ rois (Tensor): Input RoIs, shape (k, 5).
+ num_levels (int): Total level number.
+
+ Returns:
+ Tensor: Level index (0-based) of each RoI, shape (k, )
+ """
+ scale = torch.sqrt(
+ (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
+ target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
+ target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
+ return target_lvls
+
+ @force_fp32(apply_to=('feats', ), out_fp16=True)
+ def forward(self, feats, rois, roi_scale_factor=None):
+ """Forward function."""
+ out_size = self.roi_layers[0].output_size
+ num_levels = len(feats)
+ expand_dims = (-1, self.out_channels * out_size[0] * out_size[1])
+ if torch.onnx.is_in_onnx_export():
+ # Work around to export mask-rcnn to onnx
+ roi_feats = rois[:, :1].clone().detach()
+ roi_feats = roi_feats.expand(*expand_dims)
+ roi_feats = roi_feats.reshape(-1, self.out_channels, *out_size)
+ roi_feats = roi_feats * 0
+ else:
+ roi_feats = feats[0].new_zeros(
+ rois.size(0), self.out_channels, *out_size)
+
+ if num_levels == 1:
+ if len(rois) == 0:
+ return roi_feats
+ return self.roi_layers[0](feats[0], rois)
+
+ target_lvls = self.map_roi_levels(rois, num_levels)
+
+ if roi_scale_factor is not None:
+ rois = self.roi_rescale(rois, roi_scale_factor)
+
+ for i in range(num_levels):
+ mask = target_lvls == i
+ if torch.onnx.is_in_onnx_export():
+ # To keep all roi_align nodes exported to onnx
+ # and skip nonzero op
+ mask = mask.float().unsqueeze(-1)
+ # select target level rois and reset the rest rois to zero.
+ rois_i = rois.clone().detach()
+ rois_i = rois_i * mask
+ mask_exp = mask.expand(*expand_dims).reshape(roi_feats.shape)
+ roi_feats_t = self.roi_layers[i](feats[i], rois_i)
+ roi_feats_t = roi_feats_t * mask_exp
+ roi_feats = roi_feats + roi_feats_t
+ continue
+ inds = mask.nonzero(as_tuple=False).squeeze(1)
+ if inds.numel() > 0:
+ rois_ = rois[inds]
+ roi_feats_t = self.roi_layers[i](feats[i], rois_)
+ roi_feats[inds] = roi_feats_t
+ else:
+ # Sometimes some pyramid levels will not be used for RoI
+ # feature extraction and this will cause an incomplete
+ # computation graph in one GPU, which is different from those
+ # in other GPUs and will cause a hanging error.
+ # Therefore, we add it to ensure each feature pyramid is
+ # included in the computation graph to avoid runtime bugs.
+ roi_feats = roi_feats + sum(
+ x.view(-1)[0]
+ for x in self.parameters()) * 0. + feats[i].sum() * 0.
+ return roi_feats
diff --git a/mmdet/models/roi_heads/scnet_roi_head.py b/mmdet/models/roi_heads/scnet_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..32f56aa8a24d5f825351b714a99fce836eacbf18
--- /dev/null
+++ b/mmdet/models/roi_heads/scnet_roi_head.py
@@ -0,0 +1,605 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
+ merge_aug_masks, multiclass_nms)
+from ..builder import HEADS, build_head, build_roi_extractor
+from ..utils.brick_wrappers import adaptive_avg_pool2d
+from .cascade_roi_head import CascadeRoIHead
+
+
+@HEADS.register_module()
+class SCNetRoIHead(CascadeRoIHead):
+ """RoIHead for `SCNet `_.
+
+ Args:
+ num_stages (int): number of cascade stages.
+ stage_loss_weights (list): loss weight of cascade stages.
+ semantic_roi_extractor (dict): config to init semantic roi extractor.
+ semantic_head (dict): config to init semantic head.
+ feat_relay_head (dict): config to init feature_relay_head.
+ glbctx_head (dict): config to init global context head.
+ """
+
+ def __init__(self,
+ num_stages,
+ stage_loss_weights,
+ semantic_roi_extractor=None,
+ semantic_head=None,
+ feat_relay_head=None,
+ glbctx_head=None,
+ **kwargs):
+ super(SCNetRoIHead, self).__init__(num_stages, stage_loss_weights,
+ **kwargs)
+ assert self.with_bbox and self.with_mask
+ assert not self.with_shared_head # shared head is not supported
+
+ if semantic_head is not None:
+ self.semantic_roi_extractor = build_roi_extractor(
+ semantic_roi_extractor)
+ self.semantic_head = build_head(semantic_head)
+
+ if feat_relay_head is not None:
+ self.feat_relay_head = build_head(feat_relay_head)
+
+ if glbctx_head is not None:
+ self.glbctx_head = build_head(glbctx_head)
+
+ def init_mask_head(self, mask_roi_extractor, mask_head):
+ """Initialize ``mask_head``"""
+ if mask_roi_extractor is not None:
+ self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor)
+ self.mask_head = build_head(mask_head)
+
+ @property
+ def with_semantic(self):
+ """bool: whether the head has semantic head"""
+ return hasattr(self,
+ 'semantic_head') and self.semantic_head is not None
+
+ @property
+ def with_feat_relay(self):
+ """bool: whether the head has feature relay head"""
+ return (hasattr(self, 'feat_relay_head')
+ and self.feat_relay_head is not None)
+
+ @property
+ def with_glbctx(self):
+ """bool: whether the head has global context head"""
+ return hasattr(self, 'glbctx_head') and self.glbctx_head is not None
+
+ def _fuse_glbctx(self, roi_feats, glbctx_feat, rois):
+ """Fuse global context feats with roi feats."""
+ assert roi_feats.size(0) == rois.size(0)
+ img_inds = torch.unique(rois[:, 0].cpu(), sorted=True).long()
+ fused_feats = torch.zeros_like(roi_feats)
+ for img_id in img_inds:
+ inds = (rois[:, 0] == img_id.item())
+ fused_feats[inds] = roi_feats[inds] + glbctx_feat[img_id]
+ return fused_feats
+
+ def _slice_pos_feats(self, feats, sampling_results):
+ """Get features from pos rois."""
+ num_rois = [res.bboxes.size(0) for res in sampling_results]
+ num_pos_rois = [res.pos_bboxes.size(0) for res in sampling_results]
+ inds = torch.zeros(sum(num_rois), dtype=torch.bool)
+ start = 0
+ for i in range(len(num_rois)):
+ start = 0 if i == 0 else start + num_rois[i - 1]
+ stop = start + num_pos_rois[i]
+ inds[start:stop] = 1
+ sliced_feats = feats[inds]
+ return sliced_feats
+
+ def _bbox_forward(self,
+ stage,
+ x,
+ rois,
+ semantic_feat=None,
+ glbctx_feat=None):
+ """Box head forward function used in both training and testing."""
+ bbox_roi_extractor = self.bbox_roi_extractor[stage]
+ bbox_head = self.bbox_head[stage]
+ bbox_feats = bbox_roi_extractor(
+ x[:len(bbox_roi_extractor.featmap_strides)], rois)
+ if self.with_semantic and semantic_feat is not None:
+ bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ rois)
+ if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
+ bbox_semantic_feat = adaptive_avg_pool2d(
+ bbox_semantic_feat, bbox_feats.shape[-2:])
+ bbox_feats = bbox_feats + bbox_semantic_feat
+ if self.with_glbctx and glbctx_feat is not None:
+ bbox_feats = self._fuse_glbctx(bbox_feats, glbctx_feat, rois)
+ cls_score, bbox_pred, relayed_feat = bbox_head(
+ bbox_feats, return_shared_feat=True)
+
+ bbox_results = dict(
+ cls_score=cls_score,
+ bbox_pred=bbox_pred,
+ relayed_feat=relayed_feat)
+ return bbox_results
+
+ def _mask_forward(self,
+ x,
+ rois,
+ semantic_feat=None,
+ glbctx_feat=None,
+ relayed_feat=None):
+ """Mask head forward function used in both training and testing."""
+ mask_feats = self.mask_roi_extractor(
+ x[:self.mask_roi_extractor.num_inputs], rois)
+ if self.with_semantic and semantic_feat is not None:
+ mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ rois)
+ if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
+ mask_semantic_feat = F.adaptive_avg_pool2d(
+ mask_semantic_feat, mask_feats.shape[-2:])
+ mask_feats = mask_feats + mask_semantic_feat
+ if self.with_glbctx and glbctx_feat is not None:
+ mask_feats = self._fuse_glbctx(mask_feats, glbctx_feat, rois)
+ if self.with_feat_relay and relayed_feat is not None:
+ mask_feats = mask_feats + relayed_feat
+ mask_pred = self.mask_head(mask_feats)
+ mask_results = dict(mask_pred=mask_pred)
+
+ return mask_results
+
+ def _bbox_forward_train(self,
+ stage,
+ x,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ rcnn_train_cfg,
+ semantic_feat=None,
+ glbctx_feat=None):
+ """Run forward function and calculate loss for box head in training."""
+ bbox_head = self.bbox_head[stage]
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(
+ stage,
+ x,
+ rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+
+ bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, rcnn_train_cfg)
+ loss_bbox = bbox_head.loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(
+ loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets)
+ return bbox_results
+
+ def _mask_forward_train(self,
+ x,
+ sampling_results,
+ gt_masks,
+ rcnn_train_cfg,
+ semantic_feat=None,
+ glbctx_feat=None,
+ relayed_feat=None):
+ """Run forward function and calculate loss for mask head in
+ training."""
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+ mask_results = self._mask_forward(
+ x,
+ pos_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat,
+ relayed_feat=relayed_feat)
+
+ mask_targets = self.mask_head.get_targets(sampling_results, gt_masks,
+ rcnn_train_cfg)
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ loss_mask = self.mask_head.loss(mask_results['mask_pred'],
+ mask_targets, pos_labels)
+
+ mask_results = loss_mask
+ return mask_results
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ gt_semantic_seg=None):
+ """
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ proposal_list (list[Tensors]): list of region proposals.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None, list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (None, Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+ gt_semantic_seg (None, list[Tensor]): semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ losses = dict()
+
+ # semantic segmentation branch
+ if self.with_semantic:
+ semantic_pred, semantic_feat = self.semantic_head(x)
+ loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg)
+ losses['loss_semantic_seg'] = loss_seg
+ else:
+ semantic_feat = None
+
+ # global context branch
+ if self.with_glbctx:
+ mc_pred, glbctx_feat = self.glbctx_head(x)
+ loss_glbctx = self.glbctx_head.loss(mc_pred, gt_labels)
+ losses['loss_glbctx'] = loss_glbctx
+ else:
+ glbctx_feat = None
+
+ for i in range(self.num_stages):
+ self.current_stage = i
+ rcnn_train_cfg = self.train_cfg[i]
+ lw = self.stage_loss_weights[i]
+
+ # assign gts and sample proposals
+ sampling_results = []
+ bbox_assigner = self.bbox_assigner[i]
+ bbox_sampler = self.bbox_sampler[i]
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+
+ for j in range(num_imgs):
+ assign_result = bbox_assigner.assign(proposal_list[j],
+ gt_bboxes[j],
+ gt_bboxes_ignore[j],
+ gt_labels[j])
+ sampling_result = bbox_sampler.sample(
+ assign_result,
+ proposal_list[j],
+ gt_bboxes[j],
+ gt_labels[j],
+ feats=[lvl_feat[j][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+
+ bbox_results = \
+ self._bbox_forward_train(
+ i, x, sampling_results, gt_bboxes, gt_labels,
+ rcnn_train_cfg, semantic_feat, glbctx_feat)
+ roi_labels = bbox_results['bbox_targets'][0]
+
+ for name, value in bbox_results['loss_bbox'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # refine boxes
+ if i < self.num_stages - 1:
+ pos_is_gts = [res.pos_is_gt for res in sampling_results]
+ with torch.no_grad():
+ proposal_list = self.bbox_head[i].refine_bboxes(
+ bbox_results['rois'], roi_labels,
+ bbox_results['bbox_pred'], pos_is_gts, img_metas)
+
+ if self.with_feat_relay:
+ relayed_feat = self._slice_pos_feats(bbox_results['relayed_feat'],
+ sampling_results)
+ relayed_feat = self.feat_relay_head(relayed_feat)
+ else:
+ relayed_feat = None
+
+ mask_results = self._mask_forward_train(x, sampling_results, gt_masks,
+ rcnn_train_cfg, semantic_feat,
+ glbctx_feat, relayed_feat)
+ mask_lw = sum(self.stage_loss_weights)
+ losses['loss_mask'] = mask_lw * mask_results['loss_mask']
+
+ return losses
+
+ def simple_test(self, x, proposal_list, img_metas, rescale=False):
+ """Test without augmentation.
+
+ Args:
+ x (tuple[Tensor]): Features from upstream network. Each
+ has shape (batch_size, c, h, w).
+ proposal_list (list(Tensor)): Proposals from rpn head.
+ Each has shape (num_proposals, 5), last dimension
+ 5 represent (x1, y1, x2, y2, score).
+ img_metas (list[dict]): Meta information of images.
+ rescale (bool): Whether to rescale the results to
+ the original image. Default: True.
+
+ Returns:
+ list[list[np.ndarray]] or list[tuple]: When no mask branch,
+ it is bbox results of each image and classes with type
+ `list[list[np.ndarray]]`. The outer list
+ corresponds to each image. The inner list
+ corresponds to each class. When the model has mask branch,
+ it contains bbox results and mask results.
+ The outer list corresponds to each image, and first element
+ of tuple is bbox results, second element is mask results.
+ """
+ if self.with_semantic:
+ _, semantic_feat = self.semantic_head(x)
+ else:
+ semantic_feat = None
+
+ if self.with_glbctx:
+ mc_pred, glbctx_feat = self.glbctx_head(x)
+ else:
+ glbctx_feat = None
+
+ num_imgs = len(proposal_list)
+ img_shapes = tuple(meta['img_shape'] for meta in img_metas)
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ # "ms" in variable names means multi-stage
+ ms_scores = []
+ rcnn_test_cfg = self.test_cfg
+
+ rois = bbox2roi(proposal_list)
+
+ if rois.shape[0] == 0:
+ # There is no proposal in the whole batch
+ bbox_results = [[
+ np.zeros((0, 5), dtype=np.float32)
+ for _ in range(self.bbox_head[-1].num_classes)
+ ]] * num_imgs
+
+ if self.with_mask:
+ mask_classes = self.mask_head.num_classes
+ segm_results = [[[] for _ in range(mask_classes)]
+ for _ in range(num_imgs)]
+ results = list(zip(bbox_results, segm_results))
+ else:
+ results = bbox_results
+
+ return results
+
+ for i in range(self.num_stages):
+ bbox_head = self.bbox_head[i]
+ bbox_results = self._bbox_forward(
+ i,
+ x,
+ rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+ # split batch bbox prediction back to each image
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+ num_proposals_per_img = tuple(len(p) for p in proposal_list)
+ rois = rois.split(num_proposals_per_img, 0)
+ cls_score = cls_score.split(num_proposals_per_img, 0)
+ bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
+ ms_scores.append(cls_score)
+
+ if i < self.num_stages - 1:
+ refine_rois_list = []
+ for j in range(num_imgs):
+ if rois[j].shape[0] > 0:
+ bbox_label = cls_score[j][:, :-1].argmax(dim=1)
+ refine_rois = bbox_head.regress_by_class(
+ rois[j], bbox_label, bbox_pred[j], img_metas[j])
+ refine_rois_list.append(refine_rois)
+ rois = torch.cat(refine_rois_list)
+
+ # average scores of each image by stages
+ cls_score = [
+ sum([score[i] for score in ms_scores]) / float(len(ms_scores))
+ for i in range(num_imgs)
+ ]
+
+ # apply bbox post-processing to each image individually
+ det_bboxes = []
+ det_labels = []
+ for i in range(num_imgs):
+ det_bbox, det_label = self.bbox_head[-1].get_bboxes(
+ rois[i],
+ cls_score[i],
+ bbox_pred[i],
+ img_shapes[i],
+ scale_factors[i],
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+ det_bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head[-1].num_classes)
+ for i in range(num_imgs)
+ ]
+
+ if self.with_mask:
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ mask_classes = self.mask_head.num_classes
+ det_segm_results = [[[] for _ in range(mask_classes)]
+ for _ in range(num_imgs)]
+ else:
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i]
+ for i in range(num_imgs)
+ ]
+ mask_rois = bbox2roi(_bboxes)
+
+ # get relay feature on mask_rois
+ bbox_results = self._bbox_forward(
+ -1,
+ x,
+ mask_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+ relayed_feat = bbox_results['relayed_feat']
+ relayed_feat = self.feat_relay_head(relayed_feat)
+
+ mask_results = self._mask_forward(
+ x,
+ mask_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat,
+ relayed_feat=relayed_feat)
+ mask_pred = mask_results['mask_pred']
+
+ # split batch mask prediction back to each image
+ num_bbox_per_img = tuple(len(_bbox) for _bbox in _bboxes)
+ mask_preds = mask_pred.split(num_bbox_per_img, 0)
+
+ # apply mask post-processing to each image individually
+ det_segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ det_segm_results.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ else:
+ segm_result = self.mask_head.get_seg_masks(
+ mask_preds[i], _bboxes[i], det_labels[i],
+ self.test_cfg, ori_shapes[i], scale_factors[i],
+ rescale)
+ det_segm_results.append(segm_result)
+
+ # return results
+ if self.with_mask:
+ return list(zip(det_bbox_results, det_segm_results))
+ else:
+ return det_bbox_results
+
+ def aug_test(self, img_feats, proposal_list, img_metas, rescale=False):
+ if self.with_semantic:
+ semantic_feats = [
+ self.semantic_head(feat)[1] for feat in img_feats
+ ]
+ else:
+ semantic_feats = [None] * len(img_metas)
+
+ if self.with_glbctx:
+ glbctx_feats = [self.glbctx_head(feat)[1] for feat in img_feats]
+ else:
+ glbctx_feats = [None] * len(img_metas)
+
+ rcnn_test_cfg = self.test_cfg
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta, semantic_feat, glbctx_feat in zip(
+ img_feats, img_metas, semantic_feats, glbctx_feats):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip)
+ # "ms" in variable names means multi-stage
+ ms_scores = []
+
+ rois = bbox2roi([proposals])
+
+ if rois.shape[0] == 0:
+ # There is no proposal in the single image
+ aug_bboxes.append(rois.new_zeros(0, 4))
+ aug_scores.append(rois.new_zeros(0, 1))
+ continue
+
+ for i in range(self.num_stages):
+ bbox_head = self.bbox_head[i]
+ bbox_results = self._bbox_forward(
+ i,
+ x,
+ rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+ ms_scores.append(bbox_results['cls_score'])
+ if i < self.num_stages - 1:
+ bbox_label = bbox_results['cls_score'].argmax(dim=1)
+ rois = bbox_head.regress_by_class(
+ rois, bbox_label, bbox_results['bbox_pred'],
+ img_meta[0])
+
+ cls_score = sum(ms_scores) / float(len(ms_scores))
+ bboxes, scores = self.bbox_head[-1].get_bboxes(
+ rois,
+ cls_score,
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ aug_bboxes.append(bboxes)
+ aug_scores.append(scores)
+
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+
+ det_bbox_results = bbox2result(det_bboxes, det_labels,
+ self.bbox_head[-1].num_classes)
+
+ if self.with_mask:
+ if det_bboxes.shape[0] == 0:
+ det_segm_results = [[]
+ for _ in range(self.mask_head.num_classes)]
+ else:
+ aug_masks = []
+ for x, img_meta, semantic_feat, glbctx_feat in zip(
+ img_feats, img_metas, semantic_feats, glbctx_feats):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip)
+ mask_rois = bbox2roi([_bboxes])
+ # get relay feature on mask_rois
+ bbox_results = self._bbox_forward(
+ -1,
+ x,
+ mask_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+ relayed_feat = bbox_results['relayed_feat']
+ relayed_feat = self.feat_relay_head(relayed_feat)
+ mask_results = self._mask_forward(
+ x,
+ mask_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat,
+ relayed_feat=relayed_feat)
+ mask_pred = mask_results['mask_pred']
+ aug_masks.append(mask_pred.sigmoid().cpu().numpy())
+ merged_masks = merge_aug_masks(aug_masks, img_metas,
+ self.test_cfg)
+ ori_shape = img_metas[0][0]['ori_shape']
+ det_segm_results = self.mask_head.get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ rcnn_test_cfg,
+ ori_shape,
+ scale_factor=1.0,
+ rescale=False)
+ return [(det_bbox_results, det_segm_results)]
+ else:
+ return [det_bbox_results]
diff --git a/mmdet/models/roi_heads/shared_heads/__init__.py b/mmdet/models/roi_heads/shared_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d56636ab34d1dd2592828238099bcdccf179d6d3
--- /dev/null
+++ b/mmdet/models/roi_heads/shared_heads/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .res_layer import ResLayer
+
+__all__ = ['ResLayer']
diff --git a/mmdet/models/roi_heads/shared_heads/res_layer.py b/mmdet/models/roi_heads/shared_heads/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bef00a0581b225df618616e5c5b8f417337d9fe1
--- /dev/null
+++ b/mmdet/models/roi_heads/shared_heads/res_layer.py
@@ -0,0 +1,80 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+from mmcv.runner import BaseModule, auto_fp16
+
+from mmdet.models.backbones import ResNet
+from mmdet.models.builder import SHARED_HEADS
+from mmdet.models.utils import ResLayer as _ResLayer
+
+
+@SHARED_HEADS.register_module()
+class ResLayer(BaseModule):
+
+ def __init__(self,
+ depth,
+ stage=3,
+ stride=2,
+ dilation=1,
+ style='pytorch',
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ with_cp=False,
+ dcn=None,
+ pretrained=None,
+ init_cfg=None):
+ super(ResLayer, self).__init__(init_cfg)
+
+ self.norm_eval = norm_eval
+ self.norm_cfg = norm_cfg
+ self.stage = stage
+ self.fp16_enabled = False
+ block, stage_blocks = ResNet.arch_settings[depth]
+ stage_block = stage_blocks[stage]
+ planes = 64 * 2**stage
+ inplanes = 64 * 2**(stage - 1) * block.expansion
+
+ res_layer = _ResLayer(
+ block,
+ inplanes,
+ planes,
+ stage_block,
+ stride=stride,
+ dilation=dilation,
+ style=style,
+ with_cp=with_cp,
+ norm_cfg=self.norm_cfg,
+ dcn=dcn)
+ self.add_module(f'layer{stage + 1}', res_layer)
+
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be specified at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is a deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is None:
+ if init_cfg is None:
+ self.init_cfg = [
+ dict(type='Kaiming', layer='Conv2d'),
+ dict(
+ type='Constant',
+ val=1,
+ layer=['_BatchNorm', 'GroupNorm'])
+ ]
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ @auto_fp16()
+ def forward(self, x):
+ res_layer = getattr(self, f'layer{self.stage + 1}')
+ out = res_layer(x)
+ return out
+
+ def train(self, mode=True):
+ super(ResLayer, self).train(mode)
+ if self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
diff --git a/mmdet/models/roi_heads/sparse_roi_head.py b/mmdet/models/roi_heads/sparse_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2613469e3a7cf397f19c04b24c43ab50b0c75551
--- /dev/null
+++ b/mmdet/models/roi_heads/sparse_roi_head.py
@@ -0,0 +1,424 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from mmdet.core import bbox2result, bbox2roi, bbox_xyxy_to_cxcywh
+from mmdet.core.bbox.samplers import PseudoSampler
+from ..builder import HEADS
+from .cascade_roi_head import CascadeRoIHead
+
+
+@HEADS.register_module()
+class SparseRoIHead(CascadeRoIHead):
+ r"""The RoIHead for `Sparse R-CNN: End-to-End Object Detection with
+ Learnable Proposals `_
+ and `Instances as Queries `_
+
+ Args:
+ num_stages (int): Number of stage whole iterative process.
+ Defaults to 6.
+ stage_loss_weights (Tuple[float]): The loss
+ weight of each stage. By default all stages have
+ the same weight 1.
+ bbox_roi_extractor (dict): Config of box roi extractor.
+ mask_roi_extractor (dict): Config of mask roi extractor.
+ bbox_head (dict): Config of box head.
+ mask_head (dict): Config of mask head.
+ train_cfg (dict, optional): Configuration information in train stage.
+ Defaults to None.
+ test_cfg (dict, optional): Configuration information in test stage.
+ Defaults to None.
+ pretrained (str, optional): model pretrained path. Default: None
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+
+ """
+
+ def __init__(self,
+ num_stages=6,
+ stage_loss_weights=(1, 1, 1, 1, 1, 1),
+ proposal_feature_channel=256,
+ bbox_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(
+ type='RoIAlign', output_size=7, sampling_ratio=2),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ mask_roi_extractor=None,
+ bbox_head=dict(
+ type='DIIHead',
+ num_classes=80,
+ num_fcs=2,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_reg_fcs=3,
+ feedforward_channels=2048,
+ hidden_channels=256,
+ dropout=0.0,
+ roi_feat_size=7,
+ ffn_act_cfg=dict(type='ReLU', inplace=True)),
+ mask_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ init_cfg=None):
+ assert bbox_roi_extractor is not None
+ assert bbox_head is not None
+ assert len(stage_loss_weights) == num_stages
+ self.num_stages = num_stages
+ self.stage_loss_weights = stage_loss_weights
+ self.proposal_feature_channel = proposal_feature_channel
+ super(SparseRoIHead, self).__init__(
+ num_stages,
+ stage_loss_weights,
+ bbox_roi_extractor=bbox_roi_extractor,
+ mask_roi_extractor=mask_roi_extractor,
+ bbox_head=bbox_head,
+ mask_head=mask_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained,
+ init_cfg=init_cfg)
+ # train_cfg would be None when run the test.py
+ if train_cfg is not None:
+ for stage in range(num_stages):
+ assert isinstance(self.bbox_sampler[stage], PseudoSampler), \
+ 'Sparse R-CNN and QueryInst only support `PseudoSampler`'
+
+ def _bbox_forward(self, stage, x, rois, object_feats, img_metas):
+ """Box head forward function used in both training and testing. Returns
+ all regression, classification results and a intermediate feature.
+
+ Args:
+ stage (int): The index of current stage in
+ iterative process.
+ x (List[Tensor]): List of FPN features
+ rois (Tensor): Rois in total batch. With shape (num_proposal, 5).
+ the last dimension 5 represents (img_index, x1, y1, x2, y2).
+ object_feats (Tensor): The object feature extracted from
+ the previous stage.
+ img_metas (dict): meta information of images.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of bbox head outputs,
+ Containing the following results:
+
+ - cls_score (Tensor): The score of each class, has
+ shape (batch_size, num_proposals, num_classes)
+ when use focal loss or
+ (batch_size, num_proposals, num_classes+1)
+ otherwise.
+ - decode_bbox_pred (Tensor): The regression results
+ with shape (batch_size, num_proposal, 4).
+ The last dimension 4 represents
+ [tl_x, tl_y, br_x, br_y].
+ - object_feats (Tensor): The object feature extracted
+ from current stage
+ - detach_cls_score_list (list[Tensor]): The detached
+ classification results, length is batch_size, and
+ each tensor has shape (num_proposal, num_classes).
+ - detach_proposal_list (list[tensor]): The detached
+ regression results, length is batch_size, and each
+ tensor has shape (num_proposal, 4). The last
+ dimension 4 represents [tl_x, tl_y, br_x, br_y].
+ """
+ num_imgs = len(img_metas)
+ bbox_roi_extractor = self.bbox_roi_extractor[stage]
+ bbox_head = self.bbox_head[stage]
+ bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
+ rois)
+ cls_score, bbox_pred, object_feats, attn_feats = bbox_head(
+ bbox_feats, object_feats)
+ proposal_list = self.bbox_head[stage].refine_bboxes(
+ rois,
+ rois.new_zeros(len(rois)), # dummy arg
+ bbox_pred.view(-1, bbox_pred.size(-1)),
+ [rois.new_zeros(object_feats.size(1)) for _ in range(num_imgs)],
+ img_metas)
+ bbox_results = dict(
+ cls_score=cls_score,
+ decode_bbox_pred=torch.cat(proposal_list),
+ object_feats=object_feats,
+ attn_feats=attn_feats,
+ # detach then use it in label assign
+ detach_cls_score_list=[
+ cls_score[i].detach() for i in range(num_imgs)
+ ],
+ detach_proposal_list=[item.detach() for item in proposal_list])
+
+ return bbox_results
+
+ def _mask_forward(self, stage, x, rois, attn_feats):
+ """Mask head forward function used in both training and testing."""
+ mask_roi_extractor = self.mask_roi_extractor[stage]
+ mask_head = self.mask_head[stage]
+ mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
+ rois)
+ # do not support caffe_c4 model anymore
+ mask_pred = mask_head(mask_feats, attn_feats)
+
+ mask_results = dict(mask_pred=mask_pred)
+ return mask_results
+
+ def _mask_forward_train(self, stage, x, attn_feats, sampling_results,
+ gt_masks, rcnn_train_cfg):
+ """Run forward function and calculate loss for mask head in
+ training."""
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+ attn_feats = torch.cat([
+ feats[res.pos_inds]
+ for (feats, res) in zip(attn_feats, sampling_results)
+ ])
+ mask_results = self._mask_forward(stage, x, pos_rois, attn_feats)
+
+ mask_targets = self.mask_head[stage].get_targets(
+ sampling_results, gt_masks, rcnn_train_cfg)
+
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+
+ loss_mask = self.mask_head[stage].loss(mask_results['mask_pred'],
+ mask_targets, pos_labels)
+ mask_results.update(loss_mask)
+ return mask_results
+
+ def forward_train(self,
+ x,
+ proposal_boxes,
+ proposal_features,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ imgs_whwh=None,
+ gt_masks=None):
+ """Forward function in training stage.
+
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+ proposals (Tensor): Decoded proposal bboxes, has shape
+ (batch_size, num_proposals, 4)
+ proposal_features (Tensor): Expanded proposal
+ features, has shape
+ (batch_size, num_proposals, proposal_feature_channel)
+ img_metas (list[dict]): list of image info dict where
+ each dict has: 'img_shape', 'scale_factor', 'flip',
+ and may also contain 'filename', 'ori_shape',
+ 'pad_shape', and 'img_norm_cfg'. For details on the
+ values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ imgs_whwh (Tensor): Tensor with shape (batch_size, 4),
+ the dimension means
+ [img_width,img_height, img_width, img_height].
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components of all stage.
+ """
+
+ num_imgs = len(img_metas)
+ num_proposals = proposal_boxes.size(1)
+ imgs_whwh = imgs_whwh.repeat(1, num_proposals, 1)
+ all_stage_bbox_results = []
+ proposal_list = [proposal_boxes[i] for i in range(len(proposal_boxes))]
+ object_feats = proposal_features
+ all_stage_loss = {}
+ for stage in range(self.num_stages):
+ rois = bbox2roi(proposal_list)
+ bbox_results = self._bbox_forward(stage, x, rois, object_feats,
+ img_metas)
+ all_stage_bbox_results.append(bbox_results)
+ if gt_bboxes_ignore is None:
+ # TODO support ignore
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+ sampling_results = []
+ cls_pred_list = bbox_results['detach_cls_score_list']
+ proposal_list = bbox_results['detach_proposal_list']
+ for i in range(num_imgs):
+ normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(proposal_list[i] /
+ imgs_whwh[i])
+ assign_result = self.bbox_assigner[stage].assign(
+ normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
+ gt_labels[i], img_metas[i])
+ sampling_result = self.bbox_sampler[stage].sample(
+ assign_result, proposal_list[i], gt_bboxes[i])
+ sampling_results.append(sampling_result)
+ bbox_targets = self.bbox_head[stage].get_targets(
+ sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage],
+ True)
+ cls_score = bbox_results['cls_score']
+ decode_bbox_pred = bbox_results['decode_bbox_pred']
+
+ single_stage_loss = self.bbox_head[stage].loss(
+ cls_score.view(-1, cls_score.size(-1)),
+ decode_bbox_pred.view(-1, 4),
+ *bbox_targets,
+ imgs_whwh=imgs_whwh)
+
+ if self.with_mask:
+ mask_results = self._mask_forward_train(
+ stage, x, bbox_results['attn_feats'], sampling_results,
+ gt_masks, self.train_cfg[stage])
+ single_stage_loss['loss_mask'] = mask_results['loss_mask']
+
+ for key, value in single_stage_loss.items():
+ all_stage_loss[f'stage{stage}_{key}'] = value * \
+ self.stage_loss_weights[stage]
+ object_feats = bbox_results['object_feats']
+
+ return all_stage_loss
+
+ def simple_test(self,
+ x,
+ proposal_boxes,
+ proposal_features,
+ img_metas,
+ imgs_whwh,
+ rescale=False):
+ """Test without augmentation.
+
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+ proposal_boxes (Tensor): Decoded proposal bboxes, has shape
+ (batch_size, num_proposals, 4)
+ proposal_features (Tensor): Expanded proposal
+ features, has shape
+ (batch_size, num_proposals, proposal_feature_channel)
+ img_metas (dict): meta information of images.
+ imgs_whwh (Tensor): Tensor with shape (batch_size, 4),
+ the dimension means
+ [img_width,img_height, img_width, img_height].
+ rescale (bool): If True, return boxes in original image
+ space. Defaults to False.
+
+ Returns:
+ list[list[np.ndarray]] or list[tuple]: When no mask branch,
+ it is bbox results of each image and classes with type
+ `list[list[np.ndarray]]`. The outer list
+ corresponds to each image. The inner list
+ corresponds to each class. When the model has a mask branch,
+ it is a list[tuple] that contains bbox results and mask results.
+ The outer list corresponds to each image, and first element
+ of tuple is bbox results, second element is mask results.
+ """
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ # Decode initial proposals
+ num_imgs = len(img_metas)
+ proposal_list = [proposal_boxes[i] for i in range(num_imgs)]
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ object_feats = proposal_features
+ if all([proposal.shape[0] == 0 for proposal in proposal_list]):
+ # There is no proposal in the whole batch
+ bbox_results = [[
+ np.zeros((0, 5), dtype=np.float32)
+ for i in range(self.bbox_head[-1].num_classes)
+ ]] * num_imgs
+ return bbox_results
+
+ for stage in range(self.num_stages):
+ rois = bbox2roi(proposal_list)
+ bbox_results = self._bbox_forward(stage, x, rois, object_feats,
+ img_metas)
+ object_feats = bbox_results['object_feats']
+ cls_score = bbox_results['cls_score']
+ proposal_list = bbox_results['detach_proposal_list']
+
+ if self.with_mask:
+ rois = bbox2roi(proposal_list)
+ mask_results = self._mask_forward(stage, x, rois,
+ bbox_results['attn_feats'])
+ mask_results['mask_pred'] = mask_results['mask_pred'].reshape(
+ num_imgs, -1, *mask_results['mask_pred'].size()[1:])
+
+ num_classes = self.bbox_head[-1].num_classes
+ det_bboxes = []
+ det_labels = []
+
+ if self.bbox_head[-1].loss_cls.use_sigmoid:
+ cls_score = cls_score.sigmoid()
+ else:
+ cls_score = cls_score.softmax(-1)[..., :-1]
+
+ for img_id in range(num_imgs):
+ cls_score_per_img = cls_score[img_id]
+ scores_per_img, topk_indices = cls_score_per_img.flatten(
+ 0, 1).topk(
+ self.test_cfg.max_per_img, sorted=False)
+ labels_per_img = topk_indices % num_classes
+ bbox_pred_per_img = proposal_list[img_id][topk_indices //
+ num_classes]
+ if rescale:
+ scale_factor = img_metas[img_id]['scale_factor']
+ bbox_pred_per_img /= bbox_pred_per_img.new_tensor(scale_factor)
+ det_bboxes.append(
+ torch.cat([bbox_pred_per_img, scores_per_img[:, None]], dim=1))
+ det_labels.append(labels_per_img)
+
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i], num_classes)
+ for i in range(num_imgs)
+ ]
+
+ if self.with_mask:
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i][:, :4]
+ for i in range(len(det_bboxes))
+ ]
+ segm_results = []
+ mask_pred = mask_results['mask_pred']
+ for img_id in range(num_imgs):
+ mask_pred_per_img = mask_pred[img_id].flatten(0,
+ 1)[topk_indices]
+ mask_pred_per_img = mask_pred_per_img[:, None, ...].repeat(
+ 1, num_classes, 1, 1)
+ segm_result = self.mask_head[-1].get_seg_masks(
+ mask_pred_per_img, _bboxes[img_id], det_labels[img_id],
+ self.test_cfg, ori_shapes[img_id], scale_factors[img_id],
+ rescale)
+ segm_results.append(segm_result)
+
+ if self.with_mask:
+ results = list(zip(bbox_results, segm_results))
+ else:
+ results = bbox_results
+
+ return results
+
+ def aug_test(self, features, proposal_list, img_metas, rescale=False):
+ raise NotImplementedError(
+ 'Sparse R-CNN and QueryInst does not support `aug_test`')
+
+ def forward_dummy(self, x, proposal_boxes, proposal_features, img_metas):
+ """Dummy forward function when do the flops computing."""
+ all_stage_bbox_results = []
+ proposal_list = [proposal_boxes[i] for i in range(len(proposal_boxes))]
+ object_feats = proposal_features
+ if self.with_bbox:
+ for stage in range(self.num_stages):
+ rois = bbox2roi(proposal_list)
+ bbox_results = self._bbox_forward(stage, x, rois, object_feats,
+ img_metas)
+
+ all_stage_bbox_results.append((bbox_results, ))
+ proposal_list = bbox_results['detach_proposal_list']
+ object_feats = bbox_results['object_feats']
+
+ if self.with_mask:
+ rois = bbox2roi(proposal_list)
+ mask_results = self._mask_forward(
+ stage, x, rois, bbox_results['attn_feats'])
+ all_stage_bbox_results[-1] += (mask_results, )
+ return all_stage_bbox_results
diff --git a/mmdet/models/roi_heads/standard_roi_head.py b/mmdet/models/roi_heads/standard_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fdd82ad1f04ba927ef35d16b140b7b23d5ff3e1
--- /dev/null
+++ b/mmdet/models/roi_heads/standard_roi_head.py
@@ -0,0 +1,397 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
+from ..builder import HEADS, build_head, build_roi_extractor
+from .base_roi_head import BaseRoIHead
+from .test_mixins import BBoxTestMixin, MaskTestMixin
+
+
+@HEADS.register_module()
+class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
+ """Simplest base roi head including one bbox head and one mask head."""
+
+ def init_assigner_sampler(self):
+ """Initialize assigner and sampler."""
+ self.bbox_assigner = None
+ self.bbox_sampler = None
+ if self.train_cfg:
+ self.bbox_assigner = build_assigner(self.train_cfg.assigner)
+ self.bbox_sampler = build_sampler(
+ self.train_cfg.sampler, context=self)
+
+ def init_bbox_head(self, bbox_roi_extractor, bbox_head):
+ """Initialize ``bbox_head``"""
+ self.bbox_roi_extractor = build_roi_extractor(bbox_roi_extractor)
+ self.bbox_head = build_head(bbox_head)
+
+ def init_mask_head(self, mask_roi_extractor, mask_head):
+ """Initialize ``mask_head``"""
+ if mask_roi_extractor is not None:
+ self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor)
+ self.share_roi_extractor = False
+ else:
+ self.share_roi_extractor = True
+ self.mask_roi_extractor = self.bbox_roi_extractor
+ self.mask_head = build_head(mask_head)
+
+ def forward_dummy(self, x, proposals):
+ """Dummy forward function."""
+ # bbox head
+ outs = ()
+ rois = bbox2roi([proposals])
+ if self.with_bbox:
+ bbox_results = self._bbox_forward(x, rois)
+ outs = outs + (bbox_results['cls_score'],
+ bbox_results['bbox_pred'])
+ # mask head
+ if self.with_mask:
+ mask_rois = rois[:100]
+ mask_results = self._mask_forward(x, mask_rois)
+ outs = outs + (mask_results['mask_pred'], )
+ return outs
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ **kwargs):
+ """
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ proposals (list[Tensors]): list of region proposals.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # assign gts and sample proposals
+ if self.with_bbox or self.with_mask:
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+ sampling_results = []
+ for i in range(num_imgs):
+ assign_result = self.bbox_assigner.assign(
+ proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
+ gt_labels[i])
+ sampling_result = self.bbox_sampler.sample(
+ assign_result,
+ proposal_list[i],
+ gt_bboxes[i],
+ gt_labels[i],
+ feats=[lvl_feat[i][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+
+ losses = dict()
+ # bbox head forward and loss
+ if self.with_bbox:
+ bbox_results = self._bbox_forward_train(x, sampling_results,
+ gt_bboxes, gt_labels,
+ img_metas)
+ losses.update(bbox_results['loss_bbox'])
+
+ # mask head forward and loss
+ if self.with_mask:
+ mask_results = self._mask_forward_train(x, sampling_results,
+ bbox_results['bbox_feats'],
+ gt_masks, img_metas)
+ losses.update(mask_results['loss_mask'])
+
+ return losses
+
+ def _bbox_forward(self, x, rois):
+ """Box head forward function used in both training and testing."""
+ # TODO: a more flexible way to decide which feature maps to use
+ bbox_feats = self.bbox_roi_extractor(
+ x[:self.bbox_roi_extractor.num_inputs], rois)
+ if self.with_shared_head:
+ bbox_feats = self.shared_head(bbox_feats)
+ cls_score, bbox_pred = self.bbox_head(bbox_feats)
+
+ bbox_results = dict(
+ cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
+ return bbox_results
+
+ def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
+ img_metas):
+ """Run forward function and calculate loss for box head in training."""
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(x, rois)
+
+ bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, self.train_cfg)
+ loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(loss_bbox=loss_bbox)
+ return bbox_results
+
+ def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
+ img_metas):
+ """Run forward function and calculate loss for mask head in
+ training."""
+ if not self.share_roi_extractor:
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+ mask_results = self._mask_forward(x, pos_rois)
+ else:
+ pos_inds = []
+ device = bbox_feats.device
+ for res in sampling_results:
+ pos_inds.append(
+ torch.ones(
+ res.pos_bboxes.shape[0],
+ device=device,
+ dtype=torch.uint8))
+ pos_inds.append(
+ torch.zeros(
+ res.neg_bboxes.shape[0],
+ device=device,
+ dtype=torch.uint8))
+ pos_inds = torch.cat(pos_inds)
+
+ mask_results = self._mask_forward(
+ x, pos_inds=pos_inds, bbox_feats=bbox_feats)
+
+ mask_targets = self.mask_head.get_targets(sampling_results, gt_masks,
+ self.train_cfg)
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ loss_mask = self.mask_head.loss(mask_results['mask_pred'],
+ mask_targets, pos_labels)
+
+ mask_results.update(loss_mask=loss_mask, mask_targets=mask_targets)
+ return mask_results
+
+ def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None):
+ """Mask head forward function used in both training and testing."""
+ assert ((rois is not None) ^
+ (pos_inds is not None and bbox_feats is not None))
+ if rois is not None:
+ mask_feats = self.mask_roi_extractor(
+ x[:self.mask_roi_extractor.num_inputs], rois)
+ if self.with_shared_head:
+ mask_feats = self.shared_head(mask_feats)
+ else:
+ assert bbox_feats is not None
+ mask_feats = bbox_feats[pos_inds]
+
+ mask_pred = self.mask_head(mask_feats)
+ mask_results = dict(mask_pred=mask_pred, mask_feats=mask_feats)
+ return mask_results
+
+ async def async_simple_test(self,
+ x,
+ proposal_list,
+ img_metas,
+ proposals=None,
+ rescale=False):
+ """Async test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+
+ det_bboxes, det_labels = await self.async_test_bboxes(
+ x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
+ bbox_results = bbox2result(det_bboxes, det_labels,
+ self.bbox_head.num_classes)
+ if not self.with_mask:
+ return bbox_results
+ else:
+ segm_results = await self.async_test_mask(
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=rescale,
+ mask_test_cfg=self.test_cfg.get('mask'))
+ return bbox_results, segm_results
+
+ def simple_test(self,
+ x,
+ proposal_list,
+ img_metas,
+ proposals=None,
+ rescale=False):
+ """Test without augmentation.
+
+ Args:
+ x (tuple[Tensor]): Features from upstream network. Each
+ has shape (batch_size, c, h, w).
+ proposal_list (list(Tensor)): Proposals from rpn head.
+ Each has shape (num_proposals, 5), last dimension
+ 5 represent (x1, y1, x2, y2, score).
+ img_metas (list[dict]): Meta information of images.
+ rescale (bool): Whether to rescale the results to
+ the original image. Default: True.
+
+ Returns:
+ list[list[np.ndarray]] or list[tuple]: When no mask branch,
+ it is bbox results of each image and classes with type
+ `list[list[np.ndarray]]`. The outer list
+ corresponds to each image. The inner list
+ corresponds to each class. When the model has mask branch,
+ it contains bbox results and mask results.
+ The outer list corresponds to each image, and first element
+ of tuple is bbox results, second element is mask results.
+ """
+ assert self.with_bbox, 'Bbox head must be implemented.'
+
+ det_bboxes, det_labels = self.simple_test_bboxes(
+ x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
+
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head.num_classes)
+ for i in range(len(det_bboxes))
+ ]
+
+ if not self.with_mask:
+ return bbox_results
+ else:
+ segm_results = self.simple_test_mask(
+ x, img_metas, det_bboxes, det_labels, rescale=rescale)
+ return list(zip(bbox_results, segm_results))
+
+ def aug_test(self, x, proposal_list, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ det_bboxes, det_labels = self.aug_test_bboxes(x, img_metas,
+ proposal_list,
+ self.test_cfg)
+ if rescale:
+ _det_bboxes = det_bboxes
+ else:
+ _det_bboxes = det_bboxes.clone()
+ _det_bboxes[:, :4] *= det_bboxes.new_tensor(
+ img_metas[0][0]['scale_factor'])
+ bbox_results = bbox2result(_det_bboxes, det_labels,
+ self.bbox_head.num_classes)
+
+ # det_bboxes always keep the original scale
+ if self.with_mask:
+ segm_results = self.aug_test_mask(x, img_metas, det_bboxes,
+ det_labels)
+ return [(bbox_results, segm_results)]
+ else:
+ return [bbox_results]
+
+ def onnx_export(self, x, proposals, img_metas, rescale=False):
+ """Test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ det_bboxes, det_labels = self.bbox_onnx_export(
+ x, img_metas, proposals, self.test_cfg, rescale=rescale)
+
+ if not self.with_mask:
+ return det_bboxes, det_labels
+ else:
+ segm_results = self.mask_onnx_export(
+ x, img_metas, det_bboxes, det_labels, rescale=rescale)
+ return det_bboxes, det_labels, segm_results
+
+ def mask_onnx_export(self, x, img_metas, det_bboxes, det_labels, **kwargs):
+ """Export mask branch to onnx which supports batch inference.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ img_metas (list[dict]): Image meta info.
+ det_bboxes (Tensor): Bboxes and corresponding scores.
+ has shape [N, num_bboxes, 5].
+ det_labels (Tensor): class labels of
+ shape [N, num_bboxes].
+
+ Returns:
+ Tensor: The segmentation results of shape [N, num_bboxes,
+ image_height, image_width].
+ """
+ # image shapes of images in the batch
+
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ raise RuntimeError('[ONNX Error] Can not record MaskHead '
+ 'as it has not been executed this time')
+ batch_size = det_bboxes.size(0)
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ det_bboxes = det_bboxes[..., :4]
+ batch_index = torch.arange(
+ det_bboxes.size(0), device=det_bboxes.device).float().view(
+ -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
+ mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
+ mask_rois = mask_rois.view(-1, 5)
+ mask_results = self._mask_forward(x, mask_rois)
+ mask_pred = mask_results['mask_pred']
+ max_shape = img_metas[0]['img_shape_for_onnx']
+ num_det = det_bboxes.shape[1]
+ det_bboxes = det_bboxes.reshape(-1, 4)
+ det_labels = det_labels.reshape(-1)
+ segm_results = self.mask_head.onnx_export(mask_pred, det_bboxes,
+ det_labels, self.test_cfg,
+ max_shape)
+ segm_results = segm_results.reshape(batch_size, num_det, max_shape[0],
+ max_shape[1])
+ return segm_results
+
+ def bbox_onnx_export(self, x, img_metas, proposals, rcnn_test_cfg,
+ **kwargs):
+ """Export bbox branch to onnx which supports batch inference.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ img_metas (list[dict]): Image meta info.
+ proposals (Tensor): Region proposals with
+ batch dimension, has shape [N, num_bboxes, 5].
+ rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
+
+ Returns:
+ tuple[Tensor, Tensor]: bboxes of shape [N, num_bboxes, 5]
+ and class labels of shape [N, num_bboxes].
+ """
+ # get origin input shape to support onnx dynamic input shape
+ assert len(
+ img_metas
+ ) == 1, 'Only support one input image while in exporting to ONNX'
+ img_shapes = img_metas[0]['img_shape_for_onnx']
+
+ rois = proposals
+
+ batch_index = torch.arange(
+ rois.size(0), device=rois.device).float().view(-1, 1, 1).expand(
+ rois.size(0), rois.size(1), 1)
+
+ rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
+ batch_size = rois.shape[0]
+ num_proposals_per_img = rois.shape[1]
+
+ # Eliminate the batch dimension
+ rois = rois.view(-1, 5)
+ bbox_results = self._bbox_forward(x, rois)
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+
+ # Recover the batch dimension
+ rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
+ cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
+ cls_score.size(-1))
+
+ bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img,
+ bbox_pred.size(-1))
+ det_bboxes, det_labels = self.bbox_head.onnx_export(
+ rois, cls_score, bbox_pred, img_shapes, cfg=rcnn_test_cfg)
+
+ return det_bboxes, det_labels
diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae6e79aecf4e10a9ec25a55b480decc179ec91f6
--- /dev/null
+++ b/mmdet/models/roi_heads/test_mixins.py
@@ -0,0 +1,311 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+import warnings
+
+import numpy as np
+import torch
+
+from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes,
+ merge_aug_masks, multiclass_nms)
+
+if sys.version_info >= (3, 7):
+ from mmdet.utils.contextmanagers import completed
+
+
+class BBoxTestMixin:
+
+ if sys.version_info >= (3, 7):
+
+ async def async_test_bboxes(self,
+ x,
+ img_metas,
+ proposals,
+ rcnn_test_cfg,
+ rescale=False,
+ **kwargs):
+ """Asynchronized test for box head without augmentation."""
+ rois = bbox2roi(proposals)
+ roi_feats = self.bbox_roi_extractor(
+ x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
+ if self.with_shared_head:
+ roi_feats = self.shared_head(roi_feats)
+ sleep_interval = rcnn_test_cfg.get('async_sleep_interval', 0.017)
+
+ async with completed(
+ __name__, 'bbox_head_forward',
+ sleep_interval=sleep_interval):
+ cls_score, bbox_pred = self.bbox_head(roi_feats)
+
+ img_shape = img_metas[0]['img_shape']
+ scale_factor = img_metas[0]['scale_factor']
+ det_bboxes, det_labels = self.bbox_head.get_bboxes(
+ rois,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ return det_bboxes, det_labels
+
+ def simple_test_bboxes(self,
+ x,
+ img_metas,
+ proposals,
+ rcnn_test_cfg,
+ rescale=False):
+ """Test only det bboxes without augmentation.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ img_metas (list[dict]): Image meta info.
+ proposals (List[Tensor]): Region proposals.
+ rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor]]: The first list contains
+ the boxes of the corresponding image in a batch, each
+ tensor has the shape (num_boxes, 5) and last dimension
+ 5 represent (tl_x, tl_y, br_x, br_y, score). Each Tensor
+ in the second list is the labels with shape (num_boxes, ).
+ The length of both lists should be equal to batch_size.
+ """
+
+ rois = bbox2roi(proposals)
+
+ if rois.shape[0] == 0:
+ batch_size = len(proposals)
+ det_bbox = rois.new_zeros(0, 5)
+ det_label = rois.new_zeros((0, ), dtype=torch.long)
+ if rcnn_test_cfg is None:
+ det_bbox = det_bbox[:, :4]
+ det_label = rois.new_zeros(
+ (0, self.bbox_head.fc_cls.out_features))
+ # There is no proposal in the whole batch
+ return [det_bbox] * batch_size, [det_label] * batch_size
+
+ bbox_results = self._bbox_forward(x, rois)
+ img_shapes = tuple(meta['img_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ # split batch bbox prediction back to each image
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+ num_proposals_per_img = tuple(len(p) for p in proposals)
+ rois = rois.split(num_proposals_per_img, 0)
+ cls_score = cls_score.split(num_proposals_per_img, 0)
+
+ # some detector with_reg is False, bbox_pred will be None
+ if bbox_pred is not None:
+ # TODO move this to a sabl_roi_head
+ # the bbox prediction of some detectors like SABL is not Tensor
+ if isinstance(bbox_pred, torch.Tensor):
+ bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
+ else:
+ bbox_pred = self.bbox_head.bbox_pred_split(
+ bbox_pred, num_proposals_per_img)
+ else:
+ bbox_pred = (None, ) * len(proposals)
+
+ # apply bbox post-processing to each image individually
+ det_bboxes = []
+ det_labels = []
+ for i in range(len(proposals)):
+ if rois[i].shape[0] == 0:
+ # There is no proposal in the single image
+ det_bbox = rois[i].new_zeros(0, 5)
+ det_label = rois[i].new_zeros((0, ), dtype=torch.long)
+ if rcnn_test_cfg is None:
+ det_bbox = det_bbox[:, :4]
+ det_label = rois[i].new_zeros(
+ (0, self.bbox_head.fc_cls.out_features))
+
+ else:
+ det_bbox, det_label = self.bbox_head.get_bboxes(
+ rois[i],
+ cls_score[i],
+ bbox_pred[i],
+ img_shapes[i],
+ scale_factors[i],
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+ return det_bboxes, det_labels
+
+ def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
+ """Test det bboxes with test time augmentation."""
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta in zip(feats, img_metas):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+ # TODO more flexible
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ rois = bbox2roi([proposals])
+ bbox_results = self._bbox_forward(x, rois)
+ bboxes, scores = self.bbox_head.get_bboxes(
+ rois,
+ bbox_results['cls_score'],
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ aug_bboxes.append(bboxes)
+ aug_scores.append(scores)
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ if merged_bboxes.shape[0] == 0:
+ # There is no proposal in the single image
+ det_bboxes = merged_bboxes.new_zeros(0, 5)
+ det_labels = merged_bboxes.new_zeros((0, ), dtype=torch.long)
+ else:
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes,
+ merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+ return det_bboxes, det_labels
+
+
+class MaskTestMixin:
+
+ if sys.version_info >= (3, 7):
+
+ async def async_test_mask(self,
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=False,
+ mask_test_cfg=None):
+ """Asynchronized test for mask head without augmentation."""
+ # image shape of the first image in the batch (only one)
+ ori_shape = img_metas[0]['ori_shape']
+ scale_factor = img_metas[0]['scale_factor']
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[] for _ in range(self.mask_head.num_classes)]
+ else:
+ if rescale and not isinstance(scale_factor,
+ (float, torch.Tensor)):
+ scale_factor = det_bboxes.new_tensor(scale_factor)
+ _bboxes = (
+ det_bboxes[:, :4] *
+ scale_factor if rescale else det_bboxes)
+ mask_rois = bbox2roi([_bboxes])
+ mask_feats = self.mask_roi_extractor(
+ x[:len(self.mask_roi_extractor.featmap_strides)],
+ mask_rois)
+
+ if self.with_shared_head:
+ mask_feats = self.shared_head(mask_feats)
+ if mask_test_cfg and mask_test_cfg.get('async_sleep_interval'):
+ sleep_interval = mask_test_cfg['async_sleep_interval']
+ else:
+ sleep_interval = 0.035
+ async with completed(
+ __name__,
+ 'mask_head_forward',
+ sleep_interval=sleep_interval):
+ mask_pred = self.mask_head(mask_feats)
+ segm_result = self.mask_head.get_seg_masks(
+ mask_pred, _bboxes, det_labels, self.test_cfg, ori_shape,
+ scale_factor, rescale)
+ return segm_result
+
+ def simple_test_mask(self,
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=False):
+ """Simple test for mask head without augmentation."""
+ # image shapes of images in the batch
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ if isinstance(scale_factors[0], float):
+ warnings.warn(
+ 'Scale factor in img_metas should be a '
+ 'ndarray with shape (4,) '
+ 'arrange as (factor_w, factor_h, factor_w, factor_h), '
+ 'The scale_factor with float type has been deprecated. ')
+ scale_factors = np.array([scale_factors] * 4, dtype=np.float32)
+
+ num_imgs = len(det_bboxes)
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ segm_results = [[[] for _ in range(self.mask_head.num_classes)]
+ for _ in range(num_imgs)]
+ else:
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ if rescale:
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i][:, :4]
+ for i in range(len(det_bboxes))
+ ]
+ mask_rois = bbox2roi(_bboxes)
+ mask_results = self._mask_forward(x, mask_rois)
+ mask_pred = mask_results['mask_pred']
+ # split batch mask prediction back to each image
+ num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
+ mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ else:
+ segm_result = self.mask_head.get_seg_masks(
+ mask_preds[i], _bboxes[i], det_labels[i],
+ self.test_cfg, ori_shapes[i], scale_factors[i],
+ rescale)
+ segm_results.append(segm_result)
+ return segm_results
+
+ def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
+ """Test for mask head with test time augmentation."""
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[] for _ in range(self.mask_head.num_classes)]
+ else:
+ aug_masks = []
+ for x, img_meta in zip(feats, img_metas):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ mask_rois = bbox2roi([_bboxes])
+ mask_results = self._mask_forward(x, mask_rois)
+ # convert to numpy array to save memory
+ aug_masks.append(
+ mask_results['mask_pred'].sigmoid().cpu().numpy())
+ merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg)
+
+ ori_shape = img_metas[0][0]['ori_shape']
+ scale_factor = det_bboxes.new_ones(4)
+ segm_result = self.mask_head.get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ self.test_cfg,
+ ori_shape,
+ scale_factor=scale_factor,
+ rescale=False)
+ return segm_result
diff --git a/mmdet/models/roi_heads/trident_roi_head.py b/mmdet/models/roi_heads/trident_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..09758792de83ad1a1c9026ad2950843a13daf1b5
--- /dev/null
+++ b/mmdet/models/roi_heads/trident_roi_head.py
@@ -0,0 +1,120 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.ops import batched_nms
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
+ multiclass_nms)
+from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead
+from ..builder import HEADS
+
+
+@HEADS.register_module()
+class TridentRoIHead(StandardRoIHead):
+ """Trident roi head.
+
+ Args:
+ num_branch (int): Number of branches in TridentNet.
+ test_branch_idx (int): In inference, all 3 branches will be used
+ if `test_branch_idx==-1`, otherwise only branch with index
+ `test_branch_idx` will be used.
+ """
+
+ def __init__(self, num_branch, test_branch_idx, **kwargs):
+ self.num_branch = num_branch
+ self.test_branch_idx = test_branch_idx
+ super(TridentRoIHead, self).__init__(**kwargs)
+
+ def merge_trident_bboxes(self, trident_det_bboxes, trident_det_labels):
+ """Merge bbox predictions of each branch."""
+ if trident_det_bboxes.numel() == 0:
+ det_bboxes = trident_det_bboxes.new_zeros((0, 5))
+ det_labels = trident_det_bboxes.new_zeros((0, ), dtype=torch.long)
+ else:
+ nms_bboxes = trident_det_bboxes[:, :4]
+ nms_scores = trident_det_bboxes[:, 4].contiguous()
+ nms_inds = trident_det_labels
+ nms_cfg = self.test_cfg['nms']
+ det_bboxes, keep = batched_nms(nms_bboxes, nms_scores, nms_inds,
+ nms_cfg)
+ det_labels = trident_det_labels[keep]
+ if self.test_cfg['max_per_img'] > 0:
+ det_labels = det_labels[:self.test_cfg['max_per_img']]
+ det_bboxes = det_bboxes[:self.test_cfg['max_per_img']]
+
+ return det_bboxes, det_labels
+
+ def simple_test(self,
+ x,
+ proposal_list,
+ img_metas,
+ proposals=None,
+ rescale=False):
+ """Test without augmentation as follows:
+
+ 1. Compute prediction bbox and label per branch.
+ 2. Merge predictions of each branch according to scores of
+ bboxes, i.e., bboxes with higher score are kept to give
+ top-k prediction.
+ """
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ det_bboxes_list, det_labels_list = self.simple_test_bboxes(
+ x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
+ num_branch = self.num_branch if self.test_branch_idx == -1 else 1
+ for _ in range(len(det_bboxes_list)):
+ if det_bboxes_list[_].shape[0] == 0:
+ det_bboxes_list[_] = det_bboxes_list[_].new_empty((0, 5))
+ det_bboxes, det_labels = [], []
+ for i in range(len(img_metas) // num_branch):
+ det_result = self.merge_trident_bboxes(
+ torch.cat(det_bboxes_list[i * num_branch:(i + 1) *
+ num_branch]),
+ torch.cat(det_labels_list[i * num_branch:(i + 1) *
+ num_branch]))
+ det_bboxes.append(det_result[0])
+ det_labels.append(det_result[1])
+
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head.num_classes)
+ for i in range(len(det_bboxes))
+ ]
+ return bbox_results
+
+ def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
+ """Test det bboxes with test time augmentation."""
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta in zip(feats, img_metas):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+
+ trident_bboxes, trident_scores = [], []
+ for branch_idx in range(len(proposal_list)):
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ rois = bbox2roi([proposals])
+ bbox_results = self._bbox_forward(x, rois)
+ bboxes, scores = self.bbox_head.get_bboxes(
+ rois,
+ bbox_results['cls_score'],
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ trident_bboxes.append(bboxes)
+ trident_scores.append(scores)
+
+ aug_bboxes.append(torch.cat(trident_bboxes, 0))
+ aug_scores.append(torch.cat(trident_scores, 0))
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+ return det_bboxes, det_labels
diff --git a/mmdet/models/seg_heads/__init__.py b/mmdet/models/seg_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b489a905b1e9b6cef2e8b9575600990563128e4e
--- /dev/null
+++ b/mmdet/models/seg_heads/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .panoptic_fpn_head import PanopticFPNHead # noqa: F401,F403
+from .panoptic_fusion_heads import * # noqa: F401,F403
diff --git a/mmdet/models/seg_heads/base_semantic_head.py b/mmdet/models/seg_heads/base_semantic_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b6ca145f050fbe10f348594203b6f0aa30f5695
--- /dev/null
+++ b/mmdet/models/seg_heads/base_semantic_head.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+import torch.nn.functional as F
+from mmcv.runner import BaseModule, force_fp32
+
+from ..builder import build_loss
+from ..utils import interpolate_as
+
+
+class BaseSemanticHead(BaseModule, metaclass=ABCMeta):
+ """Base module of Semantic Head.
+
+ Args:
+ num_classes (int): the number of classes.
+ init_cfg (dict): the initialization config.
+ loss_seg (dict): the loss of the semantic head.
+ """
+
+ def __init__(self,
+ num_classes,
+ init_cfg=None,
+ loss_seg=dict(
+ type='CrossEntropyLoss',
+ ignore_index=255,
+ loss_weight=1.0)):
+ super(BaseSemanticHead, self).__init__(init_cfg)
+ self.loss_seg = build_loss(loss_seg)
+ self.num_classes = num_classes
+
+ @force_fp32(apply_to=('seg_preds', ))
+ def loss(self, seg_preds, gt_semantic_seg):
+ """Get the loss of semantic head.
+
+ Args:
+ seg_preds (Tensor): The input logits with the shape (N, C, H, W).
+ gt_semantic_seg: The ground truth of semantic segmentation with
+ the shape (N, H, W).
+ label_bias: The starting number of the semantic label.
+ Default: 1.
+
+ Returns:
+ dict: the loss of semantic head.
+ """
+ if seg_preds.shape[-2:] != gt_semantic_seg.shape[-2:]:
+ seg_preds = interpolate_as(seg_preds, gt_semantic_seg)
+ seg_preds = seg_preds.permute((0, 2, 3, 1))
+
+ loss_seg = self.loss_seg(
+ seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C]
+ gt_semantic_seg.reshape(-1).long())
+ return dict(loss_seg=loss_seg)
+
+ @abstractmethod
+ def forward(self, x):
+ """Placeholder of forward function.
+
+ Returns:
+ dict[str, Tensor]: A dictionary, including features
+ and predicted scores. Required keys: 'seg_preds'
+ and 'feats'.
+ """
+ pass
+
+ def forward_train(self, x, gt_semantic_seg):
+ output = self.forward(x)
+ seg_preds = output['seg_preds']
+ return self.loss(seg_preds, gt_semantic_seg)
+
+ def simple_test(self, x, img_metas, rescale=False):
+ output = self.forward(x)
+ seg_preds = output['seg_preds']
+ seg_preds = F.interpolate(
+ seg_preds,
+ size=img_metas[0]['pad_shape'][:2],
+ mode='bilinear',
+ align_corners=False)
+
+ if rescale:
+ h, w, _ = img_metas[0]['img_shape']
+ seg_preds = seg_preds[:, :, :h, :w]
+
+ h, w, _ = img_metas[0]['ori_shape']
+ seg_preds = F.interpolate(
+ seg_preds, size=(h, w), mode='bilinear', align_corners=False)
+ return seg_preds
diff --git a/mmdet/models/seg_heads/panoptic_fpn_head.py b/mmdet/models/seg_heads/panoptic_fpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1df2976121a7668ab468b8997728683360fae14
--- /dev/null
+++ b/mmdet/models/seg_heads/panoptic_fpn_head.py
@@ -0,0 +1,155 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv.runner import ModuleList
+
+from ..builder import HEADS
+from ..utils import ConvUpsample
+from .base_semantic_head import BaseSemanticHead
+
+
+@HEADS.register_module()
+class PanopticFPNHead(BaseSemanticHead):
+ """PanopticFPNHead used in Panoptic FPN.
+
+ In this head, the number of output channels is ``num_stuff_classes
+ + 1``, including all stuff classes and one thing class. The stuff
+ classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
+ thing classes will be merged to ``num_stuff_classes``-th channel.
+
+ Arg:
+ num_things_classes (int): Number of thing classes. Default: 80.
+ num_stuff_classes (int): Number of stuff classes. Default: 53.
+ num_classes (int): Number of classes, including all stuff
+ classes and one thing class. This argument is deprecated,
+ please use ``num_things_classes`` and ``num_stuff_classes``.
+ The module will automatically infer the num_classes by
+ ``num_stuff_classes + 1``.
+ in_channels (int): Number of channels in the input feature
+ map.
+ inner_channels (int): Number of channels in inner features.
+ start_level (int): The start level of the input features
+ used in PanopticFPN.
+ end_level (int): The end level of the used features, the
+ ``end_level``-th layer will not be used.
+ fg_range (tuple): Range of the foreground classes. It starts
+ from ``0`` to ``num_things_classes-1``. Deprecated, please use
+ ``num_things_classes`` directly.
+ bg_range (tuple): Range of the background classes. It starts
+ from ``num_things_classes`` to ``num_things_classes +
+ num_stuff_classes - 1``. Deprecated, please use
+ ``num_stuff_classes`` and ``num_things_classes`` directly.
+ conv_cfg (dict): Dictionary to construct and config
+ conv layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Use ``GN`` by default.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ loss_seg (dict): the loss of the semantic head.
+ """
+
+ def __init__(self,
+ num_things_classes=80,
+ num_stuff_classes=53,
+ num_classes=None,
+ in_channels=256,
+ inner_channels=128,
+ start_level=0,
+ end_level=4,
+ fg_range=None,
+ bg_range=None,
+ conv_cfg=None,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ init_cfg=None,
+ loss_seg=dict(
+ type='CrossEntropyLoss', ignore_index=-1,
+ loss_weight=1.0)):
+ if num_classes is not None:
+ warnings.warn(
+ '`num_classes` is deprecated now, please set '
+ '`num_stuff_classes` directly, the `num_classes` will be '
+ 'set to `num_stuff_classes + 1`')
+ # num_classes = num_stuff_classes + 1 for PanopticFPN.
+ assert num_classes == num_stuff_classes + 1
+ super(PanopticFPNHead, self).__init__(num_stuff_classes + 1, init_cfg,
+ loss_seg)
+ self.num_things_classes = num_things_classes
+ self.num_stuff_classes = num_stuff_classes
+ if fg_range is not None and bg_range is not None:
+ self.fg_range = fg_range
+ self.bg_range = bg_range
+ self.num_things_classes = fg_range[1] - fg_range[0] + 1
+ self.num_stuff_classes = bg_range[1] - bg_range[0] + 1
+ warnings.warn(
+ '`fg_range` and `bg_range` are deprecated now, '
+ f'please use `num_things_classes`={self.num_things_classes} '
+ f'and `num_stuff_classes`={self.num_stuff_classes} instead.')
+
+ # Used feature layers are [start_level, end_level)
+ self.start_level = start_level
+ self.end_level = end_level
+ self.num_stages = end_level - start_level
+ self.inner_channels = inner_channels
+
+ self.conv_upsample_layers = ModuleList()
+ for i in range(start_level, end_level):
+ self.conv_upsample_layers.append(
+ ConvUpsample(
+ in_channels,
+ inner_channels,
+ num_layers=i if i > 0 else 1,
+ num_upsample=i if i > 0 else 0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ ))
+ self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)
+
+ def _set_things_to_void(self, gt_semantic_seg):
+ """Merge thing classes to one class.
+
+ In PanopticFPN, the background labels will be reset from `0` to
+ `self.num_stuff_classes-1`, the foreground labels will be merged to
+ `self.num_stuff_classes`-th channel.
+ """
+ gt_semantic_seg = gt_semantic_seg.int()
+ fg_mask = gt_semantic_seg < self.num_things_classes
+ bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
+ gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)
+
+ new_gt_seg = torch.clone(gt_semantic_seg)
+ new_gt_seg = torch.where(bg_mask,
+ gt_semantic_seg - self.num_things_classes,
+ new_gt_seg)
+ new_gt_seg = torch.where(fg_mask,
+ fg_mask.int() * self.num_stuff_classes,
+ new_gt_seg)
+ return new_gt_seg
+
+ def loss(self, seg_preds, gt_semantic_seg):
+ """The loss of PanopticFPN head.
+
+ Things classes will be merged to one class in PanopticFPN.
+ """
+ gt_semantic_seg = self._set_things_to_void(gt_semantic_seg)
+ return super().loss(seg_preds, gt_semantic_seg)
+
+ def init_weights(self):
+ super().init_weights()
+ nn.init.normal_(self.conv_logits.weight.data, 0, 0.01)
+ self.conv_logits.bias.data.zero_()
+
+ def forward(self, x):
+ # the number of subnets must be not more than
+ # the length of features.
+ assert self.num_stages <= len(x)
+
+ feats = []
+ for i, layer in enumerate(self.conv_upsample_layers):
+ f = layer(x[self.start_level + i])
+ feats.append(f)
+
+ feats = torch.sum(torch.stack(feats, dim=0), dim=0)
+ seg_preds = self.conv_logits(feats)
+ out = dict(seg_preds=seg_preds, feats=feats)
+ return out
diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py b/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..41625a61d6d1c38c633062c24b1e3455bd3ae2df
--- /dev/null
+++ b/mmdet/models/seg_heads/panoptic_fusion_heads/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_panoptic_fusion_head import \
+ BasePanopticFusionHead # noqa: F401,F403
+from .heuristic_fusion_head import HeuristicFusionHead # noqa: F401,F403
+from .maskformer_fusion_head import MaskFormerFusionHead # noqa: F401,F403
diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/base_panoptic_fusion_head.py b/mmdet/models/seg_heads/panoptic_fusion_heads/base_panoptic_fusion_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a38ac1c6cd092f0c68fa51853bcd1969de7287a7
--- /dev/null
+++ b/mmdet/models/seg_heads/panoptic_fusion_heads/base_panoptic_fusion_head.py
@@ -0,0 +1,48 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+from mmcv.runner import BaseModule
+
+from ...builder import build_loss
+
+
+class BasePanopticFusionHead(BaseModule, metaclass=ABCMeta):
+ """Base class for panoptic heads."""
+
+ def __init__(self,
+ num_things_classes=80,
+ num_stuff_classes=53,
+ test_cfg=None,
+ loss_panoptic=None,
+ init_cfg=None,
+ **kwargs):
+ super(BasePanopticFusionHead, self).__init__(init_cfg)
+ self.num_things_classes = num_things_classes
+ self.num_stuff_classes = num_stuff_classes
+ self.num_classes = num_things_classes + num_stuff_classes
+ self.test_cfg = test_cfg
+
+ if loss_panoptic:
+ self.loss_panoptic = build_loss(loss_panoptic)
+ else:
+ self.loss_panoptic = None
+
+ @property
+ def with_loss(self):
+ """bool: whether the panoptic head contains loss function."""
+ return self.loss_panoptic is not None
+
+ @abstractmethod
+ def forward_train(self, gt_masks=None, gt_semantic_seg=None, **kwargs):
+ """Forward function during training."""
+
+ @abstractmethod
+ def simple_test(self,
+ img_metas,
+ det_labels,
+ mask_preds,
+ seg_preds,
+ det_bboxes,
+ cfg=None,
+ **kwargs):
+ """Test without augmentation."""
diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/heuristic_fusion_head.py b/mmdet/models/seg_heads/panoptic_fusion_heads/heuristic_fusion_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..06c1de2b9010fef13bd2322bbd3352d82a1f3e2f
--- /dev/null
+++ b/mmdet/models/seg_heads/panoptic_fusion_heads/heuristic_fusion_head.py
@@ -0,0 +1,126 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmdet.core.evaluation.panoptic_utils import INSTANCE_OFFSET
+from mmdet.models.builder import HEADS
+from .base_panoptic_fusion_head import BasePanopticFusionHead
+
+
+@HEADS.register_module()
+class HeuristicFusionHead(BasePanopticFusionHead):
+ """Fusion Head with Heuristic method."""
+
+ def __init__(self,
+ num_things_classes=80,
+ num_stuff_classes=53,
+ test_cfg=None,
+ init_cfg=None,
+ **kwargs):
+ super(HeuristicFusionHead,
+ self).__init__(num_things_classes, num_stuff_classes, test_cfg,
+ None, init_cfg, **kwargs)
+
+ def forward_train(self, gt_masks=None, gt_semantic_seg=None, **kwargs):
+ """HeuristicFusionHead has no training loss."""
+ return dict()
+
+ def _lay_masks(self, bboxes, labels, masks, overlap_thr=0.5):
+ """Lay instance masks to a result map.
+
+ Args:
+ bboxes: The bboxes results, (K, 4).
+ labels: The labels of bboxes, (K, ).
+ masks: The instance masks, (K, H, W).
+ overlap_thr: Threshold to determine whether two masks overlap.
+ default: 0.5.
+
+ Returns:
+ Tensor: The result map, (H, W).
+ """
+ num_insts = bboxes.shape[0]
+ id_map = torch.zeros(
+ masks.shape[-2:], device=bboxes.device, dtype=torch.long)
+ if num_insts == 0:
+ return id_map, labels
+
+ scores, bboxes = bboxes[:, -1], bboxes[:, :4]
+
+ # Sort by score to use heuristic fusion
+ order = torch.argsort(-scores)
+ bboxes = bboxes[order]
+ labels = labels[order]
+ segm_masks = masks[order]
+
+ instance_id = 1
+ left_labels = []
+ for idx in range(bboxes.shape[0]):
+ _cls = labels[idx]
+ _mask = segm_masks[idx]
+ instance_id_map = torch.ones_like(
+ _mask, dtype=torch.long) * instance_id
+ area = _mask.sum()
+ if area == 0:
+ continue
+
+ pasted = id_map > 0
+ intersect = (_mask * pasted).sum()
+ if (intersect / (area + 1e-5)) > overlap_thr:
+ continue
+
+ _part = _mask * (~pasted)
+ id_map = torch.where(_part, instance_id_map, id_map)
+ left_labels.append(_cls)
+ instance_id += 1
+
+ if len(left_labels) > 0:
+ instance_labels = torch.stack(left_labels)
+ else:
+ instance_labels = bboxes.new_zeros((0, ), dtype=torch.long)
+ assert instance_id == (len(instance_labels) + 1)
+ return id_map, instance_labels
+
+ def simple_test(self, det_bboxes, det_labels, mask_preds, seg_preds,
+ **kwargs):
+ """Fuse the results of instance and semantic segmentations.
+
+ Args:
+ det_bboxes: The bboxes results, (K, 4).
+ det_labels: The labels of bboxes, (K,).
+ mask_preds: The masks results, (K, H, W).
+ seg_preds: The semantic segmentation results,
+ (K, num_stuff + 1, H, W).
+
+ Returns:
+ Tensor : The panoptic segmentation result, (H, W).
+ """
+ mask_preds = mask_preds >= self.test_cfg.mask_thr_binary
+ id_map, labels = self._lay_masks(det_bboxes, det_labels, mask_preds,
+ self.test_cfg.mask_overlap)
+
+ seg_results = seg_preds.argmax(dim=0)
+ seg_results = seg_results + self.num_things_classes
+
+ pan_results = seg_results
+ instance_id = 1
+ for idx in range(det_labels.shape[0]):
+ _mask = id_map == (idx + 1)
+ if _mask.sum() == 0:
+ continue
+ _cls = labels[idx]
+ # simply trust detection
+ segment_id = _cls + instance_id * INSTANCE_OFFSET
+ pan_results[_mask] = segment_id
+ instance_id += 1
+
+ ids, counts = torch.unique(
+ pan_results % INSTANCE_OFFSET, return_counts=True)
+ stuff_ids = ids[ids >= self.num_things_classes]
+ stuff_counts = counts[ids >= self.num_things_classes]
+ ignore_stuff_ids = stuff_ids[
+ stuff_counts < self.test_cfg.stuff_area_limit]
+
+ assert pan_results.ndim == 2
+ pan_results[(pan_results.unsqueeze(2) == ignore_stuff_ids.reshape(
+ 1, 1, -1)).any(dim=2)] = self.num_classes
+
+ return pan_results
diff --git a/mmdet/models/seg_heads/panoptic_fusion_heads/maskformer_fusion_head.py b/mmdet/models/seg_heads/panoptic_fusion_heads/maskformer_fusion_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b59ce4deaed11b98f5d9cf7a22f177eebfeb6b7
--- /dev/null
+++ b/mmdet/models/seg_heads/panoptic_fusion_heads/maskformer_fusion_head.py
@@ -0,0 +1,241 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn.functional as F
+
+from mmdet.core.evaluation.panoptic_utils import INSTANCE_OFFSET
+from mmdet.core.mask import mask2bbox
+from mmdet.models.builder import HEADS
+from .base_panoptic_fusion_head import BasePanopticFusionHead
+
+
+@HEADS.register_module()
+class MaskFormerFusionHead(BasePanopticFusionHead):
+
+ def __init__(self,
+ num_things_classes=80,
+ num_stuff_classes=53,
+ test_cfg=None,
+ loss_panoptic=None,
+ init_cfg=None,
+ **kwargs):
+ super().__init__(num_things_classes, num_stuff_classes, test_cfg,
+ loss_panoptic, init_cfg, **kwargs)
+
+ def forward_train(self, **kwargs):
+ """MaskFormerFusionHead has no training loss."""
+ return dict()
+
+ def panoptic_postprocess(self, mask_cls, mask_pred):
+ """Panoptic segmengation inference.
+
+ Args:
+ mask_cls (Tensor): Classfication outputs of shape
+ (num_queries, cls_out_channels) for a image.
+ Note `cls_out_channels` should includes
+ background.
+ mask_pred (Tensor): Mask outputs of shape
+ (num_queries, h, w) for a image.
+
+ Returns:
+ Tensor: Panoptic segment result of shape \
+ (h, w), each element in Tensor means: \
+ ``segment_id = _cls + instance_id * INSTANCE_OFFSET``.
+ """
+ object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8)
+ iou_thr = self.test_cfg.get('iou_thr', 0.8)
+ filter_low_score = self.test_cfg.get('filter_low_score', False)
+
+ scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
+ mask_pred = mask_pred.sigmoid()
+
+ keep = labels.ne(self.num_classes) & (scores > object_mask_thr)
+ cur_scores = scores[keep]
+ cur_classes = labels[keep]
+ cur_masks = mask_pred[keep]
+
+ cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
+
+ h, w = cur_masks.shape[-2:]
+ panoptic_seg = torch.full((h, w),
+ self.num_classes,
+ dtype=torch.int32,
+ device=cur_masks.device)
+ if cur_masks.shape[0] == 0:
+ # We didn't detect any mask :(
+ pass
+ else:
+ cur_mask_ids = cur_prob_masks.argmax(0)
+ instance_id = 1
+ for k in range(cur_classes.shape[0]):
+ pred_class = int(cur_classes[k].item())
+ isthing = pred_class < self.num_things_classes
+ mask = cur_mask_ids == k
+ mask_area = mask.sum().item()
+ original_area = (cur_masks[k] >= 0.5).sum().item()
+
+ if filter_low_score:
+ mask = mask & (cur_masks[k] >= 0.5)
+
+ if mask_area > 0 and original_area > 0:
+ if mask_area / original_area < iou_thr:
+ continue
+
+ if not isthing:
+ # different stuff regions of same class will be
+ # merged here, and stuff share the instance_id 0.
+ panoptic_seg[mask] = pred_class
+ else:
+ panoptic_seg[mask] = (
+ pred_class + instance_id * INSTANCE_OFFSET)
+ instance_id += 1
+
+ return panoptic_seg
+
+ def semantic_postprocess(self, mask_cls, mask_pred):
+ """Semantic segmengation postprocess.
+
+ Args:
+ mask_cls (Tensor): Classfication outputs of shape
+ (num_queries, cls_out_channels) for a image.
+ Note `cls_out_channels` should includes
+ background.
+ mask_pred (Tensor): Mask outputs of shape
+ (num_queries, h, w) for a image.
+
+ Returns:
+ Tensor: Semantic segment result of shape \
+ (cls_out_channels, h, w).
+ """
+ # TODO add semantic segmentation result
+ raise NotImplementedError
+
+ def instance_postprocess(self, mask_cls, mask_pred):
+ """Instance segmengation postprocess.
+
+ Args:
+ mask_cls (Tensor): Classfication outputs of shape
+ (num_queries, cls_out_channels) for a image.
+ Note `cls_out_channels` should includes
+ background.
+ mask_pred (Tensor): Mask outputs of shape
+ (num_queries, h, w) for a image.
+
+ Returns:
+ tuple[Tensor]: Instance segmentation results.
+
+ - labels_per_image (Tensor): Predicted labels,\
+ shape (n, ).
+ - bboxes (Tensor): Bboxes and scores with shape (n, 5) of \
+ positive region in binary mask, the last column is scores.
+ - mask_pred_binary (Tensor): Instance masks of \
+ shape (n, h, w).
+ """
+ max_per_image = self.test_cfg.get('max_per_image', 100)
+ num_queries = mask_cls.shape[0]
+ # shape (num_queries, num_class)
+ scores = F.softmax(mask_cls, dim=-1)[:, :-1]
+ # shape (num_queries * num_class, )
+ labels = torch.arange(self.num_classes, device=mask_cls.device).\
+ unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
+ scores_per_image, top_indices = scores.flatten(0, 1).topk(
+ max_per_image, sorted=False)
+ labels_per_image = labels[top_indices]
+
+ query_indices = top_indices // self.num_classes
+ mask_pred = mask_pred[query_indices]
+
+ # extract things
+ is_thing = labels_per_image < self.num_things_classes
+ scores_per_image = scores_per_image[is_thing]
+ labels_per_image = labels_per_image[is_thing]
+ mask_pred = mask_pred[is_thing]
+
+ mask_pred_binary = (mask_pred > 0).float()
+ mask_scores_per_image = (mask_pred.sigmoid() *
+ mask_pred_binary).flatten(1).sum(1) / (
+ mask_pred_binary.flatten(1).sum(1) + 1e-6)
+ det_scores = scores_per_image * mask_scores_per_image
+ mask_pred_binary = mask_pred_binary.bool()
+ bboxes = mask2bbox(mask_pred_binary)
+ bboxes = torch.cat([bboxes, det_scores[:, None]], dim=-1)
+
+ return labels_per_image, bboxes, mask_pred_binary
+
+ def simple_test(self,
+ mask_cls_results,
+ mask_pred_results,
+ img_metas,
+ rescale=False,
+ **kwargs):
+ """Test segment without test-time aumengtation.
+
+ Only the output of last decoder layers was used.
+
+ Args:
+ mask_cls_results (Tensor): Mask classification logits,
+ shape (batch_size, num_queries, cls_out_channels).
+ Note `cls_out_channels` should includes background.
+ mask_pred_results (Tensor): Mask logits, shape
+ (batch_size, num_queries, h, w).
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): If True, return boxes in
+ original image space. Default False.
+
+ Returns:
+ list[dict[str, Tensor | tuple[Tensor]]]: Semantic segmentation \
+ results and panoptic segmentation results for each \
+ image.
+
+ .. code-block:: none
+
+ [
+ {
+ 'pan_results': Tensor, # shape = [h, w]
+ 'ins_results': tuple[Tensor],
+ # semantic segmentation results are not supported yet
+ 'sem_results': Tensor
+ },
+ ...
+ ]
+ """
+ panoptic_on = self.test_cfg.get('panoptic_on', True)
+ semantic_on = self.test_cfg.get('semantic_on', False)
+ instance_on = self.test_cfg.get('instance_on', False)
+ assert not semantic_on, 'segmantic segmentation '\
+ 'results are not supported yet.'
+
+ results = []
+ for mask_cls_result, mask_pred_result, meta in zip(
+ mask_cls_results, mask_pred_results, img_metas):
+ # remove padding
+ img_height, img_width = meta['img_shape'][:2]
+ mask_pred_result = mask_pred_result[:, :img_height, :img_width]
+
+ if rescale:
+ # return result in original resolution
+ ori_height, ori_width = meta['ori_shape'][:2]
+ mask_pred_result = F.interpolate(
+ mask_pred_result[:, None],
+ size=(ori_height, ori_width),
+ mode='bilinear',
+ align_corners=False)[:, 0]
+
+ result = dict()
+ if panoptic_on:
+ pan_results = self.panoptic_postprocess(
+ mask_cls_result, mask_pred_result)
+ result['pan_results'] = pan_results
+
+ if instance_on:
+ ins_results = self.instance_postprocess(
+ mask_cls_result, mask_pred_result)
+ result['ins_results'] = ins_results
+
+ if semantic_on:
+ sem_results = self.semantic_postprocess(
+ mask_cls_result, mask_pred_result)
+ result['sem_results'] = sem_results
+
+ results.append(result)
+
+ return results
diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e74ba89e8c2101360d921a5f8437da48d0250e9a
--- /dev/null
+++ b/mmdet/models/utils/__init__.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .brick_wrappers import AdaptiveAvgPool2d, adaptive_avg_pool2d
+from .builder import build_linear_layer, build_transformer
+from .ckpt_convert import pvt_convert
+from .conv_upsample import ConvUpsample
+from .csp_layer import CSPLayer
+from .gaussian_target import gaussian_radius, gen_gaussian_target
+from .inverted_residual import InvertedResidual
+from .make_divisible import make_divisible
+from .misc import interpolate_as, sigmoid_geometric_mean
+from .normed_predictor import NormedConv2d, NormedLinear
+from .panoptic_gt_processing import preprocess_panoptic_gt
+from .point_sample import (get_uncertain_point_coords_with_randomness,
+ get_uncertainty)
+from .positional_encoding import (LearnedPositionalEncoding,
+ SinePositionalEncoding)
+from .res_layer import ResLayer, SimplifiedBasicBlock
+from .se_layer import DyReLU, SELayer
+from .transformer import (DetrTransformerDecoder, DetrTransformerDecoderLayer,
+ DynamicConv, PatchEmbed, Transformer, nchw_to_nlc,
+ nlc_to_nchw)
+
+__all__ = [
+ 'ResLayer', 'gaussian_radius', 'gen_gaussian_target',
+ 'DetrTransformerDecoderLayer', 'DetrTransformerDecoder', 'Transformer',
+ 'build_transformer', 'build_linear_layer', 'SinePositionalEncoding',
+ 'LearnedPositionalEncoding', 'DynamicConv', 'SimplifiedBasicBlock',
+ 'NormedLinear', 'NormedConv2d', 'make_divisible', 'InvertedResidual',
+ 'SELayer', 'interpolate_as', 'ConvUpsample', 'CSPLayer',
+ 'adaptive_avg_pool2d', 'AdaptiveAvgPool2d', 'PatchEmbed', 'nchw_to_nlc',
+ 'nlc_to_nchw', 'pvt_convert', 'sigmoid_geometric_mean',
+ 'preprocess_panoptic_gt', 'DyReLU',
+ 'get_uncertain_point_coords_with_randomness', 'get_uncertainty'
+]
diff --git a/mmdet/models/utils/brick_wrappers.py b/mmdet/models/utils/brick_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa0279ab60d0943bf68ea2616df9dad87e220db4
--- /dev/null
+++ b/mmdet/models/utils/brick_wrappers.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version
+
+if torch.__version__ == 'parrots':
+ TORCH_VERSION = torch.__version__
+else:
+ # torch.__version__ could be 1.3.1+cu92, we only need the first two
+ # for comparison
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
+
+
+def adaptive_avg_pool2d(input, output_size):
+ """Handle empty batch dimension to adaptive_avg_pool2d.
+
+ Args:
+ input (tensor): 4D tensor.
+ output_size (int, tuple[int,int]): the target output size.
+ """
+ if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ if isinstance(output_size, int):
+ output_size = [output_size, output_size]
+ output_size = [*input.shape[:2], *output_size]
+ empty = NewEmptyTensorOp.apply(input, output_size)
+ return empty
+ else:
+ return F.adaptive_avg_pool2d(input, output_size)
+
+
+class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d):
+ """Handle empty batch dimension to AdaptiveAvgPool2d."""
+
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ output_size = self.output_size
+ if isinstance(output_size, int):
+ output_size = [output_size, output_size]
+ else:
+ output_size = [
+ v if v is not None else d
+ for v, d in zip(output_size,
+ x.size()[-2:])
+ ]
+ output_size = [*x.shape[:2], *output_size]
+ empty = NewEmptyTensorOp.apply(x, output_size)
+ return empty
+
+ return super().forward(x)
diff --git a/mmdet/models/utils/builder.py b/mmdet/models/utils/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..20fe7a6dcfcf242728dcd7b7639032006cc6c4e2
--- /dev/null
+++ b/mmdet/models/utils/builder.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.utils import Registry, build_from_cfg
+
+TRANSFORMER = Registry('Transformer')
+LINEAR_LAYERS = Registry('linear layers')
+
+
+def build_transformer(cfg, default_args=None):
+ """Builder for Transformer."""
+ return build_from_cfg(cfg, TRANSFORMER, default_args)
+
+
+LINEAR_LAYERS.register_module('Linear', module=nn.Linear)
+
+
+def build_linear_layer(cfg, *args, **kwargs):
+ """Build linear layer.
+ Args:
+ cfg (None or dict): The linear layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an linear layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding linear layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding linear layer.
+ Returns:
+ nn.Module: Created linear layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Linear')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in LINEAR_LAYERS:
+ raise KeyError(f'Unrecognized linear type {layer_type}')
+ else:
+ linear_layer = LINEAR_LAYERS.get(layer_type)
+
+ layer = linear_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/mmdet/models/utils/ckpt_convert.py b/mmdet/models/utils/ckpt_convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d660c4e4ddbc289f6882333e5eec4360a17aaf2
--- /dev/null
+++ b/mmdet/models/utils/ckpt_convert.py
@@ -0,0 +1,137 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+# This script consists of several convert functions which
+# can modify the weights of model in original repo to be
+# pre-trained weights.
+
+from collections import OrderedDict
+
+import torch
+
+
+def pvt_convert(ckpt):
+ new_ckpt = OrderedDict()
+ # Process the concat between q linear weights and kv linear weights
+ use_abs_pos_embed = False
+ use_conv_ffn = False
+ for k in ckpt.keys():
+ if k.startswith('pos_embed'):
+ use_abs_pos_embed = True
+ if k.find('dwconv') >= 0:
+ use_conv_ffn = True
+ for k, v in ckpt.items():
+ if k.startswith('head'):
+ continue
+ if k.startswith('norm.'):
+ continue
+ if k.startswith('cls_token'):
+ continue
+ if k.startswith('pos_embed'):
+ stage_i = int(k.replace('pos_embed', ''))
+ new_k = k.replace(f'pos_embed{stage_i}',
+ f'layers.{stage_i - 1}.1.0.pos_embed')
+ if stage_i == 4 and v.size(1) == 50: # 1 (cls token) + 7 * 7
+ new_v = v[:, 1:, :] # remove cls token
+ else:
+ new_v = v
+ elif k.startswith('patch_embed'):
+ stage_i = int(k.split('.')[0].replace('patch_embed', ''))
+ new_k = k.replace(f'patch_embed{stage_i}',
+ f'layers.{stage_i - 1}.0')
+ new_v = v
+ if 'proj.' in new_k:
+ new_k = new_k.replace('proj.', 'projection.')
+ elif k.startswith('block'):
+ stage_i = int(k.split('.')[0].replace('block', ''))
+ layer_i = int(k.split('.')[1])
+ new_layer_i = layer_i + use_abs_pos_embed
+ new_k = k.replace(f'block{stage_i}.{layer_i}',
+ f'layers.{stage_i - 1}.1.{new_layer_i}')
+ new_v = v
+ if 'attn.q.' in new_k:
+ sub_item_k = k.replace('q.', 'kv.')
+ new_k = new_k.replace('q.', 'attn.in_proj_')
+ new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
+ elif 'attn.kv.' in new_k:
+ continue
+ elif 'attn.proj.' in new_k:
+ new_k = new_k.replace('proj.', 'attn.out_proj.')
+ elif 'attn.sr.' in new_k:
+ new_k = new_k.replace('sr.', 'sr.')
+ elif 'mlp.' in new_k:
+ string = f'{new_k}-'
+ new_k = new_k.replace('mlp.', 'ffn.layers.')
+ if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
+ new_v = v.reshape((*v.shape, 1, 1))
+ new_k = new_k.replace('fc1.', '0.')
+ new_k = new_k.replace('dwconv.dwconv.', '1.')
+ if use_conv_ffn:
+ new_k = new_k.replace('fc2.', '4.')
+ else:
+ new_k = new_k.replace('fc2.', '3.')
+ string += f'{new_k} {v.shape}-{new_v.shape}'
+ elif k.startswith('norm'):
+ stage_i = int(k[4])
+ new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i - 1}.2')
+ new_v = v
+ else:
+ new_k = k
+ new_v = v
+ new_ckpt[new_k] = new_v
+
+ return new_ckpt
+
+
+def swin_converter(ckpt):
+
+ new_ckpt = OrderedDict()
+
+ def correct_unfold_reduction_order(x):
+ out_channel, in_channel = x.shape
+ x = x.reshape(out_channel, 4, in_channel // 4)
+ x = x[:, [0, 2, 1, 3], :].transpose(1,
+ 2).reshape(out_channel, in_channel)
+ return x
+
+ def correct_unfold_norm_order(x):
+ in_channel = x.shape[0]
+ x = x.reshape(4, in_channel // 4)
+ x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
+ return x
+
+ for k, v in ckpt.items():
+ if k.startswith('head'):
+ continue
+ elif k.startswith('layers'):
+ new_v = v
+ if 'attn.' in k:
+ new_k = k.replace('attn.', 'attn.w_msa.')
+ elif 'mlp.' in k:
+ if 'mlp.fc1.' in k:
+ new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
+ elif 'mlp.fc2.' in k:
+ new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
+ else:
+ new_k = k.replace('mlp.', 'ffn.')
+ elif 'downsample' in k:
+ new_k = k
+ if 'reduction.' in k:
+ new_v = correct_unfold_reduction_order(v)
+ elif 'norm.' in k:
+ new_v = correct_unfold_norm_order(v)
+ else:
+ new_k = k
+ new_k = new_k.replace('layers', 'stages', 1)
+ elif k.startswith('patch_embed'):
+ new_v = v
+ if 'proj' in k:
+ new_k = k.replace('proj', 'projection')
+ else:
+ new_k = k
+ else:
+ new_v = v
+ new_k = k
+
+ new_ckpt['backbone.' + new_k] = new_v
+
+ return new_ckpt
diff --git a/mmdet/models/utils/conv_upsample.py b/mmdet/models/utils/conv_upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb5ba7670a996af7debf5a33da955faa9fb1827a
--- /dev/null
+++ b/mmdet/models/utils/conv_upsample.py
@@ -0,0 +1,67 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule, ModuleList
+
+
+class ConvUpsample(BaseModule):
+ """ConvUpsample performs 2x upsampling after Conv.
+
+ There are several `ConvModule` layers. In the first few layers, upsampling
+ will be applied after each layer of convolution. The number of upsampling
+ must be no more than the number of ConvModule layers.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ inner_channels (int): Number of channels produced by the convolution.
+ num_layers (int): Number of convolution layers.
+ num_upsample (int | optional): Number of upsampling layer. Must be no
+ more than num_layers. Upsampling will be applied after the first
+ ``num_upsample`` layers of convolution. Default: ``num_layers``.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ init_cfg (dict): Config dict for initialization. Default: None.
+ kwargs (key word augments): Other augments used in ConvModule.
+ """
+
+ def __init__(self,
+ in_channels,
+ inner_channels,
+ num_layers=1,
+ num_upsample=None,
+ conv_cfg=None,
+ norm_cfg=None,
+ init_cfg=None,
+ **kwargs):
+ super(ConvUpsample, self).__init__(init_cfg)
+ if num_upsample is None:
+ num_upsample = num_layers
+ assert num_upsample <= num_layers, \
+ f'num_upsample({num_upsample})must be no more than ' \
+ f'num_layers({num_layers})'
+ self.num_layers = num_layers
+ self.num_upsample = num_upsample
+ self.conv = ModuleList()
+ for i in range(num_layers):
+ self.conv.append(
+ ConvModule(
+ in_channels,
+ inner_channels,
+ 3,
+ padding=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ in_channels = inner_channels
+
+ def forward(self, x):
+ num_upsample = self.num_upsample
+ for i in range(self.num_layers):
+ x = self.conv[i](x)
+ if num_upsample > 0:
+ num_upsample -= 1
+ x = F.interpolate(
+ x, scale_factor=2, mode='bilinear', align_corners=False)
+ return x
diff --git a/mmdet/models/utils/csp_layer.py b/mmdet/models/utils/csp_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5760b014f25219a4f1d547edc9dcebe618ada2c5
--- /dev/null
+++ b/mmdet/models/utils/csp_layer.py
@@ -0,0 +1,150 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+from mmcv.runner import BaseModule
+
+
+class DarknetBottleneck(BaseModule):
+ """The basic bottleneck block used in Darknet.
+
+ Each ResBlock consists of two ConvModules and the input is added to the
+ final output. Each ConvModule is composed of Conv, BN, and LeakyReLU.
+ The first convLayer has filter size of 1x1 and the second one has the
+ filter size of 3x3.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ expansion (int): The kernel size of the convolution. Default: 0.5
+ add_identity (bool): Whether to add identity to the out.
+ Default: True
+ use_depthwise (bool): Whether to use depthwise separable convolution.
+ Default: False
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='Swish').
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ expansion=0.5,
+ add_identity=True,
+ use_depthwise=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish'),
+ init_cfg=None):
+ super().__init__(init_cfg)
+ hidden_channels = int(out_channels * expansion)
+ conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
+ self.conv1 = ConvModule(
+ in_channels,
+ hidden_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.conv2 = conv(
+ hidden_channels,
+ out_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.add_identity = \
+ add_identity and in_channels == out_channels
+
+ def forward(self, x):
+ identity = x
+ out = self.conv1(x)
+ out = self.conv2(out)
+
+ if self.add_identity:
+ return out + identity
+ else:
+ return out
+
+
+class CSPLayer(BaseModule):
+ """Cross Stage Partial Layer.
+
+ Args:
+ in_channels (int): The input channels of the CSP layer.
+ out_channels (int): The output channels of the CSP layer.
+ expand_ratio (float): Ratio to adjust the number of channels of the
+ hidden layer. Default: 0.5
+ num_blocks (int): Number of blocks. Default: 1
+ add_identity (bool): Whether to add identity in blocks.
+ Default: True
+ use_depthwise (bool): Whether to depthwise separable convolution in
+ blocks. Default: False
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN')
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='Swish')
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ expand_ratio=0.5,
+ num_blocks=1,
+ add_identity=True,
+ use_depthwise=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish'),
+ init_cfg=None):
+ super().__init__(init_cfg)
+ mid_channels = int(out_channels * expand_ratio)
+ self.main_conv = ConvModule(
+ in_channels,
+ mid_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.short_conv = ConvModule(
+ in_channels,
+ mid_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.final_conv = ConvModule(
+ 2 * mid_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ self.blocks = nn.Sequential(*[
+ DarknetBottleneck(
+ mid_channels,
+ mid_channels,
+ 1.0,
+ add_identity,
+ use_depthwise,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg) for _ in range(num_blocks)
+ ])
+
+ def forward(self, x):
+ x_short = self.short_conv(x)
+
+ x_main = self.main_conv(x)
+ x_main = self.blocks(x_main)
+
+ x_final = torch.cat((x_main, x_short), dim=1)
+ return self.final_conv(x_final)
diff --git a/mmdet/models/utils/gaussian_target.py b/mmdet/models/utils/gaussian_target.py
new file mode 100644
index 0000000000000000000000000000000000000000..9997d3b13a90eca2b302b170b09a445776eda1ee
--- /dev/null
+++ b/mmdet/models/utils/gaussian_target.py
@@ -0,0 +1,268 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from math import sqrt
+
+import torch
+import torch.nn.functional as F
+
+
+def gaussian2D(radius, sigma=1, dtype=torch.float32, device='cpu'):
+ """Generate 2D gaussian kernel.
+
+ Args:
+ radius (int): Radius of gaussian kernel.
+ sigma (int): Sigma of gaussian function. Default: 1.
+ dtype (torch.dtype): Dtype of gaussian tensor. Default: torch.float32.
+ device (str): Device of gaussian tensor. Default: 'cpu'.
+
+ Returns:
+ h (Tensor): Gaussian kernel with a
+ ``(2 * radius + 1) * (2 * radius + 1)`` shape.
+ """
+ x = torch.arange(
+ -radius, radius + 1, dtype=dtype, device=device).view(1, -1)
+ y = torch.arange(
+ -radius, radius + 1, dtype=dtype, device=device).view(-1, 1)
+
+ h = (-(x * x + y * y) / (2 * sigma * sigma)).exp()
+
+ h[h < torch.finfo(h.dtype).eps * h.max()] = 0
+ return h
+
+
+def gen_gaussian_target(heatmap, center, radius, k=1):
+ """Generate 2D gaussian heatmap.
+
+ Args:
+ heatmap (Tensor): Input heatmap, the gaussian kernel will cover on
+ it and maintain the max value.
+ center (list[int]): Coord of gaussian kernel's center.
+ radius (int): Radius of gaussian kernel.
+ k (int): Coefficient of gaussian kernel. Default: 1.
+
+ Returns:
+ out_heatmap (Tensor): Updated heatmap covered by gaussian kernel.
+ """
+ diameter = 2 * radius + 1
+ gaussian_kernel = gaussian2D(
+ radius, sigma=diameter / 6, dtype=heatmap.dtype, device=heatmap.device)
+
+ x, y = center
+
+ height, width = heatmap.shape[:2]
+
+ left, right = min(x, radius), min(width - x, radius + 1)
+ top, bottom = min(y, radius), min(height - y, radius + 1)
+
+ masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
+ masked_gaussian = gaussian_kernel[radius - top:radius + bottom,
+ radius - left:radius + right]
+ out_heatmap = heatmap
+ torch.max(
+ masked_heatmap,
+ masked_gaussian * k,
+ out=out_heatmap[y - top:y + bottom, x - left:x + right])
+
+ return out_heatmap
+
+
+def gaussian_radius(det_size, min_overlap):
+ r"""Generate 2D gaussian radius.
+
+ This function is modified from the `official github repo
+ `_.
+
+ Given ``min_overlap``, radius could computed by a quadratic equation
+ according to Vieta's formulas.
+
+ There are 3 cases for computing gaussian radius, details are following:
+
+ - Explanation of figure: ``lt`` and ``br`` indicates the left-top and
+ bottom-right corner of ground truth box. ``x`` indicates the
+ generated corner at the limited position when ``radius=r``.
+
+ - Case1: one corner is inside the gt box and the other is outside.
+
+ .. code:: text
+
+ |< width >|
+
+ lt-+----------+ -
+ | | | ^
+ +--x----------+--+
+ | | | |
+ | | | | height
+ | | overlap | |
+ | | | |
+ | | | | v
+ +--+---------br--+ -
+ | | |
+ +----------+--x
+
+ To ensure IoU of generated box and gt box is larger than ``min_overlap``:
+
+ .. math::
+ \cfrac{(w-r)*(h-r)}{w*h+(w+h)r-r^2} \ge {iou} \quad\Rightarrow\quad
+ {r^2-(w+h)r+\cfrac{1-iou}{1+iou}*w*h} \ge 0 \\
+ {a} = 1,\quad{b} = {-(w+h)},\quad{c} = {\cfrac{1-iou}{1+iou}*w*h} \\
+ {r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a}
+
+ - Case2: both two corners are inside the gt box.
+
+ .. code:: text
+
+ |< width >|
+
+ lt-+----------+ -
+ | | | ^
+ +--x-------+ |
+ | | | |
+ | |overlap| | height
+ | | | |
+ | +-------x--+
+ | | | v
+ +----------+-br -
+
+ To ensure IoU of generated box and gt box is larger than ``min_overlap``:
+
+ .. math::
+ \cfrac{(w-2*r)*(h-2*r)}{w*h} \ge {iou} \quad\Rightarrow\quad
+ {4r^2-2(w+h)r+(1-iou)*w*h} \ge 0 \\
+ {a} = 4,\quad {b} = {-2(w+h)},\quad {c} = {(1-iou)*w*h} \\
+ {r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a}
+
+ - Case3: both two corners are outside the gt box.
+
+ .. code:: text
+
+ |< width >|
+
+ x--+----------------+
+ | | |
+ +-lt-------------+ | -
+ | | | | ^
+ | | | |
+ | | overlap | | height
+ | | | |
+ | | | | v
+ | +------------br--+ -
+ | | |
+ +----------------+--x
+
+ To ensure IoU of generated box and gt box is larger than ``min_overlap``:
+
+ .. math::
+ \cfrac{w*h}{(w+2*r)*(h+2*r)} \ge {iou} \quad\Rightarrow\quad
+ {4*iou*r^2+2*iou*(w+h)r+(iou-1)*w*h} \le 0 \\
+ {a} = {4*iou},\quad {b} = {2*iou*(w+h)},\quad {c} = {(iou-1)*w*h} \\
+ {r} \le \cfrac{-b+\sqrt{b^2-4*a*c}}{2*a}
+
+ Args:
+ det_size (list[int]): Shape of object.
+ min_overlap (float): Min IoU with ground truth for boxes generated by
+ keypoints inside the gaussian kernel.
+
+ Returns:
+ radius (int): Radius of gaussian kernel.
+ """
+ height, width = det_size
+
+ a1 = 1
+ b1 = (height + width)
+ c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
+ sq1 = sqrt(b1**2 - 4 * a1 * c1)
+ r1 = (b1 - sq1) / (2 * a1)
+
+ a2 = 4
+ b2 = 2 * (height + width)
+ c2 = (1 - min_overlap) * width * height
+ sq2 = sqrt(b2**2 - 4 * a2 * c2)
+ r2 = (b2 - sq2) / (2 * a2)
+
+ a3 = 4 * min_overlap
+ b3 = -2 * min_overlap * (height + width)
+ c3 = (min_overlap - 1) * width * height
+ sq3 = sqrt(b3**2 - 4 * a3 * c3)
+ r3 = (b3 + sq3) / (2 * a3)
+ return min(r1, r2, r3)
+
+
+def get_local_maximum(heat, kernel=3):
+ """Extract local maximum pixel with given kernel.
+
+ Args:
+ heat (Tensor): Target heatmap.
+ kernel (int): Kernel size of max pooling. Default: 3.
+
+ Returns:
+ heat (Tensor): A heatmap where local maximum pixels maintain its
+ own value and other positions are 0.
+ """
+ pad = (kernel - 1) // 2
+ hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
+ keep = (hmax == heat).float()
+ return heat * keep
+
+
+def get_topk_from_heatmap(scores, k=20):
+ """Get top k positions from heatmap.
+
+ Args:
+ scores (Tensor): Target heatmap with shape
+ [batch, num_classes, height, width].
+ k (int): Target number. Default: 20.
+
+ Returns:
+ tuple[torch.Tensor]: Scores, indexes, categories and coords of
+ topk keypoint. Containing following Tensors:
+
+ - topk_scores (Tensor): Max scores of each topk keypoint.
+ - topk_inds (Tensor): Indexes of each topk keypoint.
+ - topk_clses (Tensor): Categories of each topk keypoint.
+ - topk_ys (Tensor): Y-coord of each topk keypoint.
+ - topk_xs (Tensor): X-coord of each topk keypoint.
+ """
+ batch, _, height, width = scores.size()
+ topk_scores, topk_inds = torch.topk(scores.view(batch, -1), k)
+ topk_clses = topk_inds // (height * width)
+ topk_inds = topk_inds % (height * width)
+ topk_ys = topk_inds // width
+ topk_xs = (topk_inds % width).int().float()
+ return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs
+
+
+def gather_feat(feat, ind, mask=None):
+ """Gather feature according to index.
+
+ Args:
+ feat (Tensor): Target feature map.
+ ind (Tensor): Target coord index.
+ mask (Tensor | None): Mask of feature map. Default: None.
+
+ Returns:
+ feat (Tensor): Gathered feature.
+ """
+ dim = feat.size(2)
+ ind = ind.unsqueeze(2).repeat(1, 1, dim)
+ feat = feat.gather(1, ind)
+ if mask is not None:
+ mask = mask.unsqueeze(2).expand_as(feat)
+ feat = feat[mask]
+ feat = feat.view(-1, dim)
+ return feat
+
+
+def transpose_and_gather_feat(feat, ind):
+ """Transpose and gather feature according to index.
+
+ Args:
+ feat (Tensor): Target feature map.
+ ind (Tensor): Target coord index.
+
+ Returns:
+ feat (Tensor): Transposed and gathered feature.
+ """
+ feat = feat.permute(0, 2, 3, 1).contiguous()
+ feat = feat.view(feat.size(0), -1, feat.size(3))
+ feat = gather_feat(feat, ind)
+ return feat
diff --git a/mmdet/models/utils/inverted_residual.py b/mmdet/models/utils/inverted_residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f241ae3e433c4aba1496cf2038ae88e9ef395ef
--- /dev/null
+++ b/mmdet/models/utils/inverted_residual.py
@@ -0,0 +1,130 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import ConvModule
+from mmcv.cnn.bricks import DropPath
+from mmcv.runner import BaseModule
+
+from .se_layer import SELayer
+
+
+class InvertedResidual(BaseModule):
+ """Inverted Residual Block.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ mid_channels (int): The input channels of the depthwise convolution.
+ kernel_size (int): The kernel size of the depthwise convolution.
+ Default: 3.
+ stride (int): The stride of the depthwise convolution. Default: 1.
+ se_cfg (dict): Config dict for se layer. Default: None, which means no
+ se layer.
+ with_expand_conv (bool): Use expand conv or not. If set False,
+ mid_channels must be the same with in_channels.
+ Default: True.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ se_cfg=None,
+ with_expand_conv=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ drop_path_rate=0.,
+ with_cp=False,
+ init_cfg=None):
+ super(InvertedResidual, self).__init__(init_cfg)
+ self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
+ assert stride in [1, 2], f'stride must in [1, 2]. ' \
+ f'But received {stride}.'
+ self.with_cp = with_cp
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+ self.with_se = se_cfg is not None
+ self.with_expand_conv = with_expand_conv
+
+ if self.with_se:
+ assert isinstance(se_cfg, dict)
+ if not self.with_expand_conv:
+ assert mid_channels == in_channels
+
+ if self.with_expand_conv:
+ self.expand_conv = ConvModule(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.depthwise_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=kernel_size // 2,
+ groups=mid_channels,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ if self.with_se:
+ self.se = SELayer(**se_cfg)
+
+ self.linear_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x
+
+ if self.with_expand_conv:
+ out = self.expand_conv(out)
+
+ out = self.depthwise_conv(out)
+
+ if self.with_se:
+ out = self.se(out)
+
+ out = self.linear_conv(out)
+
+ if self.with_res_shortcut:
+ return x + self.drop_path(out)
+ else:
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
diff --git a/mmdet/models/utils/make_divisible.py b/mmdet/models/utils/make_divisible.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed42c2eeea2a6aed03a0be5516b8d1ef1139e486
--- /dev/null
+++ b/mmdet/models/utils/make_divisible.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
+ """Make divisible function.
+
+ This function rounds the channel number to the nearest value that can be
+ divisible by the divisor. It is taken from the original tf repo. It ensures
+ that all layers have a channel number that is divisible by divisor. It can
+ be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
+
+ Args:
+ value (int): The original channel number.
+ divisor (int): The divisor to fully divide the channel number.
+ min_value (int): The minimum value of the output channel.
+ Default: None, means that the minimum value equal to the divisor.
+ min_ratio (float): The minimum ratio of the rounded channel number to
+ the original channel number. Default: 0.9.
+
+ Returns:
+ int: The modified output channel number.
+ """
+
+ if min_value is None:
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than (1-min_ratio).
+ if new_value < min_ratio * value:
+ new_value += divisor
+ return new_value
diff --git a/mmdet/models/utils/misc.py b/mmdet/models/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f9be9abb75f99a3db9b8f6e30dcdc09748c3952
--- /dev/null
+++ b/mmdet/models/utils/misc.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch.autograd import Function
+from torch.nn import functional as F
+
+
+class SigmoidGeometricMean(Function):
+ """Forward and backward function of geometric mean of two sigmoid
+ functions.
+
+ This implementation with analytical gradient function substitutes
+ the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The
+ original implementation incurs none during gradient backprapagation
+ if both x and y are very small values.
+ """
+
+ @staticmethod
+ def forward(ctx, x, y):
+ x_sigmoid = x.sigmoid()
+ y_sigmoid = y.sigmoid()
+ z = (x_sigmoid * y_sigmoid).sqrt()
+ ctx.save_for_backward(x_sigmoid, y_sigmoid, z)
+ return z
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x_sigmoid, y_sigmoid, z = ctx.saved_tensors
+ grad_x = grad_output * z * (1 - x_sigmoid) / 2
+ grad_y = grad_output * z * (1 - y_sigmoid) / 2
+ return grad_x, grad_y
+
+
+sigmoid_geometric_mean = SigmoidGeometricMean.apply
+
+
+def interpolate_as(source, target, mode='bilinear', align_corners=False):
+ """Interpolate the `source` to the shape of the `target`.
+
+ The `source` must be a Tensor, but the `target` can be a Tensor or a
+ np.ndarray with the shape (..., target_h, target_w).
+
+ Args:
+ source (Tensor): A 3D/4D Tensor with the shape (N, H, W) or
+ (N, C, H, W).
+ target (Tensor | np.ndarray): The interpolation target with the shape
+ (..., target_h, target_w).
+ mode (str): Algorithm used for interpolation. The options are the
+ same as those in F.interpolate(). Default: ``'bilinear'``.
+ align_corners (bool): The same as the argument in F.interpolate().
+
+ Returns:
+ Tensor: The interpolated source Tensor.
+ """
+ assert len(target.shape) >= 2
+
+ def _interpolate_as(source, target, mode='bilinear', align_corners=False):
+ """Interpolate the `source` (4D) to the shape of the `target`."""
+ target_h, target_w = target.shape[-2:]
+ source_h, source_w = source.shape[-2:]
+ if target_h != source_h or target_w != source_w:
+ source = F.interpolate(
+ source,
+ size=(target_h, target_w),
+ mode=mode,
+ align_corners=align_corners)
+ return source
+
+ if len(source.shape) == 3:
+ source = source[:, None, :, :]
+ source = _interpolate_as(source, target, mode, align_corners)
+ return source[:, 0, :, :]
+ else:
+ return _interpolate_as(source, target, mode, align_corners)
diff --git a/mmdet/models/utils/normed_predictor.py b/mmdet/models/utils/normed_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0eeef7db0ca8af73c87a14f925bfa52edda0232
--- /dev/null
+++ b/mmdet/models/utils/normed_predictor.py
@@ -0,0 +1,88 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import CONV_LAYERS
+
+from .builder import LINEAR_LAYERS
+
+
+@LINEAR_LAYERS.register_module(name='NormedLinear')
+class NormedLinear(nn.Linear):
+ """Normalized Linear Layer.
+
+ Args:
+ tempeature (float, optional): Tempeature term. Default to 20.
+ power (int, optional): Power term. Default to 1.0.
+ eps (float, optional): The minimal value of divisor to
+ keep numerical stability. Default to 1e-6.
+ """
+
+ def __init__(self, *args, tempearture=20, power=1.0, eps=1e-6, **kwargs):
+ super(NormedLinear, self).__init__(*args, **kwargs)
+ self.tempearture = tempearture
+ self.power = power
+ self.eps = eps
+ self.init_weights()
+
+ def init_weights(self):
+ nn.init.normal_(self.weight, mean=0, std=0.01)
+ if self.bias is not None:
+ nn.init.constant_(self.bias, 0)
+
+ def forward(self, x):
+ weight_ = self.weight / (
+ self.weight.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
+ x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
+ x_ = x_ * self.tempearture
+
+ return F.linear(x_, weight_, self.bias)
+
+
+@CONV_LAYERS.register_module(name='NormedConv2d')
+class NormedConv2d(nn.Conv2d):
+ """Normalized Conv2d Layer.
+
+ Args:
+ tempeature (float, optional): Tempeature term. Default to 20.
+ power (int, optional): Power term. Default to 1.0.
+ eps (float, optional): The minimal value of divisor to
+ keep numerical stability. Default to 1e-6.
+ norm_over_kernel (bool, optional): Normalize over kernel.
+ Default to False.
+ """
+
+ def __init__(self,
+ *args,
+ tempearture=20,
+ power=1.0,
+ eps=1e-6,
+ norm_over_kernel=False,
+ **kwargs):
+ super(NormedConv2d, self).__init__(*args, **kwargs)
+ self.tempearture = tempearture
+ self.power = power
+ self.norm_over_kernel = norm_over_kernel
+ self.eps = eps
+
+ def forward(self, x):
+ if not self.norm_over_kernel:
+ weight_ = self.weight / (
+ self.weight.norm(dim=1, keepdim=True).pow(self.power) +
+ self.eps)
+ else:
+ weight_ = self.weight / (
+ self.weight.view(self.weight.size(0), -1).norm(
+ dim=1, keepdim=True).pow(self.power)[..., None, None] +
+ self.eps)
+ x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps)
+ x_ = x_ * self.tempearture
+
+ if hasattr(self, 'conv2d_forward'):
+ x_ = self.conv2d_forward(x_, weight_)
+ else:
+ if torch.__version__ >= '1.8':
+ x_ = self._conv_forward(x_, weight_, self.bias)
+ else:
+ x_ = self._conv_forward(x_, weight_)
+ return x_
diff --git a/mmdet/models/utils/panoptic_gt_processing.py b/mmdet/models/utils/panoptic_gt_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..7685ac96fb9750e5c3dd11aa13aa22d9fc7eeb2f
--- /dev/null
+++ b/mmdet/models/utils/panoptic_gt_processing.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+def preprocess_panoptic_gt(gt_labels, gt_masks, gt_semantic_seg, num_things,
+ num_stuff, img_metas):
+ """Preprocess the ground truth for a image.
+
+ Args:
+ gt_labels (Tensor): Ground truth labels of each bbox,
+ with shape (num_gts, ).
+ gt_masks (BitmapMasks): Ground truth masks of each instances
+ of a image, shape (num_gts, h, w).
+ gt_semantic_seg (Tensor | None): Ground truth of semantic
+ segmentation with the shape (1, h, w).
+ [0, num_thing_class - 1] means things,
+ [num_thing_class, num_class-1] means stuff,
+ 255 means VOID. It's None when training instance segmentation.
+ img_metas (dict): List of image meta information.
+
+ Returns:
+ tuple: a tuple containing the following targets.
+
+ - labels (Tensor): Ground truth class indices for a
+ image, with shape (n, ), n is the sum of number
+ of stuff type and number of instance in a image.
+ - masks (Tensor): Ground truth mask for a image, with
+ shape (n, h, w). Contains stuff and things when training
+ panoptic segmentation, and things only when training
+ instance segmentation.
+ """
+ num_classes = num_things + num_stuff
+
+ things_masks = gt_masks.pad(img_metas['pad_shape'][:2], pad_val=0)\
+ .to_tensor(dtype=torch.bool, device=gt_labels.device)
+
+ if gt_semantic_seg is None:
+ masks = things_masks.long()
+ return gt_labels, masks
+
+ things_labels = gt_labels
+ gt_semantic_seg = gt_semantic_seg.squeeze(0)
+
+ semantic_labels = torch.unique(
+ gt_semantic_seg,
+ sorted=False,
+ return_inverse=False,
+ return_counts=False)
+ stuff_masks_list = []
+ stuff_labels_list = []
+ for label in semantic_labels:
+ if label < num_things or label >= num_classes:
+ continue
+ stuff_mask = gt_semantic_seg == label
+ stuff_masks_list.append(stuff_mask)
+ stuff_labels_list.append(label)
+
+ if len(stuff_masks_list) > 0:
+ stuff_masks = torch.stack(stuff_masks_list, dim=0)
+ stuff_labels = torch.stack(stuff_labels_list, dim=0)
+ labels = torch.cat([things_labels, stuff_labels], dim=0)
+ masks = torch.cat([things_masks, stuff_masks], dim=0)
+ else:
+ labels = things_labels
+ masks = things_masks
+
+ masks = masks.long()
+ return labels, masks
diff --git a/mmdet/models/utils/point_sample.py b/mmdet/models/utils/point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2c3cf91cc934987f57cf528d4a1763c0873e4b2
--- /dev/null
+++ b/mmdet/models/utils/point_sample.py
@@ -0,0 +1,87 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.ops import point_sample
+
+
+def get_uncertainty(mask_pred, labels):
+ """Estimate uncertainty based on pred logits.
+
+ We estimate uncertainty as L1 distance between 0.0 and the logits
+ prediction in 'mask_pred' for the foreground class in `classes`.
+
+ Args:
+ mask_pred (Tensor): mask predication logits, shape (num_rois,
+ num_classes, mask_height, mask_width).
+
+ labels (list[Tensor]): Either predicted or ground truth label for
+ each predicted mask, of length num_rois.
+
+ Returns:
+ scores (Tensor): Uncertainty scores with the most uncertain
+ locations having the highest uncertainty score,
+ shape (num_rois, 1, mask_height, mask_width)
+ """
+ if mask_pred.shape[1] == 1:
+ gt_class_logits = mask_pred.clone()
+ else:
+ inds = torch.arange(mask_pred.shape[0], device=mask_pred.device)
+ gt_class_logits = mask_pred[inds, labels].unsqueeze(1)
+ return -torch.abs(gt_class_logits)
+
+
+def get_uncertain_point_coords_with_randomness(mask_pred, labels, num_points,
+ oversample_ratio,
+ importance_sample_ratio):
+ """Get ``num_points`` most uncertain points with random points during
+ train.
+
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
+ uncertainty. The uncertainties are calculated for each point using
+ 'get_uncertainty()' function that takes point's logit prediction as
+ input.
+
+ Args:
+ mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
+ mask_height, mask_width) for class-specific or class-agnostic
+ prediction.
+ labels (list): The ground truth class for each instance.
+ num_points (int): The number of points to sample.
+ oversample_ratio (int): Oversampling parameter.
+ importance_sample_ratio (float): Ratio of points that are sampled
+ via importnace sampling.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
+ that contains the coordinates sampled points.
+ """
+ assert oversample_ratio >= 1
+ assert 0 <= importance_sample_ratio <= 1
+ batch_size = mask_pred.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(
+ batch_size, num_sampled, 2, device=mask_pred.device)
+ point_logits = point_sample(mask_pred, point_coords)
+ # It is crucial to calculate uncertainty based on the sampled
+ # prediction value for the points. Calculating uncertainties of the
+ # coarse predictions first and sampling them for points leads to
+ # incorrect results. To illustrate this: assume uncertainty func(
+ # logits)=-abs(logits), a sampled point between two coarse
+ # predictions with -1 and 1 logits has 0 logits, and therefore 0
+ # uncertainty value. However, if we calculate uncertainties for the
+ # coarse predictions first, both will have -1 uncertainty,
+ # and sampled point will get -1 uncertainty.
+ point_uncertainties = get_uncertainty(point_logits, labels)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(
+ point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(
+ batch_size, dtype=torch.long, device=mask_pred.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
+ batch_size, num_uncertain_points, 2)
+ if num_random_points > 0:
+ rand_roi_coords = torch.rand(
+ batch_size, num_random_points, 2, device=mask_pred.device)
+ point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
+ return point_coords
diff --git a/mmdet/models/utils/positional_encoding.py b/mmdet/models/utils/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd29cd65606e9af1b91d422fb199d71532deeffe
--- /dev/null
+++ b/mmdet/models/utils/positional_encoding.py
@@ -0,0 +1,163 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING
+from mmcv.runner import BaseModule
+
+
+@POSITIONAL_ENCODING.register_module()
+class SinePositionalEncoding(BaseModule):
+ """Position encoding with sine and cosine functions.
+
+ See `End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. Note the final returned dimension
+ for each position is 2 times of this value.
+ temperature (int, optional): The temperature used for scaling
+ the position embedding. Defaults to 10000.
+ normalize (bool, optional): Whether to normalize the position
+ embedding. Defaults to False.
+ scale (float, optional): A scale factor that scales the position
+ embedding. The scale will be used only when `normalize` is True.
+ Defaults to 2*pi.
+ eps (float, optional): A value added to the denominator for
+ numerical stability. Defaults to 1e-6.
+ offset (float): offset add to embed when do the normalization.
+ Defaults to 0.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ num_feats,
+ temperature=10000,
+ normalize=False,
+ scale=2 * math.pi,
+ eps=1e-6,
+ offset=0.,
+ init_cfg=None):
+ super(SinePositionalEncoding, self).__init__(init_cfg)
+ if normalize:
+ assert isinstance(scale, (float, int)), 'when normalize is set,' \
+ 'scale should be provided and in float or int type, ' \
+ f'found {type(scale)}'
+ self.num_feats = num_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = scale
+ self.eps = eps
+ self.offset = offset
+
+ def forward(self, mask):
+ """Forward function for `SinePositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ # For convenience of exporting to ONNX, it's required to convert
+ # `masks` from bool to int.
+ mask = mask.to(torch.int)
+ not_mask = 1 - mask # logical_not
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ y_embed = (y_embed + self.offset) / \
+ (y_embed[:, -1:, :] + self.eps) * self.scale
+ x_embed = (x_embed + self.offset) / \
+ (x_embed[:, :, -1:] + self.eps) * self.scale
+ dim_t = torch.arange(
+ self.num_feats, dtype=torch.float32, device=mask.device)
+ dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+ B, H, W = mask.size()
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
+ dim=4).view(B, H, W, -1)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
+ dim=4).view(B, H, W, -1)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_feats={self.num_feats}, '
+ repr_str += f'temperature={self.temperature}, '
+ repr_str += f'normalize={self.normalize}, '
+ repr_str += f'scale={self.scale}, '
+ repr_str += f'eps={self.eps})'
+ return repr_str
+
+
+@POSITIONAL_ENCODING.register_module()
+class LearnedPositionalEncoding(BaseModule):
+ """Position embedding with learnable embedding weights.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. The final returned dimension for
+ each position is 2 times of this value.
+ row_num_embed (int, optional): The dictionary size of row embeddings.
+ Default 50.
+ col_num_embed (int, optional): The dictionary size of col embeddings.
+ Default 50.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """
+
+ def __init__(self,
+ num_feats,
+ row_num_embed=50,
+ col_num_embed=50,
+ init_cfg=dict(type='Uniform', layer='Embedding')):
+ super(LearnedPositionalEncoding, self).__init__(init_cfg)
+ self.row_embed = nn.Embedding(row_num_embed, num_feats)
+ self.col_embed = nn.Embedding(col_num_embed, num_feats)
+ self.num_feats = num_feats
+ self.row_num_embed = row_num_embed
+ self.col_num_embed = col_num_embed
+
+ def forward(self, mask):
+ """Forward function for `LearnedPositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ h, w = mask.shape[-2:]
+ x = torch.arange(w, device=mask.device)
+ y = torch.arange(h, device=mask.device)
+ x_embed = self.col_embed(x)
+ y_embed = self.row_embed(y)
+ pos = torch.cat(
+ (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(
+ 1, w, 1)),
+ dim=-1).permute(2, 0,
+ 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_feats={self.num_feats}, '
+ repr_str += f'row_num_embed={self.row_num_embed}, '
+ repr_str += f'col_num_embed={self.col_num_embed})'
+ return repr_str
diff --git a/mmdet/models/utils/res_layer.py b/mmdet/models/utils/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c3e89fb035d197cb82173e90659dac89ff07fab
--- /dev/null
+++ b/mmdet/models/utils/res_layer.py
@@ -0,0 +1,190 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from mmcv.runner import BaseModule, Sequential
+from torch import nn as nn
+
+
+class ResLayer(Sequential):
+ """ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ downsample_first (bool): Downsample at the first block or last block.
+ False for Hourglass, True for ResNet. Default: True
+ """
+
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ downsample_first=True,
+ **kwargs):
+ self.block = block
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if downsample_first:
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ inplanes = planes * block.expansion
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+
+ else: # downsample_first=False is for HourglassModule
+ for _ in range(num_blocks - 1):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=inplanes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ super(ResLayer, self).__init__(*layers)
+
+
+class SimplifiedBasicBlock(BaseModule):
+ """Simplified version of original basic residual block. This is used in
+ `SCNet `_.
+
+ - Norm layer is now optional
+ - Last ReLU in forward function is removed
+ """
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None,
+ init_fg=None):
+ super(SimplifiedBasicBlock, self).__init__(init_fg)
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert not with_cp, 'Not implemented yet.'
+ self.with_norm = norm_cfg is not None
+ with_bias = True if norm_cfg is None else False
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=with_bias)
+ if self.with_norm:
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, planes, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg, planes, planes, 3, padding=1, bias=with_bias)
+ if self.with_norm:
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, planes, postfix=2)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name) if self.with_norm else None
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name) if self.with_norm else None
+
+ def forward(self, x):
+ """Forward function."""
+
+ identity = x
+
+ out = self.conv1(x)
+ if self.with_norm:
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ if self.with_norm:
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
diff --git a/mmdet/models/utils/se_layer.py b/mmdet/models/utils/se_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2492103b1559df3b6d3a06811ba829621ad0cae
--- /dev/null
+++ b/mmdet/models/utils/se_layer.py
@@ -0,0 +1,127 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mmcv
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import BaseModule
+
+
+class SELayer(BaseModule):
+ """Squeeze-and-Excitation Module.
+
+ Args:
+ channels (int): The input (and output) channels of the SE layer.
+ ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
+ ``int(channels/ratio)``. Default: 16.
+ conv_cfg (None or dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ act_cfg (dict or Sequence[dict]): Config dict for activation layer.
+ If act_cfg is a dict, two activation layers will be configurated
+ by this dict. If act_cfg is a sequence of dicts, the first
+ activation layer will be configurated by the first dict and the
+ second activation layer will be configurated by the second dict.
+ Default: (dict(type='ReLU'), dict(type='Sigmoid'))
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ channels,
+ ratio=16,
+ conv_cfg=None,
+ act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
+ init_cfg=None):
+ super(SELayer, self).__init__(init_cfg)
+ if isinstance(act_cfg, dict):
+ act_cfg = (act_cfg, act_cfg)
+ assert len(act_cfg) == 2
+ assert mmcv.is_tuple_of(act_cfg, dict)
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = ConvModule(
+ in_channels=channels,
+ out_channels=int(channels / ratio),
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[0])
+ self.conv2 = ConvModule(
+ in_channels=int(channels / ratio),
+ out_channels=channels,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[1])
+
+ def forward(self, x):
+ out = self.global_avgpool(x)
+ out = self.conv1(out)
+ out = self.conv2(out)
+ return x * out
+
+
+class DyReLU(BaseModule):
+ """Dynamic ReLU (DyReLU) module.
+
+ See `Dynamic ReLU `_ for details.
+ Current implementation is specialized for task-aware attention in DyHead.
+ HSigmoid arguments in default act_cfg follow DyHead official code.
+ https://github.com/microsoft/DynamicHead/blob/master/dyhead/dyrelu.py
+
+ Args:
+ channels (int): The input (and output) channels of DyReLU module.
+ ratio (int): Squeeze ratio in Squeeze-and-Excitation-like module,
+ the intermediate channel will be ``int(channels/ratio)``.
+ Default: 4.
+ conv_cfg (None or dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ act_cfg (dict or Sequence[dict]): Config dict for activation layer.
+ If act_cfg is a dict, two activation layers will be configurated
+ by this dict. If act_cfg is a sequence of dicts, the first
+ activation layer will be configurated by the first dict and the
+ second activation layer will be configurated by the second dict.
+ Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
+ divisor=6.0))
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None
+ """
+
+ def __init__(self,
+ channels,
+ ratio=4,
+ conv_cfg=None,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=3.0, divisor=6.0)),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ if isinstance(act_cfg, dict):
+ act_cfg = (act_cfg, act_cfg)
+ assert len(act_cfg) == 2
+ assert mmcv.is_tuple_of(act_cfg, dict)
+ self.channels = channels
+ self.expansion = 4 # for a1, b1, a2, b2
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = ConvModule(
+ in_channels=channels,
+ out_channels=int(channels / ratio),
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[0])
+ self.conv2 = ConvModule(
+ in_channels=int(channels / ratio),
+ out_channels=channels * self.expansion,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[1])
+
+ def forward(self, x):
+ """Forward function."""
+ coeffs = self.global_avgpool(x)
+ coeffs = self.conv1(coeffs)
+ coeffs = self.conv2(coeffs) - 0.5 # value range: [-0.5, 0.5]
+ a1, b1, a2, b2 = torch.split(coeffs, self.channels, dim=1)
+ a1 = a1 * 2.0 + 1.0 # [-1.0, 1.0] + 1.0
+ a2 = a2 * 2.0 # [-1.0, 1.0]
+ out = torch.max(x * a1 + b1, x * a2 + b2)
+ return out
diff --git a/mmdet/models/utils/transformer.py b/mmdet/models/utils/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c390c83a1aaba0d293a4e8f927e6fceead10965
--- /dev/null
+++ b/mmdet/models/utils/transformer.py
@@ -0,0 +1,1167 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import (build_activation_layer, build_conv_layer,
+ build_norm_layer, xavier_init)
+from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER,
+ TRANSFORMER_LAYER_SEQUENCE)
+from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
+ TransformerLayerSequence,
+ build_transformer_layer_sequence)
+from mmcv.runner.base_module import BaseModule
+from mmcv.utils import to_2tuple
+from torch.nn.init import normal_
+
+from mmdet.models.utils.builder import TRANSFORMER
+
+try:
+ from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
+
+except ImportError:
+ warnings.warn(
+ '`MultiScaleDeformableAttention` in MMCV has been moved to '
+ '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV')
+ from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
+
+
+def nlc_to_nchw(x, hw_shape):
+ """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
+
+ Args:
+ x (Tensor): The input tensor of shape [N, L, C] before conversion.
+ hw_shape (Sequence[int]): The height and width of output feature map.
+
+ Returns:
+ Tensor: The output tensor of shape [N, C, H, W] after conversion.
+ """
+ H, W = hw_shape
+ assert len(x.shape) == 3
+ B, L, C = x.shape
+ assert L == H * W, 'The seq_len does not match H, W'
+ return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
+
+
+def nchw_to_nlc(x):
+ """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
+
+ Args:
+ x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
+
+ Returns:
+ Tensor: The output tensor of shape [N, L, C] after conversion.
+ """
+ assert len(x.shape) == 4
+ return x.flatten(2).transpose(1, 2).contiguous()
+
+
+class AdaptivePadding(nn.Module):
+ """Applies padding to input (if needed) so that input can get fully covered
+ by filter you specified. It support two modes "same" and "corner". The
+ "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
+ input. The "corner" mode would pad zero to bottom right.
+
+ Args:
+ kernel_size (int | tuple): Size of the kernel:
+ stride (int | tuple): Stride of the filter. Default: 1:
+ dilation (int | tuple): Spacing between kernel elements.
+ Default: 1
+ padding (str): Support "same" and "corner", "corner" mode
+ would pad zero to bottom right, and "same" mode would
+ pad zero around input. Default: "corner".
+ Example:
+ >>> kernel_size = 16
+ >>> stride = 16
+ >>> dilation = 1
+ >>> input = torch.rand(1, 1, 15, 17)
+ >>> adap_pad = AdaptivePadding(
+ >>> kernel_size=kernel_size,
+ >>> stride=stride,
+ >>> dilation=dilation,
+ >>> padding="corner")
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ >>> input = torch.rand(1, 1, 16, 17)
+ >>> out = adap_pad(input)
+ >>> assert (out.shape[2], out.shape[3]) == (16, 32)
+ """
+
+ def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
+
+ super(AdaptivePadding, self).__init__()
+
+ assert padding in ('same', 'corner')
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ padding = to_2tuple(padding)
+ dilation = to_2tuple(dilation)
+
+ self.padding = padding
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+
+ def get_pad_shape(self, input_shape):
+ input_h, input_w = input_shape
+ kernel_h, kernel_w = self.kernel_size
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(input_h / stride_h)
+ output_w = math.ceil(input_w / stride_w)
+ pad_h = max((output_h - 1) * stride_h +
+ (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
+ pad_w = max((output_w - 1) * stride_w +
+ (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
+ return pad_h, pad_w
+
+ def forward(self, x):
+ pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
+ if pad_h > 0 or pad_w > 0:
+ if self.padding == 'corner':
+ x = F.pad(x, [0, pad_w, 0, pad_h])
+ elif self.padding == 'same':
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
+ pad_h - pad_h // 2
+ ])
+ return x
+
+
+class PatchEmbed(BaseModule):
+ """Image to Patch Embedding.
+
+ We use a conv layer to implement PatchEmbed.
+
+ Args:
+ in_channels (int): The num of input channels. Default: 3
+ embed_dims (int): The dimensions of embedding. Default: 768
+ conv_type (str): The config dict for embedding
+ conv layer type selection. Default: "Conv2d.
+ kernel_size (int): The kernel_size of embedding conv. Default: 16.
+ stride (int): The slide stride of embedding conv.
+ Default: None (Would be set as `kernel_size`).
+ padding (int | tuple | string ): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int): The dilation rate of embedding conv. Default: 1.
+ bias (bool): Bias of embed conv. Default: True.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ input_size (int | tuple | None): The size of input, which will be
+ used to calculate the out size. Only work when `dynamic_size`
+ is False. Default: None.
+ init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(
+ self,
+ in_channels=3,
+ embed_dims=768,
+ conv_type='Conv2d',
+ kernel_size=16,
+ stride=16,
+ padding='corner',
+ dilation=1,
+ bias=True,
+ norm_cfg=None,
+ input_size=None,
+ init_cfg=None,
+ ):
+ super(PatchEmbed, self).__init__(init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ if stride is None:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adap_padding = AdaptivePadding(
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding)
+ # disable the padding of conv
+ padding = 0
+ else:
+ self.adap_padding = None
+ padding = to_2tuple(padding)
+
+ self.projection = build_conv_layer(
+ dict(type=conv_type),
+ in_channels=in_channels,
+ out_channels=embed_dims,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
+ else:
+ self.norm = None
+
+ if input_size:
+ input_size = to_2tuple(input_size)
+ # `init_out_size` would be used outside to
+ # calculate the num_patches
+ # when `use_abs_pos_embed` outside
+ self.init_input_size = input_size
+ if self.adap_padding:
+ pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
+ input_h, input_w = input_size
+ input_h = input_h + pad_h
+ input_w = input_w + pad_w
+ input_size = (input_h, input_w)
+
+ # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
+ h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
+ (kernel_size[0] - 1) - 1) // stride[0] + 1
+ w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
+ (kernel_size[1] - 1) - 1) // stride[1] + 1
+ self.init_out_size = (h_out, w_out)
+ else:
+ self.init_input_size = None
+ self.init_out_size = None
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
+
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+
+ - x (Tensor): Has shape (B, out_h * out_w, embed_dims)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (out_h, out_w).
+ """
+
+ if self.adap_padding:
+ x = self.adap_padding(x)
+
+ x = self.projection(x)
+ out_size = (x.shape[2], x.shape[3])
+ x = x.flatten(2).transpose(1, 2)
+ if self.norm is not None:
+ x = self.norm(x)
+ return x, out_size
+
+
+class PatchMerging(BaseModule):
+ """Merge patch feature map.
+
+ This layer groups feature map by kernel_size, and applies norm and linear
+ layers to the grouped feature map. Our implementation uses `nn.Unfold` to
+ merge patch, which is about 25% faster than original implementation.
+ Instead, we need to modify pretrained models for compatibility.
+
+ Args:
+ in_channels (int): The num of input channels.
+ to gets fully covered by filter and stride you specified..
+ Default: True.
+ out_channels (int): The num of output channels.
+ kernel_size (int | tuple, optional): the kernel size in the unfold
+ layer. Defaults to 2.
+ stride (int | tuple, optional): the stride of the sliding blocks in the
+ unfold layer. Default: None. (Would be set as `kernel_size`)
+ padding (int | tuple | string ): The padding length of
+ embedding conv. When it is a string, it means the mode
+ of adaptive padding, support "same" and "corner" now.
+ Default: "corner".
+ dilation (int | tuple, optional): dilation parameter in the unfold
+ layer. Default: 1.
+ bias (bool, optional): Whether to add bias in linear layer or not.
+ Defaults: False.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (dict, optional): The extra config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=2,
+ stride=None,
+ padding='corner',
+ dilation=1,
+ bias=False,
+ norm_cfg=dict(type='LN'),
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ if stride:
+ stride = stride
+ else:
+ stride = kernel_size
+
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+
+ if isinstance(padding, str):
+ self.adap_padding = AdaptivePadding(
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding)
+ # disable the padding of unfold
+ padding = 0
+ else:
+ self.adap_padding = None
+
+ padding = to_2tuple(padding)
+ self.sampler = nn.Unfold(
+ kernel_size=kernel_size,
+ dilation=dilation,
+ padding=padding,
+ stride=stride)
+
+ sample_dim = kernel_size[0] * kernel_size[1] * in_channels
+
+ if norm_cfg is not None:
+ self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
+ else:
+ self.norm = None
+
+ self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
+
+ def forward(self, x, input_size):
+ """
+ Args:
+ x (Tensor): Has shape (B, H*W, C_in).
+ input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
+ Default: None.
+
+ Returns:
+ tuple: Contains merged results and its spatial shape.
+
+ - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
+ - out_size (tuple[int]): Spatial shape of x, arrange as
+ (Merged_H, Merged_W).
+ """
+ B, L, C = x.shape
+ assert isinstance(input_size, Sequence), f'Expect ' \
+ f'input_size is ' \
+ f'`Sequence` ' \
+ f'but get {input_size}'
+
+ H, W = input_size
+ assert L == H * W, 'input feature has wrong size'
+
+ x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
+ # Use nn.Unfold to merge patch. About 25% faster than original method,
+ # but need to modify pretrained model for compatibility
+
+ if self.adap_padding:
+ x = self.adap_padding(x)
+ H, W = x.shape[-2:]
+
+ x = self.sampler(x)
+ # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
+
+ out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
+ (self.sampler.kernel_size[0] - 1) -
+ 1) // self.sampler.stride[0] + 1
+ out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
+ (self.sampler.kernel_size[1] - 1) -
+ 1) // self.sampler.stride[1] + 1
+
+ output_size = (out_h, out_w)
+ x = x.transpose(1, 2) # B, H/2*W/2, 4*C
+ x = self.norm(x) if self.norm else x
+ x = self.reduction(x)
+ return x, output_size
+
+
+def inverse_sigmoid(x, eps=1e-5):
+ """Inverse function of sigmoid.
+
+ Args:
+ x (Tensor): The tensor to do the
+ inverse.
+ eps (float): EPS avoid numerical
+ overflow. Defaults 1e-5.
+ Returns:
+ Tensor: The x has passed the inverse
+ function of sigmoid, has same
+ shape with input.
+ """
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+@TRANSFORMER_LAYER.register_module()
+class DetrTransformerDecoderLayer(BaseTransformerLayer):
+ """Implements decoder layer in DETR transformer.
+
+ Args:
+ attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
+ Configs for self_attention or cross_attention, the order
+ should be consistent with it in `operation_order`. If it is
+ a dict, it would be expand to the number of attention in
+ `operation_order`.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ ffn_dropout (float): Probability of an element to be zeroed
+ in ffn. Default 0.0.
+ operation_order (tuple[str]): The execution order of operation
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+ Default:None
+ act_cfg (dict): The activation config for FFNs. Default: `LN`
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: `LN`.
+ ffn_num_fcs (int): The number of fully-connected layers in FFNs.
+ Default:2.
+ """
+
+ def __init__(self,
+ attn_cfgs,
+ feedforward_channels,
+ ffn_dropout=0.0,
+ operation_order=None,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ ffn_num_fcs=2,
+ **kwargs):
+ super(DetrTransformerDecoderLayer, self).__init__(
+ attn_cfgs=attn_cfgs,
+ feedforward_channels=feedforward_channels,
+ ffn_dropout=ffn_dropout,
+ operation_order=operation_order,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ ffn_num_fcs=ffn_num_fcs,
+ **kwargs)
+ assert len(operation_order) == 6
+ assert set(operation_order) == set(
+ ['self_attn', 'norm', 'cross_attn', 'ffn'])
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DetrTransformerEncoder(TransformerLayerSequence):
+ """TransformerEncoder of DETR.
+
+ Args:
+ post_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`. Only used when `self.pre_norm` is `True`
+ """
+
+ def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs):
+ super(DetrTransformerEncoder, self).__init__(*args, **kwargs)
+ if post_norm_cfg is not None:
+ self.post_norm = build_norm_layer(
+ post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None
+ else:
+ assert not self.pre_norm, f'Use prenorm in ' \
+ f'{self.__class__.__name__},' \
+ f'Please specify post_norm_cfg'
+ self.post_norm = None
+
+ def forward(self, *args, **kwargs):
+ """Forward function for `TransformerCoder`.
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)
+ if self.post_norm is not None:
+ x = self.post_norm(x)
+ return x
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DetrTransformerDecoder(TransformerLayerSequence):
+ """Implements the decoder in DETR transformer.
+
+ Args:
+ return_intermediate (bool): Whether to return intermediate outputs.
+ post_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`.
+ """
+
+ def __init__(self,
+ *args,
+ post_norm_cfg=dict(type='LN'),
+ return_intermediate=False,
+ **kwargs):
+
+ super(DetrTransformerDecoder, self).__init__(*args, **kwargs)
+ self.return_intermediate = return_intermediate
+ if post_norm_cfg is not None:
+ self.post_norm = build_norm_layer(post_norm_cfg,
+ self.embed_dims)[1]
+ else:
+ self.post_norm = None
+
+ def forward(self, query, *args, **kwargs):
+ """Forward function for `TransformerDecoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_query, bs, embed_dims)`.
+
+ Returns:
+ Tensor: Results with shape [1, num_query, bs, embed_dims] when
+ return_intermediate is `False`, otherwise it has shape
+ [num_layers, num_query, bs, embed_dims].
+ """
+ if not self.return_intermediate:
+ x = super().forward(query, *args, **kwargs)
+ if self.post_norm:
+ x = self.post_norm(x)[None]
+ return x
+
+ intermediate = []
+ for layer in self.layers:
+ query = layer(query, *args, **kwargs)
+ if self.return_intermediate:
+ if self.post_norm is not None:
+ intermediate.append(self.post_norm(query))
+ else:
+ intermediate.append(query)
+ return torch.stack(intermediate)
+
+
+@TRANSFORMER.register_module()
+class Transformer(BaseModule):
+ """Implements the DETR transformer.
+
+ Following the official DETR implementation, this module copy-paste
+ from torch.nn.Transformer with modifications:
+
+ * positional encodings are passed in MultiheadAttention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+
+ See `paper: End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ encoder (`mmcv.ConfigDict` | Dict): Config of
+ TransformerEncoder. Defaults to None.
+ decoder ((`mmcv.ConfigDict` | Dict)): Config of
+ TransformerDecoder. Defaults to None
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self, encoder=None, decoder=None, init_cfg=None):
+ super(Transformer, self).__init__(init_cfg=init_cfg)
+ self.encoder = build_transformer_layer_sequence(encoder)
+ self.decoder = build_transformer_layer_sequence(decoder)
+ self.embed_dims = self.encoder.embed_dims
+
+ def init_weights(self):
+ # follow the official DETR to init parameters
+ for m in self.modules():
+ if hasattr(m, 'weight') and m.weight.dim() > 1:
+ xavier_init(m, distribution='uniform')
+ self._is_init = True
+
+ def forward(self, x, mask, query_embed, pos_embed):
+ """Forward function for `Transformer`.
+
+ Args:
+ x (Tensor): Input query with shape [bs, c, h, w] where
+ c = embed_dims.
+ mask (Tensor): The key_padding_mask used for encoder and decoder,
+ with shape [bs, h, w].
+ query_embed (Tensor): The query embedding for decoder, with shape
+ [num_query, c].
+ pos_embed (Tensor): The positional encoding for encoder and
+ decoder, with the same shape as `x`.
+
+ Returns:
+ tuple[Tensor]: results of decoder containing the following tensor.
+
+ - out_dec: Output from decoder. If return_intermediate_dec \
+ is True output has shape [num_dec_layers, bs,
+ num_query, embed_dims], else has shape [1, bs, \
+ num_query, embed_dims].
+ - memory: Output results from encoder, with shape \
+ [bs, embed_dims, h, w].
+ """
+ bs, c, h, w = x.shape
+ # use `view` instead of `flatten` for dynamically exporting to ONNX
+ x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
+ pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(
+ 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
+ mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w]
+ memory = self.encoder(
+ query=x,
+ key=None,
+ value=None,
+ query_pos=pos_embed,
+ query_key_padding_mask=mask)
+ target = torch.zeros_like(query_embed)
+ # out_dec: [num_layers, num_query, bs, dim]
+ out_dec = self.decoder(
+ query=target,
+ key=memory,
+ value=memory,
+ key_pos=pos_embed,
+ query_pos=query_embed,
+ key_padding_mask=mask)
+ out_dec = out_dec.transpose(1, 2)
+ memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
+ return out_dec, memory
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class DeformableDetrTransformerDecoder(TransformerLayerSequence):
+ """Implements the decoder in DETR transformer.
+
+ Args:
+ return_intermediate (bool): Whether to return intermediate outputs.
+ coder_norm_cfg (dict): Config of last normalization layer. Default:
+ `LN`.
+ """
+
+ def __init__(self, *args, return_intermediate=False, **kwargs):
+
+ super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs)
+ self.return_intermediate = return_intermediate
+
+ def forward(self,
+ query,
+ *args,
+ reference_points=None,
+ valid_ratios=None,
+ reg_branches=None,
+ **kwargs):
+ """Forward function for `TransformerDecoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_query, bs, embed_dims)`.
+ reference_points (Tensor): The reference
+ points of offset. has shape
+ (bs, num_query, 4) when as_two_stage,
+ otherwise has shape ((bs, num_query, 2).
+ valid_ratios (Tensor): The radios of valid
+ points on the feature map, has shape
+ (bs, num_levels, 2)
+ reg_branch: (obj:`nn.ModuleList`): Used for
+ refining the regression results. Only would
+ be passed when with_box_refine is True,
+ otherwise would be passed a `None`.
+
+ Returns:
+ Tensor: Results with shape [1, num_query, bs, embed_dims] when
+ return_intermediate is `False`, otherwise it has shape
+ [num_layers, num_query, bs, embed_dims].
+ """
+ output = query
+ intermediate = []
+ intermediate_reference_points = []
+ for lid, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ reference_points_input = reference_points[:, :, None] * \
+ torch.cat([valid_ratios, valid_ratios], -1)[:, None]
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = reference_points[:, :, None] * \
+ valid_ratios[:, None]
+ output = layer(
+ output,
+ *args,
+ reference_points=reference_points_input,
+ **kwargs)
+ output = output.permute(1, 0, 2)
+
+ if reg_branches is not None:
+ tmp = reg_branches[lid](output)
+ if reference_points.shape[-1] == 4:
+ new_reference_points = tmp + inverse_sigmoid(
+ reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ assert reference_points.shape[-1] == 2
+ new_reference_points = tmp
+ new_reference_points[..., :2] = tmp[
+ ..., :2] + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+
+ output = output.permute(1, 0, 2)
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(
+ intermediate_reference_points)
+
+ return output, reference_points
+
+
+@TRANSFORMER.register_module()
+class DeformableDetrTransformer(Transformer):
+ """Implements the DeformableDETR transformer.
+
+ Args:
+ as_two_stage (bool): Generate query from encoder features.
+ Default: False.
+ num_feature_levels (int): Number of feature maps from FPN:
+ Default: 4.
+ two_stage_num_proposals (int): Number of proposals when set
+ `as_two_stage` as True. Default: 300.
+ """
+
+ def __init__(self,
+ as_two_stage=False,
+ num_feature_levels=4,
+ two_stage_num_proposals=300,
+ **kwargs):
+ super(DeformableDetrTransformer, self).__init__(**kwargs)
+ self.as_two_stage = as_two_stage
+ self.num_feature_levels = num_feature_levels
+ self.two_stage_num_proposals = two_stage_num_proposals
+ self.embed_dims = self.encoder.embed_dims
+ self.init_layers()
+
+ def init_layers(self):
+ """Initialize layers of the DeformableDetrTransformer."""
+ self.level_embeds = nn.Parameter(
+ torch.Tensor(self.num_feature_levels, self.embed_dims))
+
+ if self.as_two_stage:
+ self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
+ self.enc_output_norm = nn.LayerNorm(self.embed_dims)
+ self.pos_trans = nn.Linear(self.embed_dims * 2,
+ self.embed_dims * 2)
+ self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
+ else:
+ self.reference_points = nn.Linear(self.embed_dims, 2)
+
+ def init_weights(self):
+ """Initialize the transformer weights."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MultiScaleDeformableAttention):
+ m.init_weights()
+ if not self.as_two_stage:
+ xavier_init(self.reference_points, distribution='uniform', bias=0.)
+ normal_(self.level_embeds)
+
+ def gen_encoder_output_proposals(self, memory, memory_padding_mask,
+ spatial_shapes):
+ """Generate proposals from encoded memory.
+
+ Args:
+ memory (Tensor) : The output of encoder,
+ has shape (bs, num_key, embed_dim). num_key is
+ equal the number of points on feature map from
+ all level.
+ memory_padding_mask (Tensor): Padding mask for memory.
+ has shape (bs, num_key).
+ spatial_shapes (Tensor): The shape of all feature maps.
+ has shape (num_level, 2).
+
+ Returns:
+ tuple: A tuple of feature map and bbox prediction.
+
+ - output_memory (Tensor): The input of decoder, \
+ has shape (bs, num_key, embed_dim). num_key is \
+ equal the number of points on feature map from \
+ all levels.
+ - output_proposals (Tensor): The normalized proposal \
+ after a inverse sigmoid, has shape \
+ (bs, num_keys, 4).
+ """
+
+ N, S, C = memory.shape
+ proposals = []
+ _cur = 0
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view(
+ N, H, W, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(
+ 0, H - 1, H, dtype=torch.float32, device=memory.device),
+ torch.linspace(
+ 0, W - 1, W, dtype=torch.float32, device=memory.device))
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_W.unsqueeze(-1),
+ valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
+ proposal = torch.cat((grid, wh), -1).view(N, -1, 4)
+ proposals.append(proposal)
+ _cur += (H * W)
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) &
+ (output_proposals < 0.99)).all(
+ -1, keepdim=True)
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
+ output_proposals = output_proposals.masked_fill(
+ memory_padding_mask.unsqueeze(-1), float('inf'))
+ output_proposals = output_proposals.masked_fill(
+ ~output_proposals_valid, float('inf'))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(
+ memory_padding_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid,
+ float(0))
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
+ return output_memory, output_proposals
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ """Get the reference points used in decoder.
+
+ Args:
+ spatial_shapes (Tensor): The shape of all
+ feature maps, has shape (num_level, 2).
+ valid_ratios (Tensor): The radios of valid
+ points on the feature map, has shape
+ (bs, num_levels, 2)
+ device (obj:`device`): The device where
+ reference_points should be.
+
+ Returns:
+ Tensor: reference points used in decoder, has \
+ shape (bs, num_keys, num_levels, 2).
+ """
+ reference_points_list = []
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ # TODO check this 0.5
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(
+ 0.5, H - 0.5, H, dtype=torch.float32, device=device),
+ torch.linspace(
+ 0.5, W - 0.5, W, dtype=torch.float32, device=device))
+ ref_y = ref_y.reshape(-1)[None] / (
+ valid_ratios[:, None, lvl, 1] * H)
+ ref_x = ref_x.reshape(-1)[None] / (
+ valid_ratios[:, None, lvl, 0] * W)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def get_valid_ratio(self, mask):
+ """Get the valid radios of feature maps of all level."""
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def get_proposal_pos_embed(self,
+ proposals,
+ num_pos_feats=128,
+ temperature=10000):
+ """Get the position embedding of proposal."""
+ scale = 2 * math.pi
+ dim_t = torch.arange(
+ num_pos_feats, dtype=torch.float32, device=proposals.device)
+ dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
+ # N, L, 4
+ proposals = proposals.sigmoid() * scale
+ # N, L, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # N, L, 4, 64, 2
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()),
+ dim=4).flatten(2)
+ return pos
+
+ def forward(self,
+ mlvl_feats,
+ mlvl_masks,
+ query_embed,
+ mlvl_pos_embeds,
+ reg_branches=None,
+ cls_branches=None,
+ **kwargs):
+ """Forward function for `Transformer`.
+
+ Args:
+ mlvl_feats (list(Tensor)): Input queries from
+ different level. Each element has shape
+ [bs, embed_dims, h, w].
+ mlvl_masks (list(Tensor)): The key_padding_mask from
+ different level used for encoder and decoder,
+ each element has shape [bs, h, w].
+ query_embed (Tensor): The query embedding for decoder,
+ with shape [num_query, c].
+ mlvl_pos_embeds (list(Tensor)): The positional encoding
+ of feats from different level, has the shape
+ [bs, embed_dims, h, w].
+ reg_branches (obj:`nn.ModuleList`): Regression heads for
+ feature maps from each decoder layer. Only would
+ be passed when
+ `with_box_refine` is True. Default to None.
+ cls_branches (obj:`nn.ModuleList`): Classification heads
+ for feature maps from each decoder layer. Only would
+ be passed when `as_two_stage`
+ is True. Default to None.
+
+
+ Returns:
+ tuple[Tensor]: results of decoder containing the following tensor.
+
+ - inter_states: Outputs from decoder. If
+ return_intermediate_dec is True output has shape \
+ (num_dec_layers, bs, num_query, embed_dims), else has \
+ shape (1, bs, num_query, embed_dims).
+ - init_reference_out: The initial value of reference \
+ points, has shape (bs, num_queries, 4).
+ - inter_references_out: The internal value of reference \
+ points in decoder, has shape \
+ (num_dec_layers, bs,num_query, embed_dims)
+ - enc_outputs_class: The classification score of \
+ proposals generated from \
+ encoder's feature maps, has shape \
+ (batch, h*w, num_classes). \
+ Only would be returned when `as_two_stage` is True, \
+ otherwise None.
+ - enc_outputs_coord_unact: The regression results \
+ generated from encoder's feature maps., has shape \
+ (batch, h*w, 4). Only would \
+ be returned when `as_two_stage` is True, \
+ otherwise None.
+ """
+ assert self.as_two_stage or query_embed is not None
+
+ feat_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (feat, mask, pos_embed) in enumerate(
+ zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
+ bs, c, h, w = feat.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ feat = feat.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ feat_flatten.append(feat)
+ mask_flatten.append(mask)
+ feat_flatten = torch.cat(feat_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(
+ spatial_shapes, dtype=torch.long, device=feat_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros(
+ (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack(
+ [self.get_valid_ratio(m) for m in mlvl_masks], 1)
+
+ reference_points = \
+ self.get_reference_points(spatial_shapes,
+ valid_ratios,
+ device=feat.device)
+
+ feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
+ lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(
+ 1, 0, 2) # (H*W, bs, embed_dims)
+ memory = self.encoder(
+ query=feat_flatten,
+ key=None,
+ value=None,
+ query_pos=lvl_pos_embed_flatten,
+ query_key_padding_mask=mask_flatten,
+ spatial_shapes=spatial_shapes,
+ reference_points=reference_points,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ **kwargs)
+
+ memory = memory.permute(1, 0, 2)
+ bs, _, c = memory.shape
+ if self.as_two_stage:
+ output_memory, output_proposals = \
+ self.gen_encoder_output_proposals(
+ memory, mask_flatten, spatial_shapes)
+ enc_outputs_class = cls_branches[self.decoder.num_layers](
+ output_memory)
+ enc_outputs_coord_unact = \
+ reg_branches[
+ self.decoder.num_layers](output_memory) + output_proposals
+
+ topk = self.two_stage_num_proposals
+ # We only use the first channel in enc_outputs_class as foreground,
+ # the other (num_classes - 1) channels are actually not used.
+ # Its targets are set to be 0s, which indicates the first
+ # class (foreground) because we use [0, num_classes - 1] to
+ # indicate class labels, background class is indicated by
+ # num_classes (similar convention in RPN).
+ # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa
+ # This follows the official implementation of Deformable DETR.
+ topk_proposals = torch.topk(
+ enc_outputs_class[..., 0], topk, dim=1)[1]
+ topk_coords_unact = torch.gather(
+ enc_outputs_coord_unact, 1,
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
+ topk_coords_unact = topk_coords_unact.detach()
+ reference_points = topk_coords_unact.sigmoid()
+ init_reference_out = reference_points
+ pos_trans_out = self.pos_trans_norm(
+ self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
+ query_pos, query = torch.split(pos_trans_out, c, dim=2)
+ else:
+ query_pos, query = torch.split(query_embed, c, dim=1)
+ query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
+ query = query.unsqueeze(0).expand(bs, -1, -1)
+ reference_points = self.reference_points(query_pos).sigmoid()
+ init_reference_out = reference_points
+
+ # decoder
+ query = query.permute(1, 0, 2)
+ memory = memory.permute(1, 0, 2)
+ query_pos = query_pos.permute(1, 0, 2)
+ inter_states, inter_references = self.decoder(
+ query=query,
+ key=None,
+ value=memory,
+ query_pos=query_pos,
+ key_padding_mask=mask_flatten,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reg_branches=reg_branches,
+ **kwargs)
+
+ inter_references_out = inter_references
+ if self.as_two_stage:
+ return inter_states, init_reference_out,\
+ inter_references_out, enc_outputs_class,\
+ enc_outputs_coord_unact
+ return inter_states, init_reference_out, \
+ inter_references_out, None, None
+
+
+@TRANSFORMER.register_module()
+class DynamicConv(BaseModule):
+ """Implements Dynamic Convolution.
+
+ This module generate parameters for each sample and
+ use bmm to implement 1*1 convolution. Code is modified
+ from the `official github repo `_ .
+
+ Args:
+ in_channels (int): The input feature channel.
+ Defaults to 256.
+ feat_channels (int): The inner feature channel.
+ Defaults to 64.
+ out_channels (int, optional): The output feature channel.
+ When not specified, it will be set to `in_channels`
+ by default
+ input_feat_shape (int): The shape of input feature.
+ Defaults to 7.
+ with_proj (bool): Project two-dimentional feature to
+ one-dimentional feature. Default to True.
+ act_cfg (dict): The activation config for DynamicConv.
+ norm_cfg (dict): Config dict for normalization layer. Default
+ layer normalization.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels=256,
+ feat_channels=64,
+ out_channels=None,
+ input_feat_shape=7,
+ with_proj=True,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ init_cfg=None):
+ super(DynamicConv, self).__init__(init_cfg)
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.out_channels_raw = out_channels
+ self.input_feat_shape = input_feat_shape
+ self.with_proj = with_proj
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.out_channels = out_channels if out_channels else in_channels
+
+ self.num_params_in = self.in_channels * self.feat_channels
+ self.num_params_out = self.out_channels * self.feat_channels
+ self.dynamic_layer = nn.Linear(
+ self.in_channels, self.num_params_in + self.num_params_out)
+
+ self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ self.activation = build_activation_layer(act_cfg)
+
+ num_output = self.out_channels * input_feat_shape**2
+ if self.with_proj:
+ self.fc_layer = nn.Linear(num_output, self.out_channels)
+ self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ def forward(self, param_feature, input_feature):
+ """Forward function for `DynamicConv`.
+
+ Args:
+ param_feature (Tensor): The feature can be used
+ to generate the parameter, has shape
+ (num_all_proposals, in_channels).
+ input_feature (Tensor): Feature that
+ interact with parameters, has shape
+ (num_all_proposals, in_channels, H, W).
+
+ Returns:
+ Tensor: The output feature has shape
+ (num_all_proposals, out_channels).
+ """
+ input_feature = input_feature.flatten(2).permute(2, 0, 1)
+
+ input_feature = input_feature.permute(1, 0, 2)
+ parameters = self.dynamic_layer(param_feature)
+
+ param_in = parameters[:, :self.num_params_in].view(
+ -1, self.in_channels, self.feat_channels)
+ param_out = parameters[:, -self.num_params_out:].view(
+ -1, self.feat_channels, self.out_channels)
+
+ # input_feature has shape (num_all_proposals, H*W, in_channels)
+ # param_in has shape (num_all_proposals, in_channels, feat_channels)
+ # feature has shape (num_all_proposals, H*W, feat_channels)
+ features = torch.bmm(input_feature, param_in)
+ features = self.norm_in(features)
+ features = self.activation(features)
+
+ # param_out has shape (batch_size, feat_channels, out_channels)
+ features = torch.bmm(features, param_out)
+ features = self.norm_out(features)
+ features = self.activation(features)
+
+ if self.with_proj:
+ features = features.flatten(1)
+ features = self.fc_layer(features)
+ features = self.fc_norm(features)
+ features = self.activation(features)
+
+ return features
diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5a2b6b3a0e6031954b434686dd49d87de331baa
--- /dev/null
+++ b/mmdet/utils/__init__.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .ascend_util import (batch_images_to_levels,
+ get_max_num_gt_division_factor, masked_fill)
+from .collect_env import collect_env
+from .compat_config import compat_cfg
+from .logger import get_caller_name, get_root_logger, log_img_scale
+from .memory import AvoidCUDAOOM, AvoidOOM
+from .misc import find_latest_checkpoint, update_data_root
+from .replace_cfg_vals import replace_cfg_vals
+from .rfnext import rfnext_init_model
+from .setup_env import setup_multi_processes
+from .split_batch import split_batch
+from .util_distribution import build_ddp, build_dp, get_device
+
+__all__ = [
+ 'get_root_logger', 'collect_env', 'find_latest_checkpoint',
+ 'update_data_root', 'setup_multi_processes', 'get_caller_name',
+ 'log_img_scale', 'compat_cfg', 'split_batch', 'build_ddp', 'build_dp',
+ 'get_device', 'replace_cfg_vals', 'AvoidOOM', 'AvoidCUDAOOM',
+ 'get_max_num_gt_division_factor', 'masked_fill', 'batch_images_to_levels',
+ 'rfnext_init_model'
+]
diff --git a/mmdet/utils/ascend_util.py b/mmdet/utils/ascend_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..df90dec820567e8c129baf44de788e6735ef4b94
--- /dev/null
+++ b/mmdet/utils/ascend_util.py
@@ -0,0 +1,69 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+def masked_fill(ori_tensor, mask, new_value, neg=False):
+ """The Value of ori_tensor is new_value, depending on mask.
+
+ Args:
+ ori_tensor (Tensor): Input tensor.
+ mask (Tensor): If select new_value.
+ new_value(Tensor | scalar): Value selected for ori_tensor.
+ neg (bool): If True, select ori_tensor. If False, select new_value.
+ Returns:
+ ori_tensor: (Tensor): The Value of ori_tensor is new_value,
+ depending on mask.
+ """
+ if mask is None:
+ return ori_tensor
+ else:
+ if neg:
+ return ori_tensor * mask + new_value * (1 - mask)
+ else:
+ return ori_tensor * (1 - mask) + new_value * mask
+
+
+def batch_images_to_levels(target, num_levels):
+ """Convert targets by image to targets by feature level.
+
+ [target_img0, target_img1] -> [target_level0, target_level1, ...] or
+ target_imgs -> [target_level0, target_level1, ...]
+ Args:
+ target (Tensor | List[Tensor]): Tensor split to image levels.
+ num_levels (List[int]): Image levels num.
+ Returns:
+ level_targets: (Tensor): Tensor split by image levels.
+ """
+ if not isinstance(target, torch.Tensor):
+ target = torch.stack(target, 0)
+ level_targets = []
+ start = 0
+ for n in num_levels:
+ end = start + n
+ # level_targets.append(target[:, start:end].squeeze(0))
+ level_targets.append(target[:, start:end])
+ start = end
+ return level_targets
+
+
+def get_max_num_gt_division_factor(gt_nums,
+ min_num_gt=32,
+ max_num_gt=1024,
+ division_factor=2):
+ """Count max num of gt.
+
+ Args:
+ gt_nums (List[int]): Ground truth bboxes num of images.
+ min_num_gt (int): Min num of ground truth bboxes.
+ max_num_gt (int): Max num of ground truth bboxes.
+ division_factor (int): Division factor of result.
+ Returns:
+ max_gt_nums_align: (int): max num of ground truth bboxes.
+ """
+ max_gt_nums = max(gt_nums)
+ max_gt_nums_align = min_num_gt
+ while max_gt_nums_align < max_gt_nums:
+ max_gt_nums_align *= division_factor
+ if max_gt_nums_align > max_num_gt:
+ raise RuntimeError
+ return max_gt_nums_align
diff --git a/mmdet/utils/collect_env.py b/mmdet/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..97e25c0e95394dcced4b9ddd25df7a16758886d5
--- /dev/null
+++ b/mmdet/utils/collect_env.py
@@ -0,0 +1,17 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.utils import collect_env as collect_base_env
+from mmcv.utils import get_git_hash
+
+import mmdet
+
+
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = collect_base_env()
+ env_info['MMDetection'] = mmdet.__version__ + '+' + get_git_hash()[:7]
+ return env_info
+
+
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print(f'{name}: {val}')
diff --git a/mmdet/utils/compat_config.py b/mmdet/utils/compat_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..05aa37dcd6f74dd1884069e90edf39684c897798
--- /dev/null
+++ b/mmdet/utils/compat_config.py
@@ -0,0 +1,139 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+from mmcv import ConfigDict
+
+
+def compat_cfg(cfg):
+ """This function would modify some filed to keep the compatibility of
+ config.
+
+ For example, it will move some args which will be deprecated to the correct
+ fields.
+ """
+ cfg = copy.deepcopy(cfg)
+ cfg = compat_imgs_per_gpu(cfg)
+ cfg = compat_loader_args(cfg)
+ cfg = compat_runner_args(cfg)
+ return cfg
+
+
+def compat_runner_args(cfg):
+ if 'runner' not in cfg:
+ cfg.runner = ConfigDict({
+ 'type': 'EpochBasedRunner',
+ 'max_epochs': cfg.total_epochs
+ })
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+ else:
+ if 'total_epochs' in cfg:
+ assert cfg.total_epochs == cfg.runner.max_epochs
+ return cfg
+
+
+def compat_imgs_per_gpu(cfg):
+ cfg = copy.deepcopy(cfg)
+ if 'imgs_per_gpu' in cfg.data:
+ warnings.warn('"imgs_per_gpu" is deprecated in MMDet V2.0. '
+ 'Please use "samples_per_gpu" instead')
+ if 'samples_per_gpu' in cfg.data:
+ warnings.warn(
+ f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
+ f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
+ f'={cfg.data.imgs_per_gpu} is used in this experiments')
+ else:
+ warnings.warn('Automatically set "samples_per_gpu"="imgs_per_gpu"='
+ f'{cfg.data.imgs_per_gpu} in this experiments')
+ cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
+ return cfg
+
+
+def compat_loader_args(cfg):
+ """Deprecated sample_per_gpu in cfg.data."""
+
+ cfg = copy.deepcopy(cfg)
+ if 'train_dataloader' not in cfg.data:
+ cfg.data['train_dataloader'] = ConfigDict()
+ if 'val_dataloader' not in cfg.data:
+ cfg.data['val_dataloader'] = ConfigDict()
+ if 'test_dataloader' not in cfg.data:
+ cfg.data['test_dataloader'] = ConfigDict()
+
+ # special process for train_dataloader
+ if 'samples_per_gpu' in cfg.data:
+
+ samples_per_gpu = cfg.data.pop('samples_per_gpu')
+ assert 'samples_per_gpu' not in \
+ cfg.data.train_dataloader, ('`samples_per_gpu` are set '
+ 'in `data` field and ` '
+ 'data.train_dataloader` '
+ 'at the same time. '
+ 'Please only set it in '
+ '`data.train_dataloader`. ')
+ cfg.data.train_dataloader['samples_per_gpu'] = samples_per_gpu
+
+ if 'persistent_workers' in cfg.data:
+
+ persistent_workers = cfg.data.pop('persistent_workers')
+ assert 'persistent_workers' not in \
+ cfg.data.train_dataloader, ('`persistent_workers` are set '
+ 'in `data` field and ` '
+ 'data.train_dataloader` '
+ 'at the same time. '
+ 'Please only set it in '
+ '`data.train_dataloader`. ')
+ cfg.data.train_dataloader['persistent_workers'] = persistent_workers
+
+ if 'workers_per_gpu' in cfg.data:
+
+ workers_per_gpu = cfg.data.pop('workers_per_gpu')
+ cfg.data.train_dataloader['workers_per_gpu'] = workers_per_gpu
+ cfg.data.val_dataloader['workers_per_gpu'] = workers_per_gpu
+ cfg.data.test_dataloader['workers_per_gpu'] = workers_per_gpu
+
+ # special process for val_dataloader
+ if 'samples_per_gpu' in cfg.data.val:
+ # keep default value of `sample_per_gpu` is 1
+ assert 'samples_per_gpu' not in \
+ cfg.data.val_dataloader, ('`samples_per_gpu` are set '
+ 'in `data.val` field and ` '
+ 'data.val_dataloader` at '
+ 'the same time. '
+ 'Please only set it in '
+ '`data.val_dataloader`. ')
+ cfg.data.val_dataloader['samples_per_gpu'] = \
+ cfg.data.val.pop('samples_per_gpu')
+ # special process for val_dataloader
+
+ # in case the test dataset is concatenated
+ if isinstance(cfg.data.test, dict):
+ if 'samples_per_gpu' in cfg.data.test:
+ assert 'samples_per_gpu' not in \
+ cfg.data.test_dataloader, ('`samples_per_gpu` are set '
+ 'in `data.test` field and ` '
+ 'data.test_dataloader` '
+ 'at the same time. '
+ 'Please only set it in '
+ '`data.test_dataloader`. ')
+
+ cfg.data.test_dataloader['samples_per_gpu'] = \
+ cfg.data.test.pop('samples_per_gpu')
+
+ elif isinstance(cfg.data.test, list):
+ for ds_cfg in cfg.data.test:
+ if 'samples_per_gpu' in ds_cfg:
+ assert 'samples_per_gpu' not in \
+ cfg.data.test_dataloader, ('`samples_per_gpu` are set '
+ 'in `data.test` field and ` '
+ 'data.test_dataloader` at'
+ ' the same time. '
+ 'Please only set it in '
+ '`data.test_dataloader`. ')
+ samples_per_gpu = max(
+ [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
+ cfg.data.test_dataloader['samples_per_gpu'] = samples_per_gpu
+
+ return cfg
diff --git a/mmdet/utils/contextmanagers.py b/mmdet/utils/contextmanagers.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa12bfcaff1e781b0a8cc7d7c8b839c2f2955a05
--- /dev/null
+++ b/mmdet/utils/contextmanagers.py
@@ -0,0 +1,122 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import asyncio
+import contextlib
+import logging
+import os
+import time
+from typing import List
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False))
+
+
+@contextlib.asynccontextmanager
+async def completed(trace_name='',
+ name='',
+ sleep_interval=0.05,
+ streams: List[torch.cuda.Stream] = None):
+ """Async context manager that waits for work to complete on given CUDA
+ streams."""
+ if not torch.cuda.is_available():
+ yield
+ return
+
+ stream_before_context_switch = torch.cuda.current_stream()
+ if not streams:
+ streams = [stream_before_context_switch]
+ else:
+ streams = [s if s else stream_before_context_switch for s in streams]
+
+ end_events = [
+ torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams
+ ]
+
+ if DEBUG_COMPLETED_TIME:
+ start = torch.cuda.Event(enable_timing=True)
+ stream_before_context_switch.record_event(start)
+
+ cpu_start = time.monotonic()
+ logger.debug('%s %s starting, streams: %s', trace_name, name, streams)
+ grad_enabled_before = torch.is_grad_enabled()
+ try:
+ yield
+ finally:
+ current_stream = torch.cuda.current_stream()
+ assert current_stream == stream_before_context_switch
+
+ if DEBUG_COMPLETED_TIME:
+ cpu_end = time.monotonic()
+ for i, stream in enumerate(streams):
+ event = end_events[i]
+ stream.record_event(event)
+
+ grad_enabled_after = torch.is_grad_enabled()
+
+ # observed change of torch.is_grad_enabled() during concurrent run of
+ # async_test_bboxes code
+ assert (grad_enabled_before == grad_enabled_after
+ ), 'Unexpected is_grad_enabled() value change'
+
+ are_done = [e.query() for e in end_events]
+ logger.debug('%s %s completed: %s streams: %s', trace_name, name,
+ are_done, streams)
+ with torch.cuda.stream(stream_before_context_switch):
+ while not all(are_done):
+ await asyncio.sleep(sleep_interval)
+ are_done = [e.query() for e in end_events]
+ logger.debug(
+ '%s %s completed: %s streams: %s',
+ trace_name,
+ name,
+ are_done,
+ streams,
+ )
+
+ current_stream = torch.cuda.current_stream()
+ assert current_stream == stream_before_context_switch
+
+ if DEBUG_COMPLETED_TIME:
+ cpu_time = (cpu_end - cpu_start) * 1000
+ stream_times_ms = ''
+ for i, stream in enumerate(streams):
+ elapsed_time = start.elapsed_time(end_events[i])
+ stream_times_ms += f' {stream} {elapsed_time:.2f} ms'
+ logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time,
+ stream_times_ms)
+
+
+@contextlib.asynccontextmanager
+async def concurrent(streamqueue: asyncio.Queue,
+ trace_name='concurrent',
+ name='stream'):
+ """Run code concurrently in different streams.
+
+ :param streamqueue: asyncio.Queue instance.
+
+ Queue tasks define the pool of streams used for concurrent execution.
+ """
+ if not torch.cuda.is_available():
+ yield
+ return
+
+ initial_stream = torch.cuda.current_stream()
+
+ with torch.cuda.stream(initial_stream):
+ stream = await streamqueue.get()
+ assert isinstance(stream, torch.cuda.Stream)
+
+ try:
+ with torch.cuda.stream(stream):
+ logger.debug('%s %s is starting, stream: %s', trace_name, name,
+ stream)
+ yield
+ current = torch.cuda.current_stream()
+ assert current == stream
+ logger.debug('%s %s has finished, stream: %s', trace_name,
+ name, stream)
+ finally:
+ streamqueue.task_done()
+ streamqueue.put_nowait(stream)
diff --git a/mmdet/utils/logger.py b/mmdet/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..485f641b709d88f21789c7c6048ff058bcb2bf29
--- /dev/null
+++ b/mmdet/utils/logger.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import logging
+
+from mmcv.utils import get_logger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ """Get root logger.
+
+ Args:
+ log_file (str, optional): File path of log. Defaults to None.
+ log_level (int, optional): The level of logger.
+ Defaults to logging.INFO.
+
+ Returns:
+ :obj:`logging.Logger`: The obtained logger
+ """
+ logger = get_logger(name='mmdet', log_file=log_file, log_level=log_level)
+
+ return logger
+
+
+def get_caller_name():
+ """Get name of caller method."""
+ # this_func_frame = inspect.stack()[0][0] # i.e., get_caller_name
+ # callee_frame = inspect.stack()[1][0] # e.g., log_img_scale
+ caller_frame = inspect.stack()[2][0] # e.g., caller of log_img_scale
+ caller_method = caller_frame.f_code.co_name
+ try:
+ caller_class = caller_frame.f_locals['self'].__class__.__name__
+ return f'{caller_class}.{caller_method}'
+ except KeyError: # caller is a function
+ return caller_method
+
+
+def log_img_scale(img_scale, shape_order='hw', skip_square=False):
+ """Log image size.
+
+ Args:
+ img_scale (tuple): Image size to be logged.
+ shape_order (str, optional): The order of image shape.
+ 'hw' for (height, width) and 'wh' for (width, height).
+ Defaults to 'hw'.
+ skip_square (bool, optional): Whether to skip logging for square
+ img_scale. Defaults to False.
+
+ Returns:
+ bool: Whether to have done logging.
+ """
+ if shape_order == 'hw':
+ height, width = img_scale
+ elif shape_order == 'wh':
+ width, height = img_scale
+ else:
+ raise ValueError(f'Invalid shape_order {shape_order}.')
+
+ if skip_square and (height == width):
+ return False
+
+ logger = get_root_logger()
+ caller = get_caller_name()
+ logger.info(f'image shape: height={height}, width={width} in {caller}')
+
+ return True
diff --git a/mmdet/utils/memory.py b/mmdet/utils/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb212bcaed139e5c9db595186ee8e16677921512
--- /dev/null
+++ b/mmdet/utils/memory.py
@@ -0,0 +1,213 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from collections import abc
+from contextlib import contextmanager
+from functools import wraps
+
+import torch
+
+from mmdet.utils import get_root_logger
+
+
+def cast_tensor_type(inputs, src_type=None, dst_type=None):
+ """Recursively convert Tensor in inputs from ``src_type`` to ``dst_type``.
+
+ Args:
+ inputs: Inputs that to be casted.
+ src_type (torch.dtype | torch.device): Source type.
+ src_type (torch.dtype | torch.device): Destination type.
+
+ Returns:
+ The same type with inputs, but all contained Tensors have been cast.
+ """
+ assert dst_type is not None
+ if isinstance(inputs, torch.Tensor):
+ if isinstance(dst_type, torch.device):
+ # convert Tensor to dst_device
+ if hasattr(inputs, 'to') and \
+ hasattr(inputs, 'device') and \
+ (inputs.device == src_type or src_type is None):
+ return inputs.to(dst_type)
+ else:
+ return inputs
+ else:
+ # convert Tensor to dst_dtype
+ if hasattr(inputs, 'to') and \
+ hasattr(inputs, 'dtype') and \
+ (inputs.dtype == src_type or src_type is None):
+ return inputs.to(dst_type)
+ else:
+ return inputs
+ # we need to ensure that the type of inputs to be casted are the same
+ # as the argument `src_type`.
+ elif isinstance(inputs, abc.Mapping):
+ return type(inputs)({
+ k: cast_tensor_type(v, src_type=src_type, dst_type=dst_type)
+ for k, v in inputs.items()
+ })
+ elif isinstance(inputs, abc.Iterable):
+ return type(inputs)(
+ cast_tensor_type(item, src_type=src_type, dst_type=dst_type)
+ for item in inputs)
+ # TODO: Currently not supported
+ # elif isinstance(inputs, InstanceData):
+ # for key, value in inputs.items():
+ # inputs[key] = cast_tensor_type(
+ # value, src_type=src_type, dst_type=dst_type)
+ # return inputs
+ else:
+ return inputs
+
+
+@contextmanager
+def _ignore_torch_cuda_oom():
+ """A context which ignores CUDA OOM exception from pytorch.
+
+ Code is modified from
+ # noqa: E501
+ """
+ try:
+ yield
+ except RuntimeError as e:
+ # NOTE: the string may change?
+ if 'CUDA out of memory. ' in str(e):
+ pass
+ else:
+ raise
+
+
+class AvoidOOM:
+ """Try to convert inputs to FP16 and CPU if got a PyTorch's CUDA Out of
+ Memory error. It will do the following steps:
+
+ 1. First retry after calling `torch.cuda.empty_cache()`.
+ 2. If that still fails, it will then retry by converting inputs
+ to FP16.
+ 3. If that still fails trying to convert inputs to CPUs.
+ In this case, it expects the function to dispatch to
+ CPU implementation.
+
+ Args:
+ to_cpu (bool): Whether to convert outputs to CPU if get an OOM
+ error. This will slow down the code significantly.
+ Defaults to True.
+ test (bool): Skip `_ignore_torch_cuda_oom` operate that can use
+ lightweight data in unit test, only used in
+ test unit. Defaults to False.
+
+ Examples:
+ >>> from mmdet.utils.memory import AvoidOOM
+ >>> AvoidCUDAOOM = AvoidOOM()
+ >>> output = AvoidOOM.retry_if_cuda_oom(
+ >>> some_torch_function)(input1, input2)
+ >>> # To use as a decorator
+ >>> # from mmdet.utils import AvoidCUDAOOM
+ >>> @AvoidCUDAOOM.retry_if_cuda_oom
+ >>> def function(*args, **kwargs):
+ >>> return None
+ ```
+
+ Note:
+ 1. The output may be on CPU even if inputs are on GPU. Processing
+ on CPU will slow down the code significantly.
+ 2. When converting inputs to CPU, it will only look at each argument
+ and check if it has `.device` and `.to` for conversion. Nested
+ structures of tensors are not supported.
+ 3. Since the function might be called more than once, it has to be
+ stateless.
+ """
+
+ def __init__(self, to_cpu=True, test=False):
+ self.to_cpu = to_cpu
+ self.test = test
+
+ def retry_if_cuda_oom(self, func):
+ """Makes a function retry itself after encountering pytorch's CUDA OOM
+ error.
+
+ The implementation logic is referred to
+ https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py
+
+ Args:
+ func: a stateless callable that takes tensor-like objects
+ as arguments.
+ Returns:
+ func: a callable which retries `func` if OOM is encountered.
+ """ # noqa: W605
+
+ @wraps(func)
+ def wrapped(*args, **kwargs):
+
+ # raw function
+ if not self.test:
+ with _ignore_torch_cuda_oom():
+ return func(*args, **kwargs)
+
+ # Clear cache and retry
+ torch.cuda.empty_cache()
+ with _ignore_torch_cuda_oom():
+ return func(*args, **kwargs)
+
+ # get the type and device of first tensor
+ dtype, device = None, None
+ values = args + tuple(kwargs.values())
+ for value in values:
+ if isinstance(value, torch.Tensor):
+ dtype = value.dtype
+ device = value.device
+ break
+ if dtype is None or device is None:
+ raise ValueError('There is no tensor in the inputs, '
+ 'cannot get dtype and device.')
+
+ # Convert to FP16
+ fp16_args = cast_tensor_type(args, dst_type=torch.half)
+ fp16_kwargs = cast_tensor_type(kwargs, dst_type=torch.half)
+ logger = get_root_logger()
+ logger.warning(f'Attempting to copy inputs of {str(func)} '
+ 'to FP16 due to CUDA OOM')
+
+ # get input tensor type, the output type will same as
+ # the first parameter type.
+ with _ignore_torch_cuda_oom():
+ output = func(*fp16_args, **fp16_kwargs)
+ output = cast_tensor_type(
+ output, src_type=torch.half, dst_type=dtype)
+ if not self.test:
+ return output
+ logger.warning('Using FP16 still meet CUDA OOM')
+
+ # Try on CPU. This will slow down the code significantly,
+ # therefore print a notice.
+ if self.to_cpu:
+ logger.warning(f'Attempting to copy inputs of {str(func)} '
+ 'to CPU due to CUDA OOM')
+ cpu_device = torch.empty(0).device
+ cpu_args = cast_tensor_type(args, dst_type=cpu_device)
+ cpu_kwargs = cast_tensor_type(kwargs, dst_type=cpu_device)
+
+ # convert outputs to GPU
+ with _ignore_torch_cuda_oom():
+ logger.warning(f'Convert outputs to GPU (device={device})')
+ output = func(*cpu_args, **cpu_kwargs)
+ output = cast_tensor_type(
+ output, src_type=cpu_device, dst_type=device)
+ return output
+
+ warnings.warn('Cannot convert output to GPU due to CUDA OOM, '
+ 'the output is now on CPU, which might cause '
+ 'errors if the output need to interact with GPU '
+ 'data in subsequent operations')
+ logger.warning('Cannot convert output to GPU due to '
+ 'CUDA OOM, the output is on CPU now.')
+
+ return func(*cpu_args, **cpu_kwargs)
+ else:
+ # may still get CUDA OOM error
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
+# To use AvoidOOM as a decorator
+AvoidCUDAOOM = AvoidOOM()
diff --git a/mmdet/utils/misc.py b/mmdet/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2017cbb94660c919a99e522393e83b42b27e46fe
--- /dev/null
+++ b/mmdet/utils/misc.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import glob
+import os
+import os.path as osp
+import warnings
+
+import mmcv
+import torch
+from mmcv.utils import TORCH_VERSION, digit_version, print_log
+
+
+def find_latest_checkpoint(path, suffix='pth'):
+ """Find the latest checkpoint from the working directory.
+
+ Args:
+ path(str): The path to find checkpoints.
+ suffix(str): File extension.
+ Defaults to pth.
+
+ Returns:
+ latest_path(str | None): File path of the latest checkpoint.
+ References:
+ .. [1] https://github.com/microsoft/SoftTeacher
+ /blob/main/ssod/utils/patch.py
+ """
+ if not osp.exists(path):
+ warnings.warn('The path of checkpoints does not exist.')
+ return None
+ if osp.exists(osp.join(path, f'latest.{suffix}')):
+ return osp.join(path, f'latest.{suffix}')
+
+ checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
+ if len(checkpoints) == 0:
+ warnings.warn('There are no checkpoints in the path.')
+ return None
+ latest = -1
+ latest_path = None
+ for checkpoint in checkpoints:
+ count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
+ if count > latest:
+ latest = count
+ latest_path = checkpoint
+ return latest_path
+
+
+def update_data_root(cfg, logger=None):
+ """Update data root according to env MMDET_DATASETS.
+
+ If set env MMDET_DATASETS, update cfg.data_root according to
+ MMDET_DATASETS. Otherwise, using cfg.data_root as default.
+
+ Args:
+ cfg (mmcv.Config): The model config need to modify
+ logger (logging.Logger | str | None): the way to print msg
+ """
+ assert isinstance(cfg, mmcv.Config), \
+ f'cfg got wrong type: {type(cfg)}, expected mmcv.Config'
+
+ if 'MMDET_DATASETS' in os.environ:
+ dst_root = os.environ['MMDET_DATASETS']
+ print_log(f'MMDET_DATASETS has been set to be {dst_root}.'
+ f'Using {dst_root} as data root.')
+ else:
+ return
+
+ assert isinstance(cfg, mmcv.Config), \
+ f'cfg got wrong type: {type(cfg)}, expected mmcv.Config'
+
+ def update(cfg, src_str, dst_str):
+ for k, v in cfg.items():
+ if isinstance(v, mmcv.ConfigDict):
+ update(cfg[k], src_str, dst_str)
+ if isinstance(v, str) and src_str in v:
+ cfg[k] = v.replace(src_str, dst_str)
+
+ update(cfg.data, cfg.data_root, dst_root)
+ cfg.data_root = dst_root
+
+
+_torch_version_div_indexing = (
+ 'parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) >= digit_version('1.8'))
+
+
+def floordiv(dividend, divisor, rounding_mode='trunc'):
+ if _torch_version_div_indexing:
+ return torch.div(dividend, divisor, rounding_mode=rounding_mode)
+ else:
+ return dividend // divisor
diff --git a/mmdet/utils/profiling.py b/mmdet/utils/profiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f53f456c72db57bfa69a8d022c92d153580209e
--- /dev/null
+++ b/mmdet/utils/profiling.py
@@ -0,0 +1,40 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import contextlib
+import sys
+import time
+
+import torch
+
+if sys.version_info >= (3, 7):
+
+ @contextlib.contextmanager
+ def profile_time(trace_name,
+ name,
+ enabled=True,
+ stream=None,
+ end_stream=None):
+ """Print time spent by CPU and GPU.
+
+ Useful as a temporary context manager to find sweet spots of code
+ suitable for async implementation.
+ """
+ if (not enabled) or not torch.cuda.is_available():
+ yield
+ return
+ stream = stream if stream else torch.cuda.current_stream()
+ end_stream = end_stream if end_stream else stream
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ stream.record_event(start)
+ try:
+ cpu_start = time.monotonic()
+ yield
+ finally:
+ cpu_end = time.monotonic()
+ end_stream.record_event(end)
+ end.synchronize()
+ cpu_time = (cpu_end - cpu_start) * 1000
+ gpu_time = start.elapsed_time(end)
+ msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms '
+ msg += f'gpu_time {gpu_time:.2f} ms stream {stream}'
+ print(msg, end_stream)
diff --git a/mmdet/utils/replace_cfg_vals.py b/mmdet/utils/replace_cfg_vals.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ca301dc937bb9c3fe376d7a047b8c0430e8ec73
--- /dev/null
+++ b/mmdet/utils/replace_cfg_vals.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import re
+
+from mmcv.utils import Config
+
+
+def replace_cfg_vals(ori_cfg):
+ """Replace the string "${key}" with the corresponding value.
+
+ Replace the "${key}" with the value of ori_cfg.key in the config. And
+ support replacing the chained ${key}. Such as, replace "${key0.key1}"
+ with the value of cfg.key0.key1. Code is modified from `vars.py
+ < https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/vars.py>`_ # noqa: E501
+
+ Args:
+ ori_cfg (mmcv.utils.config.Config):
+ The origin config with "${key}" generated from a file.
+
+ Returns:
+ updated_cfg [mmcv.utils.config.Config]:
+ The config with "${key}" replaced by the corresponding value.
+ """
+
+ def get_value(cfg, key):
+ for k in key.split('.'):
+ cfg = cfg[k]
+ return cfg
+
+ def replace_value(cfg):
+ if isinstance(cfg, dict):
+ return {key: replace_value(value) for key, value in cfg.items()}
+ elif isinstance(cfg, list):
+ return [replace_value(item) for item in cfg]
+ elif isinstance(cfg, tuple):
+ return tuple([replace_value(item) for item in cfg])
+ elif isinstance(cfg, str):
+ # the format of string cfg may be:
+ # 1) "${key}", which will be replaced with cfg.key directly
+ # 2) "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx",
+ # which will be replaced with the string of the cfg.key
+ keys = pattern_key.findall(cfg)
+ values = [get_value(ori_cfg, key[2:-1]) for key in keys]
+ if len(keys) == 1 and keys[0] == cfg:
+ # the format of string cfg is "${key}"
+ cfg = values[0]
+ else:
+ for key, value in zip(keys, values):
+ # the format of string cfg is
+ # "xxx${key}xxx" or "xxx${key1}xxx${key2}xxx"
+ assert not isinstance(value, (dict, list, tuple)), \
+ f'for the format of string cfg is ' \
+ f"'xxxxx${key}xxxxx' or 'xxx${key}xxx${key}xxx', " \
+ f"the type of the value of '${key}' " \
+ f'can not be dict, list, or tuple' \
+ f'but you input {type(value)} in {cfg}'
+ cfg = cfg.replace(key, str(value))
+ return cfg
+ else:
+ return cfg
+
+ # the pattern of string "${key}"
+ pattern_key = re.compile(r'\$\{[a-zA-Z\d_.]*\}')
+ # the type of ori_cfg._cfg_dict is mmcv.utils.config.ConfigDict
+ updated_cfg = Config(
+ replace_value(ori_cfg._cfg_dict), filename=ori_cfg.filename)
+ # replace the model with model_wrapper
+ if updated_cfg.get('model_wrapper', None) is not None:
+ updated_cfg.model = updated_cfg.model_wrapper
+ updated_cfg.pop('model_wrapper')
+ return updated_cfg
diff --git a/mmdet/utils/rfnext.py b/mmdet/utils/rfnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..568f3d3829dc1f0db3d610390b9f0817d738e4a7
--- /dev/null
+++ b/mmdet/utils/rfnext.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+try:
+ from mmcv.cnn import RFSearchHook
+except ImportError:
+ RFSearchHook = None
+
+
+def rfnext_init_model(detector, cfg):
+ """Rcecptive field search via dilation rates.
+
+ Please refer to `RF-Next: Efficient Receptive Field
+ Search for Convolutional Neural Networks
+ `_ for more details.
+
+ Args:
+ detector (nn.Module): The detector before initializing RF-Next.
+ cfg (mmcv.Config): The config for RF-Next.
+ If the RFSearchHook is defined in the cfg.custom_hooks,
+ the detector will be initialized for RF-Next.
+ """
+
+ if cfg.get('custom_hooks', None) is None:
+ return
+ custom_hook_types = [hook['type'] for hook in cfg.custom_hooks]
+ if 'RFSearchHook' not in custom_hook_types:
+ return
+
+ index = custom_hook_types.index('RFSearchHook')
+ rfsearch_cfg = cfg.custom_hooks[index]
+ assert rfsearch_cfg['type'] == 'RFSearchHook'
+
+ assert RFSearchHook is not None, 'Please install mmcv > 1.7.0'
+
+ # initlize a RFSearchHook
+ rfsearch_warp = RFSearchHook(
+ mode=rfsearch_cfg.get('mode', 'search'),
+ config=rfsearch_cfg.get('config', None),
+ rfstructure_file=rfsearch_cfg.get('rfstructure_file', None),
+ by_epoch=rfsearch_cfg.get('by_epoch', True),
+ verbose=rfsearch_cfg.get('verbose', True),
+ )
+ rfsearch_warp.init_model(detector)
+ rfsearch_cfg['rfstructure_file'] = None
diff --git a/mmdet/utils/setup_env.py b/mmdet/utils/setup_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..6637cf878f8205f1a3fc3938472e07f272bc19b8
--- /dev/null
+++ b/mmdet/utils/setup_env.py
@@ -0,0 +1,53 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import platform
+import warnings
+
+import cv2
+import torch.multiprocessing as mp
+
+
+def setup_multi_processes(cfg):
+ """Setup multi-processing environment variables."""
+ # set multi-process start method as `fork` to speed up the training
+ if platform.system() != 'Windows':
+ mp_start_method = cfg.get('mp_start_method', 'fork')
+ current_method = mp.get_start_method(allow_none=True)
+ if current_method is not None and current_method != mp_start_method:
+ warnings.warn(
+ f'Multi-processing start method `{mp_start_method}` is '
+ f'different from the previous setting `{current_method}`.'
+ f'It will be force set to `{mp_start_method}`. You can change '
+ f'this behavior by changing `mp_start_method` in your config.')
+ mp.set_start_method(mp_start_method, force=True)
+
+ # disable opencv multithreading to avoid system being overloaded
+ opencv_num_threads = cfg.get('opencv_num_threads', 0)
+ cv2.setNumThreads(opencv_num_threads)
+
+ # setup OMP threads
+ # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
+ workers_per_gpu = cfg.data.get('workers_per_gpu', 1)
+ if 'train_dataloader' in cfg.data:
+ workers_per_gpu = \
+ max(cfg.data.train_dataloader.get('workers_per_gpu', 1),
+ workers_per_gpu)
+
+ if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
+ omp_num_threads = 1
+ warnings.warn(
+ f'Setting OMP_NUM_THREADS environment variable for each process '
+ f'to be {omp_num_threads} in default, to avoid your system being '
+ f'overloaded, please further tune the variable for optimal '
+ f'performance in your application as needed.')
+ os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
+
+ # setup MKL threads
+ if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
+ mkl_num_threads = 1
+ warnings.warn(
+ f'Setting MKL_NUM_THREADS environment variable for each process '
+ f'to be {mkl_num_threads} in default, to avoid your system being '
+ f'overloaded, please further tune the variable for optimal '
+ f'performance in your application as needed.')
+ os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
diff --git a/mmdet/utils/split_batch.py b/mmdet/utils/split_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..0276fb331f23c1a7f7451faf2a8f768e616d45fd
--- /dev/null
+++ b/mmdet/utils/split_batch.py
@@ -0,0 +1,45 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
+def split_batch(img, img_metas, kwargs):
+ """Split data_batch by tags.
+
+ Code is modified from
+ # noqa: E501
+
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys, see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ kwargs (dict): Specific to concrete implementation.
+
+ Returns:
+ data_groups (dict): a dict that data_batch splited by tags,
+ such as 'sup', 'unsup_teacher', and 'unsup_student'.
+ """
+
+ # only stack img in the batch
+ def fuse_list(obj_list, obj):
+ return torch.stack(obj_list) if isinstance(obj,
+ torch.Tensor) else obj_list
+
+ # select data with tag from data_batch
+ def select_group(data_batch, current_tag):
+ group_flag = [tag == current_tag for tag in data_batch['tag']]
+ return {
+ k: fuse_list([vv for vv, gf in zip(v, group_flag) if gf], v)
+ for k, v in data_batch.items()
+ }
+
+ kwargs.update({'img': img, 'img_metas': img_metas})
+ kwargs.update({'tag': [meta['tag'] for meta in img_metas]})
+ tags = list(set(kwargs['tag']))
+ data_groups = {tag: select_group(kwargs, tag) for tag in tags}
+ for tag, group in data_groups.items():
+ group.pop('tag')
+ return data_groups
diff --git a/mmdet/utils/util_distribution.py b/mmdet/utils/util_distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba32cc9fbd8905c8d1adcae5f17c3e9216760c2e
--- /dev/null
+++ b/mmdet/utils/util_distribution.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+
+dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}
+
+ddp_factory = {'cuda': MMDistributedDataParallel}
+
+
+def build_dp(model, device='cuda', dim=0, *args, **kwargs):
+ """build DataParallel module by device type.
+
+ if device is cuda, return a MMDataParallel model; if device is mlu,
+ return a MLUDataParallel model.
+
+ Args:
+ model (:class:`nn.Module`): model to be parallelized.
+ device (str): device type, cuda, cpu or mlu. Defaults to cuda.
+ dim (int): Dimension used to scatter the data. Defaults to 0.
+
+ Returns:
+ nn.Module: the model to be parallelized.
+ """
+ if device == 'npu':
+ from mmcv.device.npu import NPUDataParallel
+ dp_factory['npu'] = NPUDataParallel
+ torch.npu.set_device(kwargs['device_ids'][0])
+ torch.npu.set_compile_mode(jit_compile=False)
+ model = model.npu()
+ elif device == 'cuda':
+ model = model.cuda(kwargs['device_ids'][0])
+ elif device == 'mlu':
+ from mmcv.device.mlu import MLUDataParallel
+ dp_factory['mlu'] = MLUDataParallel
+ model = model.mlu()
+
+ return dp_factory[device](model, dim=dim, *args, **kwargs)
+
+
+def build_ddp(model, device='cuda', *args, **kwargs):
+ """Build DistributedDataParallel module by device type.
+
+ If device is cuda, return a MMDistributedDataParallel model;
+ if device is mlu, return a MLUDistributedDataParallel model.
+
+ Args:
+ model (:class:`nn.Module`): module to be parallelized.
+ device (str): device type, mlu or cuda.
+
+ Returns:
+ :class:`nn.Module`: the module to be parallelized
+
+ References:
+ .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
+ DistributedDataParallel.html
+ """
+ assert device in ['cuda', 'mlu',
+ 'npu'], 'Only available for cuda or mlu or npu devices.'
+ if device == 'npu':
+ from mmcv.device.npu import NPUDistributedDataParallel
+ torch.npu.set_compile_mode(jit_compile=False)
+ ddp_factory['npu'] = NPUDistributedDataParallel
+ model = model.npu()
+ elif device == 'cuda':
+ model = model.cuda()
+ elif device == 'mlu':
+ from mmcv.device.mlu import MLUDistributedDataParallel
+ ddp_factory['mlu'] = MLUDistributedDataParallel
+ model = model.mlu()
+
+ return ddp_factory[device](model, *args, **kwargs)
+
+
+def is_npu_available():
+ """Returns a bool indicating if NPU is currently available."""
+ return hasattr(torch, 'npu') and torch.npu.is_available()
+
+
+def is_mlu_available():
+ """Returns a bool indicating if MLU is currently available."""
+ return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available()
+
+
+def get_device():
+ """Returns an available device, cpu, cuda or mlu."""
+ is_device_available = {
+ 'npu': is_npu_available(),
+ 'cuda': torch.cuda.is_available(),
+ 'mlu': is_mlu_available()
+ }
+ device_list = [k for k, v in is_device_available.items() if v]
+ return device_list[0] if len(device_list) >= 1 else 'cpu'
diff --git a/mmdet/utils/util_mixins.py b/mmdet/utils/util_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..b83b6617f5e4a202067e1659bf448962a2a2bc72
--- /dev/null
+++ b/mmdet/utils/util_mixins.py
@@ -0,0 +1,105 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""This module defines the :class:`NiceRepr` mixin class, which defines a
+``__repr__`` and ``__str__`` method that only depend on a custom ``__nice__``
+method, which you must define. This means you only have to overload one
+function instead of two. Furthermore, if the object defines a ``__len__``
+method, then the ``__nice__`` method defaults to something sensible, otherwise
+it is treated as abstract and raises ``NotImplementedError``.
+
+To use simply have your object inherit from :class:`NiceRepr`
+(multi-inheritance should be ok).
+
+This code was copied from the ubelt library: https://github.com/Erotemic/ubelt
+
+Example:
+ >>> # Objects that define __nice__ have a default __str__ and __repr__
+ >>> class Student(NiceRepr):
+ ... def __init__(self, name):
+ ... self.name = name
+ ... def __nice__(self):
+ ... return self.name
+ >>> s1 = Student('Alice')
+ >>> s2 = Student('Bob')
+ >>> print(f's1 = {s1}')
+ >>> print(f's2 = {s2}')
+ s1 =
+ s2 =
+
+Example:
+ >>> # Objects that define __len__ have a default __nice__
+ >>> class Group(NiceRepr):
+ ... def __init__(self, data):
+ ... self.data = data
+ ... def __len__(self):
+ ... return len(self.data)
+ >>> g = Group([1, 2, 3])
+ >>> print(f'g = {g}')
+ g =
+"""
+import warnings
+
+
+class NiceRepr:
+ """Inherit from this class and define ``__nice__`` to "nicely" print your
+ objects.
+
+ Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
+ Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
+ If the inheriting class has a ``__len__``, method then the default
+ ``__nice__`` method will return its length.
+
+ Example:
+ >>> class Foo(NiceRepr):
+ ... def __nice__(self):
+ ... return 'info'
+ >>> foo = Foo()
+ >>> assert str(foo) == ''
+ >>> assert repr(foo).startswith('>> class Bar(NiceRepr):
+ ... pass
+ >>> bar = Bar()
+ >>> import pytest
+ >>> with pytest.warns(None) as record:
+ >>> assert 'object at' in str(bar)
+ >>> assert 'object at' in repr(bar)
+
+ Example:
+ >>> class Baz(NiceRepr):
+ ... def __len__(self):
+ ... return 5
+ >>> baz = Baz()
+ >>> assert str(baz) == ''
+ """
+
+ def __nice__(self):
+ """str: a "nice" summary string describing this module"""
+ if hasattr(self, '__len__'):
+ # It is a common pattern for objects to use __len__ in __nice__
+ # As a convenience we define a default __nice__ for these objects
+ return str(len(self))
+ else:
+ # In all other cases force the subclass to overload __nice__
+ raise NotImplementedError(
+ f'Define the __nice__ method for {self.__class__!r}')
+
+ def __repr__(self):
+ """str: the string of the module"""
+ try:
+ nice = self.__nice__()
+ classname = self.__class__.__name__
+ return f'<{classname}({nice}) at {hex(id(self))}>'
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
+
+ def __str__(self):
+ """str: the string of the module"""
+ try:
+ classname = self.__class__.__name__
+ nice = self.__nice__()
+ return f'<{classname}({nice})>'
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
diff --git a/mmdet/utils/util_random.py b/mmdet/utils/util_random.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc1ecb6c03b026156c9947cb6d356a822448be0f
--- /dev/null
+++ b/mmdet/utils/util_random.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Helpers for random number generators."""
+import numpy as np
+
+
+def ensure_rng(rng=None):
+ """Coerces input into a random number generator.
+
+ If the input is None, then a global random state is returned.
+
+ If the input is a numeric value, then that is used as a seed to construct a
+ random state. Otherwise the input is returned as-is.
+
+ Adapted from [1]_.
+
+ Args:
+ rng (int | numpy.random.RandomState | None):
+ if None, then defaults to the global rng. Otherwise this can be an
+ integer or a RandomState class
+ Returns:
+ (numpy.random.RandomState) : rng -
+ a numpy random number generator
+
+ References:
+ .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
+ """
+
+ if rng is None:
+ rng = np.random.mtrand._rand
+ elif isinstance(rng, int):
+ rng = np.random.RandomState(rng)
+ else:
+ rng = rng
+ return rng
diff --git a/mmdet/version.py b/mmdet/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..fecd645024d90770d008d94fe62c532189a5f6b2
--- /dev/null
+++ b/mmdet/version.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+__version__ = '2.28.2'
+short_version = __version__
+
+
+def parse_version_info(version_str):
+ version_info = []
+ for x in version_str.split('.'):
+ if x.isdigit():
+ version_info.append(int(x))
+ elif x.find('rc') != -1:
+ patch_version = x.split('rc')
+ version_info.append(int(patch_version[0]))
+ version_info.append(f'rc{patch_version[1]}')
+ return tuple(version_info)
+
+
+version_info = parse_version_info(__version__)
diff --git a/projects/configs/_base_/datasets/coco_detection.py b/projects/configs/_base_/datasets/coco_detection.py
new file mode 100644
index 0000000000000000000000000000000000000000..149f590bb45fa65c29fd4c005e4a237d7dd2e117
--- /dev/null
+++ b/projects/configs/_base_/datasets/coco_detection.py
@@ -0,0 +1,49 @@
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_train2017.json',
+ img_prefix=data_root + 'train2017/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
+evaluation = dict(interval=1, metric='bbox')
diff --git a/projects/configs/_base_/datasets/coco_instance.py b/projects/configs/_base_/datasets/coco_instance.py
new file mode 100644
index 0000000000000000000000000000000000000000..9901a858414465d19d8ec6ced316b460166176b4
--- /dev/null
+++ b/projects/configs/_base_/datasets/coco_instance.py
@@ -0,0 +1,49 @@
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_train2017.json',
+ img_prefix=data_root + 'train2017/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
+evaluation = dict(metric=['bbox', 'segm'])
diff --git a/projects/configs/_base_/datasets/coco_panoptic.py b/projects/configs/_base_/datasets/coco_panoptic.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbade7c0ac20141806b93f0ea7b5ca26d748246e
--- /dev/null
+++ b/projects/configs/_base_/datasets/coco_panoptic.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'CocoPanopticDataset'
+data_root = 'data/coco/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='LoadPanopticAnnotations',
+ with_bbox=True,
+ with_mask=True,
+ with_seg=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='SegRescale', scale_factor=1 / 4),
+ dict(type='DefaultFormatBundle'),
+ dict(
+ type='Collect',
+ keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/panoptic_train2017.json',
+ img_prefix=data_root + 'train2017/',
+ seg_prefix=data_root + 'annotations/panoptic_train2017/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/panoptic_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ seg_prefix=data_root + 'annotations/panoptic_val2017/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/panoptic_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ seg_prefix=data_root + 'annotations/panoptic_val2017/',
+ pipeline=test_pipeline))
+evaluation = dict(interval=1, metric=['PQ'])
diff --git a/projects/configs/_base_/default_runtime.py b/projects/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6a9563b85145661c1c72f22e0fd1c01e8ed75e1
--- /dev/null
+++ b/projects/configs/_base_/default_runtime.py
@@ -0,0 +1,30 @@
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
+# yapf:enable
+custom_hooks = [dict(type='NumClassCheckHook')]
+
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+
+# disable opencv multithreading to avoid system being overloaded
+opencv_num_threads = 0
+# set multi-process start method as `fork` to speed up the training
+mp_start_method = 'fork'
+
+# Default setting for scaling LR automatically
+# - `enable` means enable scaling LR automatically
+# or not by default.
+# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
+auto_scale_lr = dict(enable=False, base_batch_size=16)
+
+# placeholder
+total_epochs = 1
diff --git a/projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py b/projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e6dc0aadb858b5063dd11d73d115e70ca3664c0
--- /dev/null
+++ b/projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py
@@ -0,0 +1,130 @@
+_base_ = [
+ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
+]
+
+plugin = True
+plugin_dir = 'projects/instance_segment_anything/'
+
+model = dict(
+ type='DetWrapperInstanceSAM',
+ det_wrapper_type='focalnet_dino',
+ det_wrapper_cfg=dict(num_classes=91,
+ param_dict_type='default',
+ ddetr_lr_param=False,
+ onecyclelr=False,
+ modelname='dino',
+ frozen_weights=None,
+ backbone='focalnet_L_384_22k_fl4',
+ focal_levels=4,
+ focal_windows=3,
+ use_checkpoint=False,
+ dilation=False,
+ position_embedding='sine',
+ pe_temperatureH=20,
+ pe_temperatureW=20,
+ return_interm_indices=[0, 1, 2, 3],
+ backbone_freeze_keywords=None,
+ enc_layers=6,
+ dec_layers=6,
+ unic_layers=0,
+ pre_norm=False,
+ dim_feedforward=2048,
+ hidden_dim=256,
+ dropout=0.0,
+ nheads=8,
+ num_queries=900,
+ query_dim=4,
+ num_patterns=0,
+ pdetr3_bbox_embed_diff_each_layer=False,
+ pdetr3_refHW=-1,
+ random_refpoints_xy=False,
+ fix_refpoints_hw=-1,
+ dabdetr_yolo_like_anchor_update=False,
+ dabdetr_deformable_encoder=False,
+ dabdetr_deformable_decoder=False,
+ use_deformable_box_attn=False,
+ box_attn_type='roi_align',
+ dec_layer_number=None,
+ num_feature_levels=5,
+ enc_n_points=4,
+ dec_n_points=4,
+ decoder_layer_noise=False,
+ dln_xy_noise=0.2,
+ dln_hw_noise=0.2,
+ add_channel_attention=False,
+ add_pos_value=False,
+ two_stage_type='standard',
+ two_stage_pat_embed=0,
+ two_stage_add_query_num=0,
+ two_stage_bbox_embed_share=False,
+ two_stage_class_embed_share=False,
+ two_stage_learn_wh=False,
+ two_stage_default_hw=0.05,
+ two_stage_keep_all_tokens=False,
+ num_select=300,
+ transformer_activation='relu',
+ batch_norm_type='FrozenBatchNorm2d',
+ masks=False,
+ aux_loss=True,
+ set_cost_class=2.0,
+ set_cost_bbox=5.0,
+ set_cost_giou=2.0,
+ no_interm_box_loss=False,
+ focal_alpha=0.25,
+ decoder_sa_type='sa', # ['sa', 'ca_label', 'ca_content']
+ matcher_type='HungarianMatcher', # or SimpleMinsumMatcher
+ decoder_module_seq=['sa', 'ca', 'ffn'],
+ nms_iou_threshold=-1,
+ dec_pred_bbox_embed_share=True,
+ dec_pred_class_embed_share=True,
+ use_dn=False,
+ dn_number=100,
+ dn_box_noise_scale=0.4,
+ dn_label_noise_ratio=0.5,
+ embed_init_tgt=True,
+ dn_labelbook_size=91,
+ match_unstable_error=True,
+ # for ema
+ use_ema=False,
+ ema_decay=0.9997,
+ ema_epoch=0,
+ use_detached_boxes_dec_out=False),
+ det_model_ckpt='ckpt/focalnet_l_dino.pth',
+ num_classes=80,
+ model_type='vit_b',
+ sam_checkpoint='ckpt/sam_vit_b_01ec64.pth',
+ use_sam_iou=True,
+)
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# test_pipeline, NOTE the Pad's size_divisor is different from the default
+# setting (size_divisor=32). While there is little effect on the performance
+# whether we use the default setting or use size_divisor=1.
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=1),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+
+data = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
diff --git a/projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py b/projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca22867656cb4e6d3a5d588d839bce0becb492e5
--- /dev/null
+++ b/projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py
@@ -0,0 +1,130 @@
+_base_ = [
+ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
+]
+
+plugin = True
+plugin_dir = 'projects/instance_segment_anything/'
+
+model = dict(
+ type='DetWrapperInstanceSAM',
+ det_wrapper_type='focalnet_dino',
+ det_wrapper_cfg=dict(num_classes=91,
+ param_dict_type='default',
+ ddetr_lr_param=False,
+ onecyclelr=False,
+ modelname='dino',
+ frozen_weights=None,
+ backbone='focalnet_L_384_22k_fl4',
+ focal_levels=4,
+ focal_windows=3,
+ use_checkpoint=False,
+ dilation=False,
+ position_embedding='sine',
+ pe_temperatureH=20,
+ pe_temperatureW=20,
+ return_interm_indices=[0, 1, 2, 3],
+ backbone_freeze_keywords=None,
+ enc_layers=6,
+ dec_layers=6,
+ unic_layers=0,
+ pre_norm=False,
+ dim_feedforward=2048,
+ hidden_dim=256,
+ dropout=0.0,
+ nheads=8,
+ num_queries=900,
+ query_dim=4,
+ num_patterns=0,
+ pdetr3_bbox_embed_diff_each_layer=False,
+ pdetr3_refHW=-1,
+ random_refpoints_xy=False,
+ fix_refpoints_hw=-1,
+ dabdetr_yolo_like_anchor_update=False,
+ dabdetr_deformable_encoder=False,
+ dabdetr_deformable_decoder=False,
+ use_deformable_box_attn=False,
+ box_attn_type='roi_align',
+ dec_layer_number=None,
+ num_feature_levels=5,
+ enc_n_points=4,
+ dec_n_points=4,
+ decoder_layer_noise=False,
+ dln_xy_noise=0.2,
+ dln_hw_noise=0.2,
+ add_channel_attention=False,
+ add_pos_value=False,
+ two_stage_type='standard',
+ two_stage_pat_embed=0,
+ two_stage_add_query_num=0,
+ two_stage_bbox_embed_share=False,
+ two_stage_class_embed_share=False,
+ two_stage_learn_wh=False,
+ two_stage_default_hw=0.05,
+ two_stage_keep_all_tokens=False,
+ num_select=300,
+ transformer_activation='relu',
+ batch_norm_type='FrozenBatchNorm2d',
+ masks=False,
+ aux_loss=True,
+ set_cost_class=2.0,
+ set_cost_bbox=5.0,
+ set_cost_giou=2.0,
+ no_interm_box_loss=False,
+ focal_alpha=0.25,
+ decoder_sa_type='sa', # ['sa', 'ca_label', 'ca_content']
+ matcher_type='HungarianMatcher', # or SimpleMinsumMatcher
+ decoder_module_seq=['sa', 'ca', 'ffn'],
+ nms_iou_threshold=-1,
+ dec_pred_bbox_embed_share=True,
+ dec_pred_class_embed_share=True,
+ use_dn=False,
+ dn_number=100,
+ dn_box_noise_scale=0.4,
+ dn_label_noise_ratio=0.5,
+ embed_init_tgt=True,
+ dn_labelbook_size=91,
+ match_unstable_error=True,
+ # for ema
+ use_ema=False,
+ ema_decay=0.9997,
+ ema_epoch=0,
+ use_detached_boxes_dec_out=False),
+ det_model_ckpt='ckpt/focalnet_l_dino.pth',
+ num_classes=80,
+ model_type='vit_h',
+ sam_checkpoint='ckpt/sam_vit_h_4b8939.pth',
+ use_sam_iou=True,
+)
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# test_pipeline, NOTE the Pad's size_divisor is different from the default
+# setting (size_divisor=32). While there is little effect on the performance
+# whether we use the default setting or use size_divisor=1.
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=1),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+
+data = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
diff --git a/projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py b/projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py
new file mode 100644
index 0000000000000000000000000000000000000000..b51f0253fa420bde78a395c61cabd368df9a42be
--- /dev/null
+++ b/projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py
@@ -0,0 +1,130 @@
+_base_ = [
+ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
+]
+
+plugin = True
+plugin_dir = 'projects/instance_segment_anything/'
+
+model = dict(
+ type='DetWrapperInstanceSAM',
+ det_wrapper_type='focalnet_dino',
+ det_wrapper_cfg=dict(num_classes=91,
+ param_dict_type='default',
+ ddetr_lr_param=False,
+ onecyclelr=False,
+ modelname='dino',
+ frozen_weights=None,
+ backbone='focalnet_L_384_22k_fl4',
+ focal_levels=4,
+ focal_windows=3,
+ use_checkpoint=False,
+ dilation=False,
+ position_embedding='sine',
+ pe_temperatureH=20,
+ pe_temperatureW=20,
+ return_interm_indices=[0, 1, 2, 3],
+ backbone_freeze_keywords=None,
+ enc_layers=6,
+ dec_layers=6,
+ unic_layers=0,
+ pre_norm=False,
+ dim_feedforward=2048,
+ hidden_dim=256,
+ dropout=0.0,
+ nheads=8,
+ num_queries=900,
+ query_dim=4,
+ num_patterns=0,
+ pdetr3_bbox_embed_diff_each_layer=False,
+ pdetr3_refHW=-1,
+ random_refpoints_xy=False,
+ fix_refpoints_hw=-1,
+ dabdetr_yolo_like_anchor_update=False,
+ dabdetr_deformable_encoder=False,
+ dabdetr_deformable_decoder=False,
+ use_deformable_box_attn=False,
+ box_attn_type='roi_align',
+ dec_layer_number=None,
+ num_feature_levels=5,
+ enc_n_points=4,
+ dec_n_points=4,
+ decoder_layer_noise=False,
+ dln_xy_noise=0.2,
+ dln_hw_noise=0.2,
+ add_channel_attention=False,
+ add_pos_value=False,
+ two_stage_type='standard',
+ two_stage_pat_embed=0,
+ two_stage_add_query_num=0,
+ two_stage_bbox_embed_share=False,
+ two_stage_class_embed_share=False,
+ two_stage_learn_wh=False,
+ two_stage_default_hw=0.05,
+ two_stage_keep_all_tokens=False,
+ num_select=300,
+ transformer_activation='relu',
+ batch_norm_type='FrozenBatchNorm2d',
+ masks=False,
+ aux_loss=True,
+ set_cost_class=2.0,
+ set_cost_bbox=5.0,
+ set_cost_giou=2.0,
+ no_interm_box_loss=False,
+ focal_alpha=0.25,
+ decoder_sa_type='sa', # ['sa', 'ca_label', 'ca_content']
+ matcher_type='HungarianMatcher', # or SimpleMinsumMatcher
+ decoder_module_seq=['sa', 'ca', 'ffn'],
+ nms_iou_threshold=-1,
+ dec_pred_bbox_embed_share=True,
+ dec_pred_class_embed_share=True,
+ use_dn=False,
+ dn_number=100,
+ dn_box_noise_scale=0.4,
+ dn_label_noise_ratio=0.5,
+ embed_init_tgt=True,
+ dn_labelbook_size=91,
+ match_unstable_error=True,
+ # for ema
+ use_ema=False,
+ ema_decay=0.9997,
+ ema_epoch=0,
+ use_detached_boxes_dec_out=False),
+ det_model_ckpt='ckpt/focalnet_l_dino.pth',
+ num_classes=80,
+ model_type='vit_l',
+ sam_checkpoint='ckpt/sam_vit_l_0b3195.pth',
+ use_sam_iou=True,
+)
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# test_pipeline, NOTE the Pad's size_divisor is different from the default
+# setting (size_divisor=32). While there is little effect on the performance
+# whether we use the default setting or use size_divisor=1.
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=1),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+
+data = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
diff --git a/projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py b/projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py
new file mode 100644
index 0000000000000000000000000000000000000000..d315fc4a84cce44182342578ace77e6a80adf31a
--- /dev/null
+++ b/projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py
@@ -0,0 +1,82 @@
+_base_ = [
+ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
+]
+
+plugin = True
+plugin_dir = 'projects/instance_segment_anything/'
+
+model = dict(
+ type='DetWrapperInstanceSAM',
+ det_wrapper_type='hdetr',
+ det_wrapper_cfg=dict(aux_loss=False,
+ backbone='swin_large',
+ num_classes=91,
+ cache_mode=False,
+ dec_layers=6,
+ dec_n_points=4,
+ dilation=False,
+ dim_feedforward=2048,
+ drop_path_rate=0.5,
+ dropout=0.0,
+ enc_layers=6,
+ enc_n_points=4,
+ focal_alpha=0.25,
+ frozen_weights=None,
+ hidden_dim=256,
+ k_one2many=6,
+ lambda_one2many=1.0,
+ look_forward_twice=True,
+ masks=False,
+ mixed_selection=True,
+ nheads=8,
+ num_feature_levels=4,
+ num_queries_one2many=1500,
+ num_queries_one2one=900,
+ position_embedding='sine',
+ position_embedding_scale=6.283185307179586,
+ remove_difficult=False,
+ topk=300,
+ two_stage=True,
+ use_checkpoint=False,
+ use_fp16=False,
+ use_wandb=False,
+ with_box_refine=True),
+ det_model_ckpt='ckpt/swin_l_hdetr.pth',
+ num_classes=80,
+ model_type='vit_b',
+ sam_checkpoint='ckpt/sam_vit_b_01ec64.pth',
+ use_sam_iou=True,
+)
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# test_pipeline, NOTE the Pad's size_divisor is different from the default
+# setting (size_divisor=32). While there is little effect on the performance
+# whether we use the default setting or use size_divisor=1.
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=1),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+
+data = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
diff --git a/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py b/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py
new file mode 100644
index 0000000000000000000000000000000000000000..56c13cf2207d8b1671d72e5c71f50a861e1d5d53
--- /dev/null
+++ b/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py
@@ -0,0 +1,82 @@
+_base_ = [
+ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
+]
+
+plugin = True
+plugin_dir = 'projects/instance_segment_anything/'
+
+model = dict(
+ type='DetWrapperInstanceSAM',
+ det_wrapper_type='hdetr',
+ det_wrapper_cfg=dict(aux_loss=False,
+ backbone='swin_large',
+ num_classes=91,
+ cache_mode=False,
+ dec_layers=6,
+ dec_n_points=4,
+ dilation=False,
+ dim_feedforward=2048,
+ drop_path_rate=0.5,
+ dropout=0.0,
+ enc_layers=6,
+ enc_n_points=4,
+ focal_alpha=0.25,
+ frozen_weights=None,
+ hidden_dim=256,
+ k_one2many=6,
+ lambda_one2many=1.0,
+ look_forward_twice=True,
+ masks=False,
+ mixed_selection=True,
+ nheads=8,
+ num_feature_levels=4,
+ num_queries_one2many=1500,
+ num_queries_one2one=900,
+ position_embedding='sine',
+ position_embedding_scale=6.283185307179586,
+ remove_difficult=False,
+ topk=300,
+ two_stage=True,
+ use_checkpoint=False,
+ use_fp16=False,
+ use_wandb=False,
+ with_box_refine=True),
+ det_model_ckpt='ckpt/swin_l_hdetr.pth',
+ num_classes=80,
+ model_type='vit_h',
+ sam_checkpoint='ckpt/sam_vit_h_4b8939.pth',
+ use_sam_iou=True,
+)
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# test_pipeline, NOTE the Pad's size_divisor is different from the default
+# setting (size_divisor=32). While there is little effect on the performance
+# whether we use the default setting or use size_divisor=1.
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=1),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+
+data = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
diff --git a/projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py b/projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py
new file mode 100644
index 0000000000000000000000000000000000000000..de52f962e4a9a0bf2ed615e5c3f0c2bb69366150
--- /dev/null
+++ b/projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py
@@ -0,0 +1,82 @@
+_base_ = [
+ '../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
+]
+
+plugin = True
+plugin_dir = 'projects/instance_segment_anything/'
+
+model = dict(
+ type='DetWrapperInstanceSAM',
+ det_wrapper_type='hdetr',
+ det_wrapper_cfg=dict(aux_loss=False,
+ backbone='swin_large',
+ num_classes=91,
+ cache_mode=False,
+ dec_layers=6,
+ dec_n_points=4,
+ dilation=False,
+ dim_feedforward=2048,
+ drop_path_rate=0.5,
+ dropout=0.0,
+ enc_layers=6,
+ enc_n_points=4,
+ focal_alpha=0.25,
+ frozen_weights=None,
+ hidden_dim=256,
+ k_one2many=6,
+ lambda_one2many=1.0,
+ look_forward_twice=True,
+ masks=False,
+ mixed_selection=True,
+ nheads=8,
+ num_feature_levels=4,
+ num_queries_one2many=1500,
+ num_queries_one2one=900,
+ position_embedding='sine',
+ position_embedding_scale=6.283185307179586,
+ remove_difficult=False,
+ topk=300,
+ two_stage=True,
+ use_checkpoint=False,
+ use_fp16=False,
+ use_wandb=False,
+ with_box_refine=True),
+ det_model_ckpt='ckpt/swin_l_hdetr.pth',
+ num_classes=80,
+ model_type='vit_l',
+ sam_checkpoint='ckpt/sam_vit_l_0b3195.pth',
+ use_sam_iou=True,
+)
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+# test_pipeline, NOTE the Pad's size_divisor is different from the default
+# setting (size_divisor=32). While there is little effect on the performance
+# whether we use the default setting or use size_divisor=1.
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=1),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+
+dataset_type = 'CocoDataset'
+data_root = 'data/coco/'
+
+data = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'annotations/instances_val2017.json',
+ img_prefix=data_root + 'val2017/',
+ pipeline=test_pipeline))
diff --git a/projects/instance_segment_anything/__init__.py b/projects/instance_segment_anything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9052293afe309cf6bdf32b8e02059acdada04765
--- /dev/null
+++ b/projects/instance_segment_anything/__init__.py
@@ -0,0 +1 @@
+from .models.det_wrapper_instance_sam import DetWrapperInstanceSAM
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/det_wrapper_instance_sam.py b/projects/instance_segment_anything/models/det_wrapper_instance_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..34651ea8948038fbe06de3c9982d7de1b441dad3
--- /dev/null
+++ b/projects/instance_segment_anything/models/det_wrapper_instance_sam.py
@@ -0,0 +1,129 @@
+import cv2
+import torch
+import torch.nn as nn
+from mmcv import Config
+from mmcv.runner import load_checkpoint
+
+from mmdet.core import bbox2result
+from mmdet.models import DETECTORS, BaseDetector
+from projects.instance_segment_anything.models.segment_anything import sam_model_registry, SamPredictor
+from .focalnet_dino.focalnet_dino_wrapper import FocalNetDINOWrapper
+from .hdetr.hdetr_wrapper import HDetrWrapper
+
+
+@DETECTORS.register_module()
+class DetWrapperInstanceSAM(BaseDetector):
+ wrapper_dict = {'hdetr': HDetrWrapper,
+ 'focalnet_dino': FocalNetDINOWrapper}
+
+ def __init__(self,
+ det_wrapper_type='hdetr',
+ det_wrapper_cfg=None,
+ det_model_ckpt=None,
+ num_classes=80,
+
+ model_type='vit_b',
+ sam_checkpoint=None,
+ use_sam_iou=True,
+
+ init_cfg=None,
+ train_cfg=None,
+ test_cfg=None):
+ super(DetWrapperInstanceSAM, self).__init__(init_cfg)
+ self.learnable_placeholder = nn.Embedding(1, 1)
+ det_wrapper_cfg = Config(det_wrapper_cfg)
+ assert det_wrapper_type in self.wrapper_dict.keys()
+ self.det_model = self.wrapper_dict[det_wrapper_type](args=det_wrapper_cfg)
+ if det_model_ckpt is not None:
+ load_checkpoint(self.det_model.model,
+ filename=det_model_ckpt,
+ map_location='cpu')
+
+ self.num_classes = num_classes
+
+ # Segment Anything
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
+ _ = sam.to(device=self.learnable_placeholder.weight.device)
+ self.predictor = SamPredictor(sam)
+ self.use_sam_iou = use_sam_iou
+
+ def init_weights(self):
+ pass
+
+ def simple_test(self, img, img_metas, ori_img, rescale=True):
+ """Test without augmentation.
+ Args:
+ imgs (Tensor): A batch of images.
+ img_metas (list[dict]): List of image information.
+ """
+ assert rescale
+ assert len(img_metas) == 1
+ # results: List[dict(scores, labels, boxes)]
+ results = self.det_model.simple_test(img,
+ img_metas,
+ rescale)
+
+ # Tensor(n,4), xyxy, ori image scale
+ output_boxes = results[0]['boxes']
+
+ self.predictor.set_image(ori_img)
+
+ transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])
+
+ # mask_pred: n,1,h,w
+ # sam_score: n, 1
+ mask_pred, sam_score, _ = self.predictor.predict_torch(
+ point_coords=None,
+ point_labels=None,
+ boxes=transformed_boxes,
+ multimask_output=False,
+ return_logits=True,
+ )
+ # Tensor(n,h,w), raw mask pred
+ mask_pred = mask_pred.squeeze(1)
+ sam_score = sam_score.squeeze(-1)
+
+ # Tensor(n,)
+ label_pred = results[0]['labels']
+
+ score_pred = results[0]['scores']
+
+ # mask_pred: Tensor(n,h,w)
+ # label_pred: Tensor(n,)
+ # score_pred: Tensor(n,)
+ # sam_score: Tensor(n,)
+ mask_pred_binary = (mask_pred > self.predictor.model.mask_threshold).float()
+ if self.use_sam_iou:
+ det_scores = score_pred * sam_score
+ else:
+ # n
+ mask_scores_per_image = (mask_pred * mask_pred_binary).flatten(1).sum(1) / (
+ mask_pred_binary.flatten(1).sum(1) + 1e-6)
+ det_scores = score_pred * mask_scores_per_image
+ # det_scores = score_pred
+ mask_pred_binary = mask_pred_binary.bool()
+ bboxes = torch.cat([output_boxes, det_scores[:, None]], dim=-1)
+ bbox_results = bbox2result(bboxes, label_pred, self.num_classes)
+ mask_results = [[] for _ in range(self.num_classes)]
+ for j, label in enumerate(label_pred):
+ mask = mask_pred_binary[j].detach().cpu().numpy()
+ mask_results[label].append(mask)
+ output_results = [(bbox_results, mask_results)]
+
+ return output_results
+
+ # not implemented:
+ def aug_test(self, imgs, img_metas, **kwargs):
+ raise NotImplementedError
+
+ def onnx_export(self, img, img_metas):
+ raise NotImplementedError
+
+ async def async_simple_test(self, img, img_metas, **kwargs):
+ raise NotImplementedError
+
+ def forward_train(self, imgs, img_metas, **kwargs):
+ raise NotImplementedError
+
+ def extract_feat(self, imgs):
+ raise NotImplementedError
diff --git a/projects/instance_segment_anything/models/focalnet_dino/focalnet_dino_wrapper.py b/projects/instance_segment_anything/models/focalnet_dino/focalnet_dino_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4e50100c205a8e09c65a605301100e86d477714
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/focalnet_dino_wrapper.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn.functional as F
+from mmcv.runner import BaseModule
+
+from .models import build_model
+from .models.dino.util.misc import NestedTensor, inverse_sigmoid
+
+
+class FocalNetDINOWrapper(BaseModule):
+ def __init__(self,
+ args=None,
+ init_cfg=None):
+ super(FocalNetDINOWrapper, self).__init__(init_cfg)
+ model, _, box_postprocessor = build_model(args)
+ self.model = model
+ self.box_postprocessor = box_postprocessor
+
+ self.cls_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28,
+ 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+ 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+ 82, 84, 85, 86, 87, 88, 89, 90]
+
+ def forward(self,
+ img,
+ img_metas):
+ """Forward function for training mode.
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ """
+ input_img_h, input_img_w = img_metas[0]["batch_input_shape"]
+ batch_size = img.size(0)
+ img_masks = img.new_ones((batch_size, input_img_h, input_img_w),
+ dtype=torch.bool)
+ for img_id in range(batch_size):
+ img_h, img_w, _ = img_metas[img_id]["img_shape"]
+ img_masks[img_id, :img_h, :img_w] = False
+ samples = NestedTensor(tensors=img, mask=img_masks)
+ features, poss = self.model.backbone(samples)
+
+ srcs = []
+ masks = []
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ srcs.append(self.model.input_proj[l](src))
+ masks.append(mask)
+ assert mask is not None
+ if self.model.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.model.num_feature_levels):
+ if l == _len_srcs:
+ src = self.model.input_proj[l](features[-1].tensors)
+ else:
+ src = self.model.input_proj[l](srcs[-1])
+ m = samples.mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
+ pos_l = self.model.backbone[1](NestedTensor(src, mask)).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ poss.append(pos_l)
+
+ input_query_bbox = input_query_label = attn_mask = dn_meta = None
+
+ hs, reference, hs_enc, ref_enc, init_box_proposal = self.model.transformer(srcs, masks,
+ input_query_bbox, poss,
+ input_query_label,
+ attn_mask)
+ # In case num object=0
+ hs[0] += self.model.label_enc.weight[0, 0] * 0.0
+
+ # deformable-detr-like anchor update
+ # reference_before_sigmoid = inverse_sigmoid(reference[:-1]) # n_dec, bs, nq, 4
+ outputs_coord_list = []
+ for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1],
+ self.model.bbox_embed,
+ hs)):
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
+ outputs_coord_list.append(layer_outputs_unsig)
+ outputs_coord_list = torch.stack(outputs_coord_list)
+
+ outputs_class = torch.stack([layer_cls_embed(layer_hs) for
+ layer_cls_embed, layer_hs in zip(self.model.class_embed,
+ hs)])
+ sampled_logits = outputs_class[-1][:, :, self.cls_index]
+ out = {'pred_logits': sampled_logits, 'pred_boxes': outputs_coord_list[-1]}
+
+ return out
+
+ def simple_test(self, img, img_metas, rescale=False):
+ # out: dict
+ out = self(img, img_metas)
+ if rescale:
+ ori_target_sizes = [meta_info['ori_shape'][:2] for meta_info in img_metas]
+ else:
+ ori_target_sizes = [meta_info['img_shape'][:2] for meta_info in img_metas]
+ ori_target_sizes = (out['pred_logits']).new_tensor(ori_target_sizes, dtype=torch.int64)
+ # results: List[dict(scores, labels, boxes)]
+ results = self.box_postprocessor(out, ori_target_sizes)
+
+ return results
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/__init__.py b/projects/instance_segment_anything/models/focalnet_dino/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f61d118927a7e97509b1b228046f7bc6712be8f
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/__init__.py
@@ -0,0 +1,10 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+from .dino import build_dino
+
+def build_model(args):
+ return build_dino(args)
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/__init__.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d61ba5d68aa2e0bc27a06a482f59f4fcc78cb0c2
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/__init__.py
@@ -0,0 +1,10 @@
+# ------------------------------------------------------------------------
+# Conditional DETR
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copied from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+from .dino import build_dino
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/attention.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..de4bbbd88262c7aeadae5f4fce7594b7115fcd61
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/attention.py
@@ -0,0 +1,393 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from codes in torch.nn
+# ------------------------------------------------------------------------
+
+"""
+MultiheadAttention that support query, key, and value to have different dimensions.
+Query, key, and value projections are removed.
+
+Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873
+and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
+"""
+
+import copy
+from typing import Optional, List
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+import warnings
+from typing import Tuple, Optional
+
+import torch
+from torch import Tensor
+from torch.nn.modules.linear import Linear
+from torch.nn.init import xavier_uniform_
+from torch.nn.init import constant_
+from torch.nn.init import xavier_normal_
+from torch.nn.parameter import Parameter
+from torch.nn.modules.module import Module
+from torch.nn import functional as F
+
+import warnings
+import math
+
+from torch._C import _infer_size, _add_docstr
+from torch.nn import _reduction as _Reduction
+from torch.nn.modules import utils
+from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default
+from torch.nn import grad
+from torch import _VF
+from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple
+try:
+ from torch.overrides import has_torch_function, handle_torch_function
+except:
+ from torch._overrides import has_torch_function, handle_torch_function
+Tensor = torch.Tensor
+
+from torch.nn.functional import linear, pad, softmax, dropout
+
+class MultiheadAttention(Module):
+ r"""Allows the model to jointly attend to information
+ from different representation subspaces.
+ See reference: Attention Is All You Need
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
+ Args:
+ embed_dim: total dimension of the model.
+ num_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+ bias: add bias as module parameter. Default: True.
+ add_bias_kv: add bias to the key and value sequences at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ kdim: total number of features in key. Default: None.
+ vdim: total number of features in value. Default: None.
+ Note: if kdim and vdim are None, they will be set to embed_dim such that
+ query, key, and value have the same number of features.
+ Examples::
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+ """
+ bias_k: Optional[torch.Tensor]
+ bias_v: Optional[torch.Tensor]
+
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+
+ vdim = vdim if vdim is not None else embed_dim
+ self.out_proj = Linear(vdim , vdim)
+
+ self.in_proj_bias = None
+ self.in_proj_weight = None
+ self.bias_k = self.bias_v = None
+ self.q_proj_weight = None
+ self.k_proj_weight = None
+ self.v_proj_weight = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ constant_(self.out_proj.bias, 0.)
+
+ def __setstate__(self, state):
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
+ if '_qkv_same_embed_dim' not in state:
+ state['_qkv_same_embed_dim'] = True
+
+ super(MultiheadAttention, self).__setstate__(state)
+
+ def forward(self, query, key, value, key_padding_mask=None,
+ need_weights=True, attn_mask=None):
+ # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. When given a binary mask and a value is True,
+ the corresponding value on the attention layer will be ignored. When given
+ a byte mask and a value is non-zero, the corresponding value on the attention
+ layer will be ignored
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+ Shape:
+ - Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ 3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length,
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+ - Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+ if not self._qkv_same_embed_dim:
+ return multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight, out_dim=self.vdim)
+ else:
+ return multi_head_attention_forward(
+ query, key, value, self.embed_dim, self.num_heads,
+ self.in_proj_weight, self.in_proj_bias,
+ self.bias_k, self.bias_v, self.add_zero_attn,
+ self.dropout, self.out_proj.weight, self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
+ attn_mask=attn_mask, out_dim=self.vdim)
+
+
+def multi_head_attention_forward(query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ embed_dim_to_check: int,
+ num_heads: int,
+ in_proj_weight: Tensor,
+ in_proj_bias: Tensor,
+ bias_k: Optional[Tensor],
+ bias_v: Optional[Tensor],
+ add_zero_attn: bool,
+ dropout_p: float,
+ out_proj_weight: Tensor,
+ out_proj_bias: Tensor,
+ training: bool = True,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ use_separate_proj_weight: bool = False,
+ q_proj_weight: Optional[Tensor] = None,
+ k_proj_weight: Optional[Tensor] = None,
+ v_proj_weight: Optional[Tensor] = None,
+ static_k: Optional[Tensor] = None,
+ static_v: Optional[Tensor] = None,
+ out_dim: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query, key, value: map a query and a set of key-value pairs to an output.
+ See "Attention Is All You Need" for more details.
+ embed_dim_to_check: total dimension of the model.
+ num_heads: parallel attention heads.
+ in_proj_weight, in_proj_bias: input projection weight and bias.
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
+ add_zero_attn: add a new batch of zeros to the key and
+ value sequences at dim=1.
+ dropout_p: probability of an element to be zeroed.
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
+ training: apply dropout if is ``True``.
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ need_weights: output attn_output_weights.
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
+ and value in different forms. If false, in_proj_weight will be used, which is
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
+ static_k, static_v: static key and value used for attention operators.
+ Shape:
+ Inputs:
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
+ the embedding dimension.
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
+ will be unchanged. If a BoolTensor is provided, the positions with the
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
+ is provided, it will be added to the attention weight.
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
+ Outputs:
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
+ E is the embedding dimension.
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
+ L is the target sequence length, S is the source sequence length.
+ """
+ if not torch.jit.is_scripting():
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
+ out_proj_weight, out_proj_bias)
+ if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
+ return handle_torch_function(
+ multi_head_attention_forward, tens_ops, query, key, value,
+ embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
+ bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
+ out_proj_bias, training=training, key_padding_mask=key_padding_mask,
+ need_weights=need_weights, attn_mask=attn_mask,
+ use_separate_proj_weight=use_separate_proj_weight,
+ q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
+ v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == embed_dim_to_check
+ # allow MHA to have different sizes for the feature dimension
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
+
+ head_dim = embed_dim // num_heads
+ v_head_dim = out_dim // num_heads
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
+ scaling = float(head_dim) ** -0.5
+
+ q = query * scaling
+ k = key
+ v = value
+
+ if attn_mask is not None:
+ assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
+ attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
+ 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
+ if attn_mask.dtype == torch.uint8:
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
+ attn_mask = attn_mask.to(torch.bool)
+
+ if attn_mask.dim() == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
+ raise RuntimeError('The size of the 2D attn_mask is not correct.')
+ elif attn_mask.dim() == 3:
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
+ raise RuntimeError('The size of the 3D attn_mask is not correct.')
+ else:
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
+ # attn_mask's dim is 3 now.
+
+ # convert ByteTensor key_padding_mask to bool
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
+ key_padding_mask = key_padding_mask.to(torch.bool)
+
+ if bias_k is not None and bias_v is not None:
+ if static_k is None and static_v is None:
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+ else:
+ assert static_k is None, "bias cannot be added to static key."
+ assert static_v is None, "bias cannot be added to static value."
+ else:
+ assert bias_k is None
+ assert bias_v is None
+
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1)
+
+ if static_k is not None:
+ assert static_k.size(0) == bsz * num_heads
+ assert static_k.size(2) == head_dim
+ k = static_k
+
+ if static_v is not None:
+ assert static_v.size(0) == bsz * num_heads
+ assert static_v.size(2) == v_head_dim
+ v = static_v
+
+ src_len = k.size(1)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
+ if attn_mask is not None:
+ attn_mask = pad(attn_mask, (0, 1))
+ if key_padding_mask is not None:
+ key_padding_mask = pad(key_padding_mask, (0, 1))
+
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_output_weights.masked_fill_(attn_mask, float('-inf'))
+ else:
+ attn_output_weights += attn_mask
+
+
+ if key_padding_mask is not None:
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float('-inf'),
+ )
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
+
+ # attn_output_weights = softmax(
+ # attn_output_weights, dim=-1)
+ attn_output_weights = softmax(
+ attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0], dim=-1)
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
+
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim)
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
+
+ if need_weights:
+ # average attention weights over heads
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
+ else:
+ return attn_output, None
+
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/backbone.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce856f9d9a0e3616b45a3404cc39689f424273b4
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/backbone.py
@@ -0,0 +1,263 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copied from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+"""
+Backbone modules.
+"""
+from collections import OrderedDict
+import os
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+from typing import Dict, List
+
+
+from .util.misc import NestedTensor, clean_state_dict, is_main_process
+
+from .position_encoding import build_position_encoding
+from .convnext import build_convnext
+from .swin_transformer import build_swin_transformer
+from .focal import build_focalnet
+
+class FrozenBatchNorm2d(torch.nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
+ without which any other models than torchvision.models.resnet[18,34,50,101]
+ produce nans.
+ """
+
+ def __init__(self, n):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it fuser-friendly
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ eps = 1e-5
+ scale = w * (rv + eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class BackboneBase(nn.Module):
+
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_indices: list):
+ super().__init__()
+ for name, parameter in backbone.named_parameters():
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+ parameter.requires_grad_(False)
+
+ return_layers = {}
+ for idx, layer_index in enumerate(return_interm_indices):
+ return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)})
+
+ # if len:
+ # if use_stage1_feature:
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+ # else:
+ # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
+ # else:
+ # return_layers = {'layer4': "0"}
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ self.num_channels = num_channels
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self.body(tensor_list.tensors)
+ out: Dict[str, NestedTensor] = {}
+ for name, x in xs.items():
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ out[name] = NestedTensor(x, mask)
+ # import ipdb; ipdb.set_trace()
+ return out
+
+
+class Backbone(BackboneBase):
+ """ResNet backbone with frozen BatchNorm."""
+ def __init__(self, name: str,
+ train_backbone: bool,
+ dilation: bool,
+ return_interm_indices:list,
+ batch_norm=FrozenBatchNorm2d,
+ ):
+ if name in ['resnet18', 'resnet34', 'resnet50', 'resnet101']:
+ backbone = getattr(torchvision.models, name)(
+ replace_stride_with_dilation=[False, False, dilation],
+ pretrained=is_main_process(), norm_layer=batch_norm)
+ else:
+ raise NotImplementedError("Why you can get here with name {}".format(name))
+ # num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
+ assert name not in ('resnet18', 'resnet34'), "Only resnet50 and resnet101 are available."
+ assert return_interm_indices in [[0,1,2,3], [1,2,3], [3]]
+ num_channels_all = [256, 512, 1024, 2048]
+ num_channels = num_channels_all[4-len(return_interm_indices):]
+ super().__init__(backbone, train_backbone, num_channels, return_interm_indices)
+
+
+class Joiner(nn.Sequential):
+ def __init__(self, backbone, position_embedding):
+ super().__init__(backbone, position_embedding)
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self[0](tensor_list)
+ out: List[NestedTensor] = []
+ pos = []
+ for name, x in xs.items():
+ out.append(x)
+ # position encoding
+ pos.append(self[1](x).to(x.tensors.dtype))
+
+ return out, pos
+
+
+def build_backbone(args):
+ """
+ Useful args:
+ - backbone: backbone name
+ - lr_backbone:
+ - dilation
+ - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
+ - backbone_freeze_keywords:
+ - use_checkpoint: for swin only for now
+
+ """
+ position_embedding = build_position_encoding(args)
+ train_backbone = False
+ # if not train_backbone:
+ # raise ValueError("Please set lr_backbone > 0")
+ return_interm_indices = args.return_interm_indices
+ assert return_interm_indices in [[0,1,2,3], [1,2,3], [3]]
+ backbone_freeze_keywords = args.backbone_freeze_keywords
+ use_checkpoint = getattr(args, 'use_checkpoint', False)
+
+ if args.backbone in ['resnet50', 'resnet101']:
+ backbone = Backbone(args.backbone, train_backbone, args.dilation,
+ return_interm_indices,
+ batch_norm=FrozenBatchNorm2d)
+ bb_num_channels = backbone.num_channels
+ elif args.backbone in ['swin_T_224_1k', 'swin_B_224_22k', 'swin_B_384_22k', 'swin_L_224_22k', 'swin_L_384_22k']:
+ pretrain_img_size = int(args.backbone.split('_')[-2])
+ backbone = build_swin_transformer(args.backbone, \
+ pretrain_img_size=pretrain_img_size, \
+ out_indices=tuple(return_interm_indices), \
+ dilation=args.dilation, use_checkpoint=use_checkpoint)
+
+ # freeze some layers
+ if backbone_freeze_keywords is not None:
+ for name, parameter in backbone.named_parameters():
+ for keyword in backbone_freeze_keywords:
+ if keyword in name:
+ parameter.requires_grad_(False)
+ break
+
+ pretrained_dir = args.backbone_dir
+ PTDICT = {
+ 'swin_T_224_1k': 'swin_tiny_patch4_window7_224.pth',
+ 'swin_B_384_22k': 'swin_base_patch4_window12_384.pth',
+ 'swin_L_384_22k': 'swin_large_patch4_window12_384_22k.pth',
+ }
+ # pretrainedpath = os.path.join(pretrained_dir, PTDICT[args.backbone])
+ # checkpoint = torch.load(pretrainedpath, map_location='cpu')['model']
+ from collections import OrderedDict
+ def key_select_function(keyname):
+ if 'head' in keyname:
+ return False
+ if args.dilation and 'layers.3' in keyname:
+ return False
+ return True
+ _tmp_st = OrderedDict({k:v for k, v in clean_state_dict(checkpoint).items() if key_select_function(k)})
+ _tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False)
+ print(str(_tmp_st_output))
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices):]
+ elif args.backbone in [
+ 'focalnet_L_384_22k',
+ 'focalnet_L_384_22k_fl4',
+ 'focalnet_XL_384_22k',
+ 'focalnet_XL_384_22k_fl4',
+ 'focalnet_H_224_22k',
+ 'focalnet_H_224_22k_fl4',
+ ]:
+ # added by Jianwei
+ backbone = build_focalnet(args.backbone, \
+ focal_levels=args.focal_levels, \
+ focal_windows=args.focal_windows, \
+ out_indices=tuple(return_interm_indices), \
+ use_checkpoint=use_checkpoint)
+
+ # freeze some layers
+ if backbone_freeze_keywords is not None:
+ for name, parameter in backbone.named_parameters():
+ for keyword in backbone_freeze_keywords:
+ if keyword in name:
+ parameter.requires_grad_(False)
+ break
+
+ pretrained_dir = '/'
+ PTDICT = {
+ 'focalnet_L_384_22k': 'focalnet_large_lrf_384.pth',
+ 'focalnet_L_384_22k_fl4': 'focalnet_large_lrf_384_fl4.pth',
+ 'focalnet_XL_384_22k': 'focalnet_xlarge_lrf_384.pth',
+ 'focalnet_XL_384_22k_fl4': 'focalnet_xlarge_lrf_384_fl4.pth',
+ 'focalnet_H_224_22k': 'focalnet_huge_lrf_224.pth',
+ 'focalnet_H_224_22k_fl4': 'focalnet_huge_lrf_224_fl4.pth',
+ }
+ # pretrainedpath = os.path.join(pretrained_dir, PTDICT[args.backbone])
+ # checkpoint = torch.load(pretrainedpath, map_location='cpu')['model']
+ from collections import OrderedDict
+ def key_select_function(keyname):
+ if 'head' in keyname:
+ return False
+ if args.dilation and 'layers.3' in keyname:
+ return False
+ return True
+ # _tmp_st = OrderedDict({k:v for k, v in clean_state_dict(checkpoint).items() if key_select_function(k)})
+ # _tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False)
+ # print(str(_tmp_st_output))
+ bb_num_channels = backbone.num_features[4 - len(return_interm_indices):]
+ elif args.backbone in ['convnext_xlarge_22k']:
+ backbone = build_convnext(modelname=args.backbone, pretrained=True, out_indices=tuple(return_interm_indices),backbone_dir=args.backbone_dir)
+ bb_num_channels = backbone.dims[4 - len(return_interm_indices):]
+ else:
+ raise NotImplementedError("Unknown backbone {}".format(args.backbone))
+
+
+ assert len(bb_num_channels) == len(return_interm_indices), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
+
+
+ model = Joiner(backbone, position_embedding)
+ model.num_channels = bb_num_channels
+ assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
+ return model
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/convnext.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..76eeeb2a15d9379968db53fc59fbf0f9a996f0bb
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/convnext.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from functools import partial
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import trunc_normal_, DropPath
+
+from .util.misc import NestedTensor
+# from timm.models.registry import register_model
+
+class Block(nn.Module):
+ r""" ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
+ requires_grad=True) if layer_scale_init_value > 0 else None
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+class ConvNeXt(nn.Module):
+ r""" ConvNeXt
+ A PyTorch impl of : `A ConvNet for the 2020s` -
+ https://arxiv.org/pdf/2201.03545.pdf
+
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+ """
+ def __init__(self, in_chans=3, num_classes=1000,
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
+ layer_scale_init_value=1e-6, head_init_scale=1.,
+ out_indices=[0, 1, 2, 3]
+ ):
+ super().__init__()
+ self.dims = dims
+
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
+ stem = nn.Sequential(
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
+ )
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+ cur = 0
+ for i in range(4):
+ stage = nn.Sequential(
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.out_indices = out_indices
+
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
+ for i_layer in range(4):
+ layer = norm_layer(dims[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ # self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
+ # self.head = nn.Linear(dims[-1], num_classes)
+
+ # self.apply(self._init_weights)
+ # self.head.weight.data.mul_(head_init_scale)
+ # self.head.bias.data.mul_(head_init_scale)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward_features(self, x):
+ outs = []
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer(x)
+ outs.append(x_out)
+ # return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
+ return tuple(outs)
+
+ # def forward(self, x):
+ # x = self.forward_features(x)
+ # return x
+
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ outs = self.forward_features(x)
+
+ # collect for nesttensors
+ outs_dict = {}
+ for idx, out_i in enumerate(outs):
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
+ outs_dict[idx] = NestedTensor(out_i, mask)
+
+ return outs_dict
+
+class LayerNorm(nn.Module):
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+model_urls = {
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
+}
+
+# @register_model
+# def convnext_tiny(pretrained=False, **kwargs):
+# model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
+# if pretrained:
+# url = model_urls['convnext_tiny_1k']
+# checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
+# model.load_state_dict(checkpoint["model"])
+# return model
+
+# @register_model
+# def convnext_small(pretrained=False, **kwargs):
+# model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
+# if pretrained:
+# url = model_urls['convnext_small_1k']
+# checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
+# model.load_state_dict(checkpoint["model"])
+# return model
+
+# @register_model
+# def convnext_base(pretrained=False, in_22k=False, **kwargs):
+# model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
+# if pretrained:
+# url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
+# checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
+# model.load_state_dict(checkpoint["model"])
+# return model
+
+# @register_model
+# def convnext_large(pretrained=False, in_22k=False, **kwargs):
+# model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
+# if pretrained:
+# url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
+# checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
+# model.load_state_dict(checkpoint["model"])
+# return model
+
+# @register_model
+# def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
+# model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
+# if pretrained:
+# url = model_urls['convnext_xlarge_22k'] if in_22k else model_urls['convnext_xlarge_1k']
+# checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
+# model.load_state_dict(checkpoint["model"])
+# return model
+
+def build_convnext(modelname, pretrained,backbone_dir=None, **kw):
+ assert modelname in ['convnext_xlarge_22k']
+
+ model_para_dict = {
+ 'convnext_xlarge_22k': dict(
+ depths=[3, 3, 27, 3],
+ dims=[256, 512, 1024, 2048],
+ ),
+ }
+ kw_cgf = model_para_dict[modelname]
+ kw_cgf.update(kw)
+ model = ConvNeXt(**kw_cgf)
+ if pretrained:
+ url = model_urls[modelname]
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, model_dir=backbone_dir, map_location="cpu", check_hash=True)
+ _tmp_st_output = model.load_state_dict(checkpoint["model"], strict=False)
+ print(str(_tmp_st_output))
+
+ return model
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/deformable_transformer.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/deformable_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a197eea1bc4215fad266c835a5b4fdaf509ecd05
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/deformable_transformer.py
@@ -0,0 +1,1104 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR Transformer class.
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+import math, random
+import copy
+from typing import Optional
+
+from .util.misc import inverse_sigmoid
+
+import torch
+from torch import nn, Tensor
+
+from .utils import gen_encoder_output_proposals, MLP,_get_activation_fn, gen_sineembed_for_position
+from projects.instance_segment_anything.ops.modules import MSDeformAttn
+
+class DeformableTransformer(nn.Module):
+
+ def __init__(self, d_model=256, nhead=8,
+ num_queries=300,
+ num_encoder_layers=6,
+ num_unicoder_layers=0,
+ num_decoder_layers=6,
+ dim_feedforward=2048, dropout=0.0,
+ activation="relu", normalize_before=False,
+ return_intermediate_dec=False, query_dim=4,
+ num_patterns=0,
+ modulate_hw_attn=False,
+ # for deformable encoder
+ deformable_encoder=False,
+ deformable_decoder=False,
+ num_feature_levels=1,
+ enc_n_points=4,
+ dec_n_points=4,
+ use_deformable_box_attn=False,
+ box_attn_type='roi_align',
+ # init query
+ learnable_tgt_init=False,
+ decoder_query_perturber=None,
+ add_channel_attention=False,
+ add_pos_value=False,
+ random_refpoints_xy=False,
+ # two stage
+ two_stage_type='no', # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
+ two_stage_pat_embed=0,
+ two_stage_add_query_num=0,
+ two_stage_learn_wh=False,
+ two_stage_keep_all_tokens=False,
+ # evo of #anchors
+ dec_layer_number=None,
+ rm_enc_query_scale=True,
+ rm_dec_query_scale=True,
+ rm_self_attn_layers=None,
+ key_aware_type=None,
+ # layer share
+ layer_share_type=None,
+ # for detach
+ rm_detach=None,
+ decoder_sa_type='ca',
+ module_seq=['sa', 'ca', 'ffn'],
+ # for dn
+ embed_init_tgt=False,
+
+ use_detached_boxes_dec_out=False,
+ ):
+ super().__init__()
+ self.num_feature_levels = num_feature_levels
+ self.num_encoder_layers = num_encoder_layers
+ self.num_unicoder_layers = num_unicoder_layers
+ self.num_decoder_layers = num_decoder_layers
+ self.deformable_encoder = deformable_encoder
+ self.deformable_decoder = deformable_decoder
+ self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
+ self.num_queries = num_queries
+ self.random_refpoints_xy = random_refpoints_xy
+ self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
+ assert query_dim == 4
+
+ if num_feature_levels > 1:
+ assert deformable_encoder, "only support deformable_encoder for num_feature_levels > 1"
+ if use_deformable_box_attn:
+ assert deformable_encoder or deformable_encoder
+
+ assert layer_share_type in [None, 'encoder', 'decoder', 'both']
+ if layer_share_type in ['encoder', 'both']:
+ enc_layer_share = True
+ else:
+ enc_layer_share = False
+ if layer_share_type in ['decoder', 'both']:
+ dec_layer_share = True
+ else:
+ dec_layer_share = False
+ assert layer_share_type is None
+
+ self.decoder_sa_type = decoder_sa_type
+ assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
+
+ # choose encoder layer type
+ if deformable_encoder:
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, enc_n_points, add_channel_attention=add_channel_attention, use_deformable_box_attn=use_deformable_box_attn, box_attn_type=box_attn_type)
+ else:
+ raise NotImplementedError
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+ self.encoder = TransformerEncoder(
+ encoder_layer, num_encoder_layers,
+ encoder_norm, d_model=d_model,
+ num_queries=num_queries,
+ deformable_encoder=deformable_encoder,
+ enc_layer_share=enc_layer_share,
+ two_stage_type=two_stage_type
+ )
+
+ # choose decoder layer type
+ if deformable_decoder:
+ decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, dec_n_points, use_deformable_box_attn=use_deformable_box_attn, box_attn_type=box_attn_type,
+ key_aware_type=key_aware_type,
+ decoder_sa_type=decoder_sa_type,
+ module_seq=module_seq)
+
+ else:
+ raise NotImplementedError
+
+ decoder_norm = nn.LayerNorm(d_model)
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
+ return_intermediate=return_intermediate_dec,
+ d_model=d_model, query_dim=query_dim,
+ modulate_hw_attn=modulate_hw_attn,
+ num_feature_levels=num_feature_levels,
+ deformable_decoder=deformable_decoder,
+ decoder_query_perturber=decoder_query_perturber,
+ dec_layer_number=dec_layer_number, rm_dec_query_scale=rm_dec_query_scale,
+ dec_layer_share=dec_layer_share,
+ use_detached_boxes_dec_out=use_detached_boxes_dec_out
+ )
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.dec_layers = num_decoder_layers
+ self.num_queries = num_queries # useful for single stage model only
+ self.num_patterns = num_patterns
+ if not isinstance(num_patterns, int):
+ Warning("num_patterns should be int but {}".format(type(num_patterns)))
+ self.num_patterns = 0
+
+ if num_feature_levels > 1:
+ if self.num_encoder_layers > 0:
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+ else:
+ self.level_embed = None
+
+ self.learnable_tgt_init = learnable_tgt_init
+ assert learnable_tgt_init, "why not learnable_tgt_init"
+ self.embed_init_tgt = embed_init_tgt
+ if (two_stage_type != 'no' and embed_init_tgt) or (two_stage_type == 'no'):
+ self.tgt_embed = nn.Embedding(self.num_queries, d_model)
+ nn.init.normal_(self.tgt_embed.weight.data)
+ else:
+ self.tgt_embed = None
+
+ # for two stage
+ self.two_stage_type = two_stage_type
+ self.two_stage_pat_embed = two_stage_pat_embed
+ self.two_stage_add_query_num = two_stage_add_query_num
+ self.two_stage_learn_wh = two_stage_learn_wh
+ assert two_stage_type in ['no', 'standard'], "unknown param {} of two_stage_type".format(two_stage_type)
+ if two_stage_type =='standard':
+ # anchor selection at the output of encoder
+ self.enc_output = nn.Linear(d_model, d_model)
+ self.enc_output_norm = nn.LayerNorm(d_model)
+
+ if two_stage_pat_embed > 0:
+ self.pat_embed_for_2stage = nn.Parameter(torch.Tensor(two_stage_pat_embed, d_model))
+ nn.init.normal_(self.pat_embed_for_2stage)
+
+ if two_stage_add_query_num > 0:
+ self.tgt_embed = nn.Embedding(self.two_stage_add_query_num, d_model)
+
+ if two_stage_learn_wh:
+ # import ipdb; ipdb.set_trace()
+ self.two_stage_wh_embedding = nn.Embedding(1, 2)
+ else:
+ self.two_stage_wh_embedding = None
+
+ if two_stage_type == 'no':
+ self.init_ref_points(num_queries) # init self.refpoint_embed
+
+
+ self.enc_out_class_embed = None
+ self.enc_out_bbox_embed = None
+
+ # evolution of anchors
+ self.dec_layer_number = dec_layer_number
+ if dec_layer_number is not None:
+ if self.two_stage_type != 'no' or num_patterns == 0:
+ assert dec_layer_number[0] == num_queries, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries})"
+ else:
+ assert dec_layer_number[0] == num_queries * num_patterns, f"dec_layer_number[0]({dec_layer_number[0]}) != num_queries({num_queries}) * num_patterns({num_patterns})"
+
+ self._reset_parameters()
+
+ self.rm_self_attn_layers = rm_self_attn_layers
+ if rm_self_attn_layers is not None:
+ # assert len(rm_self_attn_layers) == num_decoder_layers
+ print("Removing the self-attn in {} decoder layers".format(rm_self_attn_layers))
+ for lid, dec_layer in enumerate(self.decoder.layers):
+ if lid in rm_self_attn_layers:
+ dec_layer.rm_self_attn_modules()
+
+ self.rm_detach = rm_detach
+ if self.rm_detach:
+ assert isinstance(rm_detach, list)
+ assert any([i in ['enc_ref', 'enc_tgt', 'dec'] for i in rm_detach])
+ self.decoder.rm_detach = rm_detach
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+ if self.num_feature_levels > 1 and self.level_embed is not None:
+ nn.init.normal_(self.level_embed)
+
+ if self.two_stage_learn_wh:
+ nn.init.constant_(self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05)))
+
+
+ def get_valid_ratio(self, mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def init_ref_points(self, use_num_queries):
+ self.refpoint_embed = nn.Embedding(use_num_queries, 4)
+
+ if self.random_refpoints_xy:
+ # import ipdb; ipdb.set_trace()
+ self.refpoint_embed.weight.data[:, :2].uniform_(0,1)
+ self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
+ self.refpoint_embed.weight.data[:, :2].requires_grad = False
+
+
+
+ def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None):
+ """
+ Input:
+ - srcs: List of multi features [bs, ci, hi, wi]
+ - masks: List of multi masks [bs, hi, wi]
+ - refpoint_embed: [bs, num_dn, 4]. None in infer
+ - pos_embeds: List of multi pos embeds [bs, ci, hi, wi]
+ - tgt: [bs, num_dn, d_model]. None in infer
+
+ """
+ # if self.two_stage_type != 'no' and self.two_stage_add_query_num == 0:
+ # assert refpoint_embed is None
+
+ # prepare input for encoder
+ src_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
+ mask = mask.flatten(1) # bs, hw
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
+ if self.num_feature_levels > 1 and self.level_embed is not None:
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ else:
+ lvl_pos_embed = pos_embed
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ src_flatten.append(src)
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
+ mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+ # two stage
+ enc_topk_proposals = enc_refpoint_embed = None
+
+ #########################################################
+ # Begin Encoder
+ #########################################################
+ memory, enc_intermediate_output, enc_intermediate_refpoints = self.encoder(
+ src_flatten,
+ pos=lvl_pos_embed_flatten,
+ level_start_index=level_start_index,
+ spatial_shapes=spatial_shapes,
+ valid_ratios=valid_ratios,
+ key_padding_mask=mask_flatten,
+ ref_token_index=enc_topk_proposals, # bs, nq
+ ref_token_coord=enc_refpoint_embed, # bs, nq, 4
+ )
+ #########################################################
+ # End Encoder
+ # - memory: bs, \sum{hw}, c
+ # - mask_flatten: bs, \sum{hw}
+ # - lvl_pos_embed_flatten: bs, \sum{hw}, c
+ # - enc_intermediate_output: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
+ # - enc_intermediate_refpoints: None or (nenc+1, bs, nq, c) or (nenc, bs, nq, c)
+ #########################################################
+
+
+ if self.two_stage_type =='standard':
+ if self.two_stage_learn_wh:
+ input_hw = self.two_stage_wh_embedding.weight[0]
+ else:
+ input_hw = None
+ output_memory, output_proposals = gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes, input_hw)
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
+
+ if self.two_stage_pat_embed > 0:
+ bs, nhw, _ = output_memory.shape
+ # output_memory: bs, n, 256; self.pat_embed_for_2stage: k, 256
+ output_memory = output_memory.repeat(1, self.two_stage_pat_embed, 1)
+ _pats = self.pat_embed_for_2stage.repeat_interleave(nhw, 0)
+ output_memory = output_memory + _pats
+ output_proposals = output_proposals.repeat(1, self.two_stage_pat_embed, 1)
+
+ if self.two_stage_add_query_num > 0:
+ assert refpoint_embed is not None
+ output_memory = torch.cat((output_memory, tgt), dim=1)
+ output_proposals = torch.cat((output_proposals, refpoint_embed), dim=1)
+
+ enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
+ enc_outputs_coord_unselected = self.enc_out_bbox_embed(output_memory) + output_proposals # (bs, \sum{hw}, 4) unsigmoid
+ topk = self.num_queries
+ topk_proposals = torch.topk(enc_outputs_class_unselected.max(-1)[0], topk, dim=1)[1] # bs, nq
+
+
+ # gather boxes
+ refpoint_embed_undetach = torch.gather(enc_outputs_coord_unselected, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid
+ refpoint_embed_ = refpoint_embed_undetach.detach()
+ init_box_proposal = torch.gather(output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)).sigmoid() # sigmoid
+
+ # gather tgt
+ tgt_undetach = torch.gather(output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model))
+ if self.embed_init_tgt:
+ tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
+ else:
+ tgt_ = tgt_undetach.detach()
+
+ if refpoint_embed is not None:
+ refpoint_embed=torch.cat([refpoint_embed,refpoint_embed_],dim=1)
+ tgt=torch.cat([tgt,tgt_],dim=1)
+ else:
+ refpoint_embed,tgt=refpoint_embed_,tgt_
+
+ elif self.two_stage_type == 'no':
+ tgt_ = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
+ refpoint_embed_ = self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, 4
+
+ if refpoint_embed is not None:
+ refpoint_embed=torch.cat([refpoint_embed,refpoint_embed_],dim=1)
+ tgt=torch.cat([tgt,tgt_],dim=1)
+ else:
+ refpoint_embed,tgt=refpoint_embed_,tgt_
+
+ if self.num_patterns > 0:
+ tgt_embed = tgt.repeat(1, self.num_patterns, 1)
+ refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1)
+ tgt_pat = self.patterns.weight[None, :, :].repeat_interleave(self.num_queries, 1) # 1, n_q*n_pat, d_model
+ tgt = tgt_embed + tgt_pat
+
+ init_box_proposal = refpoint_embed_.sigmoid()
+
+ else:
+ raise NotImplementedError("unknown two_stage_type {}".format(self.two_stage_type))
+ #########################################################
+ # End preparing tgt
+ # - tgt: bs, NQ, d_model
+ # - refpoint_embed(unsigmoid): bs, NQ, d_model
+ #########################################################
+
+
+ #########################################################
+ # Begin Decoder
+ #########################################################
+ hs, references = self.decoder(
+ tgt=tgt.transpose(0, 1),
+ memory=memory.transpose(0, 1),
+ memory_key_padding_mask=mask_flatten,
+ pos=lvl_pos_embed_flatten.transpose(0, 1),
+ refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
+ level_start_index=level_start_index,
+ spatial_shapes=spatial_shapes,
+ valid_ratios=valid_ratios,tgt_mask=attn_mask)
+ #########################################################
+ # End Decoder
+ # hs: n_dec, bs, nq, d_model
+ # references: n_dec+1, bs, nq, query_dim
+ #########################################################
+
+
+ #########################################################
+ # Begin postprocess
+ #########################################################
+ if self.two_stage_type == 'standard':
+ if self.two_stage_keep_all_tokens:
+ hs_enc = output_memory.unsqueeze(0)
+ ref_enc = enc_outputs_coord_unselected.unsqueeze(0)
+ init_box_proposal = output_proposals
+ # import ipdb; ipdb.set_trace()
+ else:
+ hs_enc = tgt_undetach.unsqueeze(0)
+ ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0)
+ else:
+ hs_enc = ref_enc = None
+ #########################################################
+ # End postprocess
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or (n_enc, bs, nq, d_model) or None
+ # ref_enc: (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or (n_enc, bs, nq, d_model) or None
+ #########################################################
+
+ return hs, references, hs_enc, ref_enc, init_box_proposal
+ # hs: (n_dec, bs, nq, d_model)
+ # references: sigmoid coordinates. (n_dec+1, bs, bq, 4)
+ # hs_enc: (n_enc+1, bs, nq, d_model) or (1, bs, nq, d_model) or None
+ # ref_enc: sigmoid coordinates. \
+ # (n_enc+1, bs, nq, query_dim) or (1, bs, nq, query_dim) or None
+
+class TransformerEncoder(nn.Module):
+
+ def __init__(self,
+ encoder_layer, num_layers, norm=None, d_model=256,
+ num_queries=300,
+ deformable_encoder=False,
+ enc_layer_share=False, enc_layer_dropout_prob=None,
+ two_stage_type='no', # ['no', 'standard', 'early', 'combine', 'enceachlayer', 'enclayer1']
+ ):
+ super().__init__()
+ # prepare layers
+ if num_layers > 0:
+ self.layers = _get_clones(encoder_layer, num_layers, layer_share=enc_layer_share)
+ else:
+ self.layers = []
+ del encoder_layer
+
+ self.query_scale = None
+ self.num_queries = num_queries
+ self.deformable_encoder = deformable_encoder
+ self.num_layers = num_layers
+ self.norm = norm
+ self.d_model = d_model
+
+ self.enc_layer_dropout_prob = enc_layer_dropout_prob
+ if enc_layer_dropout_prob is not None:
+ assert isinstance(enc_layer_dropout_prob, list)
+ assert len(enc_layer_dropout_prob) == num_layers
+ for i in enc_layer_dropout_prob:
+ assert 0.0 <= i <= 1.0
+
+ self.two_stage_type = two_stage_type
+ if two_stage_type in ['enceachlayer', 'enclayer1']:
+ _proj_layer = nn.Linear(d_model, d_model)
+ _norm_layer = nn.LayerNorm(d_model)
+ if two_stage_type == 'enclayer1':
+ self.enc_norm = nn.ModuleList([_norm_layer])
+ self.enc_proj = nn.ModuleList([_proj_layer])
+ else:
+ self.enc_norm = nn.ModuleList([copy.deepcopy(_norm_layer) for i in range(num_layers - 1) ])
+ self.enc_proj = nn.ModuleList([copy.deepcopy(_proj_layer) for i in range(num_layers - 1) ])
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def forward(self,
+ src: Tensor,
+ pos: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ valid_ratios: Tensor,
+ key_padding_mask: Tensor,
+ ref_token_index: Optional[Tensor]=None,
+ ref_token_coord: Optional[Tensor]=None
+ ):
+ """
+ Input:
+ - src: [bs, sum(hi*wi), 256]
+ - pos: pos embed for src. [bs, sum(hi*wi), 256]
+ - spatial_shapes: h,w of each level [num_level, 2]
+ - level_start_index: [num_level] start point of level in sum(hi*wi).
+ - valid_ratios: [bs, num_level, 2]
+ - key_padding_mask: [bs, sum(hi*wi)]
+
+ - ref_token_index: bs, nq
+ - ref_token_coord: bs, nq, 4
+ Intermedia:
+ - reference_points: [bs, sum(hi*wi), num_level, 2]
+ Outpus:
+ - output: [bs, sum(hi*wi), 256]
+ """
+ if self.two_stage_type in ['no', 'standard', 'enceachlayer', 'enclayer1']:
+ assert ref_token_index is None
+
+ output = src
+
+ # preparation and reshape
+ if self.num_layers > 0:
+ if self.deformable_encoder:
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
+ # import ipdb; ipdb.set_trace()
+
+ intermediate_output = []
+ intermediate_ref = []
+ if ref_token_index is not None:
+ out_i = torch.gather(output, 1, ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
+ intermediate_output.append(out_i)
+ intermediate_ref.append(ref_token_coord)
+
+
+ # intermediate_coord = []
+ # main process
+ for layer_id, layer in enumerate(self.layers):
+ # main process
+ dropflag = False
+ if self.enc_layer_dropout_prob is not None:
+ prob = random.random()
+ if prob < self.enc_layer_dropout_prob[layer_id]:
+ dropflag = True
+
+ if not dropflag:
+ if self.deformable_encoder:
+ output = layer(src=output, pos=pos, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, key_padding_mask=key_padding_mask)
+ else:
+ output = layer(src=output.transpose(0, 1), pos=pos.transpose(0, 1), key_padding_mask=key_padding_mask).transpose(0, 1)
+
+ if ((layer_id == 0 and self.two_stage_type in ['enceachlayer', 'enclayer1']) \
+ or (self.two_stage_type == 'enceachlayer')) \
+ and (layer_id != self.num_layers - 1):
+ output_memory, output_proposals = gen_encoder_output_proposals(output, key_padding_mask, spatial_shapes)
+ output_memory = self.enc_norm[layer_id](self.enc_proj[layer_id](output_memory))
+
+ # gather boxes
+ topk = self.num_queries
+ enc_outputs_class = self.class_embed[layer_id](output_memory)
+ ref_token_index = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # bs, nq
+ ref_token_coord = torch.gather(output_proposals, 1, ref_token_index.unsqueeze(-1).repeat(1, 1, 4))
+
+ output = output_memory
+
+ # aux loss
+ if (layer_id != self.num_layers - 1) and ref_token_index is not None:
+ out_i = torch.gather(output, 1, ref_token_index.unsqueeze(-1).repeat(1, 1, self.d_model))
+ intermediate_output.append(out_i)
+ intermediate_ref.append(ref_token_coord)
+
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ if ref_token_index is not None:
+ intermediate_output = torch.stack(intermediate_output) # n_enc/n_enc-1, bs, \sum{hw}, d_model
+ intermediate_ref = torch.stack(intermediate_ref)
+ else:
+ intermediate_output = intermediate_ref = None
+
+ return output, intermediate_output, intermediate_ref
+
+class TransformerDecoder(nn.Module):
+
+ def __init__(self, decoder_layer, num_layers, norm=None,
+ return_intermediate=False,
+ d_model=256, query_dim=4,
+ modulate_hw_attn=False,
+ num_feature_levels=1,
+ deformable_decoder=False,
+ decoder_query_perturber=None,
+ dec_layer_number=None, # number of queries each layer in decoder
+ rm_dec_query_scale=False,
+ dec_layer_share=False,
+ dec_layer_dropout_prob=None,
+ use_detached_boxes_dec_out=False
+ ):
+ super().__init__()
+ if num_layers > 0:
+ self.layers = _get_clones(decoder_layer, num_layers, layer_share=dec_layer_share)
+ else:
+ self.layers = []
+ self.num_layers = num_layers
+ self.norm = norm
+ self.return_intermediate = return_intermediate
+ assert return_intermediate, "support return_intermediate only"
+ self.query_dim = query_dim
+ assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim)
+ self.num_feature_levels = num_feature_levels
+ self.use_detached_boxes_dec_out = use_detached_boxes_dec_out
+
+
+ self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
+ if not deformable_decoder:
+ self.query_pos_sine_scale = MLP(d_model, d_model, d_model, 2)
+ else:
+ self.query_pos_sine_scale = None
+
+ if rm_dec_query_scale:
+ self.query_scale = None
+ else:
+ raise NotImplementedError
+ self.query_scale = MLP(d_model, d_model, d_model, 2)
+ self.bbox_embed = None
+ self.class_embed = None
+
+ self.d_model = d_model
+ self.modulate_hw_attn = modulate_hw_attn
+ self.deformable_decoder = deformable_decoder
+
+ if not deformable_decoder and modulate_hw_attn:
+ self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
+ else:
+ self.ref_anchor_head = None
+
+ self.decoder_query_perturber = decoder_query_perturber
+ self.box_pred_damping = None
+
+ self.dec_layer_number = dec_layer_number
+ if dec_layer_number is not None:
+ assert isinstance(dec_layer_number, list)
+ assert len(dec_layer_number) == num_layers
+ # assert dec_layer_number[0] ==
+
+ self.dec_layer_dropout_prob = dec_layer_dropout_prob
+ if dec_layer_dropout_prob is not None:
+ assert isinstance(dec_layer_dropout_prob, list)
+ assert len(dec_layer_dropout_prob) == num_layers
+ for i in dec_layer_dropout_prob:
+ assert 0.0 <= i <= 1.0
+
+ self.rm_detach = None
+
+ def forward(self, tgt, memory,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ pos: Optional[Tensor] = None,
+ refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2
+ # for memory
+ level_start_index: Optional[Tensor] = None, # num_levels
+ spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ valid_ratios: Optional[Tensor] = None,
+
+ ):
+ """
+ Input:
+ - tgt: nq, bs, d_model
+ - memory: hw, bs, d_model
+ - pos: hw, bs, d_model
+ - refpoints_unsigmoid: nq, bs, 2/4
+ - valid_ratios/spatial_shapes: bs, nlevel, 2
+ """
+ output = tgt
+
+ intermediate = []
+ reference_points = refpoints_unsigmoid.sigmoid()
+ ref_points = [reference_points]
+
+ for layer_id, layer in enumerate(self.layers):
+ # preprocess ref points
+ if self.training and self.decoder_query_perturber is not None and layer_id != 0:
+ reference_points = self.decoder_query_perturber(reference_points)
+
+
+
+ if self.deformable_decoder:
+ if reference_points.shape[-1] == 4:
+ reference_points_input = reference_points[:, :, None] \
+ * torch.cat([valid_ratios, valid_ratios], -1)[None, :] # nq, bs, nlevel, 4
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = reference_points[:, :, None] * valid_ratios[None, :]
+ query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # nq, bs, 256*2
+ else:
+ query_sine_embed = gen_sineembed_for_position(reference_points) # nq, bs, 256*2
+ reference_points_input = None
+
+ # conditional query
+ # import ipdb; ipdb.set_trace()
+ raw_query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
+ pos_scale = self.query_scale(output) if self.query_scale is not None else 1
+ query_pos = pos_scale * raw_query_pos
+ if not self.deformable_decoder:
+ query_sine_embed = query_sine_embed[..., :self.d_model] * self.query_pos_sine_scale(output)
+
+ # modulated HW attentions
+ if not self.deformable_decoder and self.modulate_hw_attn:
+ refHW_cond = self.ref_anchor_head(output).sigmoid() # nq, bs, 2
+ query_sine_embed[..., self.d_model // 2:] *= (refHW_cond[..., 0] / reference_points[..., 2]).unsqueeze(-1)
+ query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / reference_points[..., 3]).unsqueeze(-1)
+
+ # main process
+ # import ipdb; ipdb.set_trace()
+ dropflag = False
+ if self.dec_layer_dropout_prob is not None:
+ prob = random.random()
+ if prob < self.dec_layer_dropout_prob[layer_id]:
+ dropflag = True
+ if not dropflag:
+ output = layer(
+ tgt = output,
+ tgt_query_pos = query_pos,
+ tgt_query_sine_embed = query_sine_embed,
+ tgt_key_padding_mask = tgt_key_padding_mask,
+ tgt_reference_points = reference_points_input,
+
+ memory = memory,
+ memory_key_padding_mask = memory_key_padding_mask,
+ memory_level_start_index = level_start_index,
+ memory_spatial_shapes = spatial_shapes,
+ memory_pos = pos,
+
+ self_attn_mask = tgt_mask,
+ cross_attn_mask = memory_mask
+ )
+
+ # iter update
+ if self.bbox_embed is not None:
+ # box_holder = self.bbox_embed(output)
+ # box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
+ # new_reference_points = box_holder[..., :self.query_dim].sigmoid()
+
+ reference_before_sigmoid = inverse_sigmoid(reference_points)
+ delta_unsig = self.bbox_embed[layer_id](output)
+ outputs_unsig = delta_unsig + reference_before_sigmoid
+ new_reference_points = outputs_unsig.sigmoid()
+
+ # select # ref points
+ if self.dec_layer_number is not None and layer_id != self.num_layers - 1:
+ # import ipdb; ipdb.set_trace()
+ nq_now = new_reference_points.shape[0]
+ select_number = self.dec_layer_number[layer_id + 1]
+ if nq_now != select_number:
+ class_unselected = self.class_embed[layer_id](output) # nq, bs, 91
+ topk_proposals = torch.topk(class_unselected.max(-1)[0], select_number, dim=0)[1] # new_nq, bs
+ new_reference_points = torch.gather(new_reference_points, 0, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid
+
+ if self.rm_detach and 'dec' in self.rm_detach:
+ reference_points = new_reference_points
+ else:
+ reference_points = new_reference_points.detach()
+ if self.use_detached_boxes_dec_out:
+ ref_points.append(reference_points)
+ else:
+ ref_points.append(new_reference_points)
+
+
+ intermediate.append(self.norm(output))
+ if self.dec_layer_number is not None and layer_id != self.num_layers - 1:
+ if nq_now != select_number:
+ output = torch.gather(output, 0, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)) # unsigmoid
+
+
+ return [
+ [itm_out.transpose(0, 1) for itm_out in intermediate],
+ [itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points]
+ ]
+
+class DeformableTransformerEncoderLayer(nn.Module):
+ def __init__(self,
+ d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4,
+ add_channel_attention=False,
+ use_deformable_box_attn=False,
+ box_attn_type='roi_align',
+ ):
+ super().__init__()
+
+ # self attention
+ if use_deformable_box_attn:
+ self.self_attn = MSDeformableBoxAttention(d_model, n_levels, n_heads, n_boxes=n_points, used_func=box_attn_type)
+ else:
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation, d_model=d_ffn)
+ self.dropout2 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout3 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # channel attention
+ self.add_channel_attention = add_channel_attention
+ if add_channel_attention:
+ self.activ_channel = _get_activation_fn('dyrelu', d_model=d_model)
+ self.norm_channel = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, src):
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+ src = src + self.dropout3(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None):
+ # self attention
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, key_padding_mask)
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ # ffn
+ src = self.forward_ffn(src)
+
+ # channel attn
+ if self.add_channel_attention:
+ src = self.norm_channel(src + self.activ_channel(src))
+
+ return src
+
+class DeformableTransformerDecoderLayer(nn.Module):
+ def __init__(self, d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4,
+ use_deformable_box_attn=False,
+ box_attn_type='roi_align',
+ key_aware_type=None,
+ decoder_sa_type='ca',
+ module_seq=['sa', 'ca', 'ffn'],
+ ):
+ super().__init__()
+ self.module_seq = module_seq
+ assert sorted(module_seq) == ['ca', 'ffn', 'sa']
+
+ # cross attention
+ # self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ if use_deformable_box_attn:
+ self.cross_attn = MSDeformableBoxAttention(d_model, n_levels, n_heads, n_boxes=n_points, used_func=box_attn_type)
+ else:
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # self attention
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
+ self.dropout3 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout4 = nn.Dropout(dropout)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ self.key_aware_type = key_aware_type
+ self.key_aware_proj = None
+ self.decoder_sa_type = decoder_sa_type
+ assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
+
+ if decoder_sa_type == 'ca_content':
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+
+
+
+
+ def rm_self_attn_modules(self):
+ self.self_attn = None
+ self.dropout2 = None
+ self.norm2 = None
+
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_sa(self,
+ # for tgt
+ tgt: Optional[Tensor], # nq, bs, d_model
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
+
+ # for memory
+ memory: Optional[Tensor] = None, # hw, bs, d_model
+ memory_key_padding_mask: Optional[Tensor] = None,
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ memory_pos: Optional[Tensor] = None, # pos for memory
+
+ # sa
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
+ ):
+ # self attention
+ if self.self_attn is not None:
+ if self.decoder_sa_type == 'sa':
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ elif self.decoder_sa_type == 'ca_label':
+ bs = tgt.shape[1]
+ k = v = self.label_embedding.weight[:, None, :].repeat(1, bs, 1)
+ tgt2 = self.self_attn(tgt, k, v, attn_mask=self_attn_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ elif self.decoder_sa_type == 'ca_content':
+ tgt2 = self.self_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
+ tgt_reference_points.transpose(0, 1).contiguous(),
+ memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ else:
+ raise NotImplementedError("Unknown decoder_sa_type {}".format(self.decoder_sa_type))
+
+ return tgt
+
+ def forward_ca(self,
+ # for tgt
+ tgt: Optional[Tensor], # nq, bs, d_model
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
+
+ # for memory
+ memory: Optional[Tensor] = None, # hw, bs, d_model
+ memory_key_padding_mask: Optional[Tensor] = None,
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ memory_pos: Optional[Tensor] = None, # pos for memory
+
+ # sa
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
+ ):
+ # cross attention
+ if self.key_aware_type is not None:
+
+ if self.key_aware_type == 'mean':
+ tgt = tgt + memory.mean(0, keepdim=True)
+ elif self.key_aware_type == 'proj_mean':
+ tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True)
+ else:
+ raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type))
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
+ tgt_reference_points.transpose(0, 1).contiguous(),
+ memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ return tgt
+
+ def forward(self,
+ # for tgt
+ tgt: Optional[Tensor], # nq, bs, d_model
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
+
+ # for memory
+ memory: Optional[Tensor] = None, # hw, bs, d_model
+ memory_key_padding_mask: Optional[Tensor] = None,
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ memory_pos: Optional[Tensor] = None, # pos for memory
+
+ # sa
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
+ ):
+
+ for funcname in self.module_seq:
+ if funcname == 'ffn':
+ tgt = self.forward_ffn(tgt)
+ elif funcname == 'ca':
+ tgt = self.forward_ca(tgt, tgt_query_pos, tgt_query_sine_embed, \
+ tgt_key_padding_mask, tgt_reference_points, \
+ memory, memory_key_padding_mask, memory_level_start_index, \
+ memory_spatial_shapes, memory_pos, self_attn_mask, cross_attn_mask)
+ elif funcname == 'sa':
+ tgt = self.forward_sa(tgt, tgt_query_pos, tgt_query_sine_embed, \
+ tgt_key_padding_mask, tgt_reference_points, \
+ memory, memory_key_padding_mask, memory_level_start_index, \
+ memory_spatial_shapes, memory_pos, self_attn_mask, cross_attn_mask)
+ else:
+ raise ValueError('unknown funcname {}'.format(funcname))
+
+ return tgt
+
+
+def _get_clones(module, N, layer_share=False):
+ if layer_share:
+ return nn.ModuleList([module for i in range(N)])
+ else:
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def build_deformable_transformer(args):
+ decoder_query_perturber = None
+ if args.decoder_layer_noise:
+ from .utils import RandomBoxPerturber
+ decoder_query_perturber=RandomBoxPerturber(
+ x_noise_scale=args.dln_xy_noise, y_noise_scale=args.dln_xy_noise,
+ w_noise_scale=args.dln_hw_noise, h_noise_scale=args.dln_hw_noise)
+
+ use_detached_boxes_dec_out = False
+ try:
+ use_detached_boxes_dec_out = args.use_detached_boxes_dec_out
+ except:
+ use_detached_boxes_dec_out =False
+
+ return DeformableTransformer(
+ d_model=args.hidden_dim,
+ dropout=args.dropout,
+ nhead=args.nheads,
+ num_queries=args.num_queries,
+ dim_feedforward=args.dim_feedforward,
+ num_encoder_layers=args.enc_layers,
+ num_unicoder_layers=args.unic_layers,
+ num_decoder_layers=args.dec_layers,
+ normalize_before=args.pre_norm,
+ return_intermediate_dec=True,
+ query_dim=args.query_dim,
+ activation=args.transformer_activation,
+ num_patterns=args.num_patterns,
+ modulate_hw_attn=True,
+
+ deformable_encoder=True,
+ deformable_decoder=True,
+ num_feature_levels=args.num_feature_levels,
+ enc_n_points=args.enc_n_points,
+ dec_n_points=args.dec_n_points,
+ use_deformable_box_attn=args.use_deformable_box_attn,
+ box_attn_type=args.box_attn_type,
+
+ learnable_tgt_init=True,
+ decoder_query_perturber=decoder_query_perturber,
+
+ add_channel_attention=args.add_channel_attention,
+ add_pos_value=args.add_pos_value,
+ random_refpoints_xy=args.random_refpoints_xy,
+
+ # two stage
+ two_stage_type=args.two_stage_type, # ['no', 'standard', 'early']
+ two_stage_pat_embed=args.two_stage_pat_embed,
+ two_stage_add_query_num=args.two_stage_add_query_num,
+ two_stage_learn_wh=args.two_stage_learn_wh,
+ two_stage_keep_all_tokens=args.two_stage_keep_all_tokens,
+ dec_layer_number=args.dec_layer_number,
+ rm_self_attn_layers=None,
+ key_aware_type=None,
+ layer_share_type=None,
+
+ rm_detach=None,
+ decoder_sa_type=args.decoder_sa_type,
+ module_seq=args.decoder_module_seq,
+
+ embed_init_tgt=args.embed_init_tgt,
+ use_detached_boxes_dec_out=use_detached_boxes_dec_out
+ )
+
+
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/dino.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/dino.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eae52cf0af83ffe94fc22fbab0a5698812aae27
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/dino.py
@@ -0,0 +1,778 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR model and criterion classes.
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# ------------------------------------------------------------------------
+import copy
+import math
+from typing import List
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torchvision.ops.boxes import nms
+
+from .util import box_ops
+from .util.misc import (NestedTensor, nested_tensor_from_tensor_list,
+ accuracy, get_world_size, interpolate,
+ is_dist_avail_and_initialized, inverse_sigmoid)
+
+from .backbone import build_backbone
+from .matcher import build_matcher
+from .segmentation import (dice_loss)
+from .deformable_transformer import build_deformable_transformer
+from .utils import sigmoid_focal_loss, MLP
+
+from .dn_components import prepare_for_cdn, dn_post_process
+
+
+class DINO(nn.Module):
+ """ This is the Cross-Attention Detector module that performs object detection """
+
+ def __init__(self, backbone, transformer, num_classes, num_queries,
+ aux_loss=False, iter_update=False,
+ query_dim=2,
+ random_refpoints_xy=False,
+ fix_refpoints_hw=-1,
+ num_feature_levels=1,
+ nheads=8,
+ # two stage
+ two_stage_type='no', # ['no', 'standard']
+ two_stage_add_query_num=0,
+ dec_pred_class_embed_share=True,
+ dec_pred_bbox_embed_share=True,
+ two_stage_class_embed_share=True,
+ two_stage_bbox_embed_share=True,
+ decoder_sa_type='sa',
+ num_patterns=0,
+ dn_number=100,
+ dn_box_noise_scale=0.4,
+ dn_label_noise_ratio=0.5,
+ dn_labelbook_size=100,
+ ):
+ """ Initializes the model.
+ Parameters:
+ backbone: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ num_classes: number of object classes
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
+ Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+
+ fix_refpoints_hw: -1(default): learn w and h for each box seperately
+ >0 : given fixed number
+ -2 : learn a shared w and h
+ """
+ super().__init__()
+ self.num_queries = num_queries
+ self.transformer = transformer
+ self.num_classes = num_classes
+ self.hidden_dim = hidden_dim = transformer.d_model
+ self.num_feature_levels = num_feature_levels
+ self.nheads = nheads
+ self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim)
+
+ # setting query dim
+ self.query_dim = query_dim
+ assert query_dim == 4
+ self.random_refpoints_xy = random_refpoints_xy
+ self.fix_refpoints_hw = fix_refpoints_hw
+
+ # for dn training
+ self.num_patterns = num_patterns
+ self.dn_number = dn_number
+ self.dn_box_noise_scale = dn_box_noise_scale
+ self.dn_label_noise_ratio = dn_label_noise_ratio
+ self.dn_labelbook_size = dn_labelbook_size
+
+ # prepare input projection layers
+ if num_feature_levels > 1:
+ num_backbone_outs = len(backbone.num_channels)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone.num_channels[_]
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ for _ in range(num_feature_levels - num_backbone_outs):
+ input_proj_list.append(nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
+ nn.GroupNorm(32, hidden_dim),
+ ))
+ in_channels = hidden_dim
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ assert two_stage_type == 'no', "two_stage_type should be no if num_feature_levels=1 !!!"
+ self.input_proj = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )])
+
+ self.backbone = backbone
+ self.aux_loss = aux_loss
+ self.box_pred_damping = box_pred_damping = None
+
+ self.iter_update = iter_update
+ assert iter_update, "Why not iter_update?"
+
+ # prepare pred layers
+ self.dec_pred_class_embed_share = dec_pred_class_embed_share
+ self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
+ # prepare class & box embed
+ _class_embed = nn.Linear(hidden_dim, num_classes)
+ _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+ # init the two embed layers
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ _class_embed.bias.data = torch.ones(self.num_classes) * bias_value
+ nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
+ nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
+
+ if dec_pred_bbox_embed_share:
+ box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)]
+ else:
+ box_embed_layerlist = [copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)]
+ if dec_pred_class_embed_share:
+ class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)]
+ else:
+ class_embed_layerlist = [copy.deepcopy(_class_embed) for i in range(transformer.num_decoder_layers)]
+ self.bbox_embed = nn.ModuleList(box_embed_layerlist)
+ self.class_embed = nn.ModuleList(class_embed_layerlist)
+ self.transformer.decoder.bbox_embed = self.bbox_embed
+ self.transformer.decoder.class_embed = self.class_embed
+
+ # two stage
+ self.two_stage_type = two_stage_type
+ self.two_stage_add_query_num = two_stage_add_query_num
+ assert two_stage_type in ['no', 'standard'], "unknown param {} of two_stage_type".format(two_stage_type)
+ if two_stage_type != 'no':
+ if two_stage_bbox_embed_share:
+ assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
+ self.transformer.enc_out_bbox_embed = _bbox_embed
+ else:
+ self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed)
+
+ if two_stage_class_embed_share:
+ assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
+ self.transformer.enc_out_class_embed = _class_embed
+ else:
+ self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed)
+
+ self.refpoint_embed = None
+ if self.two_stage_add_query_num > 0:
+ self.init_ref_points(two_stage_add_query_num)
+
+ self.decoder_sa_type = decoder_sa_type
+ assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
+ # self.replace_sa_with_double_ca = replace_sa_with_double_ca
+ if decoder_sa_type == 'ca_label':
+ self.label_embedding = nn.Embedding(num_classes, hidden_dim)
+ for layer in self.transformer.decoder.layers:
+ layer.label_embedding = self.label_embedding
+ else:
+ for layer in self.transformer.decoder.layers:
+ layer.label_embedding = None
+ self.label_embedding = None
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ # init input_proj
+ for proj in self.input_proj:
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
+ nn.init.constant_(proj[0].bias, 0)
+
+ def init_ref_points(self, use_num_queries):
+ self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim)
+
+ if self.random_refpoints_xy:
+ # import ipdb; ipdb.set_trace()
+ self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
+ self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
+ self.refpoint_embed.weight.data[:, :2].requires_grad = False
+
+ if self.fix_refpoints_hw > 0:
+ print("fix_refpoints_hw: {}".format(self.fix_refpoints_hw))
+ assert self.random_refpoints_xy
+ self.refpoint_embed.weight.data[:, 2:] = self.fix_refpoints_hw
+ self.refpoint_embed.weight.data[:, 2:] = inverse_sigmoid(self.refpoint_embed.weight.data[:, 2:])
+ self.refpoint_embed.weight.data[:, 2:].requires_grad = False
+ elif int(self.fix_refpoints_hw) == -1:
+ pass
+ elif int(self.fix_refpoints_hw) == -2:
+ print('learn a shared h and w')
+ assert self.random_refpoints_xy
+ self.refpoint_embed = nn.Embedding(use_num_queries, 2)
+ self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
+ self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
+ self.refpoint_embed.weight.data[:, :2].requires_grad = False
+ self.hw_embed = nn.Embedding(1, 1)
+ else:
+ raise NotImplementedError('Unknown fix_refpoints_hw {}'.format(self.fix_refpoints_hw))
+
+ def forward(self, samples: NestedTensor, targets: List = None):
+ """ The forward expects a NestedTensor, which consists of:
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+
+ It returns a dict with the following elements:
+ - "pred_logits": the classification logits (including no-object) for all queries.
+ Shape= [batch_size x num_queries x num_classes]
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
+ (center_x, center_y, width, height). These values are normalized in [0, 1],
+ relative to the size of each individual image (disregarding possible padding).
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
+ dictionnaries containing the two above keys for each decoder layer.
+ """
+ if isinstance(samples, (list, torch.Tensor)):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, poss = self.backbone(samples)
+
+ srcs = []
+ masks = []
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ srcs.append(self.input_proj[l](src))
+ masks.append(mask)
+ assert mask is not None
+ if self.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.num_feature_levels):
+ if l == _len_srcs:
+ src = self.input_proj[l](features[-1].tensors)
+ else:
+ src = self.input_proj[l](srcs[-1])
+ m = samples.mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ poss.append(pos_l)
+
+ if self.dn_number > 0 or targets is not None:
+ input_query_label, input_query_bbox, attn_mask, dn_meta = \
+ prepare_for_cdn(dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_box_noise_scale),
+ training=self.training, num_queries=self.num_queries, num_classes=self.num_classes,
+ hidden_dim=self.hidden_dim, label_enc=self.label_enc)
+ else:
+ assert targets is None
+ input_query_bbox = input_query_label = attn_mask = dn_meta = None
+
+ hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(srcs, masks, input_query_bbox, poss,
+ input_query_label, attn_mask)
+ # In case num object=0
+ hs[0] += self.label_enc.weight[0, 0] * 0.0
+
+ # deformable-detr-like anchor update
+ # reference_before_sigmoid = inverse_sigmoid(reference[:-1]) # n_dec, bs, nq, 4
+ outputs_coord_list = []
+ for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
+ outputs_coord_list.append(layer_outputs_unsig)
+ outputs_coord_list = torch.stack(outputs_coord_list)
+
+ # outputs_class = self.class_embed(hs)
+ outputs_class = torch.stack([layer_cls_embed(layer_hs) for
+ layer_cls_embed, layer_hs in zip(self.class_embed, hs)])
+ if self.dn_number > 0 and dn_meta is not None:
+ outputs_class, outputs_coord_list = \
+ dn_post_process(outputs_class, outputs_coord_list,
+ dn_meta, self.aux_loss, self._set_aux_loss)
+ out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord_list[-1]}
+ if self.aux_loss:
+ out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord_list)
+
+ # for encoder output
+ if hs_enc is not None:
+ # prepare intermediate outputs
+ interm_coord = ref_enc[-1]
+ interm_class = self.transformer.enc_out_class_embed(hs_enc[-1])
+ out['interm_outputs'] = {'pred_logits': interm_class, 'pred_boxes': interm_coord}
+ out['interm_outputs_for_matching_pre'] = {'pred_logits': interm_class, 'pred_boxes': init_box_proposal}
+
+ # prepare enc outputs
+ # import ipdb; ipdb.set_trace()
+ if hs_enc.shape[0] > 1:
+ enc_outputs_coord = []
+ enc_outputs_class = []
+ for layer_id, (layer_box_embed, layer_class_embed, layer_hs_enc, layer_ref_enc) in enumerate(
+ zip(self.enc_bbox_embed, self.enc_class_embed, hs_enc[:-1], ref_enc[:-1])):
+ layer_enc_delta_unsig = layer_box_embed(layer_hs_enc)
+ layer_enc_outputs_coord_unsig = layer_enc_delta_unsig + inverse_sigmoid(layer_ref_enc)
+ layer_enc_outputs_coord = layer_enc_outputs_coord_unsig.sigmoid()
+
+ layer_enc_outputs_class = layer_class_embed(layer_hs_enc)
+ enc_outputs_coord.append(layer_enc_outputs_coord)
+ enc_outputs_class.append(layer_enc_outputs_class)
+
+ # enc_delta_unsig = self.enc_bbox_embed(hs_enc[:-1])
+ # enc_outputs_unsig = enc_delta_unsig + ref_enc[:-1]
+ # enc_outputs_coord = enc_outputs_unsig.sigmoid()
+ # enc_outputs_class = self.enc_class_embed(hs_enc[:-1])
+ out['enc_outputs'] = [
+ {'pred_logits': a, 'pred_boxes': b} for a, b in zip(enc_outputs_class, enc_outputs_coord)
+ ]
+
+ out['dn_meta'] = dn_meta
+
+ return out
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ return [{'pred_logits': a, 'pred_boxes': b}
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for Conditional DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses):
+ """ Create the criterion.
+ Parameters:
+ num_classes: number of object categories, omitting the special no-object category
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ focal_alpha: alpha in Focal Loss
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.focal_alpha = focal_alpha
+
+ def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
+ """Classification loss (Binary focal loss)
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ assert 'pred_logits' in outputs
+ src_logits = outputs['pred_logits']
+
+ idx = self._get_src_permutation_idx(indices)
+ target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+ target_classes = torch.full(src_logits.shape[:2], self.num_classes,
+ dtype=torch.int64, device=src_logits.device)
+ target_classes[idx] = target_classes_o
+
+ target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
+ dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
+
+ target_classes_onehot = target_classes_onehot[:, :, :-1]
+ loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * \
+ src_logits.shape[1]
+ losses = {'loss_ce': loss_ce}
+
+ if log:
+ # TODO this should probably be a separate loss, not hacked in this one here
+ losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
+ return losses
+
+ @torch.no_grad()
+ def loss_cardinality(self, outputs, targets, indices, num_boxes):
+ """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
+ """
+ pred_logits = outputs['pred_logits']
+ device = pred_logits.device
+ tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
+ losses = {'cardinality_error': card_err}
+ return losses
+
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
+ """
+ assert 'pred_boxes' in outputs
+ idx = self._get_src_permutation_idx(indices)
+ src_boxes = outputs['pred_boxes'][idx]
+ target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+
+ losses = {}
+ losses['loss_bbox'] = loss_bbox.sum() / num_boxes
+
+ loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
+ box_ops.box_cxcywh_to_xyxy(target_boxes)))
+ losses['loss_giou'] = loss_giou.sum() / num_boxes
+
+ # calculate the x,y and h,w loss
+ with torch.no_grad():
+ losses['loss_xy'] = loss_bbox[..., :2].sum() / num_boxes
+ losses['loss_hw'] = loss_bbox[..., 2:].sum() / num_boxes
+
+ return losses
+
+ def loss_masks(self, outputs, targets, indices, num_boxes):
+ """Compute the losses related to the masks: the focal loss and the dice loss.
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+ """
+ assert "pred_masks" in outputs
+
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+ src_masks = outputs["pred_masks"]
+ src_masks = src_masks[src_idx]
+ masks = [t["masks"] for t in targets]
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
+ target_masks = target_masks.to(src_masks)
+ target_masks = target_masks[tgt_idx]
+
+ # upsample predictions to the target size
+ src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
+ mode="bilinear", align_corners=False)
+ src_masks = src_masks[:, 0].flatten(1)
+
+ target_masks = target_masks.flatten(1)
+ target_masks = target_masks.view(src_masks.shape)
+ losses = {
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
+ "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
+ }
+ return losses
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+ loss_map = {
+ 'labels': self.loss_labels,
+ 'cardinality': self.loss_cardinality,
+ 'boxes': self.loss_boxes,
+ 'masks': self.loss_masks,
+ # 'dn_labels': self.loss_dn_labels,
+ # 'dn_boxes': self.loss_dn_boxes
+ }
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+ def forward(self, outputs, targets, return_indices=False):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+
+ return_indices: used for vis. if True, the layer0-5 indices will be returned as well.
+
+ """
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
+ device = next(iter(outputs.values())).device
+ indices = self.matcher(outputs_without_aux, targets)
+
+ if return_indices:
+ indices0_copy = indices
+ indices_list = []
+
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_boxes = sum(len(t["labels"]) for t in targets)
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device)
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_boxes)
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+ # Compute all the requested losses
+ losses = {}
+
+ # prepare for dn loss
+ dn_meta = outputs['dn_meta']
+
+ if self.training and dn_meta and 'output_known_lbs_bboxes' in dn_meta:
+ output_known_lbs_bboxes, single_pad, scalar = self.prep_for_dn(dn_meta)
+
+ dn_pos_idx = []
+ dn_neg_idx = []
+ for i in range(len(targets)):
+ if len(targets[i]['labels']) > 0:
+ t = torch.range(0, len(targets[i]['labels']) - 1).long().cuda()
+ t = t.unsqueeze(0).repeat(scalar, 1)
+ tgt_idx = t.flatten()
+ output_idx = (torch.tensor(range(scalar)) * single_pad).long().cuda().unsqueeze(1) + t
+ output_idx = output_idx.flatten()
+ else:
+ output_idx = tgt_idx = torch.tensor([]).long().cuda()
+
+ dn_pos_idx.append((output_idx, tgt_idx))
+ dn_neg_idx.append((output_idx + single_pad // 2, tgt_idx))
+
+ output_known_lbs_bboxes = dn_meta['output_known_lbs_bboxes']
+ l_dict = {}
+ for loss in self.losses:
+ kwargs = {}
+ if 'labels' in loss:
+ kwargs = {'log': False}
+ l_dict.update(
+ self.get_loss(loss, output_known_lbs_bboxes, targets, dn_pos_idx, num_boxes * scalar, **kwargs))
+
+ l_dict = {k + f'_dn': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+ else:
+ l_dict = dict()
+ l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['loss_xy_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['loss_hw_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['cardinality_error_dn'] = torch.as_tensor(0.).to('cuda')
+ losses.update(l_dict)
+
+ for loss in self.losses:
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if 'aux_outputs' in outputs:
+ for idx, aux_outputs in enumerate(outputs['aux_outputs']):
+ indices = self.matcher(aux_outputs, targets)
+ if return_indices:
+ indices_list.append(indices)
+ for loss in self.losses:
+ if loss == 'masks':
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == 'labels':
+ # Logging is enabled only for the last layer
+ kwargs = {'log': False}
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
+ l_dict = {k + f'_{idx}': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ if self.training and dn_meta and 'output_known_lbs_bboxes' in dn_meta:
+ aux_outputs_known = output_known_lbs_bboxes['aux_outputs'][idx]
+ l_dict = {}
+ for loss in self.losses:
+ kwargs = {}
+ if 'labels' in loss:
+ kwargs = {'log': False}
+
+ l_dict.update(self.get_loss(loss, aux_outputs_known, targets, dn_pos_idx, num_boxes * scalar,
+ **kwargs))
+
+ l_dict = {k + f'_dn_{idx}': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+ else:
+ l_dict = dict()
+ l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['loss_xy_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['loss_hw_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict['cardinality_error_dn'] = torch.as_tensor(0.).to('cuda')
+ l_dict = {k + f'_{idx}': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ # interm_outputs loss
+ if 'interm_outputs' in outputs:
+ interm_outputs = outputs['interm_outputs']
+ indices = self.matcher(interm_outputs, targets)
+ if return_indices:
+ indices_list.append(indices)
+ for loss in self.losses:
+ if loss == 'masks':
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == 'labels':
+ # Logging is enabled only for the last layer
+ kwargs = {'log': False}
+ l_dict = self.get_loss(loss, interm_outputs, targets, indices, num_boxes, **kwargs)
+ l_dict = {k + f'_interm': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ # enc output loss
+ if 'enc_outputs' in outputs:
+ for i, enc_outputs in enumerate(outputs['enc_outputs']):
+ indices = self.matcher(enc_outputs, targets)
+ if return_indices:
+ indices_list.append(indices)
+ for loss in self.losses:
+ if loss == 'masks':
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == 'labels':
+ # Logging is enabled only for the last layer
+ kwargs = {'log': False}
+ l_dict = self.get_loss(loss, enc_outputs, targets, indices, num_boxes, **kwargs)
+ l_dict = {k + f'_enc_{i}': v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ if return_indices:
+ indices_list.append(indices0_copy)
+ return losses, indices_list
+
+ return losses
+
+ def prep_for_dn(self, dn_meta):
+ output_known_lbs_bboxes = dn_meta['output_known_lbs_bboxes']
+ num_dn_groups, pad_size = dn_meta['num_dn_group'], dn_meta['pad_size']
+ assert pad_size % num_dn_groups == 0
+ single_pad = pad_size // num_dn_groups
+
+ return output_known_lbs_bboxes, single_pad, num_dn_groups
+
+
+class PostProcess(nn.Module):
+ """ This module converts the model's output into the format expected by the coco api"""
+
+ def __init__(self, num_select=100, nms_iou_threshold=-1) -> None:
+ super().__init__()
+ self.num_select = num_select
+ self.nms_iou_threshold = nms_iou_threshold
+
+ @torch.no_grad()
+ def forward(self, outputs, target_sizes, not_to_xyxy=False, test=False):
+ """ Perform the computation
+ Parameters:
+ outputs: raw outputs of the model
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
+ For evaluation, this must be the original image size (before any data augmentation)
+ For visualization, this should be the image size after data augment, but before padding
+ """
+ num_select = self.num_select
+ out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
+
+ assert len(out_logits) == len(target_sizes)
+ assert target_sizes.shape[1] == 2
+
+ prob = out_logits.sigmoid()
+ topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1)
+ scores = topk_values
+ topk_boxes = topk_indexes // out_logits.shape[2]
+ labels = topk_indexes % out_logits.shape[2]
+ if not_to_xyxy:
+ boxes = out_bbox
+ else:
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
+
+ if test:
+ assert not not_to_xyxy
+ boxes[:, :, 2:] = boxes[:, :, 2:] - boxes[:, :, :2]
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+
+ if self.nms_iou_threshold > 0:
+ item_indices = [nms(b, s, iou_threshold=self.nms_iou_threshold) for b, s in zip(boxes, scores)]
+ # import ipdb; ipdb.set_trace()
+ results = [{'scores': s[i], 'labels': l[i], 'boxes': b[i]} for s, l, b, i in
+ zip(scores, labels, boxes, item_indices)]
+ else:
+ results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
+
+ return results
+
+
+def build_dino(args):
+ # the `num_classes` naming here is somewhat misleading.
+ # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
+ # is the maximum id for a class in your dataset. For example,
+ # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
+ # As another example, for a dataset that has a single class with id 1,
+ # you should pass `num_classes` to be 2 (max_obj_id + 1).
+ # For more details on this, check the following discussion
+ # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
+ # num_classes = 20 if args.dataset_file != 'coco' else 91
+ # if args.dataset_file == "coco_panoptic":
+ # # for panoptic, we just add a num_classes that is large enough to hold
+ # # max_obj_id + 1, but the exact value doesn't really matter
+ # num_classes = 250
+ # if args.dataset_file == 'o365':
+ # num_classes = 366
+ # if args.dataset_file == 'vanke':
+ # num_classes = 51
+ num_classes = args.num_classes
+
+ backbone = build_backbone(args)
+
+ transformer = build_deformable_transformer(args)
+
+ try:
+ match_unstable_error = args.match_unstable_error
+ dn_labelbook_size = args.dn_labelbook_size
+ except:
+ match_unstable_error = True
+ dn_labelbook_size = num_classes
+
+ try:
+ dec_pred_class_embed_share = args.dec_pred_class_embed_share
+ except:
+ dec_pred_class_embed_share = True
+ try:
+ dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
+ except:
+ dec_pred_bbox_embed_share = True
+
+ model = DINO(
+ backbone,
+ transformer,
+ num_classes=num_classes,
+ num_queries=args.num_queries,
+ aux_loss=True,
+ iter_update=True,
+ query_dim=4,
+ random_refpoints_xy=args.random_refpoints_xy,
+ fix_refpoints_hw=args.fix_refpoints_hw,
+ num_feature_levels=args.num_feature_levels,
+ nheads=args.nheads,
+ dec_pred_class_embed_share=dec_pred_class_embed_share,
+ dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
+ # two stage
+ two_stage_type=args.two_stage_type,
+ # box_share
+ two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
+ two_stage_class_embed_share=args.two_stage_class_embed_share,
+ decoder_sa_type=args.decoder_sa_type,
+ num_patterns=args.num_patterns,
+ dn_number=args.dn_number if args.use_dn else 0,
+ dn_box_noise_scale=args.dn_box_noise_scale,
+ dn_label_noise_ratio=args.dn_label_noise_ratio,
+ dn_labelbook_size=dn_labelbook_size,
+ )
+ matcher = build_matcher(args)
+
+ # prepare weight dict
+ box_postprocessor = PostProcess(num_select=args.num_select, nms_iou_threshold=args.nms_iou_threshold)
+
+ return model, matcher, box_postprocessor
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/dn_components.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/dn_components.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c57d11fae00f4a2c33d41e64f4b181e93d3c729
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/dn_components.py
@@ -0,0 +1,154 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# DN-DETR
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+
+
+import torch
+from .util.misc import (NestedTensor, nested_tensor_from_tensor_list,
+ accuracy, get_world_size, interpolate,
+ is_dist_avail_and_initialized, inverse_sigmoid)
+# from .DABDETR import sigmoid_focal_loss
+from .util import box_ops
+import torch.nn.functional as F
+
+
+def prepare_for_cdn(dn_args, training, num_queries, num_classes, hidden_dim, label_enc):
+ """
+ A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
+ forward function and use learnable tgt embedding, so we change this function a little bit.
+ :param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
+ :param training: if it is training or inference
+ :param num_queries: number of queires
+ :param num_classes: number of classes
+ :param hidden_dim: transformer hidden dim
+ :param label_enc: encode labels in dn
+ :return:
+ """
+ if training:
+ targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
+ # positive and negative dn queries
+ dn_number = dn_number * 2
+ known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
+ batch_size = len(known)
+ known_num = [sum(k) for k in known]
+ if int(max(known_num)) == 0:
+ dn_number = 1
+ else:
+ if dn_number >= 100:
+ dn_number = dn_number // (int(max(known_num) * 2))
+ elif dn_number < 1:
+ dn_number = 1
+ if dn_number == 0:
+ dn_number = 1
+ unmask_bbox = unmask_label = torch.cat(known)
+ labels = torch.cat([t['labels'] for t in targets])
+ boxes = torch.cat([t['boxes'] for t in targets])
+ batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
+
+ known_indice = torch.nonzero(unmask_label + unmask_bbox)
+ known_indice = known_indice.view(-1)
+
+ known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
+ known_labels = labels.repeat(2 * dn_number, 1).view(-1)
+ known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
+ known_bboxs = boxes.repeat(2 * dn_number, 1)
+ known_labels_expaned = known_labels.clone()
+ known_bbox_expand = known_bboxs.clone()
+
+ if label_noise_ratio > 0:
+ p = torch.rand_like(known_labels_expaned.float())
+ chosen_indice = torch.nonzero(p < (label_noise_ratio * 0.5)).view(-1) # half of bbox prob
+ new_label = torch.randint_like(chosen_indice, 0, num_classes) # randomly put a new one here
+ known_labels_expaned.scatter_(0, chosen_indice, new_label)
+ single_pad = int(max(known_num))
+
+ pad_size = int(single_pad * 2 * dn_number)
+ positive_idx = torch.tensor(range(len(boxes))).long().cuda().unsqueeze(0).repeat(dn_number, 1)
+ positive_idx += (torch.tensor(range(dn_number)) * len(boxes) * 2).long().cuda().unsqueeze(1)
+ positive_idx = positive_idx.flatten()
+ negative_idx = positive_idx + len(boxes)
+ if box_noise_scale > 0:
+ known_bbox_ = torch.zeros_like(known_bboxs)
+ known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
+ known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2
+
+ diff = torch.zeros_like(known_bboxs)
+ diff[:, :2] = known_bboxs[:, 2:] / 2
+ diff[:, 2:] = known_bboxs[:, 2:] / 2
+
+ rand_sign = torch.randint_like(known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
+ rand_part = torch.rand_like(known_bboxs)
+ rand_part[negative_idx] += 1.0
+ rand_part *= rand_sign
+ known_bbox_ = known_bbox_ + torch.mul(rand_part,
+ diff).cuda() * box_noise_scale
+ known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
+ known_bbox_expand[:, :2] = (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2
+ known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]
+
+ m = known_labels_expaned.long().to('cuda')
+ input_label_embed = label_enc(m)
+ input_bbox_embed = inverse_sigmoid(known_bbox_expand)
+
+ padding_label = torch.zeros(pad_size, hidden_dim).cuda()
+ padding_bbox = torch.zeros(pad_size, 4).cuda()
+
+ input_query_label = padding_label.repeat(batch_size, 1, 1)
+ input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)
+
+ map_known_indice = torch.tensor([]).to('cuda')
+ if len(known_num):
+ map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
+ map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(2 * dn_number)]).long()
+ if len(known_bid):
+ input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed
+ input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed
+
+ tgt_size = pad_size + num_queries
+ attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
+ # match query cannot see the reconstruct
+ attn_mask[pad_size:, :pad_size] = True
+ # reconstruct cannot see each other
+ for i in range(dn_number):
+ if i == 0:
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
+ if i == dn_number - 1:
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True
+ else:
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True
+ attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True
+
+ dn_meta = {
+ 'pad_size': pad_size,
+ 'num_dn_group': dn_number,
+ }
+ else:
+
+ input_query_label = None
+ input_query_bbox = None
+ attn_mask = None
+ dn_meta = None
+
+ return input_query_label, input_query_bbox, attn_mask, dn_meta
+
+
+def dn_post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
+ """
+ post process of dn after output from the transformer
+ put the dn part in the dn_meta
+ """
+ if dn_meta and dn_meta['pad_size'] > 0:
+ output_known_class = outputs_class[:, :, :dn_meta['pad_size'], :]
+ output_known_coord = outputs_coord[:, :, :dn_meta['pad_size'], :]
+ outputs_class = outputs_class[:, :, dn_meta['pad_size']:, :]
+ outputs_coord = outputs_coord[:, :, dn_meta['pad_size']:, :]
+ out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]}
+ if aux_loss:
+ out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
+ dn_meta['output_known_lbs_bboxes'] = out
+ return outputs_class, outputs_coord
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/focal.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/focal.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b190a5271f81a3732a068e1ac95610e0349f1ab
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/focal.py
@@ -0,0 +1,603 @@
+# --------------------------------------------------------
+# FocalNet for Semantic Segmentation
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Jianwei Yang
+# --------------------------------------------------------
+import math
+import time
+import numpy as np
+import json
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from .util.misc import NestedTensor
+
+class Mlp(nn.Module):
+ """ Multilayer perceptron."""
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+class FocalModulation(nn.Module):
+ """ Focal Modulation
+
+ Args:
+ dim (int): Number of input channels.
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ focal_level (int): Number of focal levels
+ focal_window (int): Focal window size at focal level 1
+ focal_factor (int, default=2): Step to increase the focal window
+ use_postln (bool, default=False): Whether use post-modulation layernorm
+ """
+
+ def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False,
+ use_postln_in_modulation=False, normalize_modulator=False):
+
+ super().__init__()
+ self.dim = dim
+
+ # specific args for focalv3
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+ self.focal_factor = focal_factor
+ self.use_postln_in_modulation = use_postln_in_modulation
+ self.normalize_modulator = normalize_modulator
+
+ self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True)
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True)
+
+ self.act = nn.GELU()
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.focal_layers = nn.ModuleList()
+
+ if self.use_postln_in_modulation:
+ self.ln = nn.LayerNorm(dim)
+
+ for k in range(self.focal_level):
+ kernel_size = self.focal_factor*k + self.focal_window
+ self.focal_layers.append(
+ nn.Sequential(
+ nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim,
+ padding=kernel_size//2, bias=False),
+ nn.GELU(),
+ )
+ )
+
+ def forward(self, x):
+ """ Forward function.
+
+ Args:
+ x: input features with shape of (B, H, W, C)
+ """
+ B, nH, nW, C = x.shape
+ x = self.f(x)
+ x = x.permute(0, 3, 1, 2).contiguous()
+ q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1)
+
+ ctx_all = 0
+ for l in range(self.focal_level):
+ ctx = self.focal_layers[l](ctx)
+ ctx_all = ctx_all + ctx*gates[:, l:l+1]
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
+ ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:]
+ if self.normalize_modulator:
+ ctx_all = ctx_all / (self.focal_level+1)
+
+ x_out = q * self.h(ctx_all)
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
+ if self.use_postln_in_modulation:
+ x_out = self.ln(x_out)
+ x_out = self.proj(x_out)
+ x_out = self.proj_drop(x_out)
+ return x_out
+
+class FocalModulationBlock(nn.Module):
+ """ Focal Modulation Block.
+
+ Args:
+ dim (int): Number of input channels.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ focal_level (int): number of focal levels
+ focal_window (int): focal kernel size at level 1
+ """
+
+ def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ focal_level=2, focal_window=9,
+ use_postln=False, use_postln_in_modulation=False,
+ normalize_modulator=False,
+ use_layerscale=False,
+ layerscale_value=1e-4):
+ super().__init__()
+ self.dim = dim
+ self.mlp_ratio = mlp_ratio
+ self.focal_window = focal_window
+ self.focal_level = focal_level
+ self.use_postln = use_postln
+ self.use_layerscale = use_layerscale
+
+ self.norm1 = norm_layer(dim)
+ self.modulation = FocalModulation(
+ dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop,
+ use_postln_in_modulation=use_postln_in_modulation,
+ normalize_modulator=normalize_modulator,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.H = None
+ self.W = None
+
+ self.gamma_1 = 1.0
+ self.gamma_2 = 1.0
+ if self.use_layerscale:
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, x):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ if not self.use_postln:
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # FM
+ x = self.modulation(x).view(B, H * W, C)
+ if self.use_postln:
+ x = self.norm1(x)
+
+ # FFN
+ x = shortcut + self.drop_path(self.gamma_1 * x)
+
+ if self.use_postln:
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
+ else:
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+
+ return x
+
+class BasicLayer(nn.Module):
+ """ A basic focal modulation layer for one stage.
+
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ drop (float, optional): Dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ focal_level (int): Number of focal levels
+ focal_window (int): Focal window size at focal level 1
+ use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ """
+
+ def __init__(self,
+ dim,
+ depth,
+ mlp_ratio=4.,
+ drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ focal_window=9,
+ focal_level=2,
+ use_conv_embed=False,
+ use_postln=False,
+ use_postln_in_modulation=False,
+ normalize_modulator=False,
+ use_layerscale=False,
+ use_checkpoint=False
+ ):
+ super().__init__()
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ FocalModulationBlock(
+ dim=dim,
+ mlp_ratio=mlp_ratio,
+ drop=drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ focal_window=focal_window,
+ focal_level=focal_level,
+ use_postln=use_postln,
+ use_postln_in_modulation=use_postln_in_modulation,
+ normalize_modulator=normalize_modulator,
+ use_layerscale=use_layerscale,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ patch_size=2,
+ in_chans=dim, embed_dim=2*dim,
+ use_conv_embed=use_conv_embed,
+ norm_layer=norm_layer,
+ is_stem=False
+ )
+
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ if self.downsample is not None:
+ x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W)
+ x_down = self.downsample(x_reshaped)
+ x_down = x_down.flatten(2).transpose(1, 2)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False
+ is_stem (bool): Is the stem block or not.
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ if use_conv_embed:
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
+ if is_stem:
+ kernel_size = 7; padding = 2; stride = 4
+ else:
+ kernel_size = 3; padding = 1; stride = 2
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+ else:
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class FocalNet(nn.Module):
+ """ FocalNet backbone.
+
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ drop_rate (float): Dropout rate.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ focal_levels (Sequence[int]): Number of focal levels at four stages
+ focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages
+ use_conv_embed (bool): Whether use overlapped convolution for patch embedding
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ pretrain_img_size=1600,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ mlp_ratio=4.,
+ drop_rate=0.,
+ drop_path_rate=0.3, # 0.3 or 0.4 works better for large+ models
+ norm_layer=nn.LayerNorm,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ focal_levels=[3,3,3,3],
+ focal_windows=[3,3,3,3],
+ use_conv_embed=False,
+ use_postln=False,
+ use_postln_in_modulation=False,
+ use_layerscale=False,
+ normalize_modulator=False,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None,
+ use_conv_embed=use_conv_embed, is_stem=True)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ mlp_ratio=mlp_ratio,
+ drop=drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
+ focal_window=focal_windows[i_layer],
+ focal_level=focal_levels[i_layer],
+ use_conv_embed=use_conv_embed,
+ use_postln=use_postln,
+ use_postln_in_modulation=use_postln_in_modulation,
+ normalize_modulator=normalize_modulator,
+ use_layerscale=use_layerscale,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ if isinstance(pretrained, str):
+ self.apply(_init_weights)
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ self.apply(_init_weights)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, tensor_list: NestedTensor):
+ """Forward function."""
+ x = tensor_list.tensors
+ tic = time.time()
+ x = self.patch_embed(x)
+ Wh, Ww = x.size(2), x.size(3)
+
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+
+ toc = time.time()
+
+ # collect for nesttensors
+ outs_dict = {}
+ for idx, out_i in enumerate(outs):
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
+ outs_dict[idx] = NestedTensor(out_i, mask)
+
+ return outs_dict
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(FocalNet, self).train(mode)
+ self._freeze_stages()
+
+
+
+def build_focalnet(modelname, **kw):
+ assert modelname in [
+ 'focalnet_L_384_22k',
+ 'focalnet_L_384_22k_fl4',
+ 'focalnet_XL_384_22k',
+ 'focalnet_XL_384_22k_fl4',
+ 'focalnet_H_224_22k',
+ 'focalnet_H_224_22k_fl4',
+ ]
+
+ if 'focal_levels' in kw:
+ kw['focal_levels'] = [kw['focal_levels']] * 4
+
+ if 'focal_windows' in kw:
+ kw['focal_windows'] = [kw['focal_windows']] * 4
+
+ model_para_dict = {
+ 'focalnet_L_384_22k': dict(
+ embed_dim=192,
+ depths=[ 2, 2, 18, 2 ],
+ focal_levels=kw.get('focal_levels', [3, 3, 3, 3]),
+ focal_windows=kw.get('focal_windows', [5, 5, 5, 5]),
+ use_conv_embed=True,
+ use_postln=True,
+ use_postln_in_modulation=False,
+ use_layerscale=True,
+ normalize_modulator=False,
+ ),
+ 'focalnet_L_384_22k_fl4': dict(
+ embed_dim=192,
+ depths=[ 2, 2, 18, 2 ],
+ focal_levels=kw.get('focal_levels', [4, 4, 4, 4]),
+ focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
+ use_conv_embed=True,
+ use_postln=True,
+ use_postln_in_modulation=False,
+ use_layerscale=True,
+ normalize_modulator=True,
+ ),
+ 'focalnet_XL_384_22k': dict(
+ embed_dim=256,
+ depths=[ 2, 2, 18, 2 ],
+ focal_levels=kw.get('focal_levels', [3, 3, 3, 3]),
+ focal_windows=kw.get('focal_windows', [5, 5, 5, 5]),
+ use_conv_embed=True,
+ use_postln=True,
+ use_postln_in_modulation=False,
+ use_layerscale=True,
+ normalize_modulator=False,
+ ),
+ 'focalnet_XL_384_22k_fl4': dict(
+ embed_dim=256,
+ depths=[ 2, 2, 18, 2 ],
+ focal_levels=kw.get('focal_levels', [4, 4, 4, 4]),
+ focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
+ use_conv_embed=True,
+ use_postln=True,
+ use_postln_in_modulation=False,
+ use_layerscale=True,
+ normalize_modulator=True,
+ ),
+ 'focalnet_H_224_22k': dict(
+ embed_dim=352,
+ depths=[ 2, 2, 18, 2 ],
+ focal_levels=kw.get('focal_levels', [3, 3, 3, 3]),
+ focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
+ use_conv_embed=True,
+ use_postln=True,
+ use_layerscale=True,
+ use_postln_in_modulation=True,
+ normalize_modulator=False,
+ ),
+ 'focalnet_H_224_22k_fl4': dict(
+ embed_dim=352,
+ depths=[ 2, 2, 18, 2 ],
+ focal_levels=kw.get('focal_levels', [4, 4, 4, 4]),
+ focal_windows=kw.get('focal_windows', [3, 3, 3, 3]),
+ use_conv_embed=True,
+ use_postln=True,
+ use_postln_in_modulation=True,
+ use_layerscale=True,
+ normalize_modulator=False,
+ ),
+ }
+
+ kw_cgf = model_para_dict[modelname]
+ kw_cgf.update(kw)
+ model = FocalNet(**kw_cgf)
+ return model
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/matcher.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..983c743bdce99c3471f112581ad7f4fac5ee0af6
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/matcher.py
@@ -0,0 +1,191 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modules to compute the matching cost and solve the corresponding LSAP.
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+
+import torch, os
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+
+from .util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+
+
+class HungarianMatcher(nn.Module):
+ """This class computes an assignment between the targets and the predictions of the network
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+ while the others are un-matched (and thus treated as non-objects).
+ """
+
+ def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, focal_alpha = 0.25):
+ """Creates the matcher
+ Params:
+ cost_class: This is the relative weight of the classification error in the matching cost
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
+ """
+ super().__init__()
+ self.cost_class = cost_class
+ self.cost_bbox = cost_bbox
+ self.cost_giou = cost_giou
+ assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
+
+ self.focal_alpha = focal_alpha
+
+ @torch.no_grad()
+ def forward(self, outputs, targets):
+ """ Performs the matching
+ Params:
+ outputs: This is a dict that contains at least these entries:
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+ objects in the target) containing the class labels
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+ Returns:
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+ """
+
+ bs, num_queries = outputs["pred_logits"].shape[:2]
+
+ # We flatten to compute the cost matrices in a batch
+ out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
+
+ # Also concat the target labels and boxes
+ tgt_ids = torch.cat([v["labels"] for v in targets])
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+ # Compute the classification cost.
+ alpha = self.focal_alpha
+ gamma = 2.0
+ neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
+ cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
+
+ # Compute the L1 cost between boxes
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
+
+ # Compute the giou cost betwen boxes
+ cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
+
+ # Final cost matrix
+ C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
+ C = C.view(bs, num_queries, -1).cpu()
+
+ sizes = [len(v["boxes"]) for v in targets]
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+
+
+class SimpleMinsumMatcher(nn.Module):
+ """This class computes an assignment between the targets and the predictions of the network
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+ while the others are un-matched (and thus treated as non-objects).
+ """
+
+ def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, focal_alpha = 0.25):
+ """Creates the matcher
+ Params:
+ cost_class: This is the relative weight of the classification error in the matching cost
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
+ """
+ super().__init__()
+ self.cost_class = cost_class
+ self.cost_bbox = cost_bbox
+ self.cost_giou = cost_giou
+ assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
+
+ self.focal_alpha = focal_alpha
+
+ @torch.no_grad()
+ def forward(self, outputs, targets):
+ """ Performs the matching
+ Params:
+ outputs: This is a dict that contains at least these entries:
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+ objects in the target) containing the class labels
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+ Returns:
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+ """
+
+ bs, num_queries = outputs["pred_logits"].shape[:2]
+
+ # We flatten to compute the cost matrices in a batch
+ out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
+
+ # Also concat the target labels and boxes
+ tgt_ids = torch.cat([v["labels"] for v in targets])
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+ # Compute the classification cost.
+ alpha = self.focal_alpha
+ gamma = 2.0
+ neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
+ pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
+ cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
+
+ # Compute the L1 cost between boxes
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
+
+ # Compute the giou cost betwen boxes
+ cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
+
+ # Final cost matrix
+ C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
+ C = C.view(bs, num_queries, -1)
+
+ sizes = [len(v["boxes"]) for v in targets]
+ indices = []
+ device = C.device
+ for i, (c, _size) in enumerate(zip(C.split(sizes, -1), sizes)):
+ weight_mat = c[i]
+ idx_i = weight_mat.min(0)[1]
+ idx_j = torch.arange(_size).to(device)
+ indices.append((idx_i, idx_j))
+
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+
+
+def build_matcher(args):
+ assert args.matcher_type in ['HungarianMatcher', 'SimpleMinsumMatcher'], "Unknown args.matcher_type: {}".format(args.matcher_type)
+ if args.matcher_type == 'HungarianMatcher':
+ return HungarianMatcher(
+ cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou,
+ focal_alpha=args.focal_alpha
+ )
+ elif args.matcher_type == 'SimpleMinsumMatcher':
+ return SimpleMinsumMatcher(
+ cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou,
+ focal_alpha=args.focal_alpha
+ )
+ else:
+ raise NotImplementedError("Unknown args.matcher_type: {}".format(args.matcher_type))
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/position_encoding.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbc589769266853e31a2fed87a56903635c3fefd
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/position_encoding.py
@@ -0,0 +1,153 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copied from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+
+from .util.misc import NestedTensor
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+class PositionEmbeddingSineHW(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+ def __init__(self, num_pos_feats=64, temperatureH=10000, temperatureW=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperatureH = temperatureH
+ self.temperatureW = temperatureW
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+
+ # import ipdb; ipdb.set_trace()
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
+ pos_x = x_embed[:, :, :, None] / dim_tx
+
+ dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
+ pos_y = y_embed[:, :, :, None] / dim_ty
+
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+
+ # import ipdb; ipdb.set_trace()
+
+ return pos
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = torch.cat([
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
+ return pos
+
+
+def build_position_encoding(args):
+ N_steps = args.hidden_dim // 2
+ if args.position_embedding in ('v2', 'sine'):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSineHW(
+ N_steps,
+ temperatureH=args.pe_temperatureH,
+ temperatureW=args.pe_temperatureW,
+ normalize=True
+ )
+ elif args.position_embedding in ('v3', 'learned'):
+ position_embedding = PositionEmbeddingLearned(N_steps)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ return position_embedding
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/segmentation.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bab139dd7937f08bd06036b46f1b912dbf03a13
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/segmentation.py
@@ -0,0 +1,375 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Conditional DETR
+# Copyright (c) 2021 Microsoft. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Copied from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# ------------------------------------------------------------------------
+
+"""
+This file provides the definition of the convolutional heads used to predict masks, as well as the losses
+"""
+import io
+from collections import defaultdict
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from PIL import Image
+
+from .util import box_ops
+from .util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list
+
+try:
+ from panopticapi.utils import id2rgb, rgb2id
+except ImportError:
+ pass
+
+
+class DETRsegm(nn.Module):
+ def __init__(self, detr, freeze_detr=False):
+ super().__init__()
+ self.detr = detr
+
+ if freeze_detr:
+ for p in self.parameters():
+ p.requires_grad_(False)
+
+ hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead
+ self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0)
+ self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim)
+
+ def forward(self, samples: NestedTensor):
+ if isinstance(samples, (list, torch.Tensor)):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, pos = self.detr.backbone(samples)
+
+ bs = features[-1].tensors.shape[0]
+
+ src, mask = features[-1].decompose()
+ assert mask is not None
+ src_proj = self.detr.input_proj(src)
+ hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])
+
+ outputs_class = self.detr.class_embed(hs)
+ outputs_coord = self.detr.bbox_embed(hs).sigmoid()
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
+ if self.detr.aux_loss:
+ out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord)
+
+ # FIXME h_boxes takes the last one computed, keep this in mind
+ bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
+
+ seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])
+ outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
+
+ out["pred_masks"] = outputs_seg_masks
+ return out
+
+
+def _expand(tensor, length: int):
+ return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
+
+
+class MaskHeadSmallConv(nn.Module):
+ """
+ Simple convolutional head, using group norm.
+ Upsampling is done using a FPN approach
+ """
+
+ def __init__(self, dim, fpn_dims, context_dim):
+ super().__init__()
+
+ inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
+ self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
+ self.gn1 = torch.nn.GroupNorm(8, dim)
+ self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
+ self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
+ self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
+ self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
+ self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
+ self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
+ self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
+ self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
+ self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1)
+
+ self.dim = dim
+
+ self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
+ self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
+ self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_uniform_(m.weight, a=1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
+ x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
+
+ x = self.lay1(x)
+ x = self.gn1(x)
+ x = F.relu(x)
+ x = self.lay2(x)
+ x = self.gn2(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter1(fpns[0])
+ if cur_fpn.size(0) != x.size(0):
+ cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
+ x = self.lay3(x)
+ x = self.gn3(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter2(fpns[1])
+ if cur_fpn.size(0) != x.size(0):
+ cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
+ x = self.lay4(x)
+ x = self.gn4(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter3(fpns[2])
+ if cur_fpn.size(0) != x.size(0):
+ cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
+ x = self.lay5(x)
+ x = self.gn5(x)
+ x = F.relu(x)
+
+ x = self.out_lay(x)
+ return x
+
+
+class MHAttentionMap(nn.Module):
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
+
+ def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):
+ super().__init__()
+ self.num_heads = num_heads
+ self.hidden_dim = hidden_dim
+ self.dropout = nn.Dropout(dropout)
+
+ self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
+ self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
+
+ nn.init.zeros_(self.k_linear.bias)
+ nn.init.zeros_(self.q_linear.bias)
+ nn.init.xavier_uniform_(self.k_linear.weight)
+ nn.init.xavier_uniform_(self.q_linear.weight)
+ self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
+
+ def forward(self, q, k, mask: Optional[Tensor] = None):
+ q = self.q_linear(q)
+ k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
+ qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
+ kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
+ weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
+
+ if mask is not None:
+ weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
+ weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size())
+ weights = self.dropout(weights)
+ return weights
+
+
+def dice_loss(inputs, targets, num_boxes):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ """
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ numerator = 2 * (inputs * targets).sum(1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss.sum() / num_boxes
+
+
+def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ Returns:
+ Loss tensor
+ """
+ prob = inputs.sigmoid()
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ return loss.mean(1).sum() / num_boxes
+
+
+class PostProcessSegm(nn.Module):
+ def __init__(self, threshold=0.5):
+ super().__init__()
+ self.threshold = threshold
+
+ @torch.no_grad()
+ def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
+ assert len(orig_target_sizes) == len(max_target_sizes)
+ max_h, max_w = max_target_sizes.max(0)[0].tolist()
+ outputs_masks = outputs["pred_masks"].squeeze(2)
+ outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
+ outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
+
+ for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
+ img_h, img_w = t[0], t[1]
+ results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
+ results[i]["masks"] = F.interpolate(
+ results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
+ ).byte()
+
+ return results
+
+
+class PostProcessPanoptic(nn.Module):
+ """This class converts the output of the model to the final panoptic result, in the format expected by the
+ coco panoptic API """
+
+ def __init__(self, is_thing_map, threshold=0.85):
+ """
+ Parameters:
+ is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether
+ the class is a thing (True) or a stuff (False) class
+ threshold: confidence threshold: segments with confidence lower than this will be deleted
+ """
+ super().__init__()
+ self.threshold = threshold
+ self.is_thing_map = is_thing_map
+
+ def forward(self, outputs, processed_sizes, target_sizes=None):
+ """ This function computes the panoptic prediction from the model's predictions.
+ Parameters:
+ outputs: This is a dict coming directly from the model. See the model doc for the content.
+ processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the
+ model, ie the size after data augmentation but before batching.
+ target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size
+ of each prediction. If left to None, it will default to the processed_sizes
+ """
+ if target_sizes is None:
+ target_sizes = processed_sizes
+ assert len(processed_sizes) == len(target_sizes)
+ out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"]
+ assert len(out_logits) == len(raw_masks) == len(target_sizes)
+ preds = []
+
+ def to_tuple(tup):
+ if isinstance(tup, tuple):
+ return tup
+ return tuple(tup.cpu().tolist())
+
+ for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
+ out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
+ ):
+ # we filter empty queries and detection below threshold
+ scores, labels = cur_logits.softmax(-1).max(-1)
+ keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold)
+ cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
+ cur_scores = cur_scores[keep]
+ cur_classes = cur_classes[keep]
+ cur_masks = cur_masks[keep]
+ cur_masks = interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
+ cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep])
+
+ h, w = cur_masks.shape[-2:]
+ assert len(cur_boxes) == len(cur_classes)
+
+ # It may be that we have several predicted masks for the same stuff class.
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
+ cur_masks = cur_masks.flatten(1)
+ stuff_equiv_classes = defaultdict(lambda: [])
+ for k, label in enumerate(cur_classes):
+ if not self.is_thing_map[label.item()]:
+ stuff_equiv_classes[label.item()].append(k)
+
+ def get_ids_area(masks, scores, dedup=False):
+ # This helper function creates the final panoptic segmentation image
+ # It also returns the area of the masks that appears on the image
+
+ m_id = masks.transpose(0, 1).softmax(-1)
+
+ if m_id.shape[-1] == 0:
+ # We didn't detect any mask :(
+ m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
+ else:
+ m_id = m_id.argmax(-1).view(h, w)
+
+ if dedup:
+ # Merge the masks corresponding to the same stuff class
+ for equiv in stuff_equiv_classes.values():
+ if len(equiv) > 1:
+ for eq_id in equiv:
+ m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
+
+ final_h, final_w = to_tuple(target_size)
+
+ seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
+ seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)
+
+ np_seg_img = (
+ torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy()
+ )
+ m_id = torch.from_numpy(rgb2id(np_seg_img))
+
+ area = []
+ for i in range(len(scores)):
+ area.append(m_id.eq(i).sum().item())
+ return area, seg_img
+
+ area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
+ if cur_classes.numel() > 0:
+ # We know filter empty masks as long as we find some
+ while True:
+ filtered_small = torch.as_tensor(
+ [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
+ )
+ if filtered_small.any().item():
+ cur_scores = cur_scores[~filtered_small]
+ cur_classes = cur_classes[~filtered_small]
+ cur_masks = cur_masks[~filtered_small]
+ area, seg_img = get_ids_area(cur_masks, cur_scores)
+ else:
+ break
+
+ else:
+ cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
+
+ segments_info = []
+ for i, a in enumerate(area):
+ cat = cur_classes[i].item()
+ segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a})
+ del cur_classes
+
+ with io.BytesIO() as out:
+ seg_img.save(out, format="PNG")
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
+ preds.append(predictions)
+ return preds
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/swin_transformer.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/swin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9fb7c6d4a2e2c2a4ea6035dbffa90796d582610
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/swin_transformer.py
@@ -0,0 +1,729 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# --------------------------------------------------------
+# modified from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from .util.misc import NestedTensor
+
+
+class Mlp(nn.Module):
+ """ Multilayer perceptron."""
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """ Forward function.
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ """ Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ """ Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ mask_matrix: Attention mask for cyclic shift.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """ Patch Merging Layer
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """ Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """ Forward function.
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class SwinTransformer(nn.Module):
+ """ Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ dilation (bool): if True, the output size if 16x downsample, ow 32x downsample.
+ """
+
+ def __init__(self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ dilation=False,
+ use_checkpoint=False):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.dilation = dilation
+
+ if use_checkpoint:
+ print("use_checkpoint!!!!!!!!!!!!!!!!!!!!!!!!")
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ # prepare downsample list
+ downsamplelist = [PatchMerging for i in range(self.num_layers)]
+ downsamplelist[-1] = None
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ if self.dilation:
+ downsamplelist[-2] = None
+ num_features[-1] = int(embed_dim * 2 ** (self.num_layers - 1)) // 2
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ # dim=int(embed_dim * 2 ** i_layer),
+ dim=num_features[i_layer],
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ downsample=downsamplelist[i_layer],
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ # num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ # def init_weights(self, pretrained=None):
+ # """Initialize the weights in backbone.
+ # Args:
+ # pretrained (str, optional): Path to pre-trained weights.
+ # Defaults to None.
+ # """
+
+ # def _init_weights(m):
+ # if isinstance(m, nn.Linear):
+ # trunc_normal_(m.weight, std=.02)
+ # if isinstance(m, nn.Linear) and m.bias is not None:
+ # nn.init.constant_(m.bias, 0)
+ # elif isinstance(m, nn.LayerNorm):
+ # nn.init.constant_(m.bias, 0)
+ # nn.init.constant_(m.weight, 1.0)
+
+ # if isinstance(pretrained, str):
+ # self.apply(_init_weights)
+ # logger = get_root_logger()
+ # load_checkpoint(self, pretrained, strict=False, logger=logger)
+ # elif pretrained is None:
+ # self.apply(_init_weights)
+ # else:
+ # raise TypeError('pretrained must be a str or None')
+
+
+ def forward_raw(self, x):
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+ # import ipdb; ipdb.set_trace()
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+ # in:
+ # torch.Size([2, 3, 1024, 1024])
+ # outs:
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
+ return tuple(outs)
+
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+ # in:
+ # torch.Size([2, 3, 1024, 1024])
+ # out:
+ # [torch.Size([2, 192, 256, 256]), torch.Size([2, 384, 128, 128]), \
+ # torch.Size([2, 768, 64, 64]), torch.Size([2, 1536, 32, 32])]
+
+ # collect for nesttensors
+ outs_dict = {}
+ for idx, out_i in enumerate(outs):
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=out_i.shape[-2:]).to(torch.bool)[0]
+ outs_dict[idx] = NestedTensor(out_i, mask)
+
+ return outs_dict
+
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
+
+
+
+def build_swin_transformer(modelname, pretrain_img_size, **kw):
+ assert modelname in ['swin_T_224_1k', 'swin_B_224_22k', 'swin_B_384_22k', 'swin_L_224_22k', 'swin_L_384_22k']
+
+ model_para_dict = {
+ 'swin_T_224_1k': dict(
+ embed_dim=96,
+ depths=[ 2, 2, 6, 2 ],
+ num_heads=[ 3, 6, 12, 24],
+ window_size=7
+ ),
+ 'swin_B_224_22k': dict(
+ embed_dim=128,
+ depths=[ 2, 2, 18, 2 ],
+ num_heads=[ 4, 8, 16, 32 ],
+ window_size=7
+ ),
+ 'swin_B_384_22k': dict(
+ embed_dim=128,
+ depths=[ 2, 2, 18, 2 ],
+ num_heads=[ 4, 8, 16, 32 ],
+ window_size=12
+ ),
+ 'swin_L_224_22k': dict(
+ embed_dim=192,
+ depths=[ 2, 2, 18, 2 ],
+ num_heads=[ 6, 12, 24, 48 ],
+ window_size=7
+ ),
+ 'swin_L_384_22k': dict(
+ embed_dim=192,
+ depths=[ 2, 2, 18, 2 ],
+ num_heads=[ 6, 12, 24, 48 ],
+ window_size=12
+ ),
+ }
+ kw_cgf = model_para_dict[modelname]
+ kw_cgf.update(kw)
+ model = SwinTransformer(pretrain_img_size=pretrain_img_size, **kw_cgf)
+ return model
+
+if __name__ == "__main__":
+ model = build_swin_transformer('swin_L_384_22k', 384, dilation=True)
+ x = torch.rand(2, 3, 1024, 1024)
+ y = model.forward_raw(x)
+ import ipdb; ipdb.set_trace()
+ x = torch.rand(2, 3, 384, 384)
+ y = model.forward_raw(x)
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/transformer_deformable.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/transformer_deformable.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab3535447670ef33f4dbe0c127c9b1aaf098f1eb
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/transformer_deformable.py
@@ -0,0 +1,670 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+import copy
+import os
+from typing import Optional, List
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+
+from .util.misc import inverse_sigmoid
+from projects.instance_segment_anything.ops.modules import MSDeformAttn
+
+from .utils import sigmoid_focal_loss, MLP, _get_activation_fn, gen_sineembed_for_position
+
+class DeformableTransformer(nn.Module):
+ def __init__(self, d_model=256, nhead=8,
+ num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
+ activation="relu", return_intermediate_dec=False,
+ num_feature_levels=4, dec_n_points=4, enc_n_points=4,
+ two_stage=False, two_stage_num_proposals=300,
+ use_dab=False, high_dim_query_update=False, no_sine_embed=False):
+ super().__init__()
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.two_stage = two_stage
+ self.two_stage_num_proposals = two_stage_num_proposals
+ self.use_dab = use_dab
+
+ encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, enc_n_points)
+ self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
+
+ decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
+ dropout, activation,
+ num_feature_levels, nhead, dec_n_points)
+ self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec,
+ use_dab=use_dab, d_model=d_model, high_dim_query_update=high_dim_query_update, no_sine_embed=no_sine_embed)
+
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+
+ if two_stage:
+ self.enc_output = nn.Linear(d_model, d_model)
+ self.enc_output_norm = nn.LayerNorm(d_model)
+ self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
+ self.pos_trans_norm = nn.LayerNorm(d_model * 2)
+ else:
+ if not self.use_dab:
+ self.reference_points = nn.Linear(d_model, 2)
+
+ self.high_dim_query_update = high_dim_query_update
+ if high_dim_query_update:
+ assert not self.use_dab, "use_dab must be True"
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+ if not self.two_stage and not self.use_dab:
+ xavier_uniform_(self.reference_points.weight.data, gain=1.0)
+ constant_(self.reference_points.bias.data, 0.)
+ normal_(self.level_embed)
+
+ def get_proposal_pos_embed(self, proposals):
+ num_pos_feats = 128
+ temperature = 10000
+ scale = 2 * math.pi
+
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
+ dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
+ # N, L, 4
+ proposals = proposals.sigmoid() * scale
+ # N, L, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # N, L, 4, 64, 2
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
+ return pos
+
+ def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
+ N_, S_, C_ = memory.shape
+ base_scale = 4.0
+ proposals = []
+ _cur = 0
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
+ wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
+ proposals.append(proposal)
+ _cur += (H_ * W_)
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
+ output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
+ return output_memory, output_proposals
+
+ def get_valid_ratio(self, mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def forward(self, srcs, masks, pos_embeds, query_embed=None):
+ """
+ Input:
+ - srcs: List([bs, c, h, w])
+ - masks: List([bs, h, w])
+ """
+ assert self.two_stage or query_embed is not None
+
+ # prepare input for encoder
+ src_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
+ mask = mask.flatten(1) # bs, hw
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ src_flatten.append(src)
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
+ mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+ # encoder
+ memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
+ # import ipdb; ipdb.set_trace()
+
+ # prepare input for decoder
+ bs, _, c = memory.shape
+ if self.two_stage:
+ output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
+
+ # hack implementation for two-stage Deformable DETR
+ enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
+ enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals
+
+ topk = self.two_stage_num_proposals
+ topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
+ topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
+ topk_coords_unact = topk_coords_unact.detach()
+ reference_points = topk_coords_unact.sigmoid()
+ init_reference_out = reference_points
+ pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
+ query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
+ elif self.use_dab:
+ reference_points = query_embed[..., self.d_model:].sigmoid()
+ tgt = query_embed[..., :self.d_model]
+ tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
+ init_reference_out = reference_points
+ else:
+ query_embed, tgt = torch.split(query_embed, c, dim=1)
+ query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
+ tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
+ reference_points = self.reference_points(query_embed).sigmoid()
+ # bs, num_quires, 2
+ init_reference_out = reference_points
+
+ # decoder
+ # import ipdb; ipdb.set_trace()
+ hs, inter_references = self.decoder(tgt, reference_points, memory,
+ spatial_shapes, level_start_index, valid_ratios,
+ query_pos=query_embed if not self.use_dab else None,
+ src_padding_mask=mask_flatten)
+
+ inter_references_out = inter_references
+ if self.two_stage:
+ return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
+ return hs, init_reference_out, inter_references_out, None, None
+
+
+class DeformableTransformerEncoderLayer(nn.Module):
+ def __init__(self,
+ d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4,
+ add_channel_attention=False,
+ use_deformable_box_attn=False,
+ box_attn_type='roi_align',
+ ):
+ super().__init__()
+
+ # self attention
+ if use_deformable_box_attn:
+ self.self_attn = MSDeformableBoxAttention(d_model, n_levels, n_heads, n_boxes=n_points, used_func=box_attn_type)
+ else:
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation, d_model=d_ffn)
+ self.dropout2 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout3 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # channel attention
+ self.add_channel_attention = add_channel_attention
+ if add_channel_attention:
+ self.activ_channel = _get_activation_fn('dyrelu', d_model=d_model)
+ self.norm_channel = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, src):
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+ src = src + self.dropout3(src2)
+ src = self.norm2(src)
+ return src
+
+ def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, key_padding_mask=None):
+ # self attention
+ # import ipdb; ipdb.set_trace()
+ src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, key_padding_mask)
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ # ffn
+ src = self.forward_ffn(src)
+
+ # channel attn
+ if self.add_channel_attention:
+ src = self.norm_channel(src + self.activ_channel(src))
+
+ return src
+
+
+class DeformableTransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super().__init__()
+ if num_layers > 0:
+ self.layers = _get_clones(encoder_layer, num_layers)
+ else:
+ self.layers = []
+ del encoder_layer
+ self.num_layers = num_layers
+ self.norm = norm
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+
+ ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
+ """
+ Input:
+ - src: [bs, sum(hi*wi), 256]
+ - spatial_shapes: h,w of each level [num_level, 2]
+ - level_start_index: [num_level] start point of level in sum(hi*wi).
+ - valid_ratios: [bs, num_level, 2]
+ - pos: pos embed for src. [bs, sum(hi*wi), 256]
+ - padding_mask: [bs, sum(hi*wi)]
+ Intermedia:
+ - reference_points: [bs, sum(hi*wi), num_lebel, 2]
+ """
+ output = src
+ # bs, sum(hi*wi), 256
+ # import ipdb; ipdb.set_trace()
+ if self.num_layers > 0:
+ reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
+ for _, layer in enumerate(self.layers):
+ output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class DeformableTransformerDecoderLayer(nn.Module):
+ def __init__(self, d_model=256, d_ffn=1024,
+ dropout=0.1, activation="relu",
+ n_levels=4, n_heads=8, n_points=4,
+ use_deformable_box_attn=False,
+ box_attn_type='roi_align',
+ key_aware_type=None,
+ decoder_sa_type='ca',
+ module_seq=['sa', 'ca', 'ffn'],
+ ):
+ super().__init__()
+ self.module_seq = module_seq
+ assert sorted(module_seq) == ['ca', 'ffn', 'sa']
+
+ # cross attention
+ # self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ if use_deformable_box_attn:
+ self.cross_attn = MSDeformableBoxAttention(d_model, n_levels, n_heads, n_boxes=n_points, used_func=box_attn_type)
+ else:
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # self attention
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1)
+ self.dropout3 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout4 = nn.Dropout(dropout)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ self.key_aware_type = key_aware_type
+ self.key_aware_proj = None
+ self.decoder_sa_type = decoder_sa_type
+ assert decoder_sa_type in ['sa', 'ca_label', 'ca_content']
+
+ if decoder_sa_type == 'ca_content':
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+
+
+
+
+ def rm_self_attn_modules(self):
+ self.self_attn = None
+ self.dropout2 = None
+ self.norm2 = None
+
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ def forward_sa(self,
+ # for tgt
+ tgt: Optional[Tensor], # nq, bs, d_model
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
+
+ # for memory
+ memory: Optional[Tensor] = None, # hw, bs, d_model
+ memory_key_padding_mask: Optional[Tensor] = None,
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ memory_pos: Optional[Tensor] = None, # pos for memory
+
+ # sa
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
+ ):
+ # self attention
+ if self.self_attn is not None:
+ # import ipdb; ipdb.set_trace()
+ if self.decoder_sa_type == 'sa':
+ q = k = self.with_pos_embed(tgt, tgt_query_pos)
+ tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ elif self.decoder_sa_type == 'ca_label':
+ # import ipdb; ipdb.set_trace()
+ # q = self.with_pos_embed(tgt, tgt_query_pos)
+ bs = tgt.shape[1]
+ k = v = self.label_embedding.weight[:, None, :].repeat(1, bs, 1)
+ tgt2 = self.self_attn(tgt, k, v, attn_mask=self_attn_mask)[0]
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ elif self.decoder_sa_type == 'ca_content':
+ tgt2 = self.self_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
+ tgt_reference_points.transpose(0, 1).contiguous(),
+ memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+ else:
+ raise NotImplementedError("Unknown decoder_sa_type {}".format(self.decoder_sa_type))
+
+ return tgt
+
+ def forward_ca(self,
+ # for tgt
+ tgt: Optional[Tensor], # nq, bs, d_model
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
+
+ # for memory
+ memory: Optional[Tensor] = None, # hw, bs, d_model
+ memory_key_padding_mask: Optional[Tensor] = None,
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ memory_pos: Optional[Tensor] = None, # pos for memory
+
+ # sa
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
+ ):
+ # cross attention
+ # import ipdb; ipdb.set_trace()
+ if self.key_aware_type is not None:
+
+ if self.key_aware_type == 'mean':
+ tgt = tgt + memory.mean(0, keepdim=True)
+ elif self.key_aware_type == 'proj_mean':
+ tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True)
+ else:
+ raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type))
+ tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
+ tgt_reference_points.transpose(0, 1).contiguous(),
+ memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ return tgt
+
+ def forward(self,
+ # for tgt
+ tgt: Optional[Tensor], # nq, bs, d_model
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
+
+ # for memory
+ memory: Optional[Tensor] = None, # hw, bs, d_model
+ memory_key_padding_mask: Optional[Tensor] = None,
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ memory_pos: Optional[Tensor] = None, # pos for memory
+
+ # sa
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
+ ):
+
+ for funcname in self.module_seq:
+ if funcname == 'ffn':
+ tgt = self.forward_ffn(tgt)
+ elif funcname == 'ca':
+ tgt = self.forward_ca(tgt, tgt_query_pos, tgt_query_sine_embed, \
+ tgt_key_padding_mask, tgt_reference_points, \
+ memory, memory_key_padding_mask, memory_level_start_index, \
+ memory_spatial_shapes, memory_pos, self_attn_mask, cross_attn_mask)
+ elif funcname == 'sa':
+ tgt = self.forward_sa(tgt, tgt_query_pos, tgt_query_sine_embed, \
+ tgt_key_padding_mask, tgt_reference_points, \
+ memory, memory_key_padding_mask, memory_level_start_index, \
+ memory_spatial_shapes, memory_pos, self_attn_mask, cross_attn_mask)
+ else:
+ raise ValueError('unknown funcname {}'.format(funcname))
+
+ return tgt
+
+ # def forward(self,
+ # # for tgt
+ # tgt: Optional[Tensor], # nq, bs, d_model
+ # tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
+ # tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
+ # tgt_key_padding_mask: Optional[Tensor] = None,
+ # tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
+
+ # # for memory
+ # memory: Optional[Tensor] = None, # hw, bs, d_model
+ # memory_key_padding_mask: Optional[Tensor] = None,
+ # memory_level_start_index: Optional[Tensor] = None, # num_levels
+ # memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
+ # memory_pos: Optional[Tensor] = None, # pos for memory
+
+ # # sa
+ # self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
+ # cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
+ # ):
+ # """
+ # Input:
+ # - tgt/tgt_query_pos: nq, bs, d_model
+ # -
+ # """
+ # assert cross_attn_mask is None
+
+ # # self attention
+ # if self.self_attn is not None:
+ # # import ipdb; ipdb.set_trace()
+ # if self.decoder_sa_type == 'sa':
+ # q = k = self.with_pos_embed(tgt, tgt_query_pos)
+ # tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0]
+ # tgt = tgt + self.dropout2(tgt2)
+ # tgt = self.norm2(tgt)
+ # elif self.decoder_sa_type == 'ca_label':
+ # # import ipdb; ipdb.set_trace()
+ # # q = self.with_pos_embed(tgt, tgt_query_pos)
+ # bs = tgt.shape[1]
+ # k = v = self.label_embedding.weight[:, None, :].repeat(1, bs, 1)
+ # tgt2 = self.self_attn(tgt, k, v, attn_mask=self_attn_mask)[0]
+ # tgt = tgt + self.dropout2(tgt2)
+ # tgt = self.norm2(tgt)
+ # elif self.decoder_sa_type == 'ca_content':
+ # tgt2 = self.self_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
+ # tgt_reference_points.transpose(0, 1).contiguous(),
+ # memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
+ # tgt = tgt + self.dropout2(tgt2)
+ # tgt = self.norm2(tgt)
+ # else:
+ # raise NotImplementedError("Unknown decoder_sa_type {}".format(self.decoder_sa_type))
+
+
+ # # cross attention
+ # # import ipdb; ipdb.set_trace()
+ # if self.key_aware_type is not None:
+ # if self.key_aware_type == 'mean':
+ # tgt = tgt + memory.mean(0, keepdim=True)
+ # elif self.key_aware_type == 'proj_mean':
+ # tgt = tgt + self.key_aware_proj(memory).mean(0, keepdim=True)
+ # else:
+ # raise NotImplementedError("Unknown key_aware_type: {}".format(self.key_aware_type))
+ # tgt2 = self.cross_attn(self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1),
+ # tgt_reference_points.transpose(0, 1).contiguous(),
+ # memory.transpose(0, 1), memory_spatial_shapes, memory_level_start_index, memory_key_padding_mask).transpose(0, 1)
+ # tgt = tgt + self.dropout1(tgt2)
+ # tgt = self.norm1(tgt)
+
+ # # ffn
+ # tgt = self.forward_ffn(tgt)
+
+ # return tgt
+
+
+class DeformableTransformerDecoder(nn.Module):
+ def __init__(self, decoder_layer, num_layers, return_intermediate=False, use_dab=False, d_model=256, query_dim=4):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.return_intermediate = return_intermediate
+ assert return_intermediate
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
+ self.bbox_embed = None
+ self.class_embed = None
+ self.use_dab = use_dab
+ self.d_model = d_model
+ self.query_dim = query_dim
+ if use_dab:
+ self.query_scale = MLP(d_model, d_model, d_model, 2)
+ self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2)
+
+
+ def forward(self, tgt, reference_points, src, src_spatial_shapes,
+ src_level_start_index, src_valid_ratios,
+ query_pos=None, src_padding_mask=None):
+ output = tgt
+ if self.use_dab:
+ assert query_pos is None
+
+ intermediate = []
+ intermediate_reference_points = [reference_points]
+ for layer_id, layer in enumerate(self.layers):
+ # import ipdb; ipdb.set_trace()
+ if reference_points.shape[-1] == 4:
+ reference_points_input = reference_points[:, :, None] \
+ * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] # bs, nq, 4, 4
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
+
+ if self.use_dab:
+ # import ipdb; ipdb.set_trace()
+ query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # bs, nq, 256*2
+ raw_query_pos = self.ref_point_head(query_sine_embed) # bs, nq, 256
+ pos_scale = self.query_scale(output) if layer_id != 0 else 1
+ query_pos = pos_scale * raw_query_pos
+
+ output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
+
+ # hack implementation for iterative bounding box refinement
+ if self.bbox_embed is not None:
+ box_holder = self.bbox_embed(output)
+ box_holder[..., :self.query_dim] += inverse_sigmoid(reference_points)
+ new_reference_points = box_holder[..., :self.query_dim].sigmoid()
+ reference_points = new_reference_points.detach()
+ if layer_id != self.num_layers - 1:
+ intermediate_reference_points.append(new_reference_points)
+
+ intermediate.append(output)
+
+ return torch.stack(intermediate), torch.stack(intermediate_reference_points)
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def build_deforamble_transformer(args):
+ return DeformableTransformer(
+ d_model=args.hidden_dim,
+ nhead=args.nheads,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ dim_feedforward=args.dim_feedforward,
+ dropout=args.dropout,
+ activation="relu",
+ return_intermediate_dec=True,
+ num_feature_levels=args.ddetr_num_feature_levels,
+ dec_n_points=args.ddetr_dec_n_points,
+ enc_n_points=args.ddetr_enc_n_points,
+ two_stage=args.ddetr_two_stage,
+ two_stage_num_proposals=args.num_queries,
+ use_dab=args.ddetr_use_dab,
+ high_dim_query_update=args.ddetr_high_dim_query_update,
+ no_sine_embed=args.ddetr_no_sine_embed)
+
+
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/__init__.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..168f9979a4623806934b0ff1102ac166704e7dec
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/box_loss.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/box_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf7c7e527723cf3e0d58f5c944e69e264ecd392c
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/box_loss.py
@@ -0,0 +1,113 @@
+# borrow from https://github.com/Zzh-tju/CIoU/blob/master/layers/modules/multibox_loss.py
+
+import torch, math
+
+
+
+def ciou(bboxes1, bboxes2):
+ bboxes1 = torch.sigmoid(bboxes1)
+ bboxes2 = torch.sigmoid(bboxes2)
+ rows = bboxes1.shape[0]
+ cols = bboxes2.shape[0]
+ cious = torch.zeros((rows, cols))
+ if rows * cols == 0:
+ return cious
+ exchange = False
+ if bboxes1.shape[0] > bboxes2.shape[0]:
+ bboxes1, bboxes2 = bboxes2, bboxes1
+ cious = torch.zeros((cols, rows))
+ exchange = True
+ w1 = torch.exp(bboxes1[:, 2])
+ h1 = torch.exp(bboxes1[:, 3])
+ w2 = torch.exp(bboxes2[:, 2])
+ h2 = torch.exp(bboxes2[:, 3])
+ area1 = w1 * h1
+ area2 = w2 * h2
+ center_x1 = bboxes1[:, 0]
+ center_y1 = bboxes1[:, 1]
+ center_x2 = bboxes2[:, 0]
+ center_y2 = bboxes2[:, 1]
+
+ inter_l = torch.max(center_x1 - w1 / 2,center_x2 - w2 / 2)
+ inter_r = torch.min(center_x1 + w1 / 2,center_x2 + w2 / 2)
+ inter_t = torch.max(center_y1 - h1 / 2,center_y2 - h2 / 2)
+ inter_b = torch.min(center_y1 + h1 / 2,center_y2 + h2 / 2)
+ inter_area = torch.clamp((inter_r - inter_l),min=0) * torch.clamp((inter_b - inter_t),min=0)
+
+ c_l = torch.min(center_x1 - w1 / 2,center_x2 - w2 / 2)
+ c_r = torch.max(center_x1 + w1 / 2,center_x2 + w2 / 2)
+ c_t = torch.min(center_y1 - h1 / 2,center_y2 - h2 / 2)
+ c_b = torch.max(center_y1 + h1 / 2,center_y2 + h2 / 2)
+
+ inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2
+ c_diag = torch.clamp((c_r - c_l),min=0)**2 + torch.clamp((c_b - c_t),min=0)**2
+
+ union = area1+area2-inter_area
+ u = (inter_diag) / c_diag
+ iou = inter_area / union
+ v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(w2 / h2) - torch.atan(w1 / h1)), 2)
+ with torch.no_grad():
+ S = (iou>0.5).float()
+ alpha= S*v/(1-iou+v)
+ cious = iou - u - alpha * v
+ cious = torch.clamp(cious,min=-1.0,max = 1.0)
+ if exchange:
+ cious = cious.T
+ return 1-cious
+
+def diou(bboxes1, bboxes2):
+ bboxes1 = torch.sigmoid(bboxes1)
+ bboxes2 = torch.sigmoid(bboxes2)
+ rows = bboxes1.shape[0]
+ cols = bboxes2.shape[0]
+ cious = torch.zeros((rows, cols))
+ if rows * cols == 0:
+ return cious
+ exchange = False
+ if bboxes1.shape[0] > bboxes2.shape[0]:
+ bboxes1, bboxes2 = bboxes2, bboxes1
+ cious = torch.zeros((cols, rows))
+ exchange = True
+ w1 = torch.exp(bboxes1[:, 2])
+ h1 = torch.exp(bboxes1[:, 3])
+ w2 = torch.exp(bboxes2[:, 2])
+ h2 = torch.exp(bboxes2[:, 3])
+ area1 = w1 * h1
+ area2 = w2 * h2
+ center_x1 = bboxes1[:, 0]
+ center_y1 = bboxes1[:, 1]
+ center_x2 = bboxes2[:, 0]
+ center_y2 = bboxes2[:, 1]
+
+ inter_l = torch.max(center_x1 - w1 / 2,center_x2 - w2 / 2)
+ inter_r = torch.min(center_x1 + w1 / 2,center_x2 + w2 / 2)
+ inter_t = torch.max(center_y1 - h1 / 2,center_y2 - h2 / 2)
+ inter_b = torch.min(center_y1 + h1 / 2,center_y2 + h2 / 2)
+ inter_area = torch.clamp((inter_r - inter_l),min=0) * torch.clamp((inter_b - inter_t),min=0)
+
+ c_l = torch.min(center_x1 - w1 / 2,center_x2 - w2 / 2)
+ c_r = torch.max(center_x1 + w1 / 2,center_x2 + w2 / 2)
+ c_t = torch.min(center_y1 - h1 / 2,center_y2 - h2 / 2)
+ c_b = torch.max(center_y1 + h1 / 2,center_y2 + h2 / 2)
+
+ inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2
+ c_diag = torch.clamp((c_r - c_l),min=0)**2 + torch.clamp((c_b - c_t),min=0)**2
+
+ union = area1+area2-inter_area
+ u = (inter_diag) / c_diag
+ iou = inter_area / union
+ dious = iou - u
+ dious = torch.clamp(dious,min=-1.0,max = 1.0)
+ if exchange:
+ dious = dious.T
+ return 1-dious
+
+
+if __name__ == "__main__":
+ x = torch.rand(10, 4)
+ y = torch.rand(10,4)
+ import ipdb;ipdb.set_trace()
+ cxy = ciou(x, y)
+ dxy = diou(x, y)
+ print(cxy.shape, dxy.shape)
+ import ipdb; ipdb.set_trace()
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/box_ops.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/box_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d8e29aedcb747393e7348d0c6aaf8a6dbe5d7c9
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/box_ops.py
@@ -0,0 +1,139 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Utilities for bounding box manipulation and GIoU.
+"""
+import torch, os
+from torchvision.ops.boxes import box_area
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+def box_xyxy_to_cxcywh(x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
+ (x1 - x0), (y1 - y0)]
+ return torch.stack(b, dim=-1)
+
+
+# modified from torchvision to also return the union
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ # import ipdb; ipdb.set_trace()
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / (union + 1e-6)
+ return iou, union
+
+
+def generalized_box_iou(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/
+
+ The boxes should be in [x0, y0, x1, y1] format
+
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
+ and M = len(boxes2)
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+ # except:
+ # import ipdb; ipdb.set_trace()
+ iou, union = box_iou(boxes1, boxes2)
+
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ area = wh[:, :, 0] * wh[:, :, 1]
+
+ return iou - (area - union) / (area + 1e-6)
+
+
+
+# modified from torchvision to also return the union
+def box_iou_pairwise(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ lt = torch.max(boxes1[:, :2], boxes2[:, :2]) # [N,2]
+ rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) # [N,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,2]
+ inter = wh[:, 0] * wh[:, 1] # [N]
+
+ union = area1 + area2 - inter
+
+ iou = inter / union
+ return iou, union
+
+
+def generalized_box_iou_pairwise(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/
+
+ Input:
+ - boxes1, boxes2: N,4
+ Output:
+ - giou: N, 4
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+ assert boxes1.shape == boxes2.shape
+ iou, union = box_iou_pairwise(boxes1, boxes2) # N, 4
+
+ lt = torch.min(boxes1[:, :2], boxes2[:, :2])
+ rb = torch.max(boxes1[:, 2:], boxes2[:, 2:])
+
+ wh = (rb - lt).clamp(min=0) # [N,2]
+ area = wh[:, 0] * wh[:, 1]
+
+ return iou - (area - union) / area
+
+def masks_to_boxes(masks):
+ """Compute the bounding boxes around the provided masks
+
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
+
+ Returns a [N, 4] tensors, with the boxes in xyxy format
+ """
+ if masks.numel() == 0:
+ return torch.zeros((0, 4), device=masks.device)
+
+ h, w = masks.shape[-2:]
+
+ y = torch.arange(0, h, dtype=torch.float)
+ x = torch.arange(0, w, dtype=torch.float)
+ y, x = torch.meshgrid(y, x)
+
+ x_mask = (masks * x.unsqueeze(0))
+ x_max = x_mask.flatten(1).max(-1)[0]
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ y_mask = (masks * y.unsqueeze(0))
+ y_max = y_mask.flatten(1).max(-1)[0]
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
+
+if __name__ == '__main__':
+ x = torch.rand(5, 4)
+ y = torch.rand(3, 4)
+ iou, union = box_iou(x, y)
+ import ipdb; ipdb.set_trace()
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/coco_id2name.json b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/coco_id2name.json
new file mode 100644
index 0000000000000000000000000000000000000000..52ad2e96e7974593e13ee56cd3eff0d22c59f96d
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/coco_id2name.json
@@ -0,0 +1 @@
+{"1": "person", "2": "bicycle", "3": "car", "4": "motorcycle", "5": "airplane", "6": "bus", "7": "train", "8": "truck", "9": "boat", "10": "traffic light", "11": "fire hydrant", "13": "stop sign", "14": "parking meter", "15": "bench", "16": "bird", "17": "cat", "18": "dog", "19": "horse", "20": "sheep", "21": "cow", "22": "elephant", "23": "bear", "24": "zebra", "25": "giraffe", "27": "backpack", "28": "umbrella", "31": "handbag", "32": "tie", "33": "suitcase", "34": "frisbee", "35": "skis", "36": "snowboard", "37": "sports ball", "38": "kite", "39": "baseball bat", "40": "baseball glove", "41": "skateboard", "42": "surfboard", "43": "tennis racket", "44": "bottle", "46": "wine glass", "47": "cup", "48": "fork", "49": "knife", "50": "spoon", "51": "bowl", "52": "banana", "53": "apple", "54": "sandwich", "55": "orange", "56": "broccoli", "57": "carrot", "58": "hot dog", "59": "pizza", "60": "donut", "61": "cake", "62": "chair", "63": "couch", "64": "potted plant", "65": "bed", "67": "dining table", "70": "toilet", "72": "tv", "73": "laptop", "74": "mouse", "75": "remote", "76": "keyboard", "77": "cell phone", "78": "microwave", "79": "oven", "80": "toaster", "81": "sink", "82": "refrigerator", "84": "book", "85": "clock", "86": "vase", "87": "scissors", "88": "teddy bear", "89": "hair drier", "90": "toothbrush"}
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/get_param_dicts.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/get_param_dicts.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf42d9f1f61fd5e4f6b50889026c608557042759
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/get_param_dicts.py
@@ -0,0 +1,84 @@
+import json
+import torch
+import torch.nn as nn
+
+
+def match_name_keywords(n: str, name_keywords: list):
+ out = False
+ for b in name_keywords:
+ if b in n:
+ out = True
+ break
+ return out
+
+
+def get_param_dict(args, model_without_ddp: nn.Module):
+ try:
+ param_dict_type = args.param_dict_type
+ except:
+ param_dict_type = 'default'
+ assert param_dict_type in ['default', 'ddetr_in_mmdet', 'large_wd']
+
+ # by default
+ if param_dict_type == 'default':
+ param_dicts = [
+ {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
+ {
+ "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
+ "lr": args.lr_backbone,
+ }
+ ]
+ return param_dicts
+
+ if param_dict_type == 'ddetr_in_mmdet':
+ param_dicts = [
+ {
+ "params":
+ [p for n, p in model_without_ddp.named_parameters()
+ if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
+ "lr": args.lr,
+ },
+ {
+ "params": [p for n, p in model_without_ddp.named_parameters()
+ if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad],
+ "lr": args.lr_backbone,
+ },
+ {
+ "params": [p for n, p in model_without_ddp.named_parameters()
+ if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
+ "lr": args.lr * args.lr_linear_proj_mult,
+ }
+ ]
+ return param_dicts
+
+ if param_dict_type == 'large_wd':
+ param_dicts = [
+ {
+ "params":
+ [p for n, p in model_without_ddp.named_parameters()
+ if not match_name_keywords(n, ['backbone']) and not match_name_keywords(n, ['norm', 'bias']) and p.requires_grad],
+ },
+ {
+ "params": [p for n, p in model_without_ddp.named_parameters()
+ if match_name_keywords(n, ['backbone']) and match_name_keywords(n, ['norm', 'bias']) and p.requires_grad],
+ "lr": args.lr_backbone,
+ "weight_decay": 0.0,
+ },
+ {
+ "params": [p for n, p in model_without_ddp.named_parameters()
+ if match_name_keywords(n, ['backbone']) and not match_name_keywords(n, ['norm', 'bias']) and p.requires_grad],
+ "lr": args.lr_backbone,
+ "weight_decay": args.weight_decay,
+ },
+ {
+ "params":
+ [p for n, p in model_without_ddp.named_parameters()
+ if not match_name_keywords(n, ['backbone']) and match_name_keywords(n, ['norm', 'bias']) and p.requires_grad],
+ "lr": args.lr,
+ "weight_decay": 0.0,
+ }
+ ]
+
+ # print("param_dicts: {}".format(param_dicts))
+
+ return param_dicts
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/logger.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..2052a5c2ba18dee1d8d9237e21ba9d2f49f78f23
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/logger.py
@@ -0,0 +1,95 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import functools
+import logging
+import os
+import sys
+from termcolor import colored
+
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+# so that calling setup_logger multiple times won't add many handlers
+@functools.lru_cache()
+def setup_logger(
+ output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None
+):
+ """
+ Initialize the detectron2 logger and set its verbosity level to "INFO".
+
+ Args:
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ name (str): the root module name of this logger
+
+ Returns:
+ logging.Logger: a logger
+ """
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = False
+
+ if abbrev_name is None:
+ abbrev_name = name
+
+ plain_formatter = logging.Formatter(
+ '[%(asctime)s.%(msecs)03d]: %(message)s',
+ datefmt='%m/%d %H:%M:%S'
+ )
+ # stdout logging: master only
+ if distributed_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ abbrev_name=str(abbrev_name),
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if distributed_rank > 0:
+ filename = filename + f".rank{distributed_rank}"
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+
+ return logger
+
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+ return open(filename, "a")
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/misc.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d2f4d7fdfb303bddef004a6536983914024da89
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/misc.py
@@ -0,0 +1,587 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+import os
+import random
+import subprocess
+import time
+from collections import OrderedDict, defaultdict, deque
+import datetime
+import pickle
+from typing import Optional, List
+
+import json, time
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch import Tensor
+
+import colorsys
+
+# needed due to empty tensor bug in pytorch and torchvision 0.5
+import torchvision
+__torchvision_need_compat_flag = float(torchvision.__version__.split('.')[1]) < 7
+if __torchvision_need_compat_flag:
+ from torchvision.ops import _new_empty_tensor
+ from torchvision.ops.misc import _output_size
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ if d.shape[0] == 0:
+ return 0
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device="cuda")
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+ if local_size != max_size:
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that all processes
+ have the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.all_reduce(values)
+ if average:
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ # print(name, str(meter))
+ # import ipdb;ipdb.set_trace()
+ if meter.count > 0:
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None, logger=None):
+ if logger is None:
+ print_func = print
+ else:
+ print_func = logger.info
+
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}',
+ 'max mem: {memory:.0f}'
+ ])
+ else:
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ])
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ # import ipdb; ipdb.set_trace()
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print_func(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print_func(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print_func('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
+ sha = 'N/A'
+ diff = "clean"
+ branch = 'N/A'
+ try:
+ sha = _run(['git', 'rev-parse', 'HEAD'])
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
+ diff = _run(['git', 'diff-index', 'HEAD'])
+ diff = "has uncommited changes" if diff else "clean"
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+def collate_fn(batch):
+ # import ipdb; ipdb.set_trace()
+ batch = list(zip(*batch))
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
+ return tuple(batch)
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+ if mask == 'auto':
+ self.mask = torch.zeros_like(tensors).to(tensors.device)
+ if self.mask.dim() == 3:
+ self.mask = self.mask.sum(0).to(bool)
+ elif self.mask.dim() == 4:
+ self.mask = self.mask.sum(1).to(bool)
+ else:
+ raise ValueError("tensors dim must be 3 or 4 but {}({})".format(self.tensors.dim(), self.tensors.shape))
+
+ def imgsize(self):
+ res = []
+ for i in range(self.tensors.shape[0]):
+ mask = self.mask[i]
+ maxH = (~mask).sum(0).max()
+ maxW = (~mask).sum(1).max()
+ res.append(torch.Tensor([maxH, maxW]))
+ return res
+
+ def to(self, device):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def to_img_list_single(self, tensor, mask):
+ assert tensor.dim() == 3, "dim of tensor should be 3 but {}".format(tensor.dim())
+ maxH = (~mask).sum(0).max()
+ maxW = (~mask).sum(1).max()
+ img = tensor[:, :maxH, :maxW]
+ return img
+
+ def to_img_list(self):
+ """remove the padding and convert to img list
+
+ Returns:
+ [type]: [description]
+ """
+ if self.tensors.dim() == 3:
+ return self.to_img_list_single(self.tensors, self.mask)
+ else:
+ res = []
+ for i in range(self.tensors.shape[0]):
+ tensor_i = self.tensors[i]
+ mask_i = self.mask[i]
+ res.append(self.to_img_list_single(tensor_i, mask_i))
+ return res
+
+ @property
+ def device(self):
+ return self.tensors.device
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+ @property
+ def shape(self):
+ return {
+ 'tensors.shape': self.tensors.shape,
+ 'mask.shape': self.mask.shape
+ }
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ if torchvision._is_tracing():
+ # nested_tensor_from_tensor_list() does not export well to ONNX
+ # call _onnx_nested_tensor_from_tensor_list() instead
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], :img.shape[2]] = False
+ else:
+ raise ValueError('not supported')
+ return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+ max_size = []
+ for i in range(tensor_list[0].dim()):
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
+ max_size.append(max_size_i)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # m[: img.shape[1], :img.shape[2]] = False
+ # which is not yet supported in onnx
+ padded_imgs = []
+ padded_masks = []
+ for img in tensor_list:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
+ padded_imgs.append(padded_img)
+
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
+ padded_masks.append(padded_mask.to(torch.bool))
+
+ tensor = torch.stack(padded_imgs)
+ mask = torch.stack(padded_masks)
+
+ return NestedTensor(tensor, mask=mask)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != '': # 'RANK' in os.environ and
+ # args.rank = int(os.environ["RANK"])
+ # args.world_size = int(os.environ['WORLD_SIZE'])
+ # args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
+
+ # launch by torch.distributed.launch
+ # Single node
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
+ # Multi nodes
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
+ # python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
+
+ local_world_size = int(os.environ['WORLD_SIZE'])
+ args.world_size = args.world_size * local_world_size
+ args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
+ args.rank = args.rank * local_world_size + args.local_rank
+ print('world size: {}, rank: {}, local rank: {}'.format(args.world_size, args.rank, args.local_rank))
+ print(json.dumps(dict(os.environ), indent=2))
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.local_rank = int(os.environ['SLURM_LOCALID'])
+ args.world_size = int(os.environ['SLURM_NPROCS'])
+
+ print('world size: {}, world rank: {}, local rank: {}, device_count: {}'.format(args.world_size, args.rank, args.local_rank, torch.cuda.device_count()))
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ args.world_size = 1
+ args.rank = 0
+ args.local_rank = 0
+ return
+
+ print("world_size:{} rank:{} local_rank:{}".format(args.world_size, args.rank, args.local_rank))
+ args.distributed = True
+ torch.cuda.set_device(args.local_rank)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ print("Before torch.distributed.barrier()")
+ torch.distributed.barrier()
+ print("End torch.distributed.barrier()")
+ setup_for_distributed(args.rank == 0)
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ if target.numel() == 0:
+ return [torch.zeros([], device=output.device)]
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
+ """
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
+ This will eventually be supported natively by PyTorch, and this
+ class can go away.
+ """
+ if __torchvision_need_compat_flag < 0.7:
+ if input.numel() > 0:
+ return torch.nn.functional.interpolate(
+ input, size, scale_factor, mode, align_corners
+ )
+
+ output_shape = _output_size(2, input, size, scale_factor)
+ output_shape = list(input.shape[:-2]) + list(output_shape)
+ return _new_empty_tensor(input, output_shape)
+ else:
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+
+class color_sys():
+ def __init__(self, num_colors) -> None:
+ self.num_colors = num_colors
+ colors=[]
+ for i in np.arange(0., 360., 360. / num_colors):
+ hue = i/360.
+ lightness = (50 + np.random.rand() * 10)/100.
+ saturation = (90 + np.random.rand() * 10)/100.
+ colors.append(tuple([int(j*255) for j in colorsys.hls_to_rgb(hue, lightness, saturation)]))
+ self.colors = colors
+
+ def __call__(self, idx):
+ return self.colors[idx]
+
+def inverse_sigmoid(x, eps=1e-3):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1/x2)
+
+def clean_state_dict(state_dict):
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k[:7] == 'module.':
+ k = k[7:] # remove `module.`
+ new_state_dict[k] = v
+ return new_state_dict
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/plot_utils.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/plot_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af67acd9f7a3ad7ec61908adfa2c67acb734ca08
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/plot_utils.py
@@ -0,0 +1,112 @@
+"""
+Plotting utilities to visualize training logs.
+"""
+import torch
+import pandas as pd
+import numpy as np
+import seaborn as sns
+import matplotlib.pyplot as plt
+
+from pathlib import Path, PurePath
+
+
+def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
+ '''
+ Function to plot specific fields from training log(s). Plots both training and test results.
+
+ :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
+ - fields = which results to plot from each log file - plots both training and test for each field.
+ - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
+ - log_name = optional, name of log file if different than default 'log.txt'.
+
+ :: Outputs - matplotlib plots of results in fields, color coded for each log file.
+ - solid lines are training results, dashed lines are test results.
+
+ '''
+ func_name = "plot_utils.py::plot_logs"
+
+ # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
+ # convert single Path to list to avoid 'not iterable' error
+
+ if not isinstance(logs, list):
+ if isinstance(logs, PurePath):
+ logs = [logs]
+ print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
+ else:
+ raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
+ Expect list[Path] or single Path obj, received {type(logs)}")
+
+ # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
+ for i, dir in enumerate(logs):
+ if not isinstance(dir, PurePath):
+ raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
+ if not dir.exists():
+ raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
+ # verify log_name exists
+ fn = Path(dir / log_name)
+ if not fn.exists():
+ print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
+ print(f"--> full path of missing log file: {fn}")
+ return
+
+ # load log file(s) and plot
+ dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
+
+ fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
+
+ for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
+ for j, field in enumerate(fields):
+ if field == 'mAP':
+ coco_eval = pd.DataFrame(
+ np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
+ ).ewm(com=ewm_col).mean()
+ axs[j].plot(coco_eval, c=color)
+ else:
+ df.interpolate().ewm(com=ewm_col).mean().plot(
+ y=[f'train_{field}', f'test_{field}'],
+ ax=axs[j],
+ color=[color] * 2,
+ style=['-', '--']
+ )
+ for ax, field in zip(axs, fields):
+ if field == 'mAP':
+ ax.legend([Path(p).name for p in logs])
+ ax.set_title(field)
+ else:
+ ax.legend([f'train', f'test'])
+ ax.set_title(field)
+
+ return fig, axs
+
+def plot_precision_recall(files, naming_scheme='iter'):
+ if naming_scheme == 'exp_id':
+ # name becomes exp_id
+ names = [f.parts[-3] for f in files]
+ elif naming_scheme == 'iter':
+ names = [f.stem for f in files]
+ else:
+ raise ValueError(f'not supported {naming_scheme}')
+ fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
+ for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
+ data = torch.load(f)
+ # precision is n_iou, n_points, n_cat, n_area, max_det
+ precision = data['precision']
+ recall = data['params'].recThrs
+ scores = data['scores']
+ # take precision for all classes, all areas and 100 detections
+ precision = precision[0, :, :, 0, -1].mean(1)
+ scores = scores[0, :, :, 0, -1].mean(1)
+ prec = precision.mean()
+ rec = data['recall'][0, :, 0, -1].mean()
+ print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
+ f'score={scores.mean():0.3f}, ' +
+ f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
+ )
+ axs[0].plot(recall, precision, c=color)
+ axs[1].plot(recall, scores, c=color)
+
+ axs[0].set_title('Precision / Recall')
+ axs[0].legend(names)
+ axs[1].set_title('Scores / Recall')
+ axs[1].legend(names)
+ return fig, axs
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/slconfig.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/slconfig.py
new file mode 100644
index 0000000000000000000000000000000000000000..0982c3e548434e5151831cbbcb88eadd6bcb44a9
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/slconfig.py
@@ -0,0 +1,435 @@
+# ==========================================================
+# Modified from mmcv
+# ==========================================================
+import os, sys
+import os.path as osp
+import ast
+import tempfile
+import shutil
+from importlib import import_module
+
+from argparse import Action
+
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+
+BASE_KEY = '_base_'
+DELETE_KEY = '_delete_'
+RESERVED_KEYS = ['filename', 'text', 'pretty_text', 'get', 'dump', 'merge_from_dict']
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+class ConfigDict(Dict):
+
+ def __missing__(self, name):
+ raise KeyError(name)
+
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no "
+ f"attribute '{name}'")
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+
+
+class SLConfig(object):
+ """
+ config files.
+ only support .py file as config now.
+
+ ref: mmcv.utils.config
+
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+ """
+ @staticmethod
+ def _validate_py_syntax(filename):
+ with open(filename) as f:
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError:
+ raise SyntaxError('There are syntax errors in config '
+ f'file {filename}')
+
+ @staticmethod
+ def _file2dict(filename):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ if filename.lower().endswith('.py'):
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(
+ dir=temp_config_dir, suffix='.py')
+ temp_config_name = osp.basename(temp_config_file.name)
+ shutil.copyfile(filename,
+ osp.join(temp_config_dir, temp_config_name))
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ SLConfig._validate_py_syntax(filename)
+ mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith('__')
+ }
+ # delete imported module
+ del sys.modules[temp_module_name]
+ # close temp file
+ temp_config_file.close()
+ elif filename.lower().endswith(('.yml', '.yaml', '.json')):
+ from .slio import slload
+ cfg_dict = slload(filename)
+ else:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+
+ cfg_text = filename + '\n'
+ with open(filename, 'r') as f:
+ cfg_text += f.read()
+
+ # parse the base file
+ if BASE_KEY in cfg_dict:
+ cfg_dir = osp.dirname(filename)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = base_filename if isinstance(
+ base_filename, list) else [base_filename]
+
+ cfg_dict_list = list()
+ cfg_text_list = list()
+ for f in base_filename:
+ _cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+ cfg_text_list.append(_cfg_text)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ if len(base_cfg_dict.keys() & c.keys()) > 0:
+ raise KeyError('Duplicate key is not allowed among bases')
+ # TODO Allow the duplicate key while warnning user
+ base_cfg_dict.update(c)
+
+ base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict)
+ cfg_dict = base_cfg_dict
+
+ # merge cfg_text
+ cfg_text_list.append(cfg_text)
+ cfg_text = '\n'.join(cfg_text_list)
+
+ return cfg_dict, cfg_text
+
+ @staticmethod
+ def _merge_a_into_b(a, b):
+ """merge dict `a` into dict `b` (non-inplace).
+ values in `a` will overwrite `b`.
+ copy first to avoid inplace modification
+
+ Args:
+ a ([type]): [description]
+ b ([type]): [description]
+
+ Returns:
+ [dict]: [description]
+ """
+ # import ipdb; ipdb.set_trace()
+ if not isinstance(a, dict):
+ return a
+
+ b = b.copy()
+ for k, v in a.items():
+ if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
+
+ if not isinstance(b[k], dict) and not isinstance(b[k], list):
+ # if :
+ # import ipdb; ipdb.set_trace()
+ raise TypeError(
+ f'{k}={v} in child config cannot inherit from base '
+ f'because {k} is a dict in the child config but is of '
+ f'type {type(b[k])} in base config. You may set '
+ f'`{DELETE_KEY}=True` to ignore the base config')
+ b[k] = SLConfig._merge_a_into_b(v, b[k])
+ elif isinstance(b, list):
+ try:
+ _ = int(k)
+ except:
+ raise TypeError(
+ f'b is a list, '
+ f'index {k} should be an int when input but {type(k)}'
+ )
+ b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)])
+ else:
+ b[k] = v
+
+ return b
+
+ @staticmethod
+ def fromfile(filename):
+ cfg_dict, cfg_text = SLConfig._file2dict(filename)
+ return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename)
+
+
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError('cfg_dict must be a dict, but '
+ f'got {type(cfg_dict)}')
+ for key in cfg_dict:
+ if key in RESERVED_KEYS:
+ raise KeyError(f'{key} is reserved for config file')
+
+ super(SLConfig, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
+ super(SLConfig, self).__setattr__('_filename', filename)
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, 'r') as f:
+ text = f.read()
+ else:
+ text = ''
+ super(SLConfig, self).__setattr__('_text', text)
+
+
+ @property
+ def filename(self):
+ return self._filename
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def pretty_text(self):
+
+ indent = 4
+
+ def _indent(s_, num_spaces):
+ s = s_.split('\n')
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(num_spaces * ' ') + line for line in s]
+ s = '\n'.join(s)
+ s = first + '\n' + s
+ return s
+
+ def _format_basic_types(k, v, use_mapping=False):
+ if isinstance(v, str):
+ v_str = f"'{v}'"
+ else:
+ v_str = str(v)
+
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent)
+
+ return attr_str
+
+ def _format_list(k, v, use_mapping=False):
+ # check if all items in the list are dict
+ if all(isinstance(_, dict) for _ in v):
+ v_str = '[\n'
+ v_str += '\n'.join(
+ f'dict({_indent(_format_dict(v_), indent)}),'
+ for v_ in v).rstrip(',')
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent) + ']'
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping)
+ return attr_str
+
+ def _contain_invalid_identifier(dict_str):
+ contain_invalid_identifier = False
+ for key_name in dict_str:
+ contain_invalid_identifier |= \
+ (not str(key_name).isidentifier())
+ return contain_invalid_identifier
+
+ def _format_dict(input_dict, outest_level=False):
+ r = ''
+ s = []
+
+ use_mapping = _contain_invalid_identifier(input_dict)
+ if use_mapping:
+ r += '{'
+ for idx, (k, v) in enumerate(input_dict.items()):
+ is_last = idx >= len(input_dict) - 1
+ end = '' if outest_level or is_last else ','
+ if isinstance(v, dict):
+ v_str = '\n' + _format_dict(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: dict({v_str}'
+ else:
+ attr_str = f'{str(k)}=dict({v_str}'
+ attr_str = _indent(attr_str, indent) + ')' + end
+ elif isinstance(v, list):
+ attr_str = _format_list(k, v, use_mapping) + end
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping) + end
+
+ s.append(attr_str)
+ r += '\n'.join(s)
+ if use_mapping:
+ r += '}'
+ return r
+
+ cfg_dict = self._cfg_dict.to_dict()
+ text = _format_dict(cfg_dict, outest_level=True)
+ # copied from setup.cfg
+ yapf_style = dict(
+ based_on_style='pep8',
+ blank_line_before_nested_class_or_def=True,
+ split_before_expression_after_opening_paren=True)
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+
+ return text
+
+
+ def __repr__(self):
+ return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
+
+ def __len__(self):
+ return len(self._cfg_dict)
+
+ def __getattr__(self, name):
+ # # debug
+ # print('+'*15)
+ # print('name=%s' % name)
+ # print("addr:", id(self))
+ # # print('type(self):', type(self))
+ # print(self.__dict__)
+ # print('+'*15)
+ # if self.__dict__ == {}:
+ # raise ValueError
+
+ return getattr(self._cfg_dict, name)
+
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+
+ def __iter__(self):
+ return iter(self._cfg_dict)
+
+ def dump(self, file=None):
+ # import ipdb; ipdb.set_trace()
+ if file is None:
+ return self.pretty_text
+ else:
+ with open(file, 'w') as f:
+ f.write(self.pretty_text)
+
+ def merge_from_dict(self, options):
+ """Merge list into cfg_dict
+
+ Merge the dict parsed by MultipleKVAction into this cfg.
+
+ Examples:
+ >>> options = {'model.backbone.depth': 50,
+ ... 'model.backbone.with_cp':True}
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
+ >>> cfg.merge_from_dict(options)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
+
+ Args:
+ options (dict): dict of configs to merge from.
+ """
+ option_cfg_dict = {}
+ for full_key, v in options.items():
+ d = option_cfg_dict
+ key_list = full_key.split('.')
+ for subkey in key_list[:-1]:
+ d.setdefault(subkey, ConfigDict())
+ d = d[subkey]
+ subkey = key_list[-1]
+ d[subkey] = v
+
+ cfg_dict = super(SLConfig, self).__getattribute__('_cfg_dict')
+ super(SLConfig, self).__setattr__(
+ '_cfg_dict', SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict))
+
+ # for multiprocess
+ def __setstate__(self, state):
+ self.__init__(state)
+
+
+ def copy(self):
+ return SLConfig(self._cfg_dict.copy())
+
+ def deepcopy(self):
+ return SLConfig(self._cfg_dict.deepcopy())
+
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options should
+ be passed as comma separated values, i.e KEY=V1,V2,V3
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ['true', 'false']:
+ return True if val.lower() == 'true' else False
+ if val.lower() in ['none', 'null']:
+ return None
+ return val
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split('=', maxsplit=1)
+ val = [self._parse_int_float_bool(v) for v in val.split(',')]
+ if len(val) == 1:
+ val = val[0]
+ options[key] = val
+ setattr(namespace, self.dest, options)
+
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/slio.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/slio.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b8f4dad2441b8352ab7311dbf16019515441331
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/slio.py
@@ -0,0 +1,173 @@
+# ==========================================================
+# Modified from mmcv
+# ==========================================================
+
+import json, pickle, yaml
+try:
+ from yaml import CLoader as Loader, CDumper as Dumper
+except ImportError:
+ from yaml import Loader, Dumper
+
+from pathlib import Path
+from abc import ABCMeta, abstractmethod
+
+# ===========================
+# Rigister handler
+# ===========================
+
+class BaseFileHandler(metaclass=ABCMeta):
+
+ @abstractmethod
+ def load_from_fileobj(self, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_str(self, obj, **kwargs):
+ pass
+
+ def load_from_path(self, filepath, mode='r', **kwargs):
+ with open(filepath, mode) as f:
+ return self.load_from_fileobj(f, **kwargs)
+
+ def dump_to_path(self, obj, filepath, mode='w', **kwargs):
+ with open(filepath, mode) as f:
+ self.dump_to_fileobj(obj, f, **kwargs)
+
+class JsonHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file):
+ return json.load(file)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ json.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ return json.dumps(obj, **kwargs)
+
+class PickleHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+
+ def load_from_path(self, filepath, **kwargs):
+ return super(PickleHandler, self).load_from_path(
+ filepath, mode='rb', **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ return pickle.dumps(obj, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ pickle.dump(obj, file, **kwargs)
+
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super(PickleHandler, self).dump_to_path(
+ obj, filepath, mode='wb', **kwargs)
+
+class YamlHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file, **kwargs):
+ kwargs.setdefault('Loader', Loader)
+ return yaml.load(file, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ yaml.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ return yaml.dump(obj, **kwargs)
+
+file_handlers = {
+ 'json': JsonHandler(),
+ 'yaml': YamlHandler(),
+ 'yml': YamlHandler(),
+ 'pickle': PickleHandler(),
+ 'pkl': PickleHandler()
+}
+
+# ===========================
+# load and dump
+# ===========================
+
+def is_str(x):
+ """Whether the input is an string instance.
+
+ Note: This method is deprecated since python 2 is no longer supported.
+ """
+ return isinstance(x, str)
+
+def slload(file, file_format=None, **kwargs):
+ """Load data from json/yaml/pickle files.
+
+ This method provides a unified api for loading data from serialized files.
+
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml" and
+ "pickle/pkl".
+
+ Returns:
+ The content from the file.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None and is_str(file):
+ file_format = file.split('.')[-1]
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ if is_str(file):
+ obj = handler.load_from_path(file, **kwargs)
+ elif hasattr(file, 'read'):
+ obj = handler.load_from_fileobj(file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filepath str or a file-object')
+ return obj
+
+
+def sldump(obj, file=None, file_format=None, **kwargs):
+ """Dump data to json/yaml/pickle strings or files.
+
+ This method provides a unified api for dumping data as strings or to files,
+ and also supports custom arguments for each file format.
+
+ Args:
+ obj (any): The python object to be dumped.
+ file (str or :obj:`Path` or file-like object, optional): If not
+ specified, then the object is dump to a str, otherwise to a file
+ specified by the filename or file-like object.
+ file_format (str, optional): Same as :func:`load`.
+
+ Returns:
+ bool: True for success, False otherwise.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None:
+ if is_str(file):
+ file_format = file.split('.')[-1]
+ elif file is None:
+ raise ValueError(
+ 'file_format must be specified since file is None')
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ if file is None:
+ return handler.dump_to_str(obj, **kwargs)
+ elif is_str(file):
+ handler.dump_to_path(obj, file, **kwargs)
+ elif hasattr(file, 'write'):
+ handler.dump_to_fileobj(obj, file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filename str or a file-object')
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/static_data_path.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/static_data_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef8b6e46b871faca3d2143b0bc552253b00239aa
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/static_data_path.py
@@ -0,0 +1,10 @@
+coco = dict(
+ train = dict(
+ img_folder = '/comp_robot/cv_public_dataset/COCO2017/train2017',
+ ann_file = '/comp_robot/cv_public_dataset/COCO2017/annotations/instances_train2017.json'
+ ),
+ val = dict(
+ img_folder = '/comp_robot/cv_public_dataset/COCO2017/val2017',
+ ann_file = '/comp_robot/cv_public_dataset/COCO2017/annotations/instances_val2017.json'
+ )
+)
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/time_counter.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/time_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..19dc2e640bb0f19e686a5078d8e4d7db7ddaad96
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/time_counter.py
@@ -0,0 +1,60 @@
+import json
+import time
+
+class TimeCounter:
+ def __init__(self) -> None:
+ pass
+
+ def clear(self):
+ self.timedict = {}
+ self.basetime = time.perf_counter()
+
+ def timeit(self, name):
+ nowtime = time.perf_counter() - self.basetime
+ self.timedict[name] = nowtime
+ self.basetime = time.perf_counter()
+
+
+class TimeHolder:
+ def __init__(self) -> None:
+ self.timedict = {}
+
+ def update(self, _timedict:dict):
+ for k,v in _timedict.items():
+ if k not in self.timedict:
+ self.timedict[k] = AverageMeter(name=k, val_only=True)
+ self.timedict[k].update(val=v)
+
+ def final_res(self):
+ return {k:v.avg for k,v in self.timedict.items()}
+
+ def __str__(self):
+ return json.dumps(self.final_res(), indent=2)
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self, name, fmt=':f', val_only=False):
+ self.name = name
+ self.fmt = fmt
+ self.val_only = val_only
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def __str__(self):
+ if self.val_only:
+ fmtstr = '{name} {val' + self.fmt + '}'
+ else:
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+ return fmtstr.format(**self.__dict__)
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/utils.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d747bef2541d5cd4d17b61778c3c84b413795467
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/utils.py
@@ -0,0 +1,473 @@
+from collections import OrderedDict
+from copy import deepcopy
+import json
+import warnings
+
+import torch
+import numpy as np
+
+def slprint(x, name='x'):
+ if isinstance(x, (torch.Tensor, np.ndarray)):
+ print(f'{name}.shape:', x.shape)
+ elif isinstance(x, (tuple, list)):
+ print('type x:', type(x))
+ for i in range(min(10, len(x))):
+ slprint(x[i], f'{name}[{i}]')
+ elif isinstance(x, dict):
+ for k,v in x.items():
+ slprint(v, f'{name}[{k}]')
+ else:
+ print(f'{name}.type:', type(x))
+
+def clean_state_dict(state_dict):
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k[:7] == 'module.':
+ k = k[7:] # remove `module.`
+ new_state_dict[k] = v
+ return new_state_dict
+
+def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) \
+ -> torch.FloatTensor:
+ # img: tensor(3,H,W) or tensor(B,3,H,W)
+ # return: same as img
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
+ if img.dim() == 3:
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (img.size(0), str(img.size()))
+ img_perm = img.permute(1,2,0)
+ mean = torch.Tensor(mean)
+ std = torch.Tensor(std)
+ img_res = img_perm * std + mean
+ return img_res.permute(2,0,1)
+ else: # img.dim() == 4
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (img.size(1), str(img.size()))
+ img_perm = img.permute(0,2,3,1)
+ mean = torch.Tensor(mean)
+ std = torch.Tensor(std)
+ img_res = img_perm * std + mean
+ return img_res.permute(0,3,1,2)
+
+
+
+class CocoClassMapper():
+ def __init__(self) -> None:
+ self.category_map_str = {"1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "10": 10, "11": 11, "13": 12, "14": 13, "15": 14, "16": 15, "17": 16, "18": 17, "19": 18, "20": 19, "21": 20, "22": 21, "23": 22, "24": 23, "25": 24, "27": 25, "28": 26, "31": 27, "32": 28, "33": 29, "34": 30, "35": 31, "36": 32, "37": 33, "38": 34, "39": 35, "40": 36, "41": 37, "42": 38, "43": 39, "44": 40, "46": 41, "47": 42, "48": 43, "49": 44, "50": 45, "51": 46, "52": 47, "53": 48, "54": 49, "55": 50, "56": 51, "57": 52, "58": 53, "59": 54, "60": 55, "61": 56, "62": 57, "63": 58, "64": 59, "65": 60, "67": 61, "70": 62, "72": 63, "73": 64, "74": 65, "75": 66, "76": 67, "77": 68, "78": 69, "79": 70, "80": 71, "81": 72, "82": 73, "84": 74, "85": 75, "86": 76, "87": 77, "88": 78, "89": 79, "90": 80}
+ self.origin2compact_mapper = {int(k):v-1 for k,v in self.category_map_str.items()}
+ self.compact2origin_mapper = {int(v-1):int(k) for k,v in self.category_map_str.items()}
+
+ def origin2compact(self, idx):
+ return self.origin2compact_mapper[int(idx)]
+
+ def compact2origin(self, idx):
+ return self.compact2origin_mapper[int(idx)]
+
+def to_device(item, device):
+ if isinstance(item, torch.Tensor):
+ return item.to(device)
+ elif isinstance(item, list):
+ return [to_device(i, device) for i in item]
+ elif isinstance(item, dict):
+ return {k: to_device(v, device) for k,v in item.items()}
+ else:
+ raise NotImplementedError("Call Shilong if you use other containers! type: {}".format(type(item)))
+
+
+
+#
+def get_gaussian_mean(x, axis, other_axis, softmax=True):
+ """
+
+ Args:
+ x (float): Input images(BxCxHxW)
+ axis (int): The index for weighted mean
+ other_axis (int): The other index
+
+ Returns: weighted index for axis, BxC
+
+ """
+ mat2line = torch.sum(x, axis=other_axis)
+ # mat2line = mat2line / mat2line.mean() * 10
+ if softmax:
+ u = torch.softmax(mat2line, axis=2)
+ else:
+ u = mat2line / (mat2line.sum(2, keepdim=True) + 1e-6)
+ size = x.shape[axis]
+ ind = torch.linspace(0, 1, size).to(x.device)
+ batch = x.shape[0]
+ channel = x.shape[1]
+ index = ind.repeat([batch, channel, 1])
+ mean_position = torch.sum(index * u, dim=2)
+ return mean_position
+
+def get_expected_points_from_map(hm, softmax=True):
+ """get_gaussian_map_from_points
+ B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
+ softargmax function
+
+ Args:
+ hm (float): Input images(BxCxHxW)
+
+ Returns:
+ weighted index for axis, BxCx2. float between 0 and 1.
+
+ """
+ # hm = 10*hm
+ B,C,H,W = hm.shape
+ y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
+ x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
+ # return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
+ return torch.stack([x_mean, y_mean], dim=2)
+
+# Positional encoding (section 5.1)
+# borrow from nerf
+class Embedder:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.create_embedding_fn()
+
+ def create_embedding_fn(self):
+ embed_fns = []
+ d = self.kwargs['input_dims']
+ out_dim = 0
+ if self.kwargs['include_input']:
+ embed_fns.append(lambda x : x)
+ out_dim += d
+
+ max_freq = self.kwargs['max_freq_log2']
+ N_freqs = self.kwargs['num_freqs']
+
+ if self.kwargs['log_sampling']:
+ freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
+ else:
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
+
+ for freq in freq_bands:
+ for p_fn in self.kwargs['periodic_fns']:
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
+ out_dim += d
+
+ self.embed_fns = embed_fns
+ self.out_dim = out_dim
+
+ def embed(self, inputs):
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
+
+
+def get_embedder(multires, i=0):
+ import torch.nn as nn
+ if i == -1:
+ return nn.Identity(), 3
+
+ embed_kwargs = {
+ 'include_input' : True,
+ 'input_dims' : 3,
+ 'max_freq_log2' : multires-1,
+ 'num_freqs' : multires,
+ 'log_sampling' : True,
+ 'periodic_fns' : [torch.sin, torch.cos],
+ }
+
+ embedder_obj = Embedder(**embed_kwargs)
+ embed = lambda x, eo=embedder_obj : eo.embed(x)
+ return embed, embedder_obj.out_dim
+
+class APOPMeter():
+ def __init__(self) -> None:
+ self.tp = 0
+ self.fp = 0
+ self.tn = 0
+ self.fn = 0
+
+ def update(self, pred, gt):
+ """
+ Input:
+ pred, gt: Tensor()
+ """
+ assert pred.shape == gt.shape
+ self.tp += torch.logical_and(pred == 1, gt == 1).sum().item()
+ self.fp += torch.logical_and(pred == 1, gt == 0).sum().item()
+ self.tn += torch.logical_and(pred == 0, gt == 0).sum().item()
+ self.tn += torch.logical_and(pred == 1, gt == 0).sum().item()
+
+ def update_cm(self, tp, fp, tn, fn):
+ self.tp += tp
+ self.fp += fp
+ self.tn += tn
+ self.tn += fn
+
+def inverse_sigmoid(x, eps=1e-5):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1/x2)
+
+import argparse
+from util.slconfig import SLConfig
+def get_raw_dict(args):
+ """
+ return the dicf contained in args.
+
+ e.g:
+ >>> with open(path, 'w') as f:
+ json.dump(get_raw_dict(args), f, indent=2)
+ """
+ if isinstance(args, argparse.Namespace):
+ return vars(args)
+ elif isinstance(args, dict):
+ return args
+ elif isinstance(args, SLConfig):
+ return args._cfg_dict
+ else:
+ raise NotImplementedError("Unknown type {}".format(type(args)))
+
+
+def stat_tensors(tensor):
+ assert tensor.dim() == 1
+ tensor_sm = tensor.softmax(0)
+ entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
+
+ return {
+ 'max': tensor.max(),
+ 'min': tensor.min(),
+ 'mean': tensor.mean(),
+ 'var': tensor.var(),
+ 'std': tensor.var() ** 0.5,
+ 'entropy': entropy
+ }
+
+
+class NiceRepr:
+ """Inherit from this class and define ``__nice__`` to "nicely" print your
+ objects.
+
+ Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
+ Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
+ If the inheriting class has a ``__len__``, method then the default
+ ``__nice__`` method will return its length.
+
+ Example:
+ >>> class Foo(NiceRepr):
+ ... def __nice__(self):
+ ... return 'info'
+ >>> foo = Foo()
+ >>> assert str(foo) == ''
+ >>> assert repr(foo).startswith('>> class Bar(NiceRepr):
+ ... pass
+ >>> bar = Bar()
+ >>> import pytest
+ >>> with pytest.warns(None) as record:
+ >>> assert 'object at' in str(bar)
+ >>> assert 'object at' in repr(bar)
+
+ Example:
+ >>> class Baz(NiceRepr):
+ ... def __len__(self):
+ ... return 5
+ >>> baz = Baz()
+ >>> assert str(baz) == ''
+ """
+
+ def __nice__(self):
+ """str: a "nice" summary string describing this module"""
+ if hasattr(self, '__len__'):
+ # It is a common pattern for objects to use __len__ in __nice__
+ # As a convenience we define a default __nice__ for these objects
+ return str(len(self))
+ else:
+ # In all other cases force the subclass to overload __nice__
+ raise NotImplementedError(
+ f'Define the __nice__ method for {self.__class__!r}')
+
+ def __repr__(self):
+ """str: the string of the module"""
+ try:
+ nice = self.__nice__()
+ classname = self.__class__.__name__
+ return f'<{classname}({nice}) at {hex(id(self))}>'
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
+
+ def __str__(self):
+ """str: the string of the module"""
+ try:
+ classname = self.__class__.__name__
+ nice = self.__nice__()
+ return f'<{classname}({nice})>'
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
+
+
+
+def ensure_rng(rng=None):
+ """Coerces input into a random number generator.
+
+ If the input is None, then a global random state is returned.
+
+ If the input is a numeric value, then that is used as a seed to construct a
+ random state. Otherwise the input is returned as-is.
+
+ Adapted from [1]_.
+
+ Args:
+ rng (int | numpy.random.RandomState | None):
+ if None, then defaults to the global rng. Otherwise this can be an
+ integer or a RandomState class
+ Returns:
+ (numpy.random.RandomState) : rng -
+ a numpy random number generator
+
+ References:
+ .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
+ """
+
+ if rng is None:
+ rng = np.random.mtrand._rand
+ elif isinstance(rng, int):
+ rng = np.random.RandomState(rng)
+ else:
+ rng = rng
+ return rng
+
+def random_boxes(num=1, scale=1, rng=None):
+ """Simple version of ``kwimage.Boxes.random``
+
+ Returns:
+ Tensor: shape (n, 4) in x1, y1, x2, y2 format.
+
+ References:
+ https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
+
+ Example:
+ >>> num = 3
+ >>> scale = 512
+ >>> rng = 0
+ >>> boxes = random_boxes(num, scale, rng)
+ >>> print(boxes)
+ tensor([[280.9925, 278.9802, 308.6148, 366.1769],
+ [216.9113, 330.6978, 224.0446, 456.5878],
+ [405.3632, 196.3221, 493.3953, 270.7942]])
+ """
+ rng = ensure_rng(rng)
+
+ tlbr = rng.rand(num, 4).astype(np.float32)
+
+ tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
+ tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
+ br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
+ br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
+
+ tlbr[:, 0] = tl_x * scale
+ tlbr[:, 1] = tl_y * scale
+ tlbr[:, 2] = br_x * scale
+ tlbr[:, 3] = br_y * scale
+
+ boxes = torch.from_numpy(tlbr)
+ return boxes
+
+
+class ModelEma(torch.nn.Module):
+ def __init__(self, model, decay=0.9997, device=None):
+ super(ModelEma, self).__init__()
+ # make a copy of the model for accumulating moving average of weights
+ self.module = deepcopy(model)
+ self.module.eval()
+
+ # import ipdb; ipdb.set_trace()
+
+ self.decay = decay
+ self.device = device # perform ema on different device from model if set
+ if self.device is not None:
+ self.module.to(device=device)
+
+ def _update(self, model, update_fn):
+ with torch.no_grad():
+ for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
+ if self.device is not None:
+ model_v = model_v.to(device=self.device)
+ ema_v.copy_(update_fn(ema_v, model_v))
+
+ def update(self, model):
+ self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
+
+ def set(self, model):
+ self._update(model, update_fn=lambda e, m: m)
+
+class BestMetricSingle():
+ def __init__(self, init_res=0.0, better='large') -> None:
+ self.init_res = init_res
+ self.best_res = init_res
+ self.best_ep = -1
+
+ self.better = better
+ assert better in ['large', 'small']
+
+ def isbetter(self, new_res, old_res):
+ if self.better == 'large':
+ return new_res > old_res
+ if self.better == 'small':
+ return new_res < old_res
+
+ def update(self, new_res, ep):
+ if self.isbetter(new_res, self.best_res):
+ self.best_res = new_res
+ self.best_ep = ep
+ return True
+ return False
+
+ def __str__(self) -> str:
+ return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep)
+
+ def __repr__(self) -> str:
+ return self.__str__()
+
+ def summary(self) -> dict:
+ return {
+ 'best_res': self.best_res,
+ 'best_ep': self.best_ep,
+ }
+
+
+class BestMetricHolder():
+ def __init__(self, init_res=0.0, better='large', use_ema=False) -> None:
+ self.best_all = BestMetricSingle(init_res, better)
+ self.use_ema = use_ema
+ if use_ema:
+ self.best_ema = BestMetricSingle(init_res, better)
+ self.best_regular = BestMetricSingle(init_res, better)
+
+
+ def update(self, new_res, epoch, is_ema=False):
+ """
+ return if the results is the best.
+ """
+ if not self.use_ema:
+ return self.best_all.update(new_res, epoch)
+ else:
+ if is_ema:
+ self.best_ema.update(new_res, epoch)
+ return self.best_all.update(new_res, epoch)
+ else:
+ self.best_regular.update(new_res, epoch)
+ return self.best_all.update(new_res, epoch)
+
+ def summary(self):
+ if not self.use_ema:
+ return self.best_all.summary()
+
+ res = {}
+ res.update({f'all_{k}':v for k,v in self.best_all.summary().items()})
+ res.update({f'regular_{k}':v for k,v in self.best_regular.summary().items()})
+ res.update({f'ema_{k}':v for k,v in self.best_ema.summary().items()})
+ return res
+
+ def __repr__(self) -> str:
+ return json.dumps(self.summary(), indent=2)
+
+ def __str__(self) -> str:
+ return self.__repr__()
+
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/vis_utils.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/vis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c90906862a92680271c233696b053d78287108e
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/vis_utils.py
@@ -0,0 +1,92 @@
+import cv2
+import numpy as np
+
+from util.utils import renorm
+from util.misc import color_sys
+
+_color_getter = color_sys(100)
+
+# plot known and unknown box
+def add_box_to_img(img, boxes, colorlist, brands=None):
+ """[summary]
+
+ Args:
+ img ([type]): np.array, H,W,3
+ boxes ([type]): list of list(4)
+ colorlist: list of colors.
+ brands: text.
+
+ Return:
+ img: np.array. H,W,3.
+ """
+ H, W = img.shape[:2]
+ for _i, (box, color) in enumerate(zip(boxes, colorlist)):
+ x, y, w, h = box[0] * W, box[1] * H, box[2] * W, box[3] * H
+ img = cv2.rectangle(img.copy(), (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), color, 2)
+ if brands is not None:
+ brand = brands[_i]
+ org = (int(x-w/2), int(y+h/2))
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ fontScale = 0.5
+ thickness = 1
+ img = cv2.putText(img.copy(), str(brand), org, font,
+ fontScale, color, thickness, cv2.LINE_AA)
+ return img
+
+def plot_dual_img(img, boxes, labels, idxs, probs=None):
+ """[summary]
+
+ Args:
+ img ([type]): 3,H,W. tensor.
+ boxes (): tensor(Kx4) or list of tensor(1x4).
+ labels ([type]): list of ints.
+ idxs ([type]): list of ints.
+ probs (optional): listof floats.
+
+ Returns:
+ img_classcolor: np.array. H,W,3. img with class-wise label.
+ img_seqcolor: np.array. H,W,3. img with seq-wise label.
+ """
+ # import ipdb; ipdb.set_trace()
+ boxes = [i.cpu().tolist() for i in boxes]
+ img = (renorm(img.cpu()).permute(1,2,0).numpy() * 255).astype(np.uint8)
+ # plot with class
+ class_colors = [_color_getter(i) for i in labels]
+ if probs is not None:
+ brands = ["{},{:.2f}".format(j,k) for j,k in zip(labels, probs)]
+ else:
+ brands = labels
+ img_classcolor = add_box_to_img(img, boxes, class_colors, brands=brands)
+ # plot with seq
+ seq_colors = [_color_getter((i * 11) % 100) for i in idxs]
+ img_seqcolor = add_box_to_img(img, boxes, seq_colors, brands=idxs)
+ return img_classcolor, img_seqcolor
+
+
+def plot_raw_img(img, boxes, labels):
+ """[summary]
+
+ Args:
+ img ([type]): 3,H,W. tensor.
+ boxes ([type]): Kx4. tensor
+ labels ([type]): K. tensor.
+
+ return:
+ img: np.array. H,W,3. img with bbox annos.
+
+ """
+ img = (renorm(img.cpu()).permute(1,2,0).numpy() * 255).astype(np.uint8)
+ H, W = img.shape[:2]
+ for box, label in zip(boxes.tolist(), labels.tolist()):
+ x, y, w, h = box[0] * W, box[1] * H, box[2] * W, box[3] * H
+ # import ipdb; ipdb.set_trace()
+ img = cv2.rectangle(img.copy(), (int(x-w/2), int(y-h/2)), (int(x+w/2), int(y+h/2)), _color_getter(label), 2)
+ # add text
+ org = (int(x-w/2), int(y+h/2))
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ fontScale = 1
+ thickness = 1
+ img = cv2.putText(img.copy(), str(label), org, font,
+ fontScale, _color_getter(label), thickness, cv2.LINE_AA)
+
+ return img
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/visualizer.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fdbdf04ea3cdaf43b6b12c3cb2e8004d897c55e
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/util/visualizer.py
@@ -0,0 +1,132 @@
+# -*- coding: utf-8 -*-
+'''
+@File : visualizer.py
+@Time : 2022/04/05 11:39:33
+@Author : Shilong Liu
+@Contact : liusl20@mail.tsinghua.edu.cn; slongliu86@gmail.com
+Modified from COCO evaluator
+'''
+
+import os, sys
+from textwrap import wrap
+import torch
+import numpy as np
+import cv2
+import datetime
+
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon
+from pycocotools import mask as maskUtils
+from matplotlib import transforms
+
+def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) \
+ -> torch.FloatTensor:
+ # img: tensor(3,H,W) or tensor(B,3,H,W)
+ # return: same as img
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
+ if img.dim() == 3:
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (img.size(0), str(img.size()))
+ img_perm = img.permute(1,2,0)
+ mean = torch.Tensor(mean)
+ std = torch.Tensor(std)
+ img_res = img_perm * std + mean
+ return img_res.permute(2,0,1)
+ else: # img.dim() == 4
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (img.size(1), str(img.size()))
+ img_perm = img.permute(0,2,3,1)
+ mean = torch.Tensor(mean)
+ std = torch.Tensor(std)
+ img_res = img_perm * std + mean
+ return img_res.permute(0,3,1,2)
+
+class ColorMap():
+ def __init__(self, basergb=[255,255,0]):
+ self.basergb = np.array(basergb)
+ def __call__(self, attnmap):
+ # attnmap: h, w. np.uint8.
+ # return: h, w, 4. np.uint8.
+ assert attnmap.dtype == np.uint8
+ h, w = attnmap.shape
+ res = self.basergb.copy()
+ res = res[None][None].repeat(h, 0).repeat(w, 1) # h, w, 3
+ attn1 = attnmap.copy()[..., None] # h, w, 1
+ res = np.concatenate((res, attn1), axis=-1).astype(np.uint8)
+ return res
+
+
+class COCOVisualizer():
+ def __init__(self) -> None:
+ pass
+
+ def visualize(self, img, tgt, caption=None, dpi=120, savedir=None, show_in_console=True):
+ """
+ img: tensor(3, H, W)
+ tgt: make sure they are all on cpu.
+ must have items: 'image_id', 'boxes', 'size'
+ """
+ plt.figure(dpi=dpi)
+ plt.rcParams['font.size'] = '5'
+ ax = plt.gca()
+ img = renorm(img).permute(1, 2, 0)
+ # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
+ # import ipdb; ipdb.set_trace()
+ ax.imshow(img)
+
+ self.addtgt(tgt)
+ if show_in_console:
+ plt.show()
+
+ if savedir is not None:
+ if caption is None:
+ savename = '{}/{}-{}.png'.format(savedir, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-'))
+ else:
+ savename = '{}/{}-{}-{}.png'.format(savedir, caption, int(tgt['image_id']), str(datetime.datetime.now()).replace(' ', '-'))
+ print("savename: {}".format(savename))
+ os.makedirs(os.path.dirname(savename), exist_ok=True)
+ plt.savefig(savename)
+ plt.close()
+
+ def addtgt(self, tgt):
+ """
+ - tgt: dict. args:
+ - boxes: num_boxes, 4. xywh, [0,1].
+ - box_label: num_boxes.
+ """
+ assert 'boxes' in tgt
+ ax = plt.gca()
+ H, W = tgt['size'].tolist()
+ numbox = tgt['boxes'].shape[0]
+
+ color = []
+ polygons = []
+ boxes = []
+ for box in tgt['boxes'].cpu():
+ unnormbbox = box * torch.Tensor([W, H, W, H])
+ unnormbbox[:2] -= unnormbbox[2:] / 2
+ [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist()
+ boxes.append([bbox_x, bbox_y, bbox_w, bbox_h])
+ poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
+ np_poly = np.array(poly).reshape((4,2))
+ polygons.append(Polygon(np_poly))
+ c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
+ color.append(c)
+
+ p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1)
+ ax.add_collection(p)
+ p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
+ ax.add_collection(p)
+
+
+ if 'box_label' in tgt:
+ assert len(tgt['box_label']) == numbox, f"{len(tgt['box_label'])} = {numbox}, "
+ for idx, bl in enumerate(tgt['box_label']):
+ _string = str(bl)
+ bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
+ # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
+ ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': color[idx], 'alpha': 0.6, 'pad': 1})
+
+ if 'caption' in tgt:
+ ax.set_title(tgt['caption'], wrap=True)
+
+
diff --git a/projects/instance_segment_anything/models/focalnet_dino/models/dino/utils.py b/projects/instance_segment_anything/models/focalnet_dino/models/dino/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..991b9d3d9d81d299e629233fc01cc8770aeb189f
--- /dev/null
+++ b/projects/instance_segment_anything/models/focalnet_dino/models/dino/utils.py
@@ -0,0 +1,177 @@
+# ------------------------------------------------------------------------
+# DINO
+# Copyright (c) 2022 IDEA. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+
+import torch
+import random
+from torch import nn, Tensor
+import os
+
+import math
+import torch.nn.functional as F
+from torch import nn
+
+
+
+
+def gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor, learnedwh=None):
+ """
+ Input:
+ - memory: bs, \sum{hw}, d_model
+ - memory_padding_mask: bs, \sum{hw}
+ - spatial_shapes: nlevel, 2
+ - learnedwh: 2
+ Output:
+ - output_memory: bs, \sum{hw}, d_model
+ - output_proposals: bs, \sum{hw}, 4
+ """
+ N_, S_, C_ = memory.shape
+ base_scale = 4.0
+ proposals = []
+ _cur = 0
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ # import ipdb; ipdb.set_trace()
+
+ grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
+
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
+
+ if learnedwh is not None:
+ # import ipdb; ipdb.set_trace()
+ wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0 ** lvl)
+ else:
+ wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
+
+ # scale = torch.cat([W_[None].unsqueeze(-1), H_[None].unsqueeze(-1)], 1).view(1, 1, 1, 2).repeat(N_, 1, 1, 1)
+ # grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
+ # wh = torch.ones_like(grid) / scale
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
+ proposals.append(proposal)
+ _cur += (H_ * W_)
+ # import ipdb; ipdb.set_trace()
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
+ output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid
+ output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
+
+ # output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
+ # output_memory = output_memory.masked_fill(~output_proposals_valid, float('inf'))
+
+ return output_memory, output_proposals
+
+
+class RandomBoxPerturber():
+ def __init__(self, x_noise_scale=0.2, y_noise_scale=0.2, w_noise_scale=0.2, h_noise_scale=0.2) -> None:
+ self.noise_scale = torch.Tensor([x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale])
+
+ def __call__(self, refanchors: Tensor) -> Tensor:
+ nq, bs, query_dim = refanchors.shape
+ device = refanchors.device
+
+ noise_raw = torch.rand_like(refanchors)
+ noise_scale = self.noise_scale.to(device)[:query_dim]
+
+ new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale)
+ return new_refanchors.clamp_(0, 1)
+
+def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ Returns:
+ Loss tensor
+ """
+ prob = inputs.sigmoid()
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+
+ return loss.mean(1).sum() / num_boxes
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+def _get_activation_fn(activation, d_model=256, batch_dim=0):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ if activation == "prelu":
+ return nn.PReLU()
+ if activation == "selu":
+ return F.selu
+
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+
+
+def gen_sineembed_for_position(pos_tensor):
+ # n_query, bs, _ = pos_tensor.size()
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
+ scale = 2 * math.pi
+ dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
+ dim_t = 10000 ** (2 * (dim_t // 2) / 128)
+ x_embed = pos_tensor[:, :, 0] * scale
+ y_embed = pos_tensor[:, :, 1] * scale
+ pos_x = x_embed[:, :, None] / dim_t
+ pos_y = y_embed[:, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
+ if pos_tensor.size(-1) == 2:
+ pos = torch.cat((pos_y, pos_x), dim=2)
+ elif pos_tensor.size(-1) == 4:
+ w_embed = pos_tensor[:, :, 2] * scale
+ pos_w = w_embed[:, :, None] / dim_t
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ h_embed = pos_tensor[:, :, 3] * scale
+ pos_h = h_embed[:, :, None] / dim_t
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
+
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
+ else:
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
+ return pos
\ No newline at end of file
diff --git a/projects/instance_segment_anything/models/hdetr/hdetr_wrapper.py b/projects/instance_segment_anything/models/hdetr/hdetr_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dd767f5f798088905b1322730197ca60ef6cceb
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/hdetr_wrapper.py
@@ -0,0 +1,138 @@
+import torch
+import torch.nn.functional as F
+from mmcv.runner import BaseModule
+
+from .models import build_model
+from .models.util.misc import NestedTensor, inverse_sigmoid
+
+
+class HDetrWrapper(BaseModule):
+ def __init__(self,
+ args=None,
+ init_cfg=None):
+ super(HDetrWrapper, self).__init__(init_cfg)
+ model, box_postprocessor = build_model(args)
+ self.model = model
+ self.box_postprocessor = box_postprocessor
+
+ self.model.num_queries = self.model.num_queries_one2one
+ self.model.transformer.two_stage_num_proposals = self.model.num_queries
+ self.cls_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28,
+ 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54,
+ 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
+ 82, 84, 85, 86, 87, 88, 89, 90]
+
+ def forward(self,
+ img,
+ img_metas):
+ """Forward function for training mode.
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ """
+ input_img_h, input_img_w = img_metas[0]["batch_input_shape"]
+ batch_size = img.size(0)
+ img_masks = img.new_ones((batch_size, input_img_h, input_img_w),
+ dtype=torch.bool)
+ for img_id in range(batch_size):
+ img_h, img_w, _ = img_metas[img_id]["img_shape"]
+ img_masks[img_id, :img_h, :img_w] = False
+ samples = NestedTensor(tensors=img, mask=img_masks)
+ features, pos = self.model.backbone(samples)
+
+ srcs = []
+ masks = []
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ srcs.append(self.model.input_proj[l](src))
+ masks.append(mask)
+ assert mask is not None
+ if self.model.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.model.num_feature_levels):
+ if l == _len_srcs:
+ src = self.model.input_proj[l](features[-1].tensors)
+ else:
+ src = self.model.input_proj[l](srcs[-1])
+ m = samples.mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
+ torch.bool
+ )[0]
+ pos_l = self.model.backbone[1](NestedTensor(src, mask)).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ pos.append(pos_l)
+
+ query_embeds = None
+ if not self.model.two_stage or self.model.mixed_selection:
+ query_embeds = self.model.query_embed.weight[0: self.model.num_queries, :]
+
+ # make attn mask
+ """ attention mask to prevent information leakage
+ """
+ self_attn_mask = (
+ torch.zeros([self.model.num_queries, self.model.num_queries, ]).bool().to(src.device)
+ )
+ self_attn_mask[self.model.num_queries_one2one:, 0: self.model.num_queries_one2one, ] = True
+ self_attn_mask[0: self.model.num_queries_one2one, self.model.num_queries_one2one:, ] = True
+
+ (
+ hs,
+ init_reference,
+ inter_references,
+ enc_outputs_class,
+ enc_outputs_coord_unact,
+ ) = self.model.transformer(srcs, masks, pos, query_embeds, self_attn_mask)
+
+ outputs_classes_one2one = []
+ outputs_coords_one2one = []
+ outputs_classes_one2many = []
+ outputs_coords_one2many = []
+ for lvl in range(hs.shape[0]):
+ if lvl == 0:
+ reference = init_reference
+ else:
+ reference = inter_references[lvl - 1]
+ reference = inverse_sigmoid(reference)
+ outputs_class = self.model.class_embed[lvl](hs[lvl])
+ tmp = self.model.bbox_embed[lvl](hs[lvl])
+ if reference.shape[-1] == 4:
+ tmp += reference
+ else:
+ assert reference.shape[-1] == 2
+ tmp[..., :2] += reference
+ outputs_coord = tmp.sigmoid()
+
+ outputs_classes_one2one.append(
+ outputs_class[:, 0: self.model.num_queries_one2one]
+ )
+ outputs_classes_one2many.append(
+ outputs_class[:, self.model.num_queries_one2one:]
+ )
+ outputs_coords_one2one.append(
+ outputs_coord[:, 0: self.model.num_queries_one2one]
+ )
+ outputs_coords_one2many.append(outputs_coord[:, self.model.num_queries_one2one:])
+ outputs_classes_one2one = torch.stack(outputs_classes_one2one)
+ outputs_coords_one2one = torch.stack(outputs_coords_one2one)
+
+ sampled_logits = outputs_classes_one2one[-1][:, :, self.cls_index]
+ out = {
+ "pred_logits": sampled_logits,
+ "pred_boxes": outputs_coords_one2one[-1],
+ }
+ return out
+
+ def simple_test(self, img, img_metas, rescale=False):
+ # out: dict
+ out = self(img, img_metas)
+ if rescale:
+ ori_target_sizes = [meta_info['ori_shape'][:2] for meta_info in img_metas]
+ else:
+ ori_target_sizes = [meta_info['img_shape'][:2] for meta_info in img_metas]
+ ori_target_sizes = out['pred_logits'].new_tensor(ori_target_sizes, dtype=torch.int64)
+ # results: List[dict(scores, labels, boxes)]
+ results = self.box_postprocessor(out, ori_target_sizes)
+ return results
diff --git a/projects/instance_segment_anything/models/hdetr/models/__init__.py b/projects/instance_segment_anything/models/hdetr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a59c33484884af523013d1ed2ef57032646336a
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/__init__.py
@@ -0,0 +1,15 @@
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+from .deformable_detr import build
+
+
+def build_model(args):
+ return build(args)
+
diff --git a/projects/instance_segment_anything/models/hdetr/models/backbone.py b/projects/instance_segment_anything/models/hdetr/models/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc4f220e45e93efa48965fab866df74888fd7bac
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/backbone.py
@@ -0,0 +1,273 @@
+# ------------------------------------------------------------------------
+# H-DETR
+# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
+# Licensed under the MIT-style license found in the LICENSE file in the root directory
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Backbone modules.
+"""
+from collections import OrderedDict
+
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+from typing import Dict, List
+
+from .util.misc import NestedTensor, is_main_process
+
+from .position_encoding import build_position_encoding
+from .swin_transformer import SwinTransformer
+
+
+class FrozenBatchNorm2d(torch.nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
+ without which any other models than torchvision.models.resnet[18,34,50,101]
+ produce nans.
+ """
+
+ def __init__(self, n, eps=1e-5):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+ self.eps = eps
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ num_batches_tracked_key = prefix + "num_batches_tracked"
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it fuser-friendly
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ eps = self.eps
+ scale = w * (rv + eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class BackboneBase(nn.Module):
+ def __init__(
+ self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool
+ ):
+ super().__init__()
+ for name, parameter in backbone.named_parameters():
+ if (
+ not train_backbone
+ or "layer2" not in name
+ and "layer3" not in name
+ and "layer4" not in name
+ ):
+ parameter.requires_grad_(False)
+ if return_interm_layers:
+ # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+ return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
+ self.strides = [8, 16, 32]
+ self.num_channels = [512, 1024, 2048]
+ else:
+ return_layers = {"layer4": "0"}
+ self.strides = [32]
+ self.num_channels = [2048]
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self.body(tensor_list.tensors)
+ out: Dict[str, NestedTensor] = {}
+ for name, x in xs.items():
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ out[name] = NestedTensor(x, mask)
+ return out
+
+
+class Backbone(BackboneBase):
+ """ResNet backbone with frozen BatchNorm."""
+
+ def __init__(
+ self,
+ name: str,
+ train_backbone: bool,
+ return_interm_layers: bool,
+ dilation: bool,
+ ):
+ norm_layer = FrozenBatchNorm2d
+ backbone = getattr(torchvision.models, name)(
+ replace_stride_with_dilation=[False, False, dilation],
+ pretrained=is_main_process(),
+ norm_layer=norm_layer,
+ )
+ assert name not in ("resnet18", "resnet34"), "number of channels are hard coded"
+ super().__init__(backbone, train_backbone, return_interm_layers)
+ if dilation:
+ self.strides[-1] = self.strides[-1] // 2
+
+
+class TransformerBackbone(nn.Module):
+ def __init__(
+ self, backbone: str, train_backbone: bool, return_interm_layers: bool, args
+ ):
+ super().__init__()
+ out_indices = (1, 2, 3)
+ if backbone == "swin_tiny":
+ backbone = SwinTransformer(
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ ape=False,
+ drop_path_rate=args.drop_path_rate,
+ patch_norm=True,
+ use_checkpoint=True,
+ out_indices=out_indices,
+ )
+ embed_dim = 96
+ # backbone.init_weights(args.pretrained_backbone_path)
+ elif backbone == "swin_small":
+ backbone = SwinTransformer(
+ embed_dim=96,
+ depths=[2, 2, 18, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ ape=False,
+ drop_path_rate=args.drop_path_rate,
+ patch_norm=True,
+ use_checkpoint=True,
+ out_indices=out_indices,
+ )
+ embed_dim = 96
+ # backbone.init_weights(args.pretrained_backbone_path)
+ elif backbone == "swin_large":
+ backbone = SwinTransformer(
+ embed_dim=192,
+ depths=[2, 2, 18, 2],
+ num_heads=[6, 12, 24, 48],
+ window_size=7,
+ ape=False,
+ drop_path_rate=args.drop_path_rate,
+ patch_norm=True,
+ use_checkpoint=True,
+ out_indices=out_indices,
+ )
+ embed_dim = 192
+ # backbone.init_weights(args.pretrained_backbone_path)
+ elif backbone == "swin_large_window12":
+ backbone = SwinTransformer(
+ pretrain_img_size=384,
+ embed_dim=192,
+ depths=[2, 2, 18, 2],
+ num_heads=[6, 12, 24, 48],
+ window_size=12,
+ ape=False,
+ drop_path_rate=args.drop_path_rate,
+ patch_norm=True,
+ use_checkpoint=True,
+ out_indices=out_indices,
+ )
+ embed_dim = 192
+ # backbone.init_weights(args.pretrained_backbone_path)
+ else:
+ raise NotImplementedError
+
+ for name, parameter in backbone.named_parameters():
+ # TODO: freeze some layers?
+ if not train_backbone:
+ parameter.requires_grad_(False)
+
+ if return_interm_layers:
+
+ self.strides = [8, 16, 32]
+ self.num_channels = [
+ embed_dim * 2,
+ embed_dim * 4,
+ embed_dim * 8,
+ ]
+ else:
+ self.strides = [32]
+ self.num_channels = [embed_dim * 8]
+
+ self.body = backbone
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self.body(tensor_list.tensors)
+
+ out: Dict[str, NestedTensor] = {}
+ for name, x in xs.items():
+ m = tensor_list.mask
+ assert m is not None
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
+ out[name] = NestedTensor(x, mask)
+ return out
+
+
+class Joiner(nn.Sequential):
+ def __init__(self, backbone, position_embedding):
+ super().__init__(backbone, position_embedding)
+ self.strides = backbone.strides
+ self.num_channels = backbone.num_channels
+
+ def forward(self, tensor_list: NestedTensor):
+ xs = self[0](tensor_list)
+ out: List[NestedTensor] = []
+ pos = []
+ for name, x in sorted(xs.items()):
+ out.append(x)
+
+ # position encoding
+ for x in out:
+ pos.append(self[1](x).to(x.tensors.dtype))
+
+ return out, pos
+
+
+def build_backbone(args):
+ position_embedding = build_position_encoding(args)
+ train_backbone = False
+ return_interm_layers = args.masks or (args.num_feature_levels > 1)
+ if "resnet" in args.backbone:
+ backbone = Backbone(
+ args.backbone, train_backbone, return_interm_layers, args.dilation,
+ )
+ else:
+ backbone = TransformerBackbone(
+ args.backbone, train_backbone, return_interm_layers, args
+ )
+ model = Joiner(backbone, position_embedding)
+ return model
diff --git a/projects/instance_segment_anything/models/hdetr/models/deformable_detr.py b/projects/instance_segment_anything/models/hdetr/models/deformable_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..afdfaf216b21f96f8ca5123208f03360e2acfa64
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/deformable_detr.py
@@ -0,0 +1,619 @@
+# ------------------------------------------------------------------------
+# H-DETR
+# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
+# Licensed under the MIT-style license found in the LICENSE file in the root directory
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Deformable DETR model and criterion classes.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+import math
+
+from .util import box_ops
+from .util.misc import (
+ NestedTensor,
+ nested_tensor_from_tensor_list,
+ accuracy,
+ get_world_size,
+ interpolate,
+ is_dist_avail_and_initialized,
+ inverse_sigmoid,
+)
+
+from .backbone import build_backbone
+from .matcher import build_matcher
+from .segmentation import (
+ DETRsegm,
+ PostProcessPanoptic,
+ PostProcessSegm,
+ dice_loss,
+ sigmoid_focal_loss,
+)
+from .deformable_transformer import build_deforamble_transformer
+import copy
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class DeformableDETR(nn.Module):
+ """ This is the Deformable DETR module that performs object detection """
+
+ def __init__(
+ self,
+ backbone,
+ transformer,
+ num_classes,
+ num_feature_levels,
+ aux_loss=True,
+ with_box_refine=False,
+ two_stage=False,
+ num_queries_one2one=300,
+ num_queries_one2many=0,
+ mixed_selection=False,
+ ):
+ """ Initializes the model.
+ Parameters:
+ backbone: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ num_classes: number of object classes
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
+ with_box_refine: iterative bounding box refinement
+ two_stage: two-stage Deformable DETR
+ num_queries_one2one: number of object queries for one-to-one matching part
+ num_queries_one2many: number of object queries for one-to-many matching part
+ mixed_selection: a trick for Deformable DETR two stage
+
+ """
+ super().__init__()
+ num_queries = num_queries_one2one + num_queries_one2many
+ self.num_queries = num_queries
+ self.transformer = transformer
+ hidden_dim = transformer.d_model
+ self.class_embed = nn.Linear(hidden_dim, num_classes)
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
+ self.num_feature_levels = num_feature_levels
+ if not two_stage:
+ self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
+ elif mixed_selection:
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
+ if num_feature_levels > 1:
+ num_backbone_outs = len(backbone.strides)
+ input_proj_list = []
+ for _ in range(num_backbone_outs):
+ in_channels = backbone.num_channels[_]
+ input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )
+ )
+ for _ in range(num_feature_levels - num_backbone_outs):
+ input_proj_list.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels, hidden_dim, kernel_size=3, stride=2, padding=1
+ ),
+ nn.GroupNorm(32, hidden_dim),
+ )
+ )
+ in_channels = hidden_dim
+ self.input_proj = nn.ModuleList(input_proj_list)
+ else:
+ self.input_proj = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
+ nn.GroupNorm(32, hidden_dim),
+ )
+ ]
+ )
+ self.backbone = backbone
+ self.aux_loss = aux_loss
+ self.with_box_refine = with_box_refine
+ self.two_stage = two_stage
+
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ self.class_embed.bias.data = torch.ones(num_classes) * bias_value
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
+ for proj in self.input_proj:
+ nn.init.xavier_uniform_(proj[0].weight, gain=1)
+ nn.init.constant_(proj[0].bias, 0)
+
+ # if two-stage, the last class_embed and bbox_embed is for region proposal generation
+ num_pred = (
+ (transformer.decoder.num_layers + 1)
+ if two_stage
+ else transformer.decoder.num_layers
+ )
+ if with_box_refine:
+ self.class_embed = _get_clones(self.class_embed, num_pred)
+ self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
+ nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
+ # hack implementation for iterative bounding box refinement
+ self.transformer.decoder.bbox_embed = self.bbox_embed
+ else:
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
+ self.class_embed = nn.ModuleList(
+ [self.class_embed for _ in range(num_pred)]
+ )
+ self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
+ self.transformer.decoder.bbox_embed = None
+ if two_stage:
+ # hack implementation for two-stage
+ self.transformer.decoder.class_embed = self.class_embed
+ for box_embed in self.bbox_embed:
+ nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
+ self.num_queries_one2one = num_queries_one2one
+ self.mixed_selection = mixed_selection
+
+ def forward(self, samples: NestedTensor):
+ """ The forward expects a NestedTensor, which consists of:
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
+
+ It returns a dict with the following elements:
+ - "pred_logits": the classification logits (including no-object) for all queries.
+ Shape= [batch_size x num_queries x (num_classes + 1)]
+ - "pred_boxes": The normalized boxes coordinates for all queries, represented as
+ (center_x, center_y, height, width). These values are normalized in [0, 1],
+ relative to the size of each individual image (disregarding possible padding).
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
+ dictionnaries containing the two above keys for each decoder layer.
+ """
+ if not isinstance(samples, NestedTensor):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, pos = self.backbone(samples)
+
+ srcs = []
+ masks = []
+ for l, feat in enumerate(features):
+ src, mask = feat.decompose()
+ srcs.append(self.input_proj[l](src))
+ masks.append(mask)
+ assert mask is not None
+ if self.num_feature_levels > len(srcs):
+ _len_srcs = len(srcs)
+ for l in range(_len_srcs, self.num_feature_levels):
+ if l == _len_srcs:
+ src = self.input_proj[l](features[-1].tensors)
+ else:
+ src = self.input_proj[l](srcs[-1])
+ m = samples.mask
+ mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(
+ torch.bool
+ )[0]
+ pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
+ srcs.append(src)
+ masks.append(mask)
+ pos.append(pos_l)
+
+ query_embeds = None
+ if not self.two_stage or self.mixed_selection:
+ query_embeds = self.query_embed.weight[0 : self.num_queries, :]
+
+ # make attn mask
+ """ attention mask to prevent information leakage
+ """
+ self_attn_mask = (
+ torch.zeros([self.num_queries, self.num_queries,]).bool().to(src.device)
+ )
+ self_attn_mask[self.num_queries_one2one :, 0 : self.num_queries_one2one,] = True
+ self_attn_mask[0 : self.num_queries_one2one, self.num_queries_one2one :,] = True
+
+ (
+ hs,
+ init_reference,
+ inter_references,
+ enc_outputs_class,
+ enc_outputs_coord_unact,
+ ) = self.transformer(srcs, masks, pos, query_embeds, self_attn_mask)
+
+ outputs_classes_one2one = []
+ outputs_coords_one2one = []
+ outputs_classes_one2many = []
+ outputs_coords_one2many = []
+ for lvl in range(hs.shape[0]):
+ if lvl == 0:
+ reference = init_reference
+ else:
+ reference = inter_references[lvl - 1]
+ reference = inverse_sigmoid(reference)
+ outputs_class = self.class_embed[lvl](hs[lvl])
+ tmp = self.bbox_embed[lvl](hs[lvl])
+ if reference.shape[-1] == 4:
+ tmp += reference
+ else:
+ assert reference.shape[-1] == 2
+ tmp[..., :2] += reference
+ outputs_coord = tmp.sigmoid()
+
+ outputs_classes_one2one.append(
+ outputs_class[:, 0 : self.num_queries_one2one]
+ )
+ outputs_classes_one2many.append(
+ outputs_class[:, self.num_queries_one2one :]
+ )
+ outputs_coords_one2one.append(
+ outputs_coord[:, 0 : self.num_queries_one2one]
+ )
+ outputs_coords_one2many.append(outputs_coord[:, self.num_queries_one2one :])
+ outputs_classes_one2one = torch.stack(outputs_classes_one2one)
+ outputs_coords_one2one = torch.stack(outputs_coords_one2one)
+ outputs_classes_one2many = torch.stack(outputs_classes_one2many)
+ outputs_coords_one2many = torch.stack(outputs_coords_one2many)
+
+ out = {
+ "pred_logits": outputs_classes_one2one[-1],
+ "pred_boxes": outputs_coords_one2one[-1],
+ "pred_logits_one2many": outputs_classes_one2many[-1],
+ "pred_boxes_one2many": outputs_coords_one2many[-1],
+ }
+ if self.aux_loss:
+ out["aux_outputs"] = self._set_aux_loss(
+ outputs_classes_one2one, outputs_coords_one2one
+ )
+ out["aux_outputs_one2many"] = self._set_aux_loss(
+ outputs_classes_one2many, outputs_coords_one2many
+ )
+
+ if self.two_stage:
+ enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
+ out["enc_outputs"] = {
+ "pred_logits": enc_outputs_class,
+ "pred_boxes": enc_outputs_coord,
+ }
+ return out
+
+ @torch.jit.unused
+ def _set_aux_loss(self, outputs_class, outputs_coord):
+ # this is a workaround to make torchscript happy, as torchscript
+ # doesn't support dictionary with non-homogeneous values, such
+ # as a dict having both a Tensor and a list.
+ return [
+ {"pred_logits": a, "pred_boxes": b}
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
+ ]
+
+
+class SetCriterion(nn.Module):
+ """ This class computes the loss for DETR.
+ The process happens in two steps:
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+ """
+
+ def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
+ """ Create the criterion.
+ Parameters:
+ num_classes: number of object categories, omitting the special no-object category
+ matcher: module able to compute a matching between targets and proposals
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
+ focal_alpha: alpha in Focal Loss
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.matcher = matcher
+ self.weight_dict = weight_dict
+ self.losses = losses
+ self.focal_alpha = focal_alpha
+
+ def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
+ """Classification loss (NLL)
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+ """
+ assert "pred_logits" in outputs
+ src_logits = outputs["pred_logits"]
+
+ idx = self._get_src_permutation_idx(indices)
+ target_classes_o = torch.cat(
+ [t["labels"][J] for t, (_, J) in zip(targets, indices)]
+ )
+ target_classes = torch.full(
+ src_logits.shape[:2],
+ self.num_classes,
+ dtype=torch.int64,
+ device=src_logits.device,
+ )
+ target_classes[idx] = target_classes_o
+
+ target_classes_onehot = torch.zeros(
+ [src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
+ dtype=src_logits.dtype,
+ layout=src_logits.layout,
+ device=src_logits.device,
+ )
+ target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
+
+ target_classes_onehot = target_classes_onehot[:, :, :-1]
+ loss_ce = (
+ sigmoid_focal_loss(
+ src_logits,
+ target_classes_onehot,
+ num_boxes,
+ alpha=self.focal_alpha,
+ gamma=2,
+ )
+ * src_logits.shape[1]
+ )
+ losses = {"loss_ce": loss_ce}
+
+ if log:
+ # TODO this should probably be a separate loss, not hacked in this one here
+ losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
+ return losses
+
+ @torch.no_grad()
+ def loss_cardinality(self, outputs, targets, indices, num_boxes):
+ """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
+ """
+ pred_logits = outputs["pred_logits"]
+ device = pred_logits.device
+ tgt_lengths = torch.as_tensor(
+ [len(v["labels"]) for v in targets], device=device
+ )
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
+ losses = {"cardinality_error": card_err}
+ return losses
+
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+ The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
+ """
+ assert "pred_boxes" in outputs
+ idx = self._get_src_permutation_idx(indices)
+ src_boxes = outputs["pred_boxes"][idx]
+ target_boxes = torch.cat(
+ [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0
+ )
+
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
+
+ losses = {}
+ losses["loss_bbox"] = loss_bbox.sum() / num_boxes
+
+ loss_giou = 1 - torch.diag(
+ box_ops.generalized_box_iou(
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
+ box_ops.box_cxcywh_to_xyxy(target_boxes),
+ )
+ )
+ losses["loss_giou"] = loss_giou.sum() / num_boxes
+ return losses
+
+ def loss_masks(self, outputs, targets, indices, num_boxes):
+ """Compute the losses related to the masks: the focal loss and the dice loss.
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+ """
+ assert "pred_masks" in outputs
+
+ src_idx = self._get_src_permutation_idx(indices)
+ tgt_idx = self._get_tgt_permutation_idx(indices)
+
+ src_masks = outputs["pred_masks"]
+
+ # TODO use valid to mask invalid areas due to padding in loss
+ target_masks, valid = nested_tensor_from_tensor_list(
+ [t["masks"] for t in targets]
+ ).decompose()
+ target_masks = target_masks.to(src_masks)
+
+ src_masks = src_masks[src_idx]
+ # upsample predictions to the target size
+ src_masks = interpolate(
+ src_masks[:, None],
+ size=target_masks.shape[-2:],
+ mode="bilinear",
+ align_corners=False,
+ )
+ src_masks = src_masks[:, 0].flatten(1)
+
+ target_masks = target_masks[tgt_idx].flatten(1)
+
+ losses = {
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
+ "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
+ }
+ return losses
+
+ def _get_src_permutation_idx(self, indices):
+ # permute predictions following indices
+ batch_idx = torch.cat(
+ [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
+ )
+ src_idx = torch.cat([src for (src, _) in indices])
+ return batch_idx, src_idx
+
+ def _get_tgt_permutation_idx(self, indices):
+ # permute targets following indices
+ batch_idx = torch.cat(
+ [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
+ )
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+ return batch_idx, tgt_idx
+
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+ loss_map = {
+ "labels": self.loss_labels,
+ "cardinality": self.loss_cardinality,
+ "boxes": self.loss_boxes,
+ "masks": self.loss_masks,
+ }
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+ def forward(self, outputs, targets):
+ """ This performs the loss computation.
+ Parameters:
+ outputs: dict of tensors, see the output specification of the model for the format
+ targets: list of dicts, such that len(targets) == batch_size.
+ The expected keys in each dict depends on the losses applied, see each loss' doc
+ """
+ outputs_without_aux = {
+ k: v
+ for k, v in outputs.items()
+ if k != "aux_outputs" and k != "enc_outputs"
+ }
+
+ # Retrieve the matching between the outputs of the last layer and the targets
+ indices = self.matcher(outputs_without_aux, targets)
+
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
+ num_boxes = sum(len(t["labels"]) for t in targets)
+ num_boxes = torch.as_tensor(
+ [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
+ )
+ if is_dist_avail_and_initialized():
+ torch.distributed.all_reduce(num_boxes)
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+ # Compute all the requested losses
+ losses = {}
+ for loss in self.losses:
+ kwargs = {}
+ losses.update(
+ self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)
+ )
+
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+ if "aux_outputs" in outputs:
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
+ indices = self.matcher(aux_outputs, targets)
+ for loss in self.losses:
+ if loss == "masks":
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == "labels":
+ # Logging is enabled only for the last layer
+ kwargs["log"] = False
+ l_dict = self.get_loss(
+ loss, aux_outputs, targets, indices, num_boxes, **kwargs
+ )
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ if "enc_outputs" in outputs:
+ enc_outputs = outputs["enc_outputs"]
+ bin_targets = copy.deepcopy(targets)
+ for bt in bin_targets:
+ bt["labels"] = torch.zeros_like(bt["labels"])
+ indices = self.matcher(enc_outputs, bin_targets)
+ for loss in self.losses:
+ if loss == "masks":
+ # Intermediate masks losses are too costly to compute, we ignore them.
+ continue
+ kwargs = {}
+ if loss == "labels":
+ # Logging is enabled only for the last layer
+ kwargs["log"] = False
+ l_dict = self.get_loss(
+ loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs
+ )
+ l_dict = {k + f"_enc": v for k, v in l_dict.items()}
+ losses.update(l_dict)
+
+ return losses
+
+
+class PostProcess(nn.Module):
+ """ This module converts the model's output into the format expected by the coco api"""
+
+ def __init__(self, topk=100):
+ super().__init__()
+ self.topk = topk
+ print("topk for eval:", self.topk)
+
+ @torch.no_grad()
+ def forward(self, outputs, target_sizes):
+ """ Perform the computation
+ Parameters:
+ outputs: raw outputs of the model
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
+ For evaluation, this must be the original image size (before any data augmentation)
+ For visualization, this should be the image size after data augment, but before padding
+ """
+ out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
+
+ assert len(out_logits) == len(target_sizes)
+ assert target_sizes.shape[1] == 2
+
+ prob = out_logits.sigmoid()
+ topk_values, topk_indexes = torch.topk(
+ prob.view(out_logits.shape[0], -1), self.topk, dim=1
+ )
+ scores = topk_values
+ topk_boxes = topk_indexes // out_logits.shape[2]
+ labels = topk_indexes % out_logits.shape[2]
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
+ boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
+
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [
+ {"scores": s, "labels": l, "boxes": b}
+ for s, l, b in zip(scores, labels, boxes)
+ ]
+
+ return results
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+def build(args):
+ backbone = build_backbone(args)
+
+ transformer = build_deforamble_transformer(args)
+ model = DeformableDETR(
+ backbone,
+ transformer,
+ num_classes=args.num_classes,
+ num_feature_levels=args.num_feature_levels,
+ aux_loss=args.aux_loss,
+ with_box_refine=args.with_box_refine,
+ two_stage=args.two_stage,
+ num_queries_one2one=args.num_queries_one2one,
+ num_queries_one2many=args.num_queries_one2many,
+ mixed_selection=args.mixed_selection,
+ )
+
+ box_postprocessor = PostProcess(topk=args.topk)
+
+ return model, box_postprocessor
diff --git a/projects/instance_segment_anything/models/hdetr/models/deformable_transformer.py b/projects/instance_segment_anything/models/hdetr/models/deformable_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7b23fd5a1514d477be119f54c18fc83cbf78dab
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/deformable_transformer.py
@@ -0,0 +1,636 @@
+# ------------------------------------------------------------------------
+# H-DETR
+# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
+# Licensed under the MIT-style license found in the LICENSE file in the root directory
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+import copy
+from typing import Optional, List
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+import torch.utils.checkpoint as checkpoint
+from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
+
+from .util.misc import inverse_sigmoid
+from projects.instance_segment_anything.ops.modules import MSDeformAttn
+
+
+class DeformableTransformer(nn.Module):
+ def __init__(
+ self,
+ d_model=256,
+ nhead=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ dim_feedforward=1024,
+ dropout=0.1,
+ activation="relu",
+ return_intermediate_dec=False,
+ num_feature_levels=4,
+ dec_n_points=4,
+ enc_n_points=4,
+ two_stage=False,
+ two_stage_num_proposals=300,
+ look_forward_twice=False,
+ mixed_selection=False,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+
+ self.d_model = d_model
+ self.nhead = nhead
+ self.two_stage = two_stage
+ self.two_stage_num_proposals = two_stage_num_proposals
+
+ encoder_layer = DeformableTransformerEncoderLayer(
+ d_model,
+ dim_feedforward,
+ dropout,
+ activation,
+ num_feature_levels,
+ nhead,
+ enc_n_points,
+ )
+ self.encoder = DeformableTransformerEncoder(
+ encoder_layer, num_encoder_layers, use_checkpoint
+ )
+
+ decoder_layer = DeformableTransformerDecoderLayer(
+ d_model,
+ dim_feedforward,
+ dropout,
+ activation,
+ num_feature_levels,
+ nhead,
+ dec_n_points,
+ )
+ self.decoder = DeformableTransformerDecoder(
+ decoder_layer,
+ num_decoder_layers,
+ return_intermediate_dec,
+ look_forward_twice,
+ use_checkpoint,
+ )
+
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
+
+ if two_stage:
+ self.enc_output = nn.Linear(d_model, d_model)
+ self.enc_output_norm = nn.LayerNorm(d_model)
+ self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
+ self.pos_trans_norm = nn.LayerNorm(d_model * 2)
+ else:
+ self.reference_points = nn.Linear(d_model, 2)
+
+ self.mixed_selection = mixed_selection
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MSDeformAttn):
+ m._reset_parameters()
+ if not self.two_stage:
+ xavier_uniform_(self.reference_points.weight.data, gain=1.0)
+ constant_(self.reference_points.bias.data, 0.0)
+ normal_(self.level_embed)
+
+ def get_proposal_pos_embed(self, proposals):
+ num_pos_feats = 128
+ temperature = 10000
+ scale = 2 * math.pi
+
+ dim_t = torch.arange(
+ num_pos_feats, dtype=torch.float32, device=proposals.device
+ )
+ dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
+ # N, L, 4
+ proposals = proposals.sigmoid() * scale
+ # N, L, 4, 128
+ pos = proposals[:, :, :, None] / dim_t
+ # N, L, 4, 64, 2
+ pos = torch.stack(
+ (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4
+ ).flatten(2)
+ return pos
+
+ def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
+ N_, S_, C_ = memory.shape
+ base_scale = 4.0
+ proposals = []
+ _cur = 0
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H_ * W_)].view(
+ N_, H_, W_, 1
+ )
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
+
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(
+ 0, H_ - 1, H_, dtype=torch.float32, device=memory.device
+ ),
+ torch.linspace(
+ 0, W_ - 1, W_, dtype=torch.float32, device=memory.device
+ ),
+ )
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(
+ N_, 1, 1, 2
+ )
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
+ wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
+ proposals.append(proposal)
+ _cur += H_ * W_
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = (
+ (output_proposals > 0.01) & (output_proposals < 0.99)
+ ).all(-1, keepdim=True)
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
+ output_proposals = output_proposals.masked_fill(
+ memory_padding_mask.unsqueeze(-1), float("inf")
+ )
+ output_proposals = output_proposals.masked_fill(
+ ~output_proposals_valid, float("inf")
+ )
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(
+ memory_padding_mask.unsqueeze(-1), float(0)
+ )
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
+ return output_memory, output_proposals
+
+ def get_valid_ratio(self, mask):
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(self, srcs, masks, pos_embeds, query_embed=None, self_attn_mask=None):
+
+ # prepare input for encoder
+ src_flatten = []
+ mask_flatten = []
+ lvl_pos_embed_flatten = []
+ spatial_shapes = []
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
+ bs, c, h, w = src.shape
+ spatial_shape = (h, w)
+ spatial_shapes.append(spatial_shape)
+ src = src.flatten(2).transpose(1, 2)
+ mask = mask.flatten(1)
+ pos_embed = pos_embed.flatten(2).transpose(1, 2)
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ src_flatten.append(src)
+ mask_flatten.append(mask)
+ src_flatten = torch.cat(src_flatten, 1)
+ mask_flatten = torch.cat(mask_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ spatial_shapes = torch.as_tensor(
+ spatial_shapes, dtype=torch.long, device=src_flatten.device
+ )
+ level_start_index = torch.cat(
+ (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
+ )
+ valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
+
+ # encoder
+ memory = self.encoder(
+ src_flatten,
+ spatial_shapes,
+ level_start_index,
+ valid_ratios,
+ lvl_pos_embed_flatten,
+ mask_flatten,
+ )
+
+ # prepare input for decoder
+ bs, _, c = memory.shape
+ if self.two_stage:
+ output_memory, output_proposals = self.gen_encoder_output_proposals(
+ memory, mask_flatten, spatial_shapes
+ )
+
+ # hack implementation for two-stage Deformable DETR
+ enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](
+ output_memory
+ )
+ enc_outputs_coord_unact = (
+ self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
+ + output_proposals
+ )
+
+ topk = self.two_stage_num_proposals
+ topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
+ topk_coords_unact = torch.gather(
+ enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
+ )
+ topk_coords_unact = topk_coords_unact.detach()
+ reference_points = topk_coords_unact.sigmoid()
+ init_reference_out = reference_points
+ pos_trans_out = self.pos_trans_norm(
+ self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))
+ )
+
+ if not self.mixed_selection:
+ query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
+ else:
+ # query_embed here is the content embed for deformable DETR
+ tgt = query_embed.unsqueeze(0).expand(bs, -1, -1)
+ query_embed, _ = torch.split(pos_trans_out, c, dim=2)
+ else:
+ query_embed, tgt = torch.split(query_embed, c, dim=1)
+ query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
+ tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
+ reference_points = self.reference_points(query_embed).sigmoid()
+ init_reference_out = reference_points
+
+ # decoder
+ hs, inter_references = self.decoder(
+ tgt,
+ reference_points,
+ memory,
+ spatial_shapes,
+ level_start_index,
+ valid_ratios,
+ query_embed,
+ mask_flatten,
+ self_attn_mask,
+ )
+
+ inter_references_out = inter_references
+ if self.two_stage:
+ return (
+ hs,
+ init_reference_out,
+ inter_references_out,
+ enc_outputs_class,
+ enc_outputs_coord_unact,
+ )
+ return hs, init_reference_out, inter_references_out, None, None
+
+
+class DeformableTransformerEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model=256,
+ d_ffn=1024,
+ dropout=0.1,
+ activation="relu",
+ n_levels=4,
+ n_heads=8,
+ n_points=4,
+ ):
+ super().__init__()
+
+ # self attention
+ self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout2 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout3 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, src):
+ src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+ src = src + self.dropout3(src2)
+ src = self.norm2(src)
+ return src
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(
+ self,
+ src,
+ pos,
+ reference_points,
+ spatial_shapes,
+ level_start_index,
+ padding_mask=None,
+ ):
+ # self attention
+ src2 = self.self_attn(
+ self.with_pos_embed(src, pos),
+ reference_points,
+ src,
+ spatial_shapes,
+ level_start_index,
+ padding_mask,
+ )
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ # ffn
+ src = self.forward_ffn(src)
+
+ return src
+
+
+class DeformableTransformerEncoder(nn.Module):
+ def __init__(self, encoder_layer, num_layers, use_checkpoint=False):
+ super().__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.use_checkpoint = use_checkpoint
+
+ @staticmethod
+ def get_reference_points(spatial_shapes, valid_ratios, device):
+ reference_points_list = []
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
+
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
+ torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
+ )
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(
+ self,
+ src,
+ spatial_shapes,
+ level_start_index,
+ valid_ratios,
+ pos=None,
+ padding_mask=None,
+ ):
+ output = src
+ reference_points = self.get_reference_points(
+ spatial_shapes, valid_ratios, device=src.device
+ )
+ for _, layer in enumerate(self.layers):
+ if self.use_checkpoint:
+ output = checkpoint.checkpoint(
+ layer,
+ output,
+ pos,
+ reference_points,
+ spatial_shapes,
+ level_start_index,
+ padding_mask,
+ )
+ else:
+ output = layer(
+ output,
+ pos,
+ reference_points,
+ spatial_shapes,
+ level_start_index,
+ padding_mask,
+ )
+
+ return output
+
+
+class DeformableTransformerDecoderLayer(nn.Module):
+ def __init__(
+ self,
+ d_model=256,
+ d_ffn=1024,
+ dropout=0.1,
+ activation="relu",
+ n_levels=4,
+ n_heads=8,
+ n_points=4,
+ ):
+ super().__init__()
+
+ # cross attention
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
+ self.dropout1 = nn.Dropout(dropout)
+ self.norm1 = nn.LayerNorm(d_model)
+
+ # self attention
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.norm2 = nn.LayerNorm(d_model)
+
+ # ffn
+ self.linear1 = nn.Linear(d_model, d_ffn)
+ self.activation = _get_activation_fn(activation)
+ self.dropout3 = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(d_ffn, d_model)
+ self.dropout4 = nn.Dropout(dropout)
+ self.norm3 = nn.LayerNorm(d_model)
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward_ffn(self, tgt):
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
+ tgt = tgt + self.dropout4(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(
+ self,
+ tgt,
+ query_pos,
+ reference_points,
+ src,
+ src_spatial_shapes,
+ level_start_index,
+ src_padding_mask=None,
+ self_attn_mask=None,
+ ):
+ # self attention
+ q = k = self.with_pos_embed(tgt, query_pos)
+ tgt2 = self.self_attn(
+ q.transpose(0, 1),
+ k.transpose(0, 1),
+ tgt.transpose(0, 1),
+ attn_mask=self_attn_mask,
+ )[0].transpose(0, 1)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ # cross attention
+ tgt2 = self.cross_attn(
+ self.with_pos_embed(tgt, query_pos),
+ reference_points,
+ src,
+ src_spatial_shapes,
+ level_start_index,
+ src_padding_mask,
+ )
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+
+ # ffn
+ tgt = self.forward_ffn(tgt)
+
+ return tgt
+
+
+class DeformableTransformerDecoder(nn.Module):
+ def __init__(
+ self,
+ decoder_layer,
+ num_layers,
+ return_intermediate=False,
+ look_forward_twice=False,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.return_intermediate = return_intermediate
+ self.look_forward_twice = look_forward_twice
+ self.use_checkpoint = use_checkpoint
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
+ self.bbox_embed = None
+ self.class_embed = None
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(
+ self,
+ tgt,
+ reference_points,
+ src,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_valid_ratios,
+ query_pos=None,
+ src_padding_mask=None,
+ self_attn_mask=None,
+ ):
+ output = tgt
+
+ intermediate = []
+ intermediate_reference_points = []
+ for lid, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ reference_points_input = (
+ reference_points[:, :, None]
+ * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
+ )
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = (
+ reference_points[:, :, None] * src_valid_ratios[:, None]
+ )
+ if self.use_checkpoint:
+ output = checkpoint.checkpoint(
+ layer,
+ output,
+ query_pos,
+ reference_points_input,
+ src,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_padding_mask,
+ self_attn_mask,
+ )
+ else:
+ output = layer(
+ output,
+ query_pos,
+ reference_points_input,
+ src,
+ src_spatial_shapes,
+ src_level_start_index,
+ src_padding_mask,
+ self_attn_mask,
+ )
+
+ # hack implementation for iterative bounding box refinement
+ if self.bbox_embed is not None:
+ tmp = self.bbox_embed[lid](output)
+ if reference_points.shape[-1] == 4:
+ new_reference_points = tmp + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ assert reference_points.shape[-1] == 2
+ new_reference_points = tmp
+ new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
+ reference_points
+ )
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(
+ new_reference_points
+ if self.look_forward_twice
+ else reference_points
+ )
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(intermediate_reference_points)
+
+ return output, reference_points
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+
+def build_deforamble_transformer(args):
+ return DeformableTransformer(
+ d_model=args.hidden_dim,
+ nhead=args.nheads,
+ num_encoder_layers=args.enc_layers,
+ num_decoder_layers=args.dec_layers,
+ dim_feedforward=args.dim_feedforward,
+ dropout=args.dropout,
+ activation="relu",
+ return_intermediate_dec=True,
+ num_feature_levels=args.num_feature_levels,
+ dec_n_points=args.dec_n_points,
+ enc_n_points=args.enc_n_points,
+ two_stage=args.two_stage,
+ two_stage_num_proposals=args.num_queries_one2one + args.num_queries_one2many,
+ mixed_selection=args.mixed_selection,
+ look_forward_twice=args.look_forward_twice,
+ use_checkpoint=args.use_checkpoint,
+ )
+
diff --git a/projects/instance_segment_anything/models/hdetr/models/matcher.py b/projects/instance_segment_anything/models/hdetr/models/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..e49a851a84686bf199cebfb9f62d1a2c349755ae
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/matcher.py
@@ -0,0 +1,124 @@
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Modules to compute the matching cost and solve the corresponding LSAP.
+"""
+import torch
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+
+from .util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+
+
+class HungarianMatcher(nn.Module):
+ """This class computes an assignment between the targets and the predictions of the network
+
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+ while the others are un-matched (and thus treated as non-objects).
+ """
+
+ def __init__(
+ self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1
+ ):
+ """Creates the matcher
+
+ Params:
+ cost_class: This is the relative weight of the classification error in the matching cost
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
+ """
+ super().__init__()
+ self.cost_class = cost_class
+ self.cost_bbox = cost_bbox
+ self.cost_giou = cost_giou
+ assert (
+ cost_class != 0 or cost_bbox != 0 or cost_giou != 0
+ ), "all costs cant be 0"
+
+ def forward(self, outputs, targets):
+ """ Performs the matching
+
+ Params:
+ outputs: This is a dict that contains at least these entries:
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
+
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+ objects in the target) containing the class labels
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+
+ Returns:
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
+ - index_i is the indices of the selected predictions (in order)
+ - index_j is the indices of the corresponding selected targets (in order)
+ For each batch element, it holds:
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+ """
+ with torch.no_grad():
+ bs, num_queries = outputs["pred_logits"].shape[:2]
+
+ # We flatten to compute the cost matrices in a batch
+ out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
+ out_bbox = outputs["pred_boxes"].flatten(
+ 0, 1
+ ) # [batch_size * num_queries, 4]
+
+ # Also concat the target labels and boxes
+ tgt_ids = torch.cat([v["labels"] for v in targets])
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+ # Compute the classification cost.
+ alpha = 0.25
+ gamma = 2.0
+ neg_cost_class = (
+ (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
+ )
+ pos_cost_class = (
+ alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
+ )
+ cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
+
+ # Compute the L1 cost between boxes
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
+
+ # Compute the giou cost betwen boxes
+ cost_giou = -generalized_box_iou(
+ box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
+ )
+
+ # Final cost matrix
+ C = (
+ self.cost_bbox * cost_bbox
+ + self.cost_class * cost_class
+ + self.cost_giou * cost_giou
+ )
+ C = C.view(bs, num_queries, -1).cpu()
+
+ sizes = [len(v["boxes"]) for v in targets]
+ indices = [
+ linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
+ ]
+ return [
+ (
+ torch.as_tensor(i, dtype=torch.int64),
+ torch.as_tensor(j, dtype=torch.int64),
+ )
+ for i, j in indices
+ ]
+
+
+def build_matcher(args):
+ return HungarianMatcher(
+ cost_class=args.set_cost_class,
+ cost_bbox=args.set_cost_bbox,
+ cost_giou=args.set_cost_giou,
+ )
diff --git a/projects/instance_segment_anything/models/hdetr/models/position_encoding.py b/projects/instance_segment_anything/models/hdetr/models/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf2868460c603df90fbd1215a6ac4f8c616af9c6
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/position_encoding.py
@@ -0,0 +1,113 @@
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Various positional encodings for the transformer.
+"""
+import math
+import torch
+from torch import nn
+
+from .util.misc import NestedTensor
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ mask = tensor_list.mask
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+class PositionEmbeddingLearned(nn.Module):
+ """
+ Absolute pos embedding, learned.
+ """
+
+ def __init__(self, num_pos_feats=256):
+ super().__init__()
+ self.row_embed = nn.Embedding(50, num_pos_feats)
+ self.col_embed = nn.Embedding(50, num_pos_feats)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.row_embed.weight)
+ nn.init.uniform_(self.col_embed.weight)
+
+ def forward(self, tensor_list: NestedTensor):
+ x = tensor_list.tensors
+ h, w = x.shape[-2:]
+ i = torch.arange(w, device=x.device)
+ j = torch.arange(h, device=x.device)
+ x_emb = self.col_embed(i)
+ y_emb = self.row_embed(j)
+ pos = (
+ torch.cat(
+ [
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
+ y_emb.unsqueeze(1).repeat(1, w, 1),
+ ],
+ dim=-1,
+ )
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .repeat(x.shape[0], 1, 1, 1)
+ )
+ return pos
+
+
+def build_position_encoding(args):
+ N_steps = args.hidden_dim // 2
+ if args.position_embedding in ("v2", "sine"):
+ # TODO find a better way of exposing other arguments
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
+ elif args.position_embedding in ("v3", "learned"):
+ position_embedding = PositionEmbeddingLearned(N_steps)
+ else:
+ raise ValueError(f"not supported {args.position_embedding}")
+
+ return position_embedding
diff --git a/projects/instance_segment_anything/models/hdetr/models/segmentation.py b/projects/instance_segment_anything/models/hdetr/models/segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c70cca99a5bb274b2d77298ac236d75663cc28
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/segmentation.py
@@ -0,0 +1,427 @@
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+This file provides the definition of the convolutional heads used to predict masks, as well as the losses
+"""
+import io
+from collections import defaultdict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+
+from .util import box_ops
+from .util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list
+
+try:
+ from panopticapi.utils import id2rgb, rgb2id
+except ImportError:
+ pass
+
+
+class DETRsegm(nn.Module):
+ def __init__(self, detr, freeze_detr=False):
+ super().__init__()
+ self.detr = detr
+
+ if freeze_detr:
+ for p in self.parameters():
+ p.requires_grad_(False)
+
+ hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead
+ self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0)
+ self.mask_head = MaskHeadSmallConv(
+ hidden_dim + nheads, [1024, 512, 256], hidden_dim
+ )
+
+ def forward(self, samples: NestedTensor):
+ if not isinstance(samples, NestedTensor):
+ samples = nested_tensor_from_tensor_list(samples)
+ features, pos = self.detr.backbone(samples)
+
+ bs = features[-1].tensors.shape[0]
+
+ src, mask = features[-1].decompose()
+ src_proj = self.detr.input_proj(src)
+ hs, memory = self.detr.transformer(
+ src_proj, mask, self.detr.query_embed.weight, pos[-1]
+ )
+
+ outputs_class = self.detr.class_embed(hs)
+ outputs_coord = self.detr.bbox_embed(hs).sigmoid()
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
+ if self.detr.aux_loss:
+ out["aux_outputs"] = [
+ {"pred_logits": a, "pred_boxes": b}
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
+ ]
+
+ # FIXME h_boxes takes the last one computed, keep this in mind
+ bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
+
+ seg_masks = self.mask_head(
+ src_proj,
+ bbox_mask,
+ [features[2].tensors, features[1].tensors, features[0].tensors],
+ )
+ outputs_seg_masks = seg_masks.view(
+ bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]
+ )
+
+ out["pred_masks"] = outputs_seg_masks
+ return out
+
+
+class MaskHeadSmallConv(nn.Module):
+ """
+ Simple convolutional head, using group norm.
+ Upsampling is done using a FPN approach
+ """
+
+ def __init__(self, dim, fpn_dims, context_dim):
+ super().__init__()
+
+ inter_dims = [
+ dim,
+ context_dim // 2,
+ context_dim // 4,
+ context_dim // 8,
+ context_dim // 16,
+ context_dim // 64,
+ ]
+ self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
+ self.gn1 = torch.nn.GroupNorm(8, dim)
+ self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
+ self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
+ self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
+ self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
+ self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
+ self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
+ self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
+ self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
+ self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1)
+
+ self.dim = dim
+
+ self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
+ self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
+ self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_uniform_(m.weight, a=1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x, bbox_mask, fpns):
+ def expand(tensor, length):
+ return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
+
+ x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
+
+ x = self.lay1(x)
+ x = self.gn1(x)
+ x = F.relu(x)
+ x = self.lay2(x)
+ x = self.gn2(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter1(fpns[0])
+ if cur_fpn.size(0) != x.size(0):
+ cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
+ x = self.lay3(x)
+ x = self.gn3(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter2(fpns[1])
+ if cur_fpn.size(0) != x.size(0):
+ cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
+ x = self.lay4(x)
+ x = self.gn4(x)
+ x = F.relu(x)
+
+ cur_fpn = self.adapter3(fpns[2])
+ if cur_fpn.size(0) != x.size(0):
+ cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
+ x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
+ x = self.lay5(x)
+ x = self.gn5(x)
+ x = F.relu(x)
+
+ x = self.out_lay(x)
+ return x
+
+
+class MHAttentionMap(nn.Module):
+ """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
+
+ def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True):
+ super().__init__()
+ self.num_heads = num_heads
+ self.hidden_dim = hidden_dim
+ self.dropout = nn.Dropout(dropout)
+
+ self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
+ self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
+
+ nn.init.zeros_(self.k_linear.bias)
+ nn.init.zeros_(self.q_linear.bias)
+ nn.init.xavier_uniform_(self.k_linear.weight)
+ nn.init.xavier_uniform_(self.q_linear.weight)
+ self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
+
+ def forward(self, q, k, mask=None):
+ q = self.q_linear(q)
+ k = F.conv2d(
+ k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias
+ )
+ qh = q.view(
+ q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads
+ )
+ kh = k.view(
+ k.shape[0],
+ self.num_heads,
+ self.hidden_dim // self.num_heads,
+ k.shape[-2],
+ k.shape[-1],
+ )
+ weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
+
+ if mask is not None:
+ weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
+ weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights)
+ weights = self.dropout(weights)
+ return weights
+
+
+def dice_loss(inputs, targets, num_boxes):
+ """
+ Compute the DICE loss, similar to generalized IOU for masks
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ """
+ inputs = inputs.sigmoid()
+ inputs = inputs.flatten(1)
+ numerator = 2 * (inputs * targets).sum(1)
+ denominator = inputs.sum(-1) + targets.sum(-1)
+ loss = 1 - (numerator + 1) / (denominator + 1)
+ return loss.sum() / num_boxes
+
+
+def sigmoid_focal_loss(
+ inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2
+):
+ """
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+ Args:
+ inputs: A float tensor of arbitrary shape.
+ The predictions for each example.
+ targets: A float tensor with the same shape as inputs. Stores the binary
+ classification label for each element in inputs
+ (0 for the negative class and 1 for the positive class).
+ alpha: (optional) Weighting factor in range (0,1) to balance
+ positive vs negative examples. Default = -1 (no weighting).
+ gamma: Exponent of the modulating factor (1 - p_t) to
+ balance easy vs hard examples.
+ Returns:
+ Loss tensor
+ """
+ prob = inputs.sigmoid()
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ loss = ce_loss * ((1 - p_t) ** gamma)
+
+ if alpha >= 0:
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+ loss = alpha_t * loss
+
+ return loss.mean(1).sum() / num_boxes
+
+
+class PostProcessSegm(nn.Module):
+ def __init__(self, threshold=0.5):
+ super().__init__()
+ self.threshold = threshold
+
+ @torch.no_grad()
+ def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
+ assert len(orig_target_sizes) == len(max_target_sizes)
+ max_h, max_w = max_target_sizes.max(0)[0].tolist()
+ outputs_masks = outputs["pred_masks"].squeeze(2)
+ outputs_masks = F.interpolate(
+ outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
+ )
+ outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
+
+ for i, (cur_mask, t, tt) in enumerate(
+ zip(outputs_masks, max_target_sizes, orig_target_sizes)
+ ):
+ img_h, img_w = t[0], t[1]
+ results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
+ results[i]["masks"] = F.interpolate(
+ results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
+ ).byte()
+
+ return results
+
+
+class PostProcessPanoptic(nn.Module):
+ """This class converts the output of the model to the final panoptic result, in the format expected by the
+ coco panoptic API """
+
+ def __init__(self, is_thing_map, threshold=0.85):
+ """
+ Parameters:
+ is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether
+ the class is a thing (True) or a stuff (False) class
+ threshold: confidence threshold: segments with confidence lower than this will be deleted
+ """
+ super().__init__()
+ self.threshold = threshold
+ self.is_thing_map = is_thing_map
+
+ def forward(self, outputs, processed_sizes, target_sizes=None):
+ """ This function computes the panoptic prediction from the model's predictions.
+ Parameters:
+ outputs: This is a dict coming directly from the model. See the model doc for the content.
+ processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the
+ model, ie the size after data augmentation but before batching.
+ target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size
+ of each prediction. If left to None, it will default to the processed_sizes
+ """
+ if target_sizes is None:
+ target_sizes = processed_sizes
+ assert len(processed_sizes) == len(target_sizes)
+ out_logits, raw_masks, raw_boxes = (
+ outputs["pred_logits"],
+ outputs["pred_masks"],
+ outputs["pred_boxes"],
+ )
+ assert len(out_logits) == len(raw_masks) == len(target_sizes)
+ preds = []
+
+ def to_tuple(tup):
+ if isinstance(tup, tuple):
+ return tup
+ return tuple(tup.cpu().tolist())
+
+ for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
+ out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
+ ):
+ # we filter empty queries and detection below threshold
+ scores, labels = cur_logits.softmax(-1).max(-1)
+ keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (
+ scores > self.threshold
+ )
+ cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
+ cur_scores = cur_scores[keep]
+ cur_classes = cur_classes[keep]
+ cur_masks = cur_masks[keep]
+ cur_masks = interpolate(
+ cur_masks[None], to_tuple(size), mode="bilinear"
+ ).squeeze(0)
+ cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep])
+
+ h, w = cur_masks.shape[-2:]
+ assert len(cur_boxes) == len(cur_classes)
+
+ # It may be that we have several predicted masks for the same stuff class.
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
+ cur_masks = cur_masks.flatten(1)
+ stuff_equiv_classes = defaultdict(lambda: [])
+ for k, label in enumerate(cur_classes):
+ if not self.is_thing_map[label.item()]:
+ stuff_equiv_classes[label.item()].append(k)
+
+ def get_ids_area(masks, scores, dedup=False):
+ # This helper function creates the final panoptic segmentation image
+ # It also returns the area of the masks that appears on the image
+
+ m_id = masks.transpose(0, 1).softmax(-1)
+
+ if m_id.shape[-1] == 0:
+ # We didn't detect any mask :(
+ m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
+ else:
+ m_id = m_id.argmax(-1).view(h, w)
+
+ if dedup:
+ # Merge the masks corresponding to the same stuff class
+ for equiv in stuff_equiv_classes.values():
+ if len(equiv) > 1:
+ for eq_id in equiv:
+ m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
+
+ final_h, final_w = to_tuple(target_size)
+
+ seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
+ seg_img = seg_img.resize(
+ size=(final_w, final_h), resample=Image.NEAREST
+ )
+
+ np_seg_img = (
+ torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
+ .view(final_h, final_w, 3)
+ .numpy()
+ )
+ m_id = torch.from_numpy(rgb2id(np_seg_img))
+
+ area = []
+ for i in range(len(scores)):
+ area.append(m_id.eq(i).sum().item())
+ return area, seg_img
+
+ area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
+ if cur_classes.numel() > 0:
+ # We know filter empty masks as long as we find some
+ while True:
+ filtered_small = torch.as_tensor(
+ [area[i] <= 4 for i, c in enumerate(cur_classes)],
+ dtype=torch.bool,
+ device=keep.device,
+ )
+ if filtered_small.any().item():
+ cur_scores = cur_scores[~filtered_small]
+ cur_classes = cur_classes[~filtered_small]
+ cur_masks = cur_masks[~filtered_small]
+ area, seg_img = get_ids_area(cur_masks, cur_scores)
+ else:
+ break
+
+ else:
+ cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
+
+ segments_info = []
+ for i, a in enumerate(area):
+ cat = cur_classes[i].item()
+ segments_info.append(
+ {
+ "id": i,
+ "isthing": self.is_thing_map[cat],
+ "category_id": cat,
+ "area": a,
+ }
+ )
+ del cur_classes
+
+ with io.BytesIO() as out:
+ seg_img.save(out, format="PNG")
+ predictions = {
+ "png_string": out.getvalue(),
+ "segments_info": segments_info,
+ }
+ preds.append(predictions)
+ return preds
diff --git a/projects/instance_segment_anything/models/hdetr/models/swin_transformer.py b/projects/instance_segment_anything/models/hdetr/models/swin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e335af2d67b1b2a772fd018b98d2b7b4455525c
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/swin_transformer.py
@@ -0,0 +1,741 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu, Yutong Lin, Yixuan Wei
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from mmdet.utils import get_root_logger
+
+
+class Mlp(nn.Module):
+ """ Multilayer perceptron."""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = (
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ )
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(
+ B, H // window_size, W // window_size, window_size, window_size, -1
+ )
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(
+ self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """ Forward function.
+
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
+ 1
+ ).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ """ Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert (
+ 0 <= self.shift_size < self.window_size
+ ), "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ mask_matrix: Attention mask for cyclic shift.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
+ )
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(
+ shifted_x, self.window_size
+ ) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(
+ -1, self.window_size * self.window_size, C
+ ) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(
+ x_windows, mask=attn_mask
+ ) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
+ )
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """ Patch Merging Layer
+
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList(
+ [
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i]
+ if isinstance(drop_path, list)
+ else drop_path,
+ norm_layer=norm_layer,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ w_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(
+ img_mask, self.window_size
+ ) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0)
+ )
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class SwinTransformer(nn.Module):
+ """ Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.drop_path_rate = drop_path_rate
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None,
+ )
+
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [
+ pretrain_img_size[0] // patch_size[0],
+ pretrain_img_size[1] // patch_size[1],
+ ]
+
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
+ )
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ )
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f"norm{i_layer}"
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ if isinstance(pretrained, str):
+ self.apply(_init_weights)
+ logger = get_root_logger()
+ elif pretrained is None:
+ self.apply(_init_weights)
+ else:
+ raise TypeError("pretrained must be a str or None")
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(
+ self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
+ )
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = {}
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f"norm{i}")
+ x_out = norm_layer(x_out)
+
+ out = (
+ x_out.view(-1, H, W, self.num_features[i])
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+ outs[str(i)] = out
+
+ return outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
diff --git a/projects/instance_segment_anything/models/hdetr/models/util/__init__.py b/projects/instance_segment_anything/models/hdetr/models/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ebdc90b7f3ac2ed5a085066dcf20722b90cbc77
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/util/__init__.py
@@ -0,0 +1,8 @@
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
diff --git a/projects/instance_segment_anything/models/hdetr/models/util/box_ops.py b/projects/instance_segment_anything/models/hdetr/models/util/box_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..de099802a63d388c310b933447b7008cfffcc773
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/util/box_ops.py
@@ -0,0 +1,94 @@
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Utilities for bounding box manipulation and GIoU.
+"""
+import torch
+from torchvision.ops.boxes import box_area
+
+
+def box_cxcywh_to_xyxy(x):
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=-1)
+
+
+def box_xyxy_to_cxcywh(x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
+ return torch.stack(b, dim=-1)
+
+
+# modified from torchvision to also return the union
+def box_iou(boxes1, boxes2):
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ union = area1[:, None] + area2 - inter
+
+ iou = inter / union
+ return iou, union
+
+
+def generalized_box_iou(boxes1, boxes2):
+ """
+ Generalized IoU from https://giou.stanford.edu/
+
+ The boxes should be in [x0, y0, x1, y1] format
+
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
+ and M = len(boxes2)
+ """
+ # degenerate boxes gives inf / nan results
+ # so do an early check
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+ iou, union = box_iou(boxes1, boxes2)
+
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ area = wh[:, :, 0] * wh[:, :, 1]
+
+ return iou - (area - union) / area
+
+
+def masks_to_boxes(masks):
+ """Compute the bounding boxes around the provided masks
+
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
+
+ Returns a [N, 4] tensors, with the boxes in xyxy format
+ """
+ if masks.numel() == 0:
+ return torch.zeros((0, 4), device=masks.device)
+
+ h, w = masks.shape[-2:]
+
+ y = torch.arange(0, h, dtype=torch.float)
+ x = torch.arange(0, w, dtype=torch.float)
+ y, x = torch.meshgrid(y, x)
+
+ x_mask = masks * x.unsqueeze(0)
+ x_max = x_mask.flatten(1).max(-1)[0]
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ y_mask = masks * y.unsqueeze(0)
+ y_max = y_mask.flatten(1).max(-1)[0]
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
+
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
diff --git a/projects/instance_segment_anything/models/hdetr/models/util/misc.py b/projects/instance_segment_anything/models/hdetr/models/util/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..877a1d93387a9d377c9cbf4a586b43e689c4beec
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/util/misc.py
@@ -0,0 +1,518 @@
+# ------------------------------------------------------------------------
+# H-DETR
+# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
+# Licensed under the MIT-style license found in the LICENSE file in the root directory
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+import os
+import subprocess
+import time
+from collections import defaultdict, deque
+import datetime
+import pickle
+from typing import Optional, List
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from torch import Tensor
+
+# needed due to empty tensor bug in pytorch and torchvision 0.5
+import torchvision
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.tensor([tensor.numel()], device="cuda")
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+ if local_size != max_size:
+ padding = torch.empty(
+ size=(max_size - local_size,), dtype=torch.uint8, device="cuda"
+ )
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that all processes
+ have the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.all_reduce(values)
+ if average:
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError(
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+ )
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join(
+ [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ "max mem: {memory:.0f}",
+ ]
+ )
+ else:
+ log_msg = self.delimiter.join(
+ [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ )
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ print(
+ log_msg.format(
+ i,
+ len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(
+ "{} Total time: {} ({:.4f} s / it)".format(
+ header, total_time_str, total_time / len(iterable)
+ )
+ )
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommited changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+def collate_fn(batch):
+ batch = list(zip(*batch))
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
+ return tuple(batch)
+
+
+def _max_by_axis(the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+ # TODO make this more general
+ if tensor_list[0].ndim == 3:
+ # TODO make it support different-sized images
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+ batch_shape = [len(tensor_list)] + max_size
+ b, c, h, w = batch_shape
+ dtype = tensor_list[0].dtype
+ device = tensor_list[0].device
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ m[: img.shape[1], : img.shape[2]] = False
+ else:
+ raise ValueError("not supported")
+ return NestedTensor(tensor, mask)
+
+
+class NestedTensor(object):
+ def __init__(self, tensors, mask: Optional[Tensor]):
+ self.tensors = tensors
+ self.mask = mask
+
+ def to(self, device, non_blocking=False):
+ # type: (Device) -> NestedTensor # noqa
+ cast_tensor = self.tensors.to(device, non_blocking=non_blocking)
+ mask = self.mask
+ if mask is not None:
+ assert mask is not None
+ cast_mask = mask.to(device, non_blocking=non_blocking)
+ else:
+ cast_mask = None
+ return NestedTensor(cast_tensor, cast_mask)
+
+ def record_stream(self, *args, **kwargs):
+ self.tensors.record_stream(*args, **kwargs)
+ if self.mask is not None:
+ self.mask.record_stream(*args, **kwargs)
+
+ def decompose(self):
+ return self.tensors, self.mask
+
+ def __repr__(self):
+ return str(self.tensors)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return int(os.environ["LOCAL_SIZE"])
+
+
+def get_local_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return int(os.environ["LOCAL_RANK"])
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = int(os.environ["LOCAL_RANK"])
+ args.dist_url = "env://"
+ os.environ["LOCAL_SIZE"] = str(torch.cuda.device_count())
+ elif "SLURM_PROCID" in os.environ:
+ proc_id = int(os.environ["SLURM_PROCID"])
+ ntasks = int(os.environ["SLURM_NTASKS"])
+ node_list = os.environ["SLURM_NODELIST"]
+ num_gpus = torch.cuda.device_count()
+ addr = subprocess.getoutput(
+ "scontrol show hostname {} | head -n1".format(node_list)
+ )
+ os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
+ os.environ["MASTER_ADDR"] = addr
+ os.environ["WORLD_SIZE"] = str(ntasks)
+ os.environ["RANK"] = str(proc_id)
+ os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
+ os.environ["LOCAL_SIZE"] = str(num_gpus)
+ args.dist_url = "env://"
+ args.world_size = ntasks
+ args.rank = proc_id
+ args.gpu = proc_id % num_gpus
+ else:
+ print("Not using distributed mode")
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = "nccl"
+ print(
+ "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
+ )
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ if target.numel() == 0:
+ return [torch.zeros([], device=output.device)]
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def interpolate(
+ input, size=None, scale_factor=None, mode="nearest", align_corners=None
+):
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
+ """
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
+ This will eventually be supported natively by PyTorch, and this
+ class can go away.
+ """
+ return torchvision.ops.misc.interpolate(
+ input, size, scale_factor, mode, align_corners
+ )
+
+
+def get_total_grad_norm(parameters, norm_type=2):
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ device = parameters[0].grad.device
+ total_norm = torch.norm(
+ torch.stack(
+ [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
+ ),
+ norm_type,
+ )
+ return total_norm
+
+
+def inverse_sigmoid(x, eps=1e-5):
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
diff --git a/projects/instance_segment_anything/models/hdetr/models/util/plot_utils.py b/projects/instance_segment_anything/models/hdetr/models/util/plot_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..759f34d252493fd93187ea3cf2ab0d63a3e2b280
--- /dev/null
+++ b/projects/instance_segment_anything/models/hdetr/models/util/plot_utils.py
@@ -0,0 +1,111 @@
+# ------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------
+# Modified from DETR (https://github.com/facebookresearch/detr)
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# ------------------------------------------------------------------------
+
+"""
+Plotting utilities to visualize training logs.
+"""
+import torch
+import pandas as pd
+import seaborn as sns
+import matplotlib.pyplot as plt
+
+from pathlib import Path, PurePath
+
+
+def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
+ '''
+ Function to plot specific fields from training log(s). Plots both training and test results.
+
+ :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
+ - fields = which results to plot from each log file - plots both training and test for each field.
+ - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
+ - log_name = optional, name of log file if different than default 'log.txt'.
+
+ :: Outputs - matplotlib plots of results in fields, color coded for each log file.
+ - solid lines are training results, dashed lines are test results.
+
+ '''
+ func_name = "plot_utils.py::plot_logs"
+
+ # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
+ # convert single Path to list to avoid 'not iterable' error
+
+ if not isinstance(logs, list):
+ if isinstance(logs, PurePath):
+ logs = [logs]
+ print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
+ else:
+ raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
+ Expect list[Path] or single Path obj, received {type(logs)}")
+
+ # verify valid dir(s) and that every item in list is Path object
+ for i, dir in enumerate(logs):
+ if not isinstance(dir, PurePath):
+ raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
+ if dir.exists():
+ continue
+ raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
+
+ # load log file(s) and plot
+ dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
+
+ fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
+
+ for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
+ for j, field in enumerate(fields):
+ if field == 'mAP':
+ coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean()
+ axs[j].plot(coco_eval, c=color)
+ else:
+ df.interpolate().ewm(com=ewm_col).mean().plot(
+ y=[f'train_{field}', f'test_{field}'],
+ ax=axs[j],
+ color=[color] * 2,
+ style=['-', '--']
+ )
+ for ax, field in zip(axs, fields):
+ ax.legend([Path(p).name for p in logs])
+ ax.set_title(field)
+
+
+def plot_precision_recall(files, naming_scheme='iter'):
+ if naming_scheme == 'exp_id':
+ # name becomes exp_id
+ names = [f.parts[-3] for f in files]
+ elif naming_scheme == 'iter':
+ names = [f.stem for f in files]
+ else:
+ raise ValueError(f'not supported {naming_scheme}')
+ fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
+ for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
+ data = torch.load(f)
+ # precision is n_iou, n_points, n_cat, n_area, max_det
+ precision = data['precision']
+ recall = data['params'].recThrs
+ scores = data['scores']
+ # take precision for all classes, all areas and 100 detections
+ precision = precision[0, :, :, 0, -1].mean(1)
+ scores = scores[0, :, :, 0, -1].mean(1)
+ prec = precision.mean()
+ rec = data['recall'][0, :, 0, -1].mean()
+ print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
+ f'score={scores.mean():0.3f}, ' +
+ f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
+ )
+ axs[0].plot(recall, precision, c=color)
+ axs[1].plot(recall, scores, c=color)
+
+ axs[0].set_title('Precision / Recall')
+ axs[0].legend(names)
+ axs[1].set_title('Scores / Recall')
+ axs[1].legend(names)
+ return fig, axs
+
+
+
diff --git a/projects/instance_segment_anything/models/segment_anything/__init__.py b/projects/instance_segment_anything/models/segment_anything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..34383d83f5e76bc801f31b20e5651e383be348b6
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .build_sam import (
+ build_sam,
+ build_sam_vit_h,
+ build_sam_vit_l,
+ build_sam_vit_b,
+ sam_model_registry,
+)
+from .predictor import SamPredictor
+from .automatic_mask_generator import SamAutomaticMaskGenerator
diff --git a/projects/instance_segment_anything/models/segment_anything/automatic_mask_generator.py b/projects/instance_segment_anything/models/segment_anything/automatic_mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..23264971b7ff5aa0b4f499ade7773b68dce984b6
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/automatic_mask_generator.py
@@ -0,0 +1,372 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area # type: ignore
+
+from typing import Any, Dict, List, Optional, Tuple
+
+from .modeling import Sam
+from .predictor import SamPredictor
+from .utils.amg import (
+ MaskData,
+ area_from_rle,
+ batch_iterator,
+ batched_mask_to_box,
+ box_xyxy_to_xywh,
+ build_all_layer_point_grids,
+ calculate_stability_score,
+ coco_encode_rle,
+ generate_crop_boxes,
+ is_box_near_crop_edge,
+ mask_to_rle_pytorch,
+ remove_small_regions,
+ rle_to_mask,
+ uncrop_boxes_xyxy,
+ uncrop_masks,
+ uncrop_points,
+)
+
+
+class SamAutomaticMaskGenerator:
+ def __init__(
+ self,
+ model: Sam,
+ points_per_side: Optional[int] = 32,
+ points_per_batch: int = 64,
+ pred_iou_thresh: float = 0.88,
+ stability_score_thresh: float = 0.95,
+ stability_score_offset: float = 1.0,
+ box_nms_thresh: float = 0.7,
+ crop_n_layers: int = 0,
+ crop_nms_thresh: float = 0.7,
+ crop_overlap_ratio: float = 512 / 1500,
+ crop_n_points_downscale_factor: int = 1,
+ point_grids: Optional[List[np.ndarray]] = None,
+ min_mask_region_area: int = 0,
+ output_mode: str = "binary_mask",
+ ) -> None:
+ """
+ Using a SAM model, generates masks for the entire image.
+ Generates a grid of point prompts over the image, then filters
+ low quality and duplicate masks. The default settings are chosen
+ for SAM with a ViT-H backbone.
+
+ Arguments:
+ model (Sam): The SAM model to use for mask prediction.
+ points_per_side (int or None): The number of points to be sampled
+ along one side of the image. The total number of points is
+ points_per_side**2. If None, 'point_grids' must provide explicit
+ point sampling.
+ points_per_batch (int): Sets the number of points run simultaneously
+ by the model. Higher numbers may be faster but use more GPU memory.
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
+ model's predicted mask quality.
+ stability_score_thresh (float): A filtering threshold in [0,1], using
+ the stability of the mask under changes to the cutoff used to binarize
+ the model's mask predictions.
+ stability_score_offset (float): The amount to shift the cutoff when
+ calculated the stability score.
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks.
+ crops_n_layers (int): If >0, mask prediction will be run again on
+ crops of the image. Sets the number of layers to run, where each
+ layer has 2**i_layer number of image crops.
+ crops_nms_thresh (float): The box IoU cutoff used by non-maximal
+ suppression to filter duplicate masks between different crops.
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
+ In the first crop layer, crops will overlap by this fraction of
+ the image length. Later layers with more crops scale down this overlap.
+ crop_n_points_downscale_factor (int): The number of points-per-side
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+ point_grids (list(np.ndarray) or None): A list over explicit grids
+ of points used for sampling, normalized to [0,1]. The nth grid in the
+ list is used in the nth crop layer. Exclusive with points_per_side.
+ min_mask_region_area (int): If >0, postprocessing will be applied
+ to remove disconnected regions and holes in masks with area smaller
+ than min_mask_region_area. Requires opencv.
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+ For large resolutions, 'binary_mask' may consume large amounts of
+ memory.
+ """
+
+ assert (points_per_side is None) != (
+ point_grids is None
+ ), "Exactly one of points_per_side or point_grid must be provided."
+ if points_per_side is not None:
+ self.point_grids = build_all_layer_point_grids(
+ points_per_side,
+ crop_n_layers,
+ crop_n_points_downscale_factor,
+ )
+ elif point_grids is not None:
+ self.point_grids = point_grids
+ else:
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+ assert output_mode in [
+ "binary_mask",
+ "uncompressed_rle",
+ "coco_rle",
+ ], f"Unknown output_mode {output_mode}."
+ if output_mode == "coco_rle":
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
+
+ if min_mask_region_area > 0:
+ import cv2 # type: ignore # noqa: F401
+
+ self.predictor = SamPredictor(model)
+ self.points_per_batch = points_per_batch
+ self.pred_iou_thresh = pred_iou_thresh
+ self.stability_score_thresh = stability_score_thresh
+ self.stability_score_offset = stability_score_offset
+ self.box_nms_thresh = box_nms_thresh
+ self.crop_n_layers = crop_n_layers
+ self.crop_nms_thresh = crop_nms_thresh
+ self.crop_overlap_ratio = crop_overlap_ratio
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+ self.min_mask_region_area = min_mask_region_area
+ self.output_mode = output_mode
+
+ @torch.no_grad()
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+ """
+ Generates masks for the given image.
+
+ Arguments:
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+ Returns:
+ list(dict(str, any)): A list over records for masks. Each record is
+ a dict containing the following keys:
+ segmentation (dict(str, any) or np.ndarray): The mask. If
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
+ is a dictionary containing the RLE.
+ bbox (list(float)): The box around the mask, in XYWH format.
+ area (int): The area in pixels of the mask.
+ predicted_iou (float): The model's own prediction of the mask's
+ quality. This is filtered by the pred_iou_thresh parameter.
+ point_coords (list(list(float))): The point coordinates input
+ to the model to generate this mask.
+ stability_score (float): A measure of the mask's quality. This
+ is filtered on using the stability_score_thresh parameter.
+ crop_box (list(float)): The crop of the image used to generate
+ the mask, given in XYWH format.
+ """
+
+ # Generate masks
+ mask_data = self._generate_masks(image)
+
+ # Filter small disconnected regions and holes in masks
+ if self.min_mask_region_area > 0:
+ mask_data = self.postprocess_small_regions(
+ mask_data,
+ self.min_mask_region_area,
+ max(self.box_nms_thresh, self.crop_nms_thresh),
+ )
+
+ # Encode masks
+ if self.output_mode == "coco_rle":
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
+ elif self.output_mode == "binary_mask":
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+ else:
+ mask_data["segmentations"] = mask_data["rles"]
+
+ # Write mask records
+ curr_anns = []
+ for idx in range(len(mask_data["segmentations"])):
+ ann = {
+ "segmentation": mask_data["segmentations"][idx],
+ "area": area_from_rle(mask_data["rles"][idx]),
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
+ "point_coords": [mask_data["points"][idx].tolist()],
+ "stability_score": mask_data["stability_score"][idx].item(),
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+ }
+ curr_anns.append(ann)
+
+ return curr_anns
+
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
+ orig_size = image.shape[:2]
+ crop_boxes, layer_idxs = generate_crop_boxes(
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
+ )
+
+ # Iterate over image crops
+ data = MaskData()
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+ data.cat(crop_data)
+
+ # Remove duplicate masks between crops
+ if len(crop_boxes) > 1:
+ # Prefer masks from smaller crops
+ scores = 1 / box_area(data["crop_boxes"])
+ scores = scores.to(data["boxes"].device)
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ scores,
+ torch.zeros(len(data["boxes"])), # categories
+ iou_threshold=self.crop_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ data.to_numpy()
+ return data
+
+ def _process_crop(
+ self,
+ image: np.ndarray,
+ crop_box: List[int],
+ crop_layer_idx: int,
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ # Crop the image and calculate embeddings
+ x0, y0, x1, y1 = crop_box
+ cropped_im = image[y0:y1, x0:x1, :]
+ cropped_im_size = cropped_im.shape[:2]
+ self.predictor.set_image(cropped_im)
+
+ # Get points for this crop
+ points_scale = np.array(cropped_im_size)[None, ::-1]
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+ # Generate masks for this crop in batches
+ data = MaskData()
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
+ data.cat(batch_data)
+ del batch_data
+ self.predictor.reset_image()
+
+ # Remove duplicates within this crop.
+ keep_by_nms = batched_nms(
+ data["boxes"].float(),
+ data["iou_preds"],
+ torch.zeros(len(data["boxes"])), # categories
+ iou_threshold=self.box_nms_thresh,
+ )
+ data.filter(keep_by_nms)
+
+ # Return to the original image frame
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+ data["points"] = uncrop_points(data["points"], crop_box)
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+ return data
+
+ def _process_batch(
+ self,
+ points: np.ndarray,
+ im_size: Tuple[int, ...],
+ crop_box: List[int],
+ orig_size: Tuple[int, ...],
+ ) -> MaskData:
+ orig_h, orig_w = orig_size
+
+ # Run model on this batch
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
+ masks, iou_preds, _ = self.predictor.predict_torch(
+ in_points[:, None, :],
+ in_labels[:, None],
+ multimask_output=True,
+ return_logits=True,
+ )
+
+ # Serialize predictions and store in MaskData
+ data = MaskData(
+ masks=masks.flatten(0, 1),
+ iou_preds=iou_preds.flatten(0, 1),
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
+ )
+ del masks
+
+ # Filter by predicted IoU
+ if self.pred_iou_thresh > 0.0:
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
+ data.filter(keep_mask)
+
+ # Calculate stability score
+ data["stability_score"] = calculate_stability_score(
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
+ )
+ if self.stability_score_thresh > 0.0:
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
+ data.filter(keep_mask)
+
+ # Threshold masks and calculate boxes
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
+ data["boxes"] = batched_mask_to_box(data["masks"])
+
+ # Filter boxes that touch crop boundaries
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
+ if not torch.all(keep_mask):
+ data.filter(keep_mask)
+
+ # Compress to RLE
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
+ del data["masks"]
+
+ return data
+
+ @staticmethod
+ def postprocess_small_regions(
+ mask_data: MaskData, min_area: int, nms_thresh: float
+ ) -> MaskData:
+ """
+ Removes small disconnected regions and holes in masks, then reruns
+ box NMS to remove any new duplicates.
+
+ Edits mask_data in place.
+
+ Requires open-cv as a dependency.
+ """
+ if len(mask_data["rles"]) == 0:
+ return mask_data
+
+ # Filter small disconnected regions and holes
+ new_masks = []
+ scores = []
+ for rle in mask_data["rles"]:
+ mask = rle_to_mask(rle)
+
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
+ unchanged = not changed
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
+ unchanged = unchanged and not changed
+
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+ # Give score=0 to changed masks and score=1 to unchanged masks
+ # so NMS will prefer ones that didn't need postprocessing
+ scores.append(float(unchanged))
+
+ # Recalculate boxes and remove any new duplicates
+ masks = torch.cat(new_masks, dim=0)
+ boxes = batched_mask_to_box(masks)
+ keep_by_nms = batched_nms(
+ boxes.float(),
+ torch.as_tensor(scores),
+ torch.zeros(len(boxes)), # categories
+ iou_threshold=nms_thresh,
+ )
+
+ # Only recalculate RLEs for masks that have changed
+ for i_mask in keep_by_nms:
+ if scores[i_mask] == 0.0:
+ mask_torch = masks[i_mask].unsqueeze(0)
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
+ mask_data.filter(keep_by_nms)
+
+ return mask_data
diff --git a/projects/instance_segment_anything/models/segment_anything/build_sam.py b/projects/instance_segment_anything/models/segment_anything/build_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..07abfca24e96eced7f13bdefd3212ce1b77b8999
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/build_sam.py
@@ -0,0 +1,107 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from functools import partial
+
+from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
+
+
+def build_sam_vit_h(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=1280,
+ encoder_depth=32,
+ encoder_num_heads=16,
+ encoder_global_attn_indexes=[7, 15, 23, 31],
+ checkpoint=checkpoint,
+ )
+
+
+build_sam = build_sam_vit_h
+
+
+def build_sam_vit_l(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=1024,
+ encoder_depth=24,
+ encoder_num_heads=16,
+ encoder_global_attn_indexes=[5, 11, 17, 23],
+ checkpoint=checkpoint,
+ )
+
+
+def build_sam_vit_b(checkpoint=None):
+ return _build_sam(
+ encoder_embed_dim=768,
+ encoder_depth=12,
+ encoder_num_heads=12,
+ encoder_global_attn_indexes=[2, 5, 8, 11],
+ checkpoint=checkpoint,
+ )
+
+
+sam_model_registry = {
+ "default": build_sam,
+ "vit_h": build_sam,
+ "vit_l": build_sam_vit_l,
+ "vit_b": build_sam_vit_b,
+}
+
+
+def _build_sam(
+ encoder_embed_dim,
+ encoder_depth,
+ encoder_num_heads,
+ encoder_global_attn_indexes,
+ checkpoint=None,
+):
+ prompt_embed_dim = 256
+ image_size = 1024
+ vit_patch_size = 16
+ image_embedding_size = image_size // vit_patch_size
+ sam = Sam(
+ image_encoder=ImageEncoderViT(
+ depth=encoder_depth,
+ embed_dim=encoder_embed_dim,
+ img_size=image_size,
+ mlp_ratio=4,
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+ num_heads=encoder_num_heads,
+ patch_size=vit_patch_size,
+ qkv_bias=True,
+ use_rel_pos=True,
+ global_attn_indexes=encoder_global_attn_indexes,
+ window_size=14,
+ out_chans=prompt_embed_dim,
+ ),
+ prompt_encoder=PromptEncoder(
+ embed_dim=prompt_embed_dim,
+ image_embedding_size=(image_embedding_size, image_embedding_size),
+ input_image_size=(image_size, image_size),
+ mask_in_chans=16,
+ ),
+ mask_decoder=MaskDecoder(
+ num_multimask_outputs=3,
+ transformer=TwoWayTransformer(
+ depth=2,
+ embedding_dim=prompt_embed_dim,
+ mlp_dim=2048,
+ num_heads=8,
+ ),
+ transformer_dim=prompt_embed_dim,
+ iou_head_depth=3,
+ iou_head_hidden_dim=256,
+ ),
+ pixel_mean=[123.675, 116.28, 103.53],
+ pixel_std=[58.395, 57.12, 57.375],
+ )
+ sam.eval()
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ sam.load_state_dict(state_dict)
+ return sam
diff --git a/projects/instance_segment_anything/models/segment_anything/modeling/__init__.py b/projects/instance_segment_anything/models/segment_anything/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38e906243d898d7fc071c0fe218338c5cace3ea1
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/modeling/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .sam import Sam
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+from .transformer import TwoWayTransformer
diff --git a/projects/instance_segment_anything/models/segment_anything/modeling/common.py b/projects/instance_segment_anything/models/segment_anything/modeling/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/modeling/common.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from typing import Type
+
+
+class MLPBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ mlp_dim: int,
+ act: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ super().__init__()
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+ self.act = act()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.lin2(self.act(self.lin1(x)))
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
diff --git a/projects/instance_segment_anything/models/segment_anything/modeling/image_encoder.py b/projects/instance_segment_anything/models/segment_anything/modeling/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6ad9ad2938842308e482a05c9d35ab08db9b2c3
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/modeling/image_encoder.py
@@ -0,0 +1,395 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Type
+
+from .common import LayerNorm2d, MLPBlock
+
+
+# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
+class ImageEncoderViT(nn.Module):
+ def __init__(
+ self,
+ img_size: int = 1024,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ out_chans: int = 256,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_abs_pos: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ global_attn_indexes: Tuple[int, ...] = (),
+ ) -> None:
+ """
+ Args:
+ img_size (int): Input image size.
+ patch_size (int): Patch size.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ depth (int): Depth of ViT.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_abs_pos (bool): If True, use absolute positional embeddings.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks.
+ global_attn_indexes (list): Indexes for blocks using global attention.
+ """
+ super().__init__()
+ self.img_size = img_size
+
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+
+ self.pos_embed: Optional[nn.Parameter] = None
+ if use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
+ )
+
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ block = Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ window_size=window_size if i not in global_attn_indexes else 0,
+ input_size=(img_size // patch_size, img_size // patch_size),
+ )
+ self.blocks.append(block)
+
+ self.neck = nn.Sequential(
+ nn.Conv2d(
+ embed_dim,
+ out_chans,
+ kernel_size=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ nn.Conv2d(
+ out_chans,
+ out_chans,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.neck(x.permute(0, 3, 1, 2))
+
+ return x
+
+
+class Block(nn.Module):
+ """Transformer blocks with support of window attention and residual propagation blocks"""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks. If it equals 0, then
+ use global attention.
+ input_size (int or None): Input resolution for calculating the relative positional
+ parameter size.
+ """
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ )
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+ self.window_size = window_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, self.window_size)
+
+ x = self.attn(x)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+ x = shortcut + x
+ x = x + self.mlp(self.norm2(x))
+
+ return x
+
+
+class Attention(nn.Module):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ input_size (int or None): Input resolution for calculating the relative positional
+ parameter size.
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ assert (
+ input_size is not None
+ ), "Input size must be provided if using relative positional encoding."
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, H, W, _ = x.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+
+ if self.use_rel_pos:
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+ x = self.proj(x)
+
+ return x
+
+
+def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(
+ windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
+) -> torch.Tensor:
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+ return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+ Args:
+ q_size (int): size of query q.
+ k_size (int): size of key k.
+ rel_pos (Tensor): relative position embeddings (L, C).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+ attn: torch.Tensor,
+ q: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
+ Args:
+ attn (Tensor): attention map.
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+ Returns:
+ attn (Tensor): attention map with added relative positional embeddings.
+ """
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+ attn = (
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+ ).view(B, q_h * q_w, k_h * k_w)
+
+ return attn
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int] = (0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
diff --git a/projects/instance_segment_anything/models/segment_anything/modeling/mask_decoder.py b/projects/instance_segment_anything/models/segment_anything/modeling/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e86f7cc9ad95582a08ef2531c68d03fa4af8d99
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/modeling/mask_decoder.py
@@ -0,0 +1,176 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import List, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class MaskDecoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ transformer_dim: int,
+ transformer: nn.Module,
+ num_multimask_outputs: int = 3,
+ activation: Type[nn.Module] = nn.GELU,
+ iou_head_depth: int = 3,
+ iou_head_hidden_dim: int = 256,
+ ) -> None:
+ """
+ Predicts masks given an image and prompt embeddings, using a
+ tranformer architecture.
+
+ Arguments:
+ transformer_dim (int): the channel dimension of the transformer
+ transformer (nn.Module): the transformer used to predict masks
+ num_multimask_outputs (int): the number of masks to predict
+ when disambiguating masks
+ activation (nn.Module): the type of activation to use when
+ upscaling masks
+ iou_head_depth (int): the depth of the MLP used to predict
+ mask quality
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
+ used to predict mask quality
+ """
+ super().__init__()
+ self.transformer_dim = transformer_dim
+ self.transformer = transformer
+
+ self.num_multimask_outputs = num_multimask_outputs
+
+ self.iou_token = nn.Embedding(1, transformer_dim)
+ self.num_mask_tokens = num_multimask_outputs + 1
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ LayerNorm2d(transformer_dim // 4),
+ activation(),
+ nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ activation(),
+ )
+ self.output_hypernetworks_mlps = nn.ModuleList(
+ [
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+ for i in range(self.num_mask_tokens)
+ ]
+ )
+
+ self.iou_prediction_head = MLP(
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
+ )
+
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ multimask_output: bool,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict masks given image and prompt embeddings.
+
+ Arguments:
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+ multimask_output (bool): Whether to return multiple masks or a single
+ mask.
+
+ Returns:
+ torch.Tensor: batched predicted masks
+ torch.Tensor: batched predictions of mask quality
+ """
+ masks, iou_pred = self.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=image_pe,
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
+ dense_prompt_embeddings=dense_prompt_embeddings,
+ )
+
+ # Select the correct mask or masks for outptu
+ if multimask_output:
+ mask_slice = slice(1, None)
+ else:
+ mask_slice = slice(0, 1)
+ masks = masks[:, mask_slice, :, :]
+ iou_pred = iou_pred[:, mask_slice]
+
+ # Prepare output
+ return masks, iou_pred
+
+ def predict_masks(
+ self,
+ image_embeddings: torch.Tensor,
+ image_pe: torch.Tensor,
+ sparse_prompt_embeddings: torch.Tensor,
+ dense_prompt_embeddings: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Predicts masks. See 'forward' for more details."""
+ # Concatenate output tokens
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+ # Expand per-image data in batch direction to be per-mask
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+ src = src + dense_prompt_embeddings
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+ b, c, h, w = src.shape
+
+ # Run the transformer
+ hs, src = self.transformer(src, pos_src, tokens)
+ iou_token_out = hs[:, 0, :]
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+ # Upscale mask embeddings and predict masks using the mask tokens
+ src = src.transpose(1, 2).view(b, c, h, w)
+ upscaled_embedding = self.output_upscaling(src)
+ hyper_in_list: List[torch.Tensor] = []
+ for i in range(self.num_mask_tokens):
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+ hyper_in = torch.stack(hyper_in_list, dim=1)
+ b, c, h, w = upscaled_embedding.shape
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+ # Generate mask quality predictions
+ iou_pred = self.iou_prediction_head(iou_token_out)
+
+ return masks, iou_pred
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ output_dim: int,
+ num_layers: int,
+ sigmoid_output: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
diff --git a/projects/instance_segment_anything/models/segment_anything/modeling/prompt_encoder.py b/projects/instance_segment_anything/models/segment_anything/modeling/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/modeling/prompt_encoder.py
@@ -0,0 +1,214 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch import nn
+
+from typing import Any, Optional, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ image_embedding_size: Tuple[int, int],
+ input_image_size: Tuple[int, int],
+ mask_in_chans: int,
+ activation: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ """
+ Encodes prompts for input to SAM's mask decoder.
+
+ Arguments:
+ embed_dim (int): The prompts' embedding dimension
+ image_embedding_size (tuple(int, int)): The spatial size of the
+ image embedding, as (H, W).
+ input_image_size (int): The padded size of the image as input
+ to the image encoder, as (H, W).
+ mask_in_chans (int): The number of hidden channels used for
+ encoding input masks.
+ activation (nn.Module): The activation to use when encoding
+ input masks.
+ """
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.input_image_size = input_image_size
+ self.image_embedding_size = image_embedding_size
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
+ self.point_embeddings = nn.ModuleList(point_embeddings)
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
+ self.mask_downscaling = nn.Sequential(
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans // 4),
+ activation(),
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+ LayerNorm2d(mask_in_chans),
+ activation(),
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+ )
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+ def get_dense_pe(self) -> torch.Tensor:
+ """
+ Returns the positional encoding used to encode point prompts,
+ applied to a dense set of points the shape of the image encoding.
+
+ Returns:
+ torch.Tensor: Positional encoding with shape
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
+ """
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+ def _embed_points(
+ self,
+ points: torch.Tensor,
+ labels: torch.Tensor,
+ pad: bool,
+ ) -> torch.Tensor:
+ """Embeds point prompts."""
+ points = points + 0.5 # Shift to center of pixel
+ if pad:
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+ points = torch.cat([points, padding_point], dim=1)
+ labels = torch.cat([labels, padding_label], dim=1)
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+ point_embedding[labels == -1] = 0.0
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
+ return point_embedding
+
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+ """Embeds box prompts."""
+ boxes = boxes + 0.5 # Shift to center of pixel
+ coords = boxes.reshape(-1, 2, 2)
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+ return corner_embedding
+
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+ """Embeds mask inputs."""
+ mask_embedding = self.mask_downscaling(masks)
+ return mask_embedding
+
+ def _get_batch_size(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> int:
+ """
+ Gets the batch size of the output given the batch size of the input prompts.
+ """
+ if points is not None:
+ return points[0].shape[0]
+ elif boxes is not None:
+ return boxes.shape[0]
+ elif masks is not None:
+ return masks.shape[0]
+ else:
+ return 1
+
+ def _get_device(self) -> torch.device:
+ return self.point_embeddings[0].weight.device
+
+ def forward(
+ self,
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+ boxes: Optional[torch.Tensor],
+ masks: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Embeds different types of prompts, returning both sparse and dense
+ embeddings.
+
+ Arguments:
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+ and labels to embed.
+ boxes (torch.Tensor or none): boxes to embed
+ masks (torch.Tensor or none): masks to embed
+
+ Returns:
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
+ BxNx(embed_dim), where N is determined by the number of input points
+ and boxes.
+ torch.Tensor: dense embeddings for the masks, in the shape
+ Bx(embed_dim)x(embed_H)x(embed_W)
+ """
+ bs = self._get_batch_size(points, boxes, masks)
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+ if points is not None:
+ coords, labels = points
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+ if boxes is not None:
+ box_embeddings = self._embed_boxes(boxes)
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+ if masks is not None:
+ dense_embeddings = self._embed_masks(masks)
+ else:
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+ )
+
+ return sparse_embeddings, dense_embeddings
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+ """
+
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer(
+ "positional_encoding_gaussian_matrix",
+ scale * torch.randn((2, num_pos_feats)),
+ )
+
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
+ y_embed = grid.cumsum(dim=0) - 0.5
+ x_embed = grid.cumsum(dim=1) - 0.5
+ y_embed = y_embed / h
+ x_embed = x_embed / w
+
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+ return pe.permute(2, 0, 1) # C x H x W
+
+ def forward_with_coords(
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+ ) -> torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
diff --git a/projects/instance_segment_anything/models/segment_anything/modeling/sam.py b/projects/instance_segment_anything/models/segment_anything/modeling/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..303bc2f40c3dbc84f5d4286bb73336e075a86589
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/modeling/sam.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import Any, Dict, List, Tuple
+
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+
+
+class Sam(nn.Module):
+ mask_threshold: float = 0.0
+ image_format: str = "RGB"
+
+ def __init__(
+ self,
+ image_encoder: ImageEncoderViT,
+ prompt_encoder: PromptEncoder,
+ mask_decoder: MaskDecoder,
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
+ ) -> None:
+ """
+ SAM predicts object masks from an image and input prompts.
+
+ Arguments:
+ image_encoder (ImageEncoderViT): The backbone used to encode the
+ image into image embeddings that allow for efficient mask prediction.
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+ and encoded prompts.
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
+ """
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.prompt_encoder = prompt_encoder
+ self.mask_decoder = mask_decoder
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+ @property
+ def device(self) -> Any:
+ return self.pixel_mean.device
+
+ @torch.no_grad()
+ def forward(
+ self,
+ batched_input: List[Dict[str, Any]],
+ multimask_output: bool,
+ ) -> List[Dict[str, torch.Tensor]]:
+ """
+ Predicts masks end-to-end from provided images and prompts.
+ If prompts are not known in advance, using SamPredictor is
+ recommended over calling the model directly.
+
+ Arguments:
+ batched_input (list(dict)): A list over input images, each a
+ dictionary with the following keys. A prompt key can be
+ excluded if it is not present.
+ 'image': The image as a torch tensor in 3xHxW format,
+ already transformed for input to the model.
+ 'original_size': (tuple(int, int)) The original size of
+ the image before transformation, as (H, W).
+ 'point_coords': (torch.Tensor) Batched point prompts for
+ this image, with shape BxNx2. Already transformed to the
+ input frame of the model.
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
+ with shape BxN.
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
+ Already transformed to the input frame of the model.
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
+ in the form Bx1xHxW.
+ multimask_output (bool): Whether the model should predict multiple
+ disambiguating masks, or return a single mask.
+
+ Returns:
+ (list(dict)): A list over input images, where each element is
+ as dictionary with the following keys.
+ 'masks': (torch.Tensor) Batched binary mask predictions,
+ with shape BxCxHxW, where B is the number of input promts,
+ C is determiend by multimask_output, and (H, W) is the
+ original size of the image.
+ 'iou_predictions': (torch.Tensor) The model's predictions
+ of mask quality, in shape BxC.
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
+ to subsequent iterations of prediction.
+ """
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
+ image_embeddings = self.image_encoder(input_images)
+
+ outputs = []
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
+ if "point_coords" in image_record:
+ points = (image_record["point_coords"], image_record["point_labels"])
+ else:
+ points = None
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
+ points=points,
+ boxes=image_record.get("boxes", None),
+ masks=image_record.get("mask_inputs", None),
+ )
+ low_res_masks, iou_predictions = self.mask_decoder(
+ image_embeddings=curr_embedding.unsqueeze(0),
+ image_pe=self.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+ masks = self.postprocess_masks(
+ low_res_masks,
+ input_size=image_record["image"].shape[-2:],
+ original_size=image_record["original_size"],
+ )
+ masks = masks > self.mask_threshold
+ outputs.append(
+ {
+ "masks": masks,
+ "iou_predictions": iou_predictions,
+ "low_res_logits": low_res_masks,
+ }
+ )
+ return outputs
+
+ def postprocess_masks(
+ self,
+ masks: torch.Tensor,
+ input_size: Tuple[int, ...],
+ original_size: Tuple[int, ...],
+ ) -> torch.Tensor:
+ """
+ Remove padding and upscale masks to the original image size.
+
+ Arguments:
+ masks (torch.Tensor): Batched masks from the mask_decoder,
+ in BxCxHxW format.
+ input_size (tuple(int, int)): The size of the image input to the
+ model, in (H, W) format. Used to remove padding.
+ original_size (tuple(int, int)): The original size of the image
+ before resizing for input to the model, in (H, W) format.
+
+ Returns:
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+ is given by original_size.
+ """
+ masks = F.interpolate(
+ masks,
+ (self.image_encoder.img_size, self.image_encoder.img_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+ masks = masks[..., : input_size[0], : input_size[1]]
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
+ return masks
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.image_encoder.img_size - h
+ padw = self.image_encoder.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
diff --git a/projects/instance_segment_anything/models/segment_anything/modeling/transformer.py b/projects/instance_segment_anything/models/segment_anything/modeling/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1a2812f613cc55b1d0b3e3e1d0c84a760d1fb87
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/modeling/transformer.py
@@ -0,0 +1,240 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor, nn
+
+import math
+from typing import Tuple, Type
+
+from .common import MLPBlock
+
+
+class TwoWayTransformer(nn.Module):
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ A transformer decoder that attends to an input image using
+ queries whose positional embedding is supplied.
+
+ Args:
+ depth (int): number of layers in the transformer
+ embedding_dim (int): the channel dimension for the input embeddings
+ num_heads (int): the number of heads for multihead attention. Must
+ divide embedding_dim
+ mlp_dim (int): the channel dimension internal to the MLP block
+ activation (nn.Module): the activation to use in the MLP block
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+ self.final_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self,
+ image_embedding: Tensor,
+ image_pe: Tensor,
+ point_embedding: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ image_embedding (torch.Tensor): image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding (torch.Tensor): the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding
+ torch.Tensor: the processed image_embedding
+ """
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ bs, c, h, w = image_embedding.shape
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys = layer(
+ queries=queries,
+ keys=keys,
+ query_pe=point_embedding,
+ key_pe=image_pe,
+ )
+
+ # Apply the final attenion layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: Type[nn.Module] = nn.ReLU,
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
+ inputs.
+
+ Arguments:
+ embedding_dim (int): the channel dimension of the embeddings
+ num_heads (int): the number of heads in the attention layers
+ mlp_dim (int): the hidden dimension of the mlp block
+ activation (nn.Module): the activation of the mlp block
+ skip_first_layer_pe (bool): skip the PE on the first layer
+ """
+ super().__init__()
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+ )
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ else:
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class Attention(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ downsample_rate: int = 1,
+ ) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
+
+ def _recombine_heads(self, x: Tensor) -> Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+
+ # Get output
+ out = attn @ v
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
diff --git a/projects/instance_segment_anything/models/segment_anything/predictor.py b/projects/instance_segment_anything/models/segment_anything/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..00211b36bd566456c96508419108ed5788e08777
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/predictor.py
@@ -0,0 +1,270 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .modeling import Sam
+
+from typing import Optional, Tuple
+
+from .utils.transforms import ResizeLongestSide
+
+
+class SamPredictor(nn.Module):
+ def __init__(
+ self,
+ sam_model: Sam,
+ ) -> None:
+ """
+ Uses SAM to calculate the image embedding for an image, and then
+ allow repeated, efficient mask prediction given prompts.
+
+ Arguments:
+ sam_model (Sam): The model to use for mask prediction.
+ """
+ super().__init__()
+ self.model = sam_model
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
+ self.reset_image()
+
+ def set_image(
+ self,
+ image: np.ndarray,
+ image_format: str = "RGB",
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method.
+
+ Arguments:
+ image (np.ndarray): The image for calculating masks. Expects an
+ image in HWC uint8 format, with pixel values in [0, 255].
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
+ """
+ assert image_format in [
+ "RGB",
+ "BGR",
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+ if image_format != self.model.image_format:
+ image = image[..., ::-1]
+
+ # Transform the image to the form expected by the model
+ input_image = self.transform.apply_image(image)
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
+
+ self.set_torch_image(input_image_torch, image.shape[:2])
+
+ @torch.no_grad()
+ def set_torch_image(
+ self,
+ transformed_image: torch.Tensor,
+ original_image_size: Tuple[int, ...],
+ ) -> None:
+ """
+ Calculates the image embeddings for the provided image, allowing
+ masks to be predicted with the 'predict' method. Expects the input
+ image to be already transformed to the format expected by the model.
+
+ Arguments:
+ transformed_image (torch.Tensor): The input image, with shape
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
+ original_image_size (tuple(int, int)): The size of the image
+ before transformation, in (H, W) format.
+ """
+ assert (
+ len(transformed_image.shape) == 4
+ and transformed_image.shape[1] == 3
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
+ self.reset_image()
+
+ self.original_size = original_image_size
+ self.input_size = tuple(transformed_image.shape[-2:])
+ input_image = self.model.preprocess(transformed_image)
+ self.features = self.model.image_encoder(input_image)
+ self.is_image_set = True
+
+ def predict(
+ self,
+ point_coords: Optional[np.ndarray] = None,
+ point_labels: Optional[np.ndarray] = None,
+ box: Optional[np.ndarray] = None,
+ mask_input: Optional[np.ndarray] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+
+ Arguments:
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (np.ndarray or None): A length N array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A length 4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form 1xHxW, where
+ for SAM, H=W=256.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (np.ndarray): The output masks in CxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (np.ndarray): An array of length C containing the model's
+ predictions for the quality of each mask.
+ (np.ndarray): An array of shape CxHxW, where C is the number
+ of masks and H=W=256. These low resolution logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ # Transform input prompts
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
+ if point_coords is not None:
+ assert (
+ point_labels is not None
+ ), "point_labels must be supplied if point_coords is supplied."
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+ if box is not None:
+ box = self.transform.apply_boxes(box, self.original_size)
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
+ box_torch = box_torch[None, :]
+ if mask_input is not None:
+ mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
+ mask_input_torch = mask_input_torch[None, :, :, :]
+
+ masks, iou_predictions, low_res_masks = self.predict_torch(
+ coords_torch,
+ labels_torch,
+ box_torch,
+ mask_input_torch,
+ multimask_output,
+ return_logits=return_logits,
+ )
+
+ masks = masks[0].detach().cpu().numpy()
+ iou_predictions = iou_predictions[0].detach().cpu().numpy()
+ low_res_masks = low_res_masks[0].detach().cpu().numpy()
+ return masks, iou_predictions, low_res_masks
+
+ @torch.no_grad()
+ def predict_torch(
+ self,
+ point_coords: Optional[torch.Tensor],
+ point_labels: Optional[torch.Tensor],
+ boxes: Optional[torch.Tensor] = None,
+ mask_input: Optional[torch.Tensor] = None,
+ multimask_output: bool = True,
+ return_logits: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Predict masks for the given input prompts, using the currently set image.
+ Input prompts are batched torch tensors and are expected to already be
+ transformed to the input frame using ResizeLongestSide.
+
+ Arguments:
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+ model. Each point is in (X,Y) in pixels.
+ point_labels (torch.Tensor or None): A BxN array of labels for the
+ point prompts. 1 indicates a foreground point and 0 indicates a
+ background point.
+ box (np.ndarray or None): A Bx4 array given a box prompt to the
+ model, in XYXY format.
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
+ for SAM, H=W=256. Masks returned by a previous iteration of the
+ predict method do not need further transformation.
+ multimask_output (bool): If true, the model will return three masks.
+ For ambiguous input prompts (such as a single click), this will often
+ produce better masks than a single prediction. If only a single
+ mask is needed, the model's predicted quality score can be used
+ to select the best mask. For non-ambiguous prompts, such as multiple
+ input prompts, multimask_output=False can give better results.
+ return_logits (bool): If true, returns un-thresholded masks logits
+ instead of a binary mask.
+
+ Returns:
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
+ number of masks, and (H, W) is the original image size.
+ (torch.Tensor): An array of shape BxC containing the model's
+ predictions for the quality of each mask.
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
+ of masks and H=W=256. These low res logits can be passed to
+ a subsequent iteration as mask input.
+ """
+ if not self.is_image_set:
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+ if point_coords is not None:
+ points = (point_coords, point_labels)
+ else:
+ points = None
+
+ # Embed prompts
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
+ points=points,
+ boxes=boxes,
+ masks=mask_input,
+ )
+
+ # Predict masks
+ low_res_masks, iou_predictions = self.model.mask_decoder(
+ image_embeddings=self.features,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embeddings,
+ dense_prompt_embeddings=dense_embeddings,
+ multimask_output=multimask_output,
+ )
+
+ # Upscale the masks to the original image resolution
+ masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
+
+ if not return_logits:
+ masks = masks > self.model.mask_threshold
+
+ return masks, iou_predictions, low_res_masks
+
+ def get_image_embedding(self) -> torch.Tensor:
+ """
+ Returns the image embeddings for the currently set image, with
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
+ """
+ if not self.is_image_set:
+ raise RuntimeError(
+ "An image must be set with .set_image(...) to generate an embedding."
+ )
+ assert self.features is not None, "Features must exist if an image has been set."
+ return self.features
+
+ @property
+ def device(self) -> torch.device:
+ return self.model.device
+
+ def reset_image(self) -> None:
+ """Resets the currently set image."""
+ self.is_image_set = False
+ self.features = None
+ self.orig_h = None
+ self.orig_w = None
+ self.input_h = None
+ self.input_w = None
diff --git a/projects/instance_segment_anything/models/segment_anything/utils/__init__.py b/projects/instance_segment_anything/models/segment_anything/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/projects/instance_segment_anything/models/segment_anything/utils/amg.py b/projects/instance_segment_anything/models/segment_anything/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a137778e45c464c079658ecb87ec53270e789f7
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/utils/amg.py
@@ -0,0 +1,346 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+
+class MaskData:
+ """
+ A structure for storing masks and their related data in batched format.
+ Implements basic filtering and concatenation.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ for v in kwargs.values():
+ assert isinstance(
+ v, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats = dict(**kwargs)
+
+ def __setitem__(self, key: str, item: Any) -> None:
+ assert isinstance(
+ item, (list, np.ndarray, torch.Tensor)
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
+ self._stats[key] = item
+
+ def __delitem__(self, key: str) -> None:
+ del self._stats[key]
+
+ def __getitem__(self, key: str) -> Any:
+ return self._stats[key]
+
+ def items(self) -> ItemsView[str, Any]:
+ return self._stats.items()
+
+ def filter(self, keep: torch.Tensor) -> None:
+ for k, v in self._stats.items():
+ if v is None:
+ self._stats[k] = None
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = v[keep.detach().cpu().numpy()]
+ elif isinstance(v, list) and keep.dtype == torch.bool:
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+ elif isinstance(v, list):
+ self._stats[k] = [v[i] for i in keep]
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def cat(self, new_stats: "MaskData") -> None:
+ for k, v in new_stats.items():
+ if k not in self._stats or self._stats[k] is None:
+ self._stats[k] = deepcopy(v)
+ elif isinstance(v, torch.Tensor):
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+ elif isinstance(v, np.ndarray):
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+ elif isinstance(v, list):
+ self._stats[k] = self._stats[k] + deepcopy(v)
+ else:
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+ def to_numpy(self) -> None:
+ for k, v in self._stats.items():
+ if isinstance(v, torch.Tensor):
+ self._stats[k] = v.detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+ return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+ box_xywh = deepcopy(box_xyxy)
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
+ return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+ assert len(args) > 0 and all(
+ len(a) == len(args[0]) for a in args
+ ), "Batched iteration must have inputs of all the same size."
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+ for b in range(n_batches):
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+ """
+ Encodes masks to an uncompressed RLE, in the format expected by
+ pycoco tools.
+ """
+ # Put in fortran order and flatten h,w
+ b, h, w = tensor.shape
+ tensor = tensor.permute(0, 2, 1).flatten(1)
+
+ # Compute change indices
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
+ change_indices = diff.nonzero()
+
+ # Encode run length
+ out = []
+ for i in range(b):
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+ cur_idxs = torch.cat(
+ [
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ cur_idxs + 1,
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+ ]
+ )
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+ counts = [] if tensor[i, 0] == 0 else [0]
+ counts.extend(btw_idxs.detach().cpu().tolist())
+ out.append({"size": [h, w], "counts": counts})
+ return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+ """Compute a binary mask from an uncompressed RLE."""
+ h, w = rle["size"]
+ mask = np.empty(h * w, dtype=bool)
+ idx = 0
+ parity = False
+ for count in rle["counts"]:
+ mask[idx : idx + count] = parity
+ idx += count
+ parity ^= True
+ mask = mask.reshape(w, h)
+ return mask.transpose() # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+ return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+ """
+ Computes the stability score for a batch of masks. The stability
+ score is the IoU between the binary masks obtained by thresholding
+ the predicted mask logits at high and low values.
+ """
+ # One mask is always contained inside the other.
+ # Save memory by preventing unnecesary cast to torch.int64
+ intersections = (
+ (masks > (mask_threshold + threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ unions = (
+ (masks > (mask_threshold - threshold_offset))
+ .sum(-1, dtype=torch.int16)
+ .sum(-1, dtype=torch.int32)
+ )
+ return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+ offset = 1 / (2 * n_per_side)
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+ return points
+
+
+def build_all_layer_point_grids(
+ n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+ """Generates point grids for all crop layers."""
+ points_by_layer = []
+ for i in range(n_layers + 1):
+ n_points = int(n_per_side / (scale_per_layer**i))
+ points_by_layer.append(build_point_grid(n_points))
+ return points_by_layer
+
+
+def generate_crop_boxes(
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+ """
+ Generates a list of crop boxes of different sizes. Each layer
+ has (2**i)**2 boxes for the ith layer.
+ """
+ crop_boxes, layer_idxs = [], []
+ im_h, im_w = im_size
+ short_side = min(im_h, im_w)
+
+ # Original image
+ crop_boxes.append([0, 0, im_w, im_h])
+ layer_idxs.append(0)
+
+ def crop_len(orig_len, n_crops, overlap):
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+ for i_layer in range(n_layers):
+ n_crops_per_side = 2 ** (i_layer + 1)
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+ # Crops in XYWH format
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+ crop_boxes.append(box)
+ layer_idxs.append(i_layer + 1)
+
+ return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+ # Check if boxes has a channel dimension
+ if len(boxes.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+ x0, y0, _, _ = crop_box
+ offset = torch.tensor([[x0, y0]], device=points.device)
+ # Check if points has a channel dimension
+ if len(points.shape) == 3:
+ offset = offset.unsqueeze(1)
+ return points + offset
+
+
+def uncrop_masks(
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+ x0, y0, x1, y1 = crop_box
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+ return masks
+ # Coordinate transform masks
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
+ return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+ mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+ """
+ Removes small disconnected regions and holes in a mask. Returns the
+ mask and an indicator of if the mask has been modified.
+ """
+ import cv2 # type: ignore
+
+ assert mode in ["holes", "islands"]
+ correct_holes = mode == "holes"
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+ sizes = stats[:, -1][1:] # Row 0 is background label
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+ if len(small_regions) == 0:
+ return mask, False
+ fill_labels = [0] + small_regions
+ if not correct_holes:
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+ # If every region is below threshold, keep largest
+ if len(fill_labels) == 0:
+ fill_labels = [int(np.argmax(sizes)) + 1]
+ mask = np.isin(regions, fill_labels)
+ return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+ from pycocotools import mask as mask_utils # type: ignore
+
+ h, w = uncompressed_rle["size"]
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
+ return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+ """
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+ """
+ # torch.max below raises an error on empty inputs, just skip in this case
+ if torch.numel(masks) == 0:
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+ # Normalize shape to CxHxW
+ shape = masks.shape
+ h, w = shape[-2:]
+ if len(shape) > 2:
+ masks = masks.flatten(0, -3)
+ else:
+ masks = masks.unsqueeze(0)
+
+ # Get top and bottom edges
+ in_height, _ = torch.max(masks, dim=-1)
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+ in_height_coords = in_height_coords + h * (~in_height)
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+ # Get left and right edges
+ in_width, _ = torch.max(masks, dim=-2)
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
+ in_width_coords = in_width_coords + w * (~in_width)
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+ # If the mask is empty the right edge will be to the left of the left edge.
+ # Replace these boxes with [0, 0, 0, 0]
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+ out = out * (~empty_filter).unsqueeze(-1)
+
+ # Return to original shape
+ if len(shape) > 2:
+ out = out.reshape(*shape[:-2], 4)
+ else:
+ out = out[0]
+
+ return out
diff --git a/projects/instance_segment_anything/models/segment_anything/utils/onnx.py b/projects/instance_segment_anything/models/segment_anything/utils/onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..4297b31291e036700d6ad0b818afb7dd72da3054
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/utils/onnx.py
@@ -0,0 +1,144 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from typing import Tuple
+
+from ..modeling import Sam
+from .amg import calculate_stability_score
+
+
+class SamOnnxModel(nn.Module):
+ """
+ This model should not be called directly, but is used in ONNX export.
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
+ with some functions modified to enable model tracing. Also supports extra
+ options controlling what information. See the ONNX export script for details.
+ """
+
+ def __init__(
+ self,
+ model: Sam,
+ return_single_mask: bool,
+ use_stability_score: bool = False,
+ return_extra_metrics: bool = False,
+ ) -> None:
+ super().__init__()
+ self.mask_decoder = model.mask_decoder
+ self.model = model
+ self.img_size = model.image_encoder.img_size
+ self.return_single_mask = return_single_mask
+ self.use_stability_score = use_stability_score
+ self.stability_score_offset = 1.0
+ self.return_extra_metrics = return_extra_metrics
+
+ @staticmethod
+ def resize_longest_image_size(
+ input_image_size: torch.Tensor, longest_side: int
+ ) -> torch.Tensor:
+ input_image_size = input_image_size.to(torch.float32)
+ scale = longest_side / torch.max(input_image_size)
+ transformed_size = scale * input_image_size
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
+ return transformed_size
+
+ def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
+ point_coords = point_coords + 0.5
+ point_coords = point_coords / self.img_size
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
+
+ point_embedding = point_embedding * (point_labels != -1)
+ point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
+ point_labels == -1
+ )
+
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
+ point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
+ i
+ ].weight * (point_labels == i)
+
+ return point_embedding
+
+ def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
+ mask_embedding = mask_embedding + (
+ 1 - has_mask_input
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
+ return mask_embedding
+
+ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
+ masks = F.interpolate(
+ masks,
+ size=(self.img_size, self.img_size),
+ mode="bilinear",
+ align_corners=False,
+ )
+
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
+ masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
+
+ orig_im_size = orig_im_size.to(torch.int64)
+ h, w = orig_im_size[0], orig_im_size[1]
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
+ return masks
+
+ def select_masks(
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Determine if we should return the multiclick mask or not from the number of points.
+ # The reweighting is used to avoid control flow.
+ score_reweight = torch.tensor(
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
+ ).to(iou_preds.device)
+ score = iou_preds + (num_points - 2.5) * score_reweight
+ best_idx = torch.argmax(score, dim=1)
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
+
+ return masks, iou_preds
+
+ @torch.no_grad()
+ def forward(
+ self,
+ image_embeddings: torch.Tensor,
+ point_coords: torch.Tensor,
+ point_labels: torch.Tensor,
+ mask_input: torch.Tensor,
+ has_mask_input: torch.Tensor,
+ orig_im_size: torch.Tensor,
+ ):
+ sparse_embedding = self._embed_points(point_coords, point_labels)
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
+
+ masks, scores = self.model.mask_decoder.predict_masks(
+ image_embeddings=image_embeddings,
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
+ sparse_prompt_embeddings=sparse_embedding,
+ dense_prompt_embeddings=dense_embedding,
+ )
+
+ if self.use_stability_score:
+ scores = calculate_stability_score(
+ masks, self.model.mask_threshold, self.stability_score_offset
+ )
+
+ if self.return_single_mask:
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
+
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
+
+ if self.return_extra_metrics:
+ stability_scores = calculate_stability_score(
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
+ )
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
+ return upscaled_masks, scores, stability_scores, areas, masks
+
+ return upscaled_masks, scores, masks
diff --git a/projects/instance_segment_anything/models/segment_anything/utils/transforms.py b/projects/instance_segment_anything/models/segment_anything/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ad346661f84b0647026e130a552c4b38b83e2ac
--- /dev/null
+++ b/projects/instance_segment_anything/models/segment_anything/utils/transforms.py
@@ -0,0 +1,102 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.transforms.functional import resize, to_pil_image # type: ignore
+
+from copy import deepcopy
+from typing import Tuple
+
+
+class ResizeLongestSide:
+ """
+ Resizes images to longest side 'target_length', as well as provides
+ methods for resizing coordinates and boxes. Provides methods for
+ transforming both numpy array and batched torch tensors.
+ """
+
+ def __init__(self, target_length: int) -> None:
+ self.target_length = target_length
+
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
+ """
+ Expects a numpy array with shape HxWxC in uint8 format.
+ """
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ return np.array(resize(to_pil_image(image), target_size))
+
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array of length 2 in the final dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).astype(float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ """
+ Expects a numpy array shape Bx4. Requires the original image size
+ in (H, W) format.
+ """
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
+ """
+ Expects batched images with shape BxCxHxW and float format. This
+ transformation may not exactly match apply_image. apply_image is
+ the transformation expected by the model.
+ """
+ # Expects an image in BCHW format. May not exactly match apply_image.
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ return F.interpolate(
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
+ )
+
+ def apply_coords_torch(
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with length 2 in the last dimension. Requires the
+ original image size in (H, W) format.
+ """
+ old_h, old_w = original_size
+ new_h, new_w = self.get_preprocess_shape(
+ original_size[0], original_size[1], self.target_length
+ )
+ coords = deepcopy(coords).to(torch.float)
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
+ return coords
+
+ def apply_boxes_torch(
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
+ ) -> torch.Tensor:
+ """
+ Expects a torch tensor with shape Bx4. Requires the original image
+ size in (H, W) format.
+ """
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
+ return boxes.reshape(-1, 4)
+
+ @staticmethod
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
+ """
+ Compute the output size given input size and target long side length.
+ """
+ scale = long_side_length * 1.0 / max(oldh, oldw)
+ newh, neww = oldh * scale, oldw * scale
+ neww = int(neww + 0.5)
+ newh = int(newh + 0.5)
+ return (newh, neww)
diff --git a/projects/instance_segment_anything/ops/functions/__init__.py b/projects/instance_segment_anything/ops/functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f682455af45d3687f0266acce6018741fe7c303
--- /dev/null
+++ b/projects/instance_segment_anything/ops/functions/__init__.py
@@ -0,0 +1,10 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
+
diff --git a/projects/instance_segment_anything/ops/functions/ms_deform_attn_func.py b/projects/instance_segment_anything/ops/functions/ms_deform_attn_func.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3037fc0ad822aeb2e1fbbf1dbc6789ca0259c9d
--- /dev/null
+++ b/projects/instance_segment_anything/ops/functions/ms_deform_attn_func.py
@@ -0,0 +1,117 @@
+# ------------------------------------------------------------------------
+# H-DETR
+# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
+# Licensed under the MIT-style license found in the LICENSE file in the root directory
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+try:
+ import MultiScaleDeformableAttention as MSDA
+except:
+ pass
+
+
+class MSDeformAttnFunction(Function):
+ @staticmethod
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(
+ ctx,
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ im2col_step,
+ ):
+ ctx.im2col_step = im2col_step
+ output = MSDA.ms_deform_attn_forward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ ctx.im2col_step,
+ )
+ ctx.save_for_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ )
+ return output
+
+ @staticmethod
+ @once_differentiable
+ @torch.cuda.amp.custom_bwd
+ def backward(ctx, grad_output):
+ (
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ ) = ctx.saved_tensors
+ grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output,
+ ctx.im2col_step,
+ )
+
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
+
+
+def ms_deform_attn_core_pytorch(
+ value, value_spatial_shapes, sampling_locations, attention_weights
+):
+ # for debug and test only,
+ # need to use cuda version instead
+ N_, S_, M_, D_ = value.shape
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+ value_l_ = (
+ value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
+ )
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
+ # N_*M_, D_, Lq_, P_
+ sampling_value_l_ = F.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode="bilinear",
+ padding_mode="zeros",
+ align_corners=False,
+ )
+ sampling_value_list.append(sampling_value_l_)
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ N_ * M_, 1, Lq_, L_ * P_
+ )
+ output = (
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+ .sum(-1)
+ .view(N_, M_ * D_, Lq_)
+ )
+ return output.transpose(1, 2).contiguous()
diff --git a/projects/instance_segment_anything/ops/make.sh b/projects/instance_segment_anything/ops/make.sh
new file mode 100644
index 0000000000000000000000000000000000000000..106b685722bc6ed70a06bf04309e75e62f73a430
--- /dev/null
+++ b/projects/instance_segment_anything/ops/make.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+python setup.py build install
diff --git a/projects/instance_segment_anything/ops/modules/__init__.py b/projects/instance_segment_anything/ops/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f82cb1ad9d634a87b54ba6a71b58a230bcade5fe
--- /dev/null
+++ b/projects/instance_segment_anything/ops/modules/__init__.py
@@ -0,0 +1,9 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn import MSDeformAttn
diff --git a/projects/instance_segment_anything/ops/modules/ms_deform_attn.py b/projects/instance_segment_anything/ops/modules/ms_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..09ccc2d5951952850e5f12552ef6adf5afd2181a
--- /dev/null
+++ b/projects/instance_segment_anything/ops/modules/ms_deform_attn.py
@@ -0,0 +1,174 @@
+# ------------------------------------------------------------------------
+# H-DETR
+# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
+# Licensed under the MIT-style license found in the LICENSE file in the root directory
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import warnings
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_uniform_, constant_
+from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
+
+from ..functions import MSDeformAttnFunction, ms_deform_attn_core_pytorch
+
+
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError(
+ "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))
+ )
+ return (n & (n - 1) == 0) and n != 0
+
+
+class MSDeformAttn(nn.Module):
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
+ """
+ Multi-Scale Deformable Attention Module
+ :param d_model hidden dimension
+ :param n_levels number of feature levels
+ :param n_heads number of attention heads
+ :param n_points number of sampling points per attention head per feature level
+ """
+ super().__init__()
+ if d_model % n_heads != 0:
+ raise ValueError(
+ "d_model must be divisible by n_heads, but got {} and {}".format(
+ d_model, n_heads
+ )
+ )
+ _d_per_head = d_model // n_heads
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
+ if not _is_power_of_2(_d_per_head):
+ warnings.warn(
+ "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
+ "which is more efficient in our CUDA implementation."
+ )
+
+ self.im2col_step = 64
+
+ self.d_model = d_model
+ self.n_levels = n_levels
+ self.n_heads = n_heads
+ self.n_points = n_points
+
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.output_proj = nn.Linear(d_model, d_model)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ constant_(self.sampling_offsets.weight.data, 0.0)
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (
+ 2.0 * math.pi / self.n_heads
+ )
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+ .view(self.n_heads, 1, 1, 2)
+ .repeat(1, self.n_levels, self.n_points, 1)
+ )
+ for i in range(self.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias.data = grid_init.view(-1)
+ constant_(self.attention_weights.weight.data, 0.0)
+ constant_(self.attention_weights.bias.data, 0.0)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.0)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.0)
+
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(
+ self,
+ query,
+ reference_points,
+ input_flatten,
+ input_spatial_shapes,
+ input_level_start_index,
+ input_padding_mask=None,
+ ):
+ """
+ :param query (N, Length_{query}, C)
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
+
+ :return output (N, Length_{query}, C)
+ """
+ N, Len_q, _ = query.shape
+ N, Len_in, _ = input_flatten.shape
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+ value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
+ sampling_offsets = self.sampling_offsets(query).view(
+ N, Len_q, self.n_heads, self.n_levels, self.n_points, 2
+ )
+ attention_weights = self.attention_weights(query).view(
+ N, Len_q, self.n_heads, self.n_levels * self.n_points
+ )
+ attention_weights = F.softmax(attention_weights, -1).view(
+ N, Len_q, self.n_heads, self.n_levels, self.n_points
+ )
+ # N, Len_q, n_heads, n_levels, n_points, 2
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1
+ )
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :]
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ )
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = (
+ reference_points[:, :, None, :, None, :2]
+ + sampling_offsets
+ / self.n_points
+ * reference_points[:, :, None, :, None, 2:]
+ * 0.5
+ )
+ else:
+ raise ValueError(
+ "Last dim of reference_points must be 2 or 4, but get {} instead.".format(
+ reference_points.shape[-1]
+ )
+ )
+ if ((IS_CUDA_AVAILABLE and value.is_cuda)
+ or (IS_MLU_AVAILABLE and value.is_mlu)):
+ output = MSDeformAttnFunction.apply(
+ value,
+ input_spatial_shapes,
+ input_level_start_index,
+ sampling_locations,
+ attention_weights,
+ self.im2col_step,
+ )
+ else:
+ output = ms_deform_attn_core_pytorch(value,
+ input_spatial_shapes,
+ sampling_locations,
+ attention_weights)
+ output = self.output_proj(output)
+ return output
diff --git a/projects/instance_segment_anything/ops/setup.py b/projects/instance_segment_anything/ops/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0131bc21cf1b45b90fcf174e2c53e4c08e9c641
--- /dev/null
+++ b/projects/instance_segment_anything/ops/setup.py
@@ -0,0 +1,71 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+import os
+import glob
+
+import torch
+
+from torch.utils.cpp_extension import CUDA_HOME
+from torch.utils.cpp_extension import CppExtension
+from torch.utils.cpp_extension import CUDAExtension
+
+from setuptools import find_packages
+from setuptools import setup
+
+requirements = ["torch", "torchvision"]
+
+def get_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ extensions_dir = os.path.join(this_dir, "src")
+
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
+
+ sources = main_file + source_cpu
+ extension = CppExtension
+ extra_compile_args = {"cxx": []}
+ define_macros = []
+
+ if torch.cuda.is_available() and CUDA_HOME is not None:
+ extension = CUDAExtension
+ sources += source_cuda
+ define_macros += [("WITH_CUDA", None)]
+ extra_compile_args["nvcc"] = [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ else:
+ raise NotImplementedError('Cuda is not availabel')
+
+ sources = [os.path.join(extensions_dir, s) for s in sources]
+ include_dirs = [extensions_dir]
+ ext_modules = [
+ extension(
+ "MultiScaleDeformableAttention",
+ sources,
+ include_dirs=include_dirs,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+ return ext_modules
+
+setup(
+ name="MultiScaleDeformableAttention",
+ version="1.0",
+ author="Weijie Su",
+ url="https://github.com/fundamentalvision/Deformable-DETR",
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
+ packages=find_packages(exclude=("configs", "tests",)),
+ ext_modules=get_extensions(),
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
diff --git a/projects/instance_segment_anything/ops/src/cpu/ms_deform_attn_cpu.cpp b/projects/instance_segment_anything/ops/src/cpu/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e1bf854de1f3860d20b6fef5c1a17817c268e70a
--- /dev/null
+++ b/projects/instance_segment_anything/ops/src/cpu/ms_deform_attn_cpu.cpp
@@ -0,0 +1,41 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
diff --git a/projects/instance_segment_anything/ops/src/cpu/ms_deform_attn_cpu.h b/projects/instance_segment_anything/ops/src/cpu/ms_deform_attn_cpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..81b7b58a3d9502bbb684dc84687a526dedf94cae
--- /dev/null
+++ b/projects/instance_segment_anything/ops/src/cpu/ms_deform_attn_cpu.h
@@ -0,0 +1,33 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
+
diff --git a/projects/instance_segment_anything/ops/src/cuda/ms_deform_attn_cuda.cu b/projects/instance_segment_anything/ops/src/cuda/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..d6d583647cce987196d5ad1968a8a365a379e774
--- /dev/null
+++ b/projects/instance_segment_anything/ops/src/cuda/ms_deform_attn_cuda.cu
@@ -0,0 +1,153 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include "cuda/ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
\ No newline at end of file
diff --git a/projects/instance_segment_anything/ops/src/cuda/ms_deform_attn_cuda.h b/projects/instance_segment_anything/ops/src/cuda/ms_deform_attn_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..c7ae53f99c820ce6193b608ad344550348a0b42c
--- /dev/null
+++ b/projects/instance_segment_anything/ops/src/cuda/ms_deform_attn_cuda.h
@@ -0,0 +1,30 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
diff --git a/projects/instance_segment_anything/ops/src/cuda/ms_deform_im2col_cuda.cuh b/projects/instance_segment_anything/ops/src/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..6bc2acb7aea0eab2e9e91e769a16861e1652c284
--- /dev/null
+++ b/projects/instance_segment_anything/ops/src/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
\ No newline at end of file
diff --git a/projects/instance_segment_anything/ops/src/ms_deform_attn.h b/projects/instance_segment_anything/ops/src/ms_deform_attn.h
new file mode 100644
index 0000000000000000000000000000000000000000..ac0ef2ec25f7d0ee51ca2d807b159ddf85652017
--- /dev/null
+++ b/projects/instance_segment_anything/ops/src/ms_deform_attn.h
@@ -0,0 +1,62 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
diff --git a/projects/instance_segment_anything/ops/src/vision.cpp b/projects/instance_segment_anything/ops/src/vision.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2201f63a51dca16d0b31148ed2c9e8e47ec15bdc
--- /dev/null
+++ b/projects/instance_segment_anything/ops/src/vision.cpp
@@ -0,0 +1,16 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
diff --git a/projects/instance_segment_anything/ops/test.py b/projects/instance_segment_anything/ops/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dbf6d5547d131f01a8c5c28b76557bd27a9334b
--- /dev/null
+++ b/projects/instance_segment_anything/ops/test.py
@@ -0,0 +1,89 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import time
+import torch
+import torch.nn as nn
+from torch.autograd import gradcheck
+
+from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
+
+
+N, M, D = 1, 2, 2
+Lq, L, P = 2, 2, 2
+shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
+level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
+S = sum([(H*W).item() for H, W in shapes])
+
+
+torch.manual_seed(3)
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_double():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_float():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
+
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ func = MSDeformAttnFunction.apply
+
+ value.requires_grad = grad_value
+ sampling_locations.requires_grad = grad_sampling_loc
+ attention_weights.requires_grad = grad_attn_weight
+
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
+
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
+
+
+if __name__ == '__main__':
+ check_forward_equal_with_pytorch_double()
+ check_forward_equal_with_pytorch_float()
+
+ for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
+ check_gradient_numerical(channels, True, True, True)
+
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..19b4c600ff62a5fa2811f09dd6955cd59240ba6e
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+torch==1.10.0
+torchvision==0.11.0
+cython
+numpy
+matplotlib
+pycocotools
+scipy
+six
+terminaltables
+mmcv-full
+gradio
diff --git a/requirements/albu.txt b/requirements/albu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f421fbbdc472527e6010cb62a7d0236cf034f24f
--- /dev/null
+++ b/requirements/albu.txt
@@ -0,0 +1 @@
+albumentations>=0.3.2 --no-binary qudida,albumentations
diff --git a/requirements/build.txt b/requirements/build.txt
new file mode 100644
index 0000000000000000000000000000000000000000..81558298594a9619f3187d220f1accede1865de7
--- /dev/null
+++ b/requirements/build.txt
@@ -0,0 +1,3 @@
+# These must be installed before building mmdetection
+cython
+numpy
diff --git a/requirements/docs.txt b/requirements/docs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b5626007318406548f03ebe65a7afbd90e3c82d0
--- /dev/null
+++ b/requirements/docs.txt
@@ -0,0 +1,8 @@
+docutils==0.16.0
+markdown>=3.4.0
+myst-parser
+-e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
+sphinx==5.3.0
+sphinx-copybutton
+sphinx_markdown_tables>=0.0.17
+sphinx_rtd_theme
diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b53dbf4108e40ef58b06e02ee38f89f9bea4f806
--- /dev/null
+++ b/requirements/mminstall.txt
@@ -0,0 +1 @@
+mmcv-full>=1.3.17
diff --git a/requirements/optional.txt b/requirements/optional.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4f0065a9b4d86ff0909ce8db000ebe54e51743c8
--- /dev/null
+++ b/requirements/optional.txt
@@ -0,0 +1,3 @@
+cityscapesscripts
+imagecorruptions
+scikit-learn
diff --git a/requirements/readthedocs.txt b/requirements/readthedocs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e1bf21b3ecead215c158e5a3c57f0651ac9c7155
--- /dev/null
+++ b/requirements/readthedocs.txt
@@ -0,0 +1,4 @@
+mmcv
+scipy
+torch
+torchvision
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c815aef86c2b8744bb9fa1da64582c43b520ca9f
--- /dev/null
+++ b/requirements/runtime.txt
@@ -0,0 +1,6 @@
+matplotlib
+numpy
+pycocotools
+scipy
+six
+terminaltables
diff --git a/requirements/tests.txt b/requirements/tests.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2ff795a58f9781bf4b399cee47ad5a48059aaf58
--- /dev/null
+++ b/requirements/tests.txt
@@ -0,0 +1,15 @@
+asynctest
+codecov
+flake8
+interrogate
+isort==4.3.21
+# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
+kwarray
+-e git+https://github.com/open-mmlab/mmtracking#egg=mmtrack
+onnx==1.7.0
+onnxruntime>=1.8.0
+protobuf<=3.20.1
+pytest
+ubelt
+xdoctest>=0.10.0
+yapf
diff --git a/tools/convert_ckpt.py b/tools/convert_ckpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cac1a066682ede92cbb879ce275399d2006aa10
--- /dev/null
+++ b/tools/convert_ckpt.py
@@ -0,0 +1,11 @@
+import torch
+import argparse
+
+parer = argparse.ArgumentParser()
+parer.add_argument('source_file')
+parer.add_argument('des_file')
+args = parer.parse_args()
+
+ckpt = torch.load(args.source_file, map_location='cpu')
+ckpt = ckpt['model']
+torch.save(ckpt, args.des_file)
\ No newline at end of file