arijitbhowmick
Updated model bitnet-b1.58-2B-4T-bf16 with 1bitLLM-bitnet_b1_58-3B config files
4712f6c
import math | |
import torch | |
from torch import nn | |
def weight_quant(weight, num_bits=1): | |
dtype = weight.dtype | |
weight = weight.float() | |
s = 1 / weight.abs().mean().clamp(min=1e-5) | |
result = (weight * s).round().clamp(-1, 1) / s | |
return result.type(dtype) | |
def activation_quant(x, num_bits=8): | |
dtype = x.dtype | |
x = x.float() | |
Qn = -2 ** (num_bits - 1) | |
Qp = 2 ** (num_bits - 1) - 1 | |
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) | |
result = (x * s).round().clamp(Qn, Qp) / s | |
return result.type(dtype) | |
class BitLinear(nn.Linear): | |
def __init__(self, | |
*kargs, | |
weight_bits=1, | |
input_bits=8, | |
**kwargs | |
): | |
super(BitLinear, self).__init__(*kargs, **kwargs) | |
""" | |
RMSNorm is placed outside BitLinear | |
""" | |
self.weight_bits = weight_bits | |
self.input_bits = input_bits | |
def forward(self, input): | |
quant_input = input + (activation_quant(input, self.input_bits) - input).detach() | |
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach() | |
out = nn.functional.linear(quant_input, quant_weight) | |
if not self.bias is None: | |
out += self.bias.view(1, -1).expand_as(out) | |
return out |