| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): | 
					
					
						
						| 
							 | 
						    dtype = tensor.dtype | 
					
					
						
						| 
							 | 
						    clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240.,  dtype=dtype) | 
					
					
						
						| 
							 | 
						    quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype) | 
					
					
						
						| 
							 | 
						    return quant_tensor | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): | 
					
					
						
						| 
							 | 
						    return tensor * scale | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def qdq_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): | 
					
					
						
						| 
							 | 
						    query = dequantize_fp8(quantize_fp8(query, query_scale), query_scale) | 
					
					
						
						| 
							 | 
						    key = dequantize_fp8(quantize_fp8(key, key_scale), key_scale) | 
					
					
						
						| 
							 | 
						    value = dequantize_fp8(quantize_fp8(value, value_scale), value_scale) | 
					
					
						
						| 
							 | 
						    L, S = query.size(-2), key.size(-2) | 
					
					
						
						| 
							 | 
						    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | 
					
					
						
						| 
							 | 
						    attn_bias = torch.zeros(L, S, dtype=query.dtype) | 
					
					
						
						| 
							 | 
						    if is_causal: | 
					
					
						
						| 
							 | 
						        assert attn_mask is None | 
					
					
						
						| 
							 | 
						        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) | 
					
					
						
						| 
							 | 
						        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | 
					
					
						
						| 
							 | 
						        attn_bias.to(query.dtype) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if attn_mask is not None: | 
					
					
						
						| 
							 | 
						        if attn_mask.dtype == torch.bool: | 
					
					
						
						| 
							 | 
						            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            attn_bias += attn_mask | 
					
					
						
						| 
							 | 
						    attn_weight = query @ key.transpose(-2, -1) * scale_factor | 
					
					
						
						| 
							 | 
						    attn_weight += attn_bias | 
					
					
						
						| 
							 | 
						    attn_weight = dequantize_fp8(quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale), softmax_scale) | 
					
					
						
						| 
							 | 
						    attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | 
					
					
						
						| 
							 | 
						    return attn_weight @ value | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def qop_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): | 
					
					
						
						| 
							 | 
						    query = quantize_fp8(query, query_scale) | 
					
					
						
						| 
							 | 
						    key = quantize_fp8(key, key_scale) | 
					
					
						
						| 
							 | 
						    value = quantize_fp8(value, value_scale) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    L, S = query.size(-2), key.size(-2) | 
					
					
						
						| 
							 | 
						    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | 
					
					
						
						| 
							 | 
						    scale_factor *= (query_scale * key_scale) | 
					
					
						
						| 
							 | 
						    attn_bias = torch.zeros(L, S, dtype=query.dtype) | 
					
					
						
						| 
							 | 
						    if is_causal: | 
					
					
						
						| 
							 | 
						        assert attn_mask is None | 
					
					
						
						| 
							 | 
						        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) | 
					
					
						
						| 
							 | 
						        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | 
					
					
						
						| 
							 | 
						        attn_bias.to(query.dtype) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if attn_mask is not None: | 
					
					
						
						| 
							 | 
						        if attn_mask.dtype == torch.bool: | 
					
					
						
						| 
							 | 
						            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            attn_bias += attn_mask | 
					
					
						
						| 
							 | 
						    attn_weight = (query @ key.transpose(-2, -1)) * scale_factor  | 
					
					
						
						| 
							 | 
						    attn_weight += attn_bias | 
					
					
						
						| 
							 | 
						    attn_weight = quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale) | 
					
					
						
						| 
							 | 
						    attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | 
					
					
						
						| 
							 | 
						    return (attn_weight @ value) * (softmax_scale * value_scale)  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class QuantScaledDotProductAttention(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__(self, quant_param): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        q_scale = torch.tensor(quant_param['out_q']['act_scale']).view(quant_param['out_q']['act_scale_shape']) | 
					
					
						
						| 
							 | 
						        k_scale = torch.tensor(quant_param['out_k']['act_scale']).view(quant_param['out_k']['act_scale_shape']) | 
					
					
						
						| 
							 | 
						        v_scale = torch.tensor(quant_param['out_v']['act_scale']).view(quant_param['out_v']['act_scale_shape']) | 
					
					
						
						| 
							 | 
						        sm_scale = torch.tensor(quant_param['output_softmax_quant']['act_scale']).view(quant_param['output_softmax_quant']['act_scale_shape']) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        assert quant_param['out_q']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Q Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_q']['act_zp_dtype']}" | 
					
					
						
						| 
							 | 
						        assert quant_param['out_k']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"K Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_k']['act_zp_dtype']}" | 
					
					
						
						| 
							 | 
						        assert quant_param['out_v']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"V Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_v']['act_zp_dtype']}" | 
					
					
						
						| 
							 | 
						        assert quant_param['output_softmax_quant']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"SoftMax Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['output_softmax_quant']['act_zp_dtype']}" | 
					
					
						
						| 
							 | 
						        self.register_buffer('q_scale', q_scale) | 
					
					
						
						| 
							 | 
						        self.register_buffer('k_scale', k_scale) | 
					
					
						
						| 
							 | 
						        self.register_buffer('v_scale', v_scale) | 
					
					
						
						| 
							 | 
						        self.register_buffer('sm_scale', sm_scale) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def qdq_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): | 
					
					
						
						| 
							 | 
						        return qdq_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def qop_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): | 
					
					
						
						| 
							 | 
						        return qop_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, qop=False): | 
					
					
						
						| 
							 | 
						        if qop: | 
					
					
						
						| 
							 | 
						            return self.qop_forward(query, key, value, attn_mask, dropout_p, is_causal, scale) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            return self.qdq_forward(query, key, value, attn_mask, dropout_p, is_causal, scale) | 
					
					
						
						| 
							 | 
						
 |