Spaces:
Build error
Build error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| import tempfile | |
| import torch.nn as nn | |
| from tools.deployment.pytorch2onnx import _convert_batchnorm, pytorch2onnx | |
| class DummyModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv = nn.Conv3d(1, 2, 1) | |
| self.bn = nn.SyncBatchNorm(2) | |
| def forward(self, x): | |
| return self.bn(self.conv(x)) | |
| def forward_dummy(self, x): | |
| return (self.forward(x), ) | |
| def test_onnx_exporting(): | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| out_file = osp.join(tmpdir, 'tmp.onnx') | |
| model = DummyModel() | |
| model = _convert_batchnorm(model) | |
| # test exporting | |
| pytorch2onnx(model, (1, 1, 1, 1, 1), output_file=out_file) | |