Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import importlib | |
| import os | |
| import pkgutil | |
| import warnings | |
| from collections import namedtuple | |
| import torch | |
| if torch.__version__ != 'parrots': | |
| def load_ext(name, funcs): | |
| ext = importlib.import_module('mmcv.' + name) | |
| for fun in funcs: | |
| assert hasattr(ext, fun), f'{fun} miss in module {name}' | |
| return ext | |
| else: | |
| from parrots import extension | |
| from parrots.base import ParrotsException | |
| has_return_value_ops = [ | |
| 'nms', | |
| 'softnms', | |
| 'nms_match', | |
| 'nms_rotated', | |
| 'top_pool_forward', | |
| 'top_pool_backward', | |
| 'bottom_pool_forward', | |
| 'bottom_pool_backward', | |
| 'left_pool_forward', | |
| 'left_pool_backward', | |
| 'right_pool_forward', | |
| 'right_pool_backward', | |
| 'fused_bias_leakyrelu', | |
| 'upfirdn2d', | |
| 'ms_deform_attn_forward', | |
| 'pixel_group', | |
| 'contour_expand', | |
| ] | |
| def get_fake_func(name, e): | |
| def fake_func(*args, **kwargs): | |
| warnings.warn(f'{name} is not supported in parrots now') | |
| raise e | |
| return fake_func | |
| def load_ext(name, funcs): | |
| ExtModule = namedtuple('ExtModule', funcs) | |
| ext_list = [] | |
| lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | |
| for fun in funcs: | |
| try: | |
| ext_fun = extension.load(fun, name, lib_dir=lib_root) | |
| except ParrotsException as e: | |
| if 'No element registered' not in e.message: | |
| warnings.warn(e.message) | |
| ext_fun = get_fake_func(fun, e) | |
| ext_list.append(ext_fun) | |
| else: | |
| if fun in has_return_value_ops: | |
| ext_list.append(ext_fun.op) | |
| else: | |
| ext_list.append(ext_fun.op_) | |
| return ExtModule(*ext_list) | |
| def check_ops_exist(): | |
| ext_loader = pkgutil.find_loader('mmcv._ext') | |
| return ext_loader is not None | |