""" PyTorch-native implementation of the L-Mul algorithm. """ from __future__ import annotations import torch __all__ = [ "l_mul_tensor", "l_mul_attention", ] def l_mul_tensor( x: torch.Tensor, y: torch.Tensor, *, offset: float = 1.0 / 16.0 ) -> torch.Tensor: """ Approximates `x * y` element-wise using the L-Mul algorithm. """ sign = torch.sign(x) * torch.sign(y) x_abs = torch.abs(x) y_abs = torch.abs(y) # Decompose tensors into mantissa and exponent. # torch.frexp gives mantissa in [0.5, 1.0) mx, ex = torch.frexp(x_abs) my, ey = torch.frexp(y_abs) # The paper's logic implies a mantissa in [1.0, 2.0). # We reconstruct this by multiplying the mantissa by 2 and adjusting the exponent. mant_a = mx * 2.0 mant_b = my * 2.0 exp_a = ex - 1 exp_b = ey - 1 # Approximate multiplication using the L-Mul formula result_mant = (mant_a - 1.0) + (mant_b - 1.0) + 1.0 + offset result_exp = exp_a + exp_b # Reconstruct the final number result = torch.ldexp(result_mant, result_exp) # Apply the correct sign and handle zero inputs final_result = sign * result final_result[x == 0] = 0 final_result[y == 0] = 0 return final_result def l_mul_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor | None = None, dropout: torch.nn.Module | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """ Scaled dot-product attention where matrix multiplications are replaced by the L-Mul approximation for performance. """ d_k = query.size(-1) # Approximate Q @ K.T using l_mul_tensor. # This requires broadcasting and summing to perform the matrix multiplication. scores = l_mul_tensor( query.unsqueeze(-1), key.transpose(-2, -1).unsqueeze(-3) ).sum(-2) scores = scores / (d_k ** 0.5) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn_probs = torch.nn.functional.softmax(scores, dim=-1) if dropout is not None: attn_probs = dropout(attn_probs) # Approximate Attn @ V using l_mul_tensor output = l_mul_tensor( attn_probs.unsqueeze(-1), value.unsqueeze(-3) ).sum(-2) return output, attn_probs