FP4
Collection
MXFP4/NVFP4 models
•
2 items
•
Updated
This is an MXFP4 calibrated weight-only quantized Meta-Llama-3.1-8B-Instruct model, as presented in our blogpost.
pip install safetensors==0.6.0.dev0
import os, torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from os.path import join as pjoin
from safetensors import safe_open
@torch.compile(fullgraph=True)
def matmul_fp4(x, W_q, scales, group_size, fp4_values):
def unpack_over_cols(W_q_packed, W_nbits, num_output_cols, dtype):
n_rows, n_cols = W_q_packed.shape
device = W_q_packed.device
shifts = torch.arange(num_output_cols // n_cols, device=device, dtype=W_q_packed.dtype) * W_nbits
W_q_unpacked = ((W_q_packed.unsqueeze(-1) >> shifts) & ((1 << W_nbits) - 1)).to(dtype)
W_q_unpacked = W_q_unpacked.view(n_rows, num_output_cols)
return W_q_unpacked
N, K = W_q.shape[0], W_q.shape[1] * 2
W_q = fp4_values[unpack_over_cols(W_q, W_nbits=4, num_output_cols=K, dtype=torch.int32)]
W_r = (W_q.float().view([-1, group_size]) * scales.float()).reshape([N, K]).to(x.dtype).T
return torch.matmul(x, W_r)
class AutoModelForCausalLMFP4:
@classmethod
def from_pretrained(
cls,
save_dir_or_hub,
torch_dtype=torch.bfloat16,
cache_dir=None,
device_map="cuda:0",
*args,
**kwargs
):
#Download snapshot
if os.path.exists(save_dir_or_hub):
save_dir = save_dir_or_hub
else:
save_dir = snapshot_download(repo_id=save_dir_or_hub, cache_dir=cache_dir)
#Create model from config
config = AutoConfig.from_pretrained(pjoin(save_dir, "config.json"))
config.torch_dtype = str(torch_dtype).split('.')[-1]
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
#Load and patch
state_dict = {}
with safe_open(pjoin(save_dir, "model.safetensors"), framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
dtype = torch_dtype if tensor.is_floating_point() else tensor.dtype
state_dict[key] = tensor.to(device=device_map, dtype=dtype, non_blocking=True)
cls.patch_model_for_fp4_inference(model=model, torch_dtype=torch_dtype, device=device_map, state_dict=state_dict)
return model
@classmethod
def patch_model_for_fp4_inference(cls, model, torch_dtype, device, state_dict):
model.fp4_values = torch.tensor(
[0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6],
dtype=torch_dtype,
device=device,
)
def patch_linearlayers(model, fct):
for name, layer in model.named_children():
if isinstance(layer, torch.nn.Linear):
setattr(model, name, fct(layer, name))
else:
patch_linearlayers(layer, fct)
def patch_enable_fp4(layer, arg):
#Load params
if('lm_head' in layer.name):
return layer
if(hasattr(layer, 'weight')):
del layer.weight
for key in ['W_q', 'scales', 'shift', 'post_scale', 'meta']:
param_tag, param = layer.name + '.' + key, None
if(param_tag in state_dict):
param = state_dict[param_tag].tolist() if key in ["meta"] else state_dict[param_tag]
setattr(layer, key, param)
#Set forward pass
def forward(self, x):
if(hasattr(self, 'weight')):
out = torch.matmul(x, self.weight.data.T)
else:
out = matmul_fp4(x, self.W_q, self.scales, self.meta[-1], model.fp4_values)
if(self.post_scale is not None):
out *= self.post_scale
if(self.shift is not None):
out += self.shift
if(self.bias is not None):
out += self.bias
return out
layer.forward = lambda x: forward(layer, x)
return layer
try: #FP4 params will fail here
model.load_state_dict(state_dict, assign=True)
except:
pass
for name, module in model.named_modules():
module.name = name
patch_linearlayers(model, patch_enable_fp4)
model = model.to(device)
model_id = "mobiuslabsgmbh/Llama-3.1-8B-Instruct_mxfp4_weights_calib_demo"
model = AutoModelForCausalLMFP4.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='cuda')
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Check the trained params
# print( model.model.layers[-1].self_attn.v_proj.shift)
# tensor([ 0.0034, -0.0036, 0.0054, ..., 0.0036, -0.0076, -0.0068],
# device='cuda:0', dtype=torch.bfloat16)
# print( model.model.layers[-1].self_attn.v_proj.post_scale)
# tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0', dtype=torch.bfloat16)
outputs = model.generate(
tokenizer.apply_chat_template(
[{"role": "user", "content": "Solve the following equation: x^2 + 1 = -1"}],
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device),
max_new_tokens=256,
)
print(tokenizer.decode(outputs[0]))