| import torch | |
| import torch.nn as nn | |
| from math_model import QuantConv2d | |
| torch.manual_seed(0) | |
| batch_size = 1 | |
| out_ch = 128 | |
| in_ch = 64 | |
| k = 3 | |
| h = 5 | |
| w = 5 | |
| i = 2*torch.rand((batch_size,in_ch,h,w)) - 1. | |
| l = nn.Conv2d(in_ch, out_ch, k, bias=True) | |
| quant_params = { | |
| 'smoothquant_mul': torch.rand((in_ch,)), | |
| 'smoothquant_mul_shape': (1,in_ch,1,1), | |
| 'weight_scale': torch.rand((out_ch,)), | |
| 'weight_scale': torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values / 128., | |
| 'weight_scale_shape': (out_ch,1,1,1), | |
| 'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=(1,2,3))) * (128 / torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values)), -128, 127), | |
| 'weight_zp_shape': (out_ch,1,1,1), | |
| 'weight_zp_dtype': 'torch.int8', | |
| 'input_scale': torch.max(torch.abs(i)) / 128., | |
| 'input_scale_shape': tuple(), | |
| 'input_zp': torch.zeros((1,)), | |
| 'input_zp_shape': tuple(), | |
| 'input_zp_dtype': 'torch.int8', | |
| } | |
| print(quant_params) | |
| ql = QuantConv2d(in_ch, out_ch, k, quant_params) | |
| ql.conv2d.load_state_dict(l.state_dict()) | |
| o_qdq = ql(i) | |
| o_qop = ql(i, qop=True) | |
| print(o_qdq.shape) | |
| print(o_qop.shape) | |
| print(o_qdq - o_qop) | |