Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | |
| dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel} | |
| ddp_factory = {'cuda': MMDistributedDataParallel} | |
| def build_dp(model, device='cuda', dim=0, *args, **kwargs): | |
| """build DataParallel module by device type. | |
| if device is cuda, return a MMDataParallel model; if device is mlu, | |
| return a MLUDataParallel model. | |
| Args: | |
| model (:class:`nn.Module`): model to be parallelized. | |
| device (str): device type, cuda, cpu or mlu. Defaults to cuda. | |
| dim (int): Dimension used to scatter the data. Defaults to 0. | |
| Returns: | |
| nn.Module: the model to be parallelized. | |
| """ | |
| if device == 'npu': | |
| from mmcv.device.npu import NPUDataParallel | |
| dp_factory['npu'] = NPUDataParallel | |
| torch.npu.set_device(kwargs['device_ids'][0]) | |
| torch.npu.set_compile_mode(jit_compile=False) | |
| model = model.npu() | |
| elif device == 'cuda': | |
| model = model.cuda(kwargs['device_ids'][0]) | |
| elif device == 'mlu': | |
| from mmcv.device.mlu import MLUDataParallel | |
| dp_factory['mlu'] = MLUDataParallel | |
| model = model.mlu() | |
| return dp_factory[device](model, dim=dim, *args, **kwargs) | |
| def build_ddp(model, device='cuda', *args, **kwargs): | |
| """Build DistributedDataParallel module by device type. | |
| If device is cuda, return a MMDistributedDataParallel model; | |
| if device is mlu, return a MLUDistributedDataParallel model. | |
| Args: | |
| model (:class:`nn.Module`): module to be parallelized. | |
| device (str): device type, mlu or cuda. | |
| Returns: | |
| :class:`nn.Module`: the module to be parallelized | |
| References: | |
| .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. | |
| DistributedDataParallel.html | |
| """ | |
| assert device in ['cuda', 'mlu', | |
| 'npu'], 'Only available for cuda or mlu or npu devices.' | |
| if device == 'npu': | |
| from mmcv.device.npu import NPUDistributedDataParallel | |
| torch.npu.set_compile_mode(jit_compile=False) | |
| ddp_factory['npu'] = NPUDistributedDataParallel | |
| model = model.npu() | |
| elif device == 'cuda': | |
| model = model.cuda() | |
| elif device == 'mlu': | |
| from mmcv.device.mlu import MLUDistributedDataParallel | |
| ddp_factory['mlu'] = MLUDistributedDataParallel | |
| model = model.mlu() | |
| return ddp_factory[device](model, *args, **kwargs) | |
| def is_npu_available(): | |
| """Returns a bool indicating if NPU is currently available.""" | |
| return hasattr(torch, 'npu') and torch.npu.is_available() | |
| def is_mlu_available(): | |
| """Returns a bool indicating if MLU is currently available.""" | |
| return hasattr(torch, 'is_mlu_available') and torch.is_mlu_available() | |
| def get_device(): | |
| """Returns an available device, cpu, cuda or mlu.""" | |
| is_device_available = { | |
| 'npu': is_npu_available(), | |
| 'cuda': torch.cuda.is_available(), | |
| 'mlu': is_mlu_available() | |
| } | |
| device_list = [k for k, v in is_device_available.items() if v] | |
| return device_list[0] if len(device_list) >= 1 else 'cpu' | |