import bitblas import torch import torch.nn as nn from dataclasses import dataclass from typing import Literal from bitblas.cache import OperatorCache from torch.nn import functional as F def gelu_approx(x): return F.gelu(x, approximate="tanh") @dataclass class LinearWeights: weight: torch.Tensor bias: torch.Tensor class Linear(nn.Module): """ Linear layer with support for bitblas quantization. If dtype is torch.int8, it uses bitblas for quantization. Otherwise, it uses a standard nn.Linear layer. """ def __init__( self, in_features: int, out_features: int, bias: bool = True, dtype: torch.dtype = None, group_size: int = 128, ): super().__init__() if dtype == torch.int8: self.linear = bitblas.Linear( in_features=in_features, out_features=out_features, bias=bias, with_zeros=True, zeros_mode="original", with_scaling=True, A_dtype="float16", W_dtype="uint4", accum_dtype="float16", out_dtype="float16", fast_decoding=True, enable_tuning=True, group_size=group_size, ) else: self.linear = nn.Linear( in_features=in_features, out_features=out_features, bias=bias, dtype=torch.float16, ) def forward(self, x): return self.linear(x) @property def weight(self) -> torch.Tensor: try: return self.linear.weight except AttributeError: return self.linear.qweight @property def bias(self) -> torch.Tensor: return self.linear.bias def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: return F.linear(x, w.weight, w.bias) @dataclass class LayerNormWeights: weight: torch.Tensor bias: torch.Tensor def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor: return F.layer_norm(x, w.bias.shape, w.weight, w.bias) @dataclass class MLPWeights: fc1: LinearWeights fc2: LinearWeights act: Literal["gelu_approx"] = "gelu_approx" def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: x = w.fc1(x) x = gelu_approx(x) x = w.fc2(x) return x @dataclass class AttentionWeights: qkv: LinearWeights proj: LinearWeights def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor: bsz, q_len, d_model = x.shape head_dim = d_model // n_heads q, k, v = [ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) for t in linear(x, w.qkv).chunk(3, dim=-1) ] out = F.scaled_dot_product_attention(q, k, v) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out = linear(out, w.proj) return out