Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| from mmcv.device._functions import Scatter, scatter | |
| from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE | |
| def test_scatter(): | |
| # if the device is CPU, just return the input | |
| input = torch.zeros([1, 3, 3, 3]) | |
| output = scatter(input=input, devices=[-1]) | |
| assert torch.allclose(input, output) | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = scatter(input=inputs, devices=[-1]) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input, output) | |
| # if the device is MLU, copy the input from CPU to MLU | |
| if IS_MLU_AVAILABLE: | |
| input = torch.zeros([1, 3, 3, 3]) | |
| output = scatter(input=input, devices=[0]) | |
| assert torch.allclose(input.to('mlu'), output) | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = scatter(input=inputs, devices=[0]) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input.to('mlu'), output) | |
| # if the device is MPS, copy the input from CPU to MPS | |
| if IS_MPS_AVAILABLE: | |
| input = torch.zeros([1, 3, 3, 3]) | |
| output = scatter(input=input, devices=[0]) | |
| assert torch.allclose(input.to('mps'), output) | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = scatter(input=inputs, devices=[0]) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input.to('mps'), output) | |
| # input should be a tensor or list of tensor | |
| with pytest.raises(Exception): | |
| scatter(5, [-1]) | |
| def test_Scatter(): | |
| # if the device is CPU, just return the input | |
| target_devices = [-1] | |
| input = torch.zeros([1, 3, 3, 3]) | |
| outputs = Scatter.forward(target_devices, input) | |
| assert isinstance(outputs, tuple) | |
| assert torch.allclose(input, outputs[0]) | |
| target_devices = [-1] | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = Scatter.forward(target_devices, inputs) | |
| assert isinstance(outputs, tuple) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input, output) | |
| # if the device is MLU, copy the input from CPU to MLU | |
| if IS_MLU_AVAILABLE: | |
| target_devices = [0] | |
| input = torch.zeros([1, 3, 3, 3]) | |
| outputs = Scatter.forward(target_devices, input) | |
| assert isinstance(outputs, tuple) | |
| assert torch.allclose(input.to('mlu'), outputs[0]) | |
| target_devices = [0] | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = Scatter.forward(target_devices, inputs) | |
| assert isinstance(outputs, tuple) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input.to('mlu'), output[0]) | |
| # if the device is MPS, copy the input from CPU to MPS | |
| if IS_MPS_AVAILABLE: | |
| target_devices = [0] | |
| input = torch.zeros([1, 3, 3, 3]) | |
| outputs = Scatter.forward(target_devices, input) | |
| assert isinstance(outputs, tuple) | |
| assert torch.allclose(input.to('mps'), outputs[0]) | |
| target_devices = [0] | |
| inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | |
| outputs = Scatter.forward(target_devices, inputs) | |
| assert isinstance(outputs, tuple) | |
| for input, output in zip(inputs, outputs): | |
| assert torch.allclose(input.to('mps'), output[0]) | |