|
"""
|
|
PyTorch-native implementation of the L-Mul algorithm.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
__all__ = [
|
|
"l_mul_tensor",
|
|
"l_mul_attention",
|
|
]
|
|
|
|
def l_mul_tensor(
|
|
x: torch.Tensor, y: torch.Tensor, *, offset: float = 1.0 / 16.0
|
|
) -> torch.Tensor:
|
|
"""
|
|
Approximates `x * y` element-wise using the L-Mul algorithm.
|
|
"""
|
|
sign = torch.sign(x) * torch.sign(y)
|
|
|
|
x_abs = torch.abs(x)
|
|
y_abs = torch.abs(y)
|
|
|
|
|
|
|
|
mx, ex = torch.frexp(x_abs)
|
|
my, ey = torch.frexp(y_abs)
|
|
|
|
|
|
|
|
mant_a = mx * 2.0
|
|
mant_b = my * 2.0
|
|
exp_a = ex - 1
|
|
exp_b = ey - 1
|
|
|
|
|
|
result_mant = (mant_a - 1.0) + (mant_b - 1.0) + 1.0 + offset
|
|
result_exp = exp_a + exp_b
|
|
|
|
|
|
result = torch.ldexp(result_mant, result_exp)
|
|
|
|
|
|
final_result = sign * result
|
|
final_result[x == 0] = 0
|
|
final_result[y == 0] = 0
|
|
|
|
return final_result
|
|
|
|
def l_mul_attention(
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
mask: torch.Tensor | None = None,
|
|
dropout: torch.nn.Module | None = None
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Scaled dot-product attention where matrix multiplications are replaced
|
|
by the L-Mul approximation for performance.
|
|
"""
|
|
d_k = query.size(-1)
|
|
|
|
|
|
|
|
scores = l_mul_tensor(
|
|
query.unsqueeze(-1),
|
|
key.transpose(-2, -1).unsqueeze(-3)
|
|
).sum(-2)
|
|
|
|
scores = scores / (d_k ** 0.5)
|
|
|
|
if mask is not None:
|
|
scores = scores.masked_fill(mask == 0, -1e9)
|
|
|
|
attn_probs = torch.nn.functional.softmax(scores, dim=-1)
|
|
if dropout is not None:
|
|
attn_probs = dropout(attn_probs)
|
|
|
|
|
|
output = l_mul_tensor(
|
|
attn_probs.unsqueeze(-1),
|
|
value.unsqueeze(-3)
|
|
).sum(-2)
|
|
|
|
return output, attn_probs
|
|
|
|
|
|
|
|
|