Peacemann's picture
Add lmul-attention version of nvidia/Llama-3_3-Nemotron-Super-49B-v1
8010194 verified
"""
PyTorch-native implementation of the L-Mul algorithm.
"""
from __future__ import annotations
import torch
__all__ = [
"l_mul_tensor",
"l_mul_attention",
]
def l_mul_tensor(
x: torch.Tensor, y: torch.Tensor, *, offset: float = 1.0 / 16.0
) -> torch.Tensor:
"""
Approximates `x * y` element-wise using the L-Mul algorithm.
"""
sign = torch.sign(x) * torch.sign(y)
x_abs = torch.abs(x)
y_abs = torch.abs(y)
# Decompose tensors into mantissa and exponent.
# torch.frexp gives mantissa in [0.5, 1.0)
mx, ex = torch.frexp(x_abs)
my, ey = torch.frexp(y_abs)
# The paper's logic implies a mantissa in [1.0, 2.0).
# We reconstruct this by multiplying the mantissa by 2 and adjusting the exponent.
mant_a = mx * 2.0
mant_b = my * 2.0
exp_a = ex - 1
exp_b = ey - 1
# Approximate multiplication using the L-Mul formula
result_mant = (mant_a - 1.0) + (mant_b - 1.0) + 1.0 + offset
result_exp = exp_a + exp_b
# Reconstruct the final number
result = torch.ldexp(result_mant, result_exp)
# Apply the correct sign and handle zero inputs
final_result = sign * result
final_result[x == 0] = 0
final_result[y == 0] = 0
return final_result
def l_mul_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor | None = None,
dropout: torch.nn.Module | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Scaled dot-product attention where matrix multiplications are replaced
by the L-Mul approximation for performance.
"""
d_k = query.size(-1)
# Approximate Q @ K.T using l_mul_tensor.
# This requires broadcasting and summing to perform the matrix multiplication.
scores = l_mul_tensor(
query.unsqueeze(-1),
key.transpose(-2, -1).unsqueeze(-3)
).sum(-2)
scores = scores / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_probs = torch.nn.functional.softmax(scores, dim=-1)
if dropout is not None:
attn_probs = dropout(attn_probs)
# Approximate Attn @ V using l_mul_tensor
output = l_mul_tensor(
attn_probs.unsqueeze(-1),
value.unsqueeze(-3)
).sum(-2)
return output, attn_probs