| # -*- coding: utf-8 -*- | |
| # File : unittest.py | |
| # Author : Jiayuan Mao | |
| # Email : [email protected] | |
| # Date : 27/01/2018 | |
| # | |
| # This file is part of Synchronized-BatchNorm-PyTorch. | |
| # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | |
| # Distributed under MIT License. | |
| import unittest | |
| import numpy as np | |
| from torch.autograd import Variable | |
| def as_numpy(v): | |
| if isinstance(v, Variable): | |
| v = v.data | |
| return v.cpu().numpy() | |
| class TorchTestCase(unittest.TestCase): | |
| def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): | |
| npa, npb = as_numpy(a), as_numpy(b) | |
| self.assertTrue( | |
| np.allclose(npa, npb, atol=atol), | |
| 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) | |
| ) | |

