File size: 3,588 Bytes
1c87faa 9b4ed9c 1c87faa 9b4ed9c 1c87faa 9b4ed9c 1c87faa 9b4ed9c 1c87faa 9b4ed9c 1c87faa 9b4ed9c 1c87faa 9b4ed9c 80427a0 9b4ed9c f089fb8 80427a0 f089fb8 80427a0 9b4ed9c 30185df 80427a0 9b4ed9c f089fb8 80427a0 9b4ed9c 1c87faa 9b4ed9c 1c87faa 9b4ed9c 1c87faa 9b4ed9c 80427a0 9b4ed9c 80427a0 9b4ed9c 1c87faa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Literal
from torchao import quantize_
from torchao.quantization import int4_weight_only
from .packing import dequantize_tensor
def gelu_approx(x):
return F.gelu(x, approximate="tanh")
@dataclass
class LinearWeights:
weight: torch.Tensor
bias: torch.Tensor
def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
return F.linear(x, w.weight, w.bias)
class QuantizedLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
dtype: torch.dtype,
):
# TODO: Take group_size as an input instead of hardcoding it here.
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.ParameterDict(
{
"packed": nn.Parameter(
torch.empty(
out_features * in_features // (128 * 2), 128, dtype=torch.uint8
),
requires_grad=False,
),
"scale": nn.Parameter(
torch.empty(out_features * in_features // 128, 1),
requires_grad=False,
),
"zero_point": nn.Parameter(
torch.empty(out_features * in_features // 128, 1),
requires_grad=False,
),
}
)
self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
self.unpacked = False
def unpack(self):
if self.unpacked:
return
self.weight = nn.Parameter(
dequantize_tensor(
self.weight["packed"],
self.weight["scale"],
self.weight["zero_point"],
(self.out_features, self.in_features),
torch.bfloat16,
)
)
with torch.device("meta"):
self.linear = nn.Linear(
self.in_features, self.out_features, dtype=torch.bfloat16
)
self.linear.weight = self.weight
self.linear.bias = nn.Parameter(
self.bias.to(torch.bfloat16), requires_grad=False
)
del self.weight, self.bias
quantize_(self, int4_weight_only(group_size=128))
self.unpacked = True
torch.cuda.empty_cache()
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.unpacked:
self.unpack()
return self.linear(x)
@dataclass
class LayerNormWeights:
weight: torch.Tensor
bias: torch.Tensor
def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
@dataclass
class MLPWeights:
fc1: LinearWeights
fc2: LinearWeights
act: Literal["gelu_approx"] = "gelu_approx"
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
x = w.fc1(x)
x = gelu_approx(x)
x = w.fc2(x)
return x
@dataclass
class AttentionWeights:
qkv: LinearWeights
proj: LinearWeights
def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
bsz, q_len, d_model = x.shape
head_dim = d_model // n_heads
q, k, v = [
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
for t in linear(x, w.qkv).chunk(3, dim=-1)
]
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
out = linear(out, w.proj)
return out
|