File size: 2,419 Bytes
6175b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""

PyTorch-native implementation of the L-Mul algorithm.

"""
from __future__ import annotations

import torch

__all__ = [
    "l_mul_tensor",
    "l_mul_attention",
]

def l_mul_tensor(

    x: torch.Tensor, y: torch.Tensor, *, offset: float = 1.0 / 16.0

) -> torch.Tensor:
    """

    Approximates `x * y` element-wise using the L-Mul algorithm.

    """
    sign = torch.sign(x) * torch.sign(y)
    
    x_abs = torch.abs(x)
    y_abs = torch.abs(y)
    
    # Decompose tensors into mantissa and exponent.
    # torch.frexp gives mantissa in [0.5, 1.0)
    mx, ex = torch.frexp(x_abs)
    my, ey = torch.frexp(y_abs)

    # The paper's logic implies a mantissa in [1.0, 2.0).
    # We reconstruct this by multiplying the mantissa by 2 and adjusting the exponent.
    mant_a = mx * 2.0
    mant_b = my * 2.0
    exp_a = ex - 1
    exp_b = ey - 1

    # Approximate multiplication using the L-Mul formula
    result_mant = (mant_a - 1.0) + (mant_b - 1.0) + 1.0 + offset
    result_exp = exp_a + exp_b
    
    # Reconstruct the final number
    result = torch.ldexp(result_mant, result_exp)
    
    # Apply the correct sign and handle zero inputs
    final_result = sign * result
    final_result[x == 0] = 0
    final_result[y == 0] = 0
    
    return final_result

def l_mul_attention(

    query: torch.Tensor,

    key: torch.Tensor,

    value: torch.Tensor,

    mask: torch.Tensor | None = None,

    dropout: torch.nn.Module | None = None

) -> tuple[torch.Tensor, torch.Tensor]:
    """

    Scaled dot-product attention where matrix multiplications are replaced

    by the L-Mul approximation for performance.

    """
    d_k = query.size(-1)

    # Approximate Q @ K.T using l_mul_tensor.
    # This requires broadcasting and summing to perform the matrix multiplication.
    scores = l_mul_tensor(
        query.unsqueeze(-1),
        key.transpose(-2, -1).unsqueeze(-3)
    ).sum(-2)
    
    scores = scores / (d_k ** 0.5)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    attn_probs = torch.nn.functional.softmax(scores, dim=-1)
    if dropout is not None:
        attn_probs = dropout(attn_probs)

    # Approximate Attn @ V using l_mul_tensor
    output = l_mul_tensor(
        attn_probs.unsqueeze(-1),
        value.unsqueeze(-3)
    ).sum(-2)
    
    return output, attn_probs