This is an MXFP4 calibrated weight-only quantized Meta-Llama-3.1-8B-Instruct model, as presented in our blogpost.

Usage

Installation

pip install safetensors==0.6.0.dev0
import os, torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from os.path import join as pjoin
from safetensors import safe_open

@torch.compile(fullgraph=True)
def matmul_fp4(x, W_q, scales, group_size, fp4_values):
    def unpack_over_cols(W_q_packed, W_nbits, num_output_cols, dtype):
        n_rows, n_cols = W_q_packed.shape
        device         = W_q_packed.device
        shifts         = torch.arange(num_output_cols // n_cols, device=device, dtype=W_q_packed.dtype) * W_nbits 
        W_q_unpacked   = ((W_q_packed.unsqueeze(-1) >> shifts) & ((1 << W_nbits) - 1)).to(dtype)
        W_q_unpacked   = W_q_unpacked.view(n_rows, num_output_cols)
        return W_q_unpacked

    N, K = W_q.shape[0], W_q.shape[1] * 2
    W_q  = fp4_values[unpack_over_cols(W_q, W_nbits=4, num_output_cols=K, dtype=torch.int32)]
    W_r  = (W_q.float().view([-1, group_size]) * scales.float()).reshape([N, K]).to(x.dtype).T
    return torch.matmul(x, W_r)

class AutoModelForCausalLMFP4:

    @classmethod
    def from_pretrained(
        cls,
        save_dir_or_hub,
        torch_dtype=torch.bfloat16,
        cache_dir=None,
        device_map="cuda:0",
        *args,
        **kwargs
    ):

        #Download snapshot
        if os.path.exists(save_dir_or_hub):
            save_dir = save_dir_or_hub
        else:
            save_dir = snapshot_download(repo_id=save_dir_or_hub, cache_dir=cache_dir)

        #Create model from config
        config = AutoConfig.from_pretrained(pjoin(save_dir, "config.json"))
        config.torch_dtype = str(torch_dtype).split('.')[-1]
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config)

        #Load and patch
        state_dict = {}
        with safe_open(pjoin(save_dir, "model.safetensors"), framework="pt", device="cpu") as f:
           for key in f.keys():
                tensor = f.get_tensor(key)
                dtype = torch_dtype if tensor.is_floating_point() else tensor.dtype
                state_dict[key] = tensor.to(device=device_map, dtype=dtype, non_blocking=True)

        cls.patch_model_for_fp4_inference(model=model, torch_dtype=torch_dtype, device=device_map, state_dict=state_dict)

        return model

    @classmethod
    def patch_model_for_fp4_inference(cls, model, torch_dtype, device, state_dict):

        model.fp4_values = torch.tensor(
            [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6],
            dtype=torch_dtype,
            device=device,
        )

        def patch_linearlayers(model, fct):
            for name, layer in model.named_children():
                if isinstance(layer, torch.nn.Linear):
                    setattr(model, name, fct(layer, name))
                else:
                    patch_linearlayers(layer, fct)

        def patch_enable_fp4(layer, arg):
            #Load params
            if('lm_head' in layer.name): 
                return layer

            if(hasattr(layer, 'weight')):
                del layer.weight
            for key in ['W_q', 'scales', 'shift', 'post_scale', 'meta']:
                param_tag, param = layer.name + '.' + key, None
                if(param_tag in state_dict):
                    param = state_dict[param_tag].tolist() if key in ["meta"] else state_dict[param_tag]
                    setattr(layer, key, param)

            #Set forward pass
            def forward(self, x):
                if(hasattr(self, 'weight')):
                    out = torch.matmul(x, self.weight.data.T)
                else:
                    out = matmul_fp4(x, self.W_q, self.scales, self.meta[-1], model.fp4_values)
                if(self.post_scale is not None):
                    out *= self.post_scale
                if(self.shift is not None):
                    out += self.shift
                if(self.bias is not None):
                    out += self.bias
                return out

            layer.forward = lambda x: forward(layer, x)

            return layer
        
        try: #FP4 params will fail here
            model.load_state_dict(state_dict, assign=True)
        except:
            pass

        for name, module in model.named_modules():
            module.name = name
        patch_linearlayers(model, patch_enable_fp4)
        model = model.to(device)

Usage

model_id = "mobiuslabsgmbh/Llama-3.1-8B-Instruct_mxfp4_weights_calib_demo"
model = AutoModelForCausalLMFP4.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='cuda')
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Check the trained params
# print( model.model.layers[-1].self_attn.v_proj.shift)
# tensor([ 0.0034, -0.0036,  0.0054,  ...,  0.0036, -0.0076, -0.0068],
#       device='cuda:0', dtype=torch.bfloat16)

# print( model.model.layers[-1].self_attn.v_proj.post_scale)
# tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0', dtype=torch.bfloat16)

outputs = model.generate(
    tokenizer.apply_chat_template(
        [{"role": "user", "content": "Solve the following equation: x^2 + 1 = -1"}],
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(model.device),
    max_new_tokens=256,
)
print(tokenizer.decode(outputs[0]))
Downloads last month
15
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including mobiuslabsgmbh/Llama-3.1-8B-Instruct_mxfp4_weights_calib_demo