| import torch | |
| import torch.nn as nn | |
| from attn import QuantScaledDotProductAttention | |
| torch.manual_seed(0) | |
| batch_size = 1 | |
| seq_len = 11 | |
| hidden_size = 21 | |
| query = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1. | |
| key = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1. | |
| value = 2.*torch.rand((batch_size,seq_len,hidden_size)) - 1. | |
| quant_params = { | |
| "output_softmax_quant": { | |
| "act_scale": 1./240., | |
| "act_scale_shape": [], | |
| "act_zp": 0.0, | |
| "act_zp_shape": [], | |
| "act_zp_dtype": "torch.float8_e4m3fnuz" | |
| }, | |
| "out_q": { | |
| "act_scale": torch.max(torch.abs(query)) / 240., | |
| "act_scale_shape": [], | |
| "act_zp": 0.0, | |
| "act_zp_shape": [], | |
| "act_zp_dtype": "torch.float8_e4m3fnuz" | |
| }, | |
| "out_k": { | |
| "act_scale": torch.max(torch.abs(key)) / 240., | |
| "act_scale_shape": [], | |
| "act_zp": 0.0, | |
| "act_zp_shape": [], | |
| "act_zp_dtype": "torch.float8_e4m3fnuz" | |
| }, | |
| "out_v": { | |
| "act_scale": torch.max(torch.abs(value)) / 240., | |
| "act_scale_shape": [], | |
| "act_zp": 0.0, | |
| "act_zp_shape": [], | |
| "act_zp_dtype": "torch.float8_e4m3fnuz" | |
| }, | |
| } | |
| print(quant_params) | |
| qsdpa = QuantScaledDotProductAttention(quant_params) | |
| o_qdq = qsdpa(query, key, value) | |
| o_qop = qsdpa(query, key, value, qop=True) | |
| print(o_qdq.shape) | |
| print(o_qop.shape) | |
| print(o_qdq - o_qop) | |