EthanReid
Initial import of code and 4-bit weights
1c87faa
raw
history blame
2.95 kB
import bitblas
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Literal
from bitblas.cache import OperatorCache
from torch.nn import functional as F
def gelu_approx(x):
return F.gelu(x, approximate="tanh")
@dataclass
class LinearWeights:
weight: torch.Tensor
bias: torch.Tensor
class Linear(nn.Module):
"""
Linear layer with support for bitblas quantization.
If dtype is torch.int8, it uses bitblas for quantization.
Otherwise, it uses a standard nn.Linear layer.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
group_size: int = 128,
):
super().__init__()
if dtype == torch.int8:
self.linear = bitblas.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
with_zeros=True,
zeros_mode="original",
with_scaling=True,
A_dtype="float16",
W_dtype="uint4",
accum_dtype="float16",
out_dtype="float16",
fast_decoding=True,
enable_tuning=True,
group_size=group_size,
)
else:
self.linear = nn.Linear(
in_features=in_features,
out_features=out_features,
bias=bias,
dtype=torch.float16,
)
def forward(self, x):
return self.linear(x)
@property
def weight(self) -> torch.Tensor:
try:
return self.linear.weight
except AttributeError:
return self.linear.qweight
@property
def bias(self) -> torch.Tensor:
return self.linear.bias
def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
return F.linear(x, w.weight, w.bias)
@dataclass
class LayerNormWeights:
weight: torch.Tensor
bias: torch.Tensor
def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
@dataclass
class MLPWeights:
fc1: LinearWeights
fc2: LinearWeights
act: Literal["gelu_approx"] = "gelu_approx"
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
x = w.fc1(x)
x = gelu_approx(x)
x = w.fc2(x)
return x
@dataclass
class AttentionWeights:
qkv: LinearWeights
proj: LinearWeights
def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
bsz, q_len, d_model = x.shape
head_dim = d_model // n_heads
q, k, v = [
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
for t in linear(x, w.qkv).chunk(3, dim=-1)
]
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
out = linear(out, w.proj)
return out