File size: 2,947 Bytes
1c87faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import bitblas
import torch
import torch.nn as nn

from dataclasses import dataclass
from typing import Literal
from bitblas.cache import OperatorCache
from torch.nn import functional as F


def gelu_approx(x):
    return F.gelu(x, approximate="tanh")


@dataclass
class LinearWeights:
    weight: torch.Tensor
    bias: torch.Tensor


class Linear(nn.Module):
    """
    Linear layer with support for bitblas quantization.
    If dtype is torch.int8, it uses bitblas for quantization.
    Otherwise, it uses a standard nn.Linear layer.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        dtype: torch.dtype = None,
        group_size: int = 128,
    ):
        super().__init__()

        if dtype == torch.int8:
            self.linear = bitblas.Linear(
                in_features=in_features,
                out_features=out_features,
                bias=bias,
                with_zeros=True,
                zeros_mode="original",
                with_scaling=True,
                A_dtype="float16",
                W_dtype="uint4",
                accum_dtype="float16",
                out_dtype="float16",
                fast_decoding=True,
                enable_tuning=True,
                group_size=group_size,
            )
        else:
            self.linear = nn.Linear(
                in_features=in_features,
                out_features=out_features,
                bias=bias,
                dtype=torch.float16,
            )

    def forward(self, x):
        return self.linear(x)

    @property
    def weight(self) -> torch.Tensor:
        try:
            return self.linear.weight
        except AttributeError:
            return self.linear.qweight

    @property
    def bias(self) -> torch.Tensor:
        return self.linear.bias


def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
    return F.linear(x, w.weight, w.bias)


@dataclass
class LayerNormWeights:
    weight: torch.Tensor
    bias: torch.Tensor


def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
    return F.layer_norm(x, w.bias.shape, w.weight, w.bias)


@dataclass
class MLPWeights:
    fc1: LinearWeights
    fc2: LinearWeights
    act: Literal["gelu_approx"] = "gelu_approx"


def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:

    x = w.fc1(x)
    x = gelu_approx(x)
    x = w.fc2(x)
    return x


@dataclass
class AttentionWeights:
    qkv: LinearWeights
    proj: LinearWeights


def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
    bsz, q_len, d_model = x.shape
    head_dim = d_model // n_heads

    q, k, v = [
        t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
        for t in linear(x, w.qkv).chunk(3, dim=-1)
    ]
    out = F.scaled_dot_product_attention(q, k, v)
    out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
    out = linear(out, w.proj)
    return out