Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import pytest | |
| import torch | |
| _USING_PARROTS = True | |
| try: | |
| from parrots.autograd import gradcheck | |
| except ImportError: | |
| from torch.autograd import gradcheck, gradgradcheck | |
| _USING_PARROTS = False | |
| class TestFusedBiasLeakyReLU: | |
| def setup_class(cls): | |
| if not torch.cuda.is_available(): | |
| return | |
| cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda() | |
| cls.bias = torch.zeros(2, requires_grad=True).cuda() | |
| def test_gradient(self): | |
| from mmcv.ops import FusedBiasLeakyReLU | |
| if _USING_PARROTS: | |
| gradcheck( | |
| FusedBiasLeakyReLU(2).cuda(), | |
| self.input_tensor, | |
| delta=1e-4, | |
| pt_atol=1e-3) | |
| else: | |
| gradcheck( | |
| FusedBiasLeakyReLU(2).cuda(), | |
| self.input_tensor, | |
| eps=1e-4, | |
| atol=1e-3) | |
| def test_gradgradient(self): | |
| from mmcv.ops import FusedBiasLeakyReLU | |
| gradgradcheck( | |
| FusedBiasLeakyReLU(2).cuda(), | |
| self.input_tensor, | |
| eps=1e-4, | |
| atol=1e-3) | |