Model Card for Model ID

Comand

python ao/prep_model.py --quant_type fp8 --push_to_hub True

Script to generate

#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0

"""
Script for quantizing LLM models with TorchAO.
Supports various quantization configurations and model types.
"""

import random
import numpy as np
import torch
import time
from pathlib import Path
from typing import Optional, Literal

from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from transformer_nuggets.utils.benchmark import benchmark_cuda_function_in_microseconds
from torchao.quantization.quant_api import (
    Float8DynamicActivationFloat8WeightConfig,
    Int4WeightOnlyConfig,
    Int8WeightOnlyConfig,
    Int8DynamicActivationInt8WeightConfig,
    PerRow,
    PerTensor,
    GemliteUIntXWeightOnlyConfig,
    Int4DynamicActivationInt4WeightConfig,
    Int8DynamicActivationInt4WeightConfig,
    CutlassInt4PackedLayout,
)
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
from torchao.prototype.mx_formats import MXGemmKernelChoice
from jsonargparse import CLI, Namespace
from rich import print


# Set seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_quantization_config(args):
    """Create TorchAo quantization config based on provided args."""
    granularity_mapping = {
        "per_row": PerRow(),
        "per_tensor": PerTensor(),
    }

    gran = granularity_mapping[args.granularity]

    match args.quant_type:
        case "autoquant":
            return TorchAoConfig("autoquant", min_sqnr=args.min_sqnr)
        case "fp8":
            return TorchAoConfig(
                Float8DynamicActivationFloat8WeightConfig(granularity=gran)
            )
        case "int4_weight_only":
            return TorchAoConfig(Int4WeightOnlyConfig(group_size=128))
        case "int8_weight_only":
            return TorchAoConfig(Int8WeightOnlyConfig())
        case "int8_dynamic_act_int8_weight":
            return TorchAoConfig(Int8DynamicActivationInt8WeightConfig())
        case "gemlite":
            return TorchAoConfig(GemliteUIntXWeightOnlyConfig())
        case "A4W4":
            return TorchAoConfig(Int4DynamicActivationInt4WeightConfig())
        case "A8W4":
            return TorchAoConfig(
                Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
            )
        case "mxfp8":
            return TorchAoConfig(MXFPInferenceConfig())
        case "mxfp4":
            return TorchAoConfig(
                MXFPInferenceConfig(
                    activation_dtype=torch.float4_e2m1fn_x2,
                    weight_dtype=torch.float4_e2m1fn_x2,
                    block_size=32,
                    gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
                )
            )
        case _:
            raise ValueError(f"Unsupported quantization type: {args.quant_type}")


def benchmark_model(model, input_ids, max_new_tokens, name=""):
    """Benchmark model generation speed."""
    try:
        time_ms = benchmark_cuda_function_in_microseconds(
            model.generate,
            **input_ids,
            max_new_tokens=max_new_tokens,
            cache_implementation="static",
        )
        tokens_per_second = max_new_tokens / (time_ms / 1000)
        print(
            f"{name} model: {time_ms:.2f}ms for {max_new_tokens} tokens ({tokens_per_second:.2f} tokens/sec)"
        )
        return time_ms
    except ImportError:
        # Fallback to simple timing if inductor utils not available
        print("torch._inductor.utils not available, using simple timing")
        start = time.time()
        model.generate(
            **input_ids, max_new_tokens=max_new_tokens, cache_implementation="static"
        )
        elapsed = (time.time() - start) * 1000  # ms
        tokens_per_second = max_new_tokens / (elapsed / 1000)
        print(
            f"{name} model: {elapsed:.2f}ms for {max_new_tokens} tokens ({tokens_per_second:.2f} tokens/sec)"
        )
        return elapsed


def main(
    model_name: str = "facebook/opt-125m",
    output_dir: Optional[str] = None,
    push_to_hub: bool = False,
    quant_type: Literal[
        "fp8",
        "int4_weight_only",
        "int8_weight_only",
        "int8_dynamic_act_int8_weight",
        "autoquant",
        "gemlite",
        "A4W4",
        "A8W4",
        "fp8",
        "mxfp4",
    ] = "fp8",
    granularity: Literal["per_row", "per_tensor"] = "per_row",
    min_sqnr: Optional[float] = None,
    max_new_tokens: int = 64,
    benchmark: bool = False,
    bench_tokens: int = 100,
    device_map: str = "cuda",
):
    """
    Quantize a model with TorchAO and test its performance.

    Args:
        model_name: Model to quantize (e.g., meta-llama/Meta-Llama-3-8B, facebook/opt-125m)
        output_dir: Directory to save the quantized model
        push_to_hub: HF Hub repo name to push the model (e.g., 'your-username/model-name')
        quant_type: Quantization type to use
        granularity: Quantization granularity
        min_sqnr: Minimum SQNR for autoquant
        max_new_tokens: Max tokens to generate for testing
        benchmark: Run benchmarking comparison
        bench_tokens: Number of tokens to generate for benchmarking
        device_map: Device mapping strategy
    """
    # Set seed before creating the model
    set_seed(42)

    # Set default output directory based on model base name if not provided
    if output_dir is None:
        model_base_name = model_name.split("/")[-1]
        output_dir = f"data/{quant_type}-{model_base_name}"

    # Convert to args-like object for compatibility with the rest of the code
    args = Namespace(
        model_name=model_name,
        output_dir=output_dir,
        push_to_hub=push_to_hub,
        quant_type=quant_type,
        granularity=granularity,
        min_sqnr=min_sqnr,
        max_new_tokens=max_new_tokens,
        benchmark=benchmark,
        bench_tokens=bench_tokens,
        device_map=device_map,
    )
    print(f"Using Model name: {args.model_name}")
    print(f"Quantization type: {args.quant_type}")

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Get quantization config
    quantization_config = get_quantization_config(args)

    # Load and quantize model
    print("Loading and quantizing model...")
    quantized_model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        torch_dtype="bfloat16",
        device_map=args.device_map,
        quantization_config=quantization_config,
    )

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    # Test prompts
    prompts = [
        "Why is Pytorch 2.0 the best machine learning compiler?",
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    # Test generation
    print("\nTesting quantized model generation...")
    input_ids = tokenizer(prompts, return_tensors="pt", padding=True).to(
        quantized_model.device
    )
    outputs = quantized_model.generate(**input_ids, max_new_tokens=args.max_new_tokens)

    for i, (prompt, output) in enumerate(zip(prompts, outputs)):
        generated_text = tokenizer.decode(output, skip_special_tokens=True)
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

    # Save quantized model
    print(f"\n📁Saving quantized model to: {output_dir}")
    quantized_model.save_pretrained(output_dir, safe_serialization=False)
    tokenizer.save_pretrained(output_dir)

    # Push to HuggingFace hub if requested
    if args.push_to_hub:
        # Get model name from output_dir
        model_name = output_dir.name
        hub_path = f"drisspg/ao_models/{model_name}"
        print(f"Pushing model to HuggingFace Hub: {hub_path}")
        quantized_model.push_to_hub(model_name, safe_serialization=False)
        tokenizer.push_to_hub(model_name)

    # Load saved model to verify
    print("\nLoading saved quantized model to verify...")
    loaded_model = AutoModelForCausalLM.from_pretrained(
        output_dir, device_map=args.device_map, torch_dtype="auto"
    )

    # Test loaded model with first prompt
    test_prompt = prompts[0]
    input_ids = tokenizer(test_prompt, return_tensors="pt").to(loaded_model.device)
    output = loaded_model.generate(**input_ids, max_new_tokens=args.max_new_tokens)
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    print(f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}")

    # Benchmark if requested
    if args.benchmark:
        print("\nBenchmarking models...")
        # Benchmark quantized model
        print("Benchmarking quantized model:")
        quant_time = benchmark_model(
            loaded_model, input_ids, args.bench_tokens, f"Quantized ({args.quant_type})"
        )

        # Load and benchmark original model in BF16
        print("\nLoading original model in BF16 for comparison...")
        bf16_model = AutoModelForCausalLM.from_pretrained(
            args.model_name, device_map=args.device_map, torch_dtype=torch.bfloat16
        )

        # Benchmark original model
        print("Benchmarking original BF16 model:")
        bf16_time = benchmark_model(bf16_model, input_ids, args.bench_tokens, "BF16")

        # Calculate speedup
        speedup = bf16_time / quant_time if quant_time > 0 else 0
        print(f"\nSpeedup: {speedup:.2f}x")

    print("\nQuantization process completed successfully.")


if __name__ == "__main__":
    CLI(main)
Downloads last month
466
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support