Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Standalone Model Quantization Script | |
Quick quantization of trained models using torchao | |
""" | |
import os | |
import sys | |
import argparse | |
import logging | |
from pathlib import Path | |
# Add the project root to the path | |
project_root = Path(__file__).parent.parent.parent | |
sys.path.append(str(project_root)) | |
from scripts.model_tonic.quantize_model import ModelQuantizer | |
def main(): | |
"""Standalone quantization script""" | |
parser = argparse.ArgumentParser(description="Quantize a trained model using torchao") | |
parser.add_argument("model_path", help="Path to the trained model") | |
parser.add_argument("repo_name", help="Hugging Face repository name for quantized model") | |
parser.add_argument("--quant-type", choices=["int8_weight_only", "int4_weight_only", "int8_dynamic"], | |
default="int8_weight_only", help="Quantization type") | |
parser.add_argument("--device", default="auto", help="Device for quantization (auto, cpu, cuda)") | |
parser.add_argument("--group-size", type=int, default=128, help="Group size for quantization") | |
parser.add_argument("--token", help="Hugging Face token") | |
parser.add_argument("--private", action="store_true", help="Create private repository") | |
parser.add_argument("--trackio-url", help="Trackio URL for monitoring") | |
parser.add_argument("--experiment-name", help="Experiment name for tracking") | |
parser.add_argument("--dataset-repo", help="HF Dataset repository") | |
parser.add_argument("--save-only", action="store_true", help="Save quantized model locally without pushing to HF") | |
args = parser.parse_args() | |
# Setup logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
print("π Starting Model Quantization") | |
print("=" * 40) | |
print(f"Model: {args.model_path}") | |
print(f"Quantization: {args.quant_type}") | |
print(f"Device: {args.device}") | |
print(f"Repository: {args.repo_name}") | |
print(f"Save only: {args.save_only}") | |
print("=" * 40) | |
# Initialize quantizer | |
quantizer = ModelQuantizer( | |
model_path=args.model_path, | |
repo_name=args.repo_name, | |
token=args.token, | |
private=args.private, | |
trackio_url=args.trackio_url, | |
experiment_name=args.experiment_name, | |
dataset_repo=args.dataset_repo | |
) | |
if args.save_only: | |
# Just quantize and save locally | |
print("πΎ Quantizing and saving locally...") | |
quantized_path = quantizer.quantize_model( | |
quant_type=args.quant_type, | |
device=args.device, | |
group_size=args.group_size | |
) | |
if quantized_path: | |
print(f"β Quantized model saved to: {quantized_path}") | |
print(f"π You can find the quantized model in: {quantized_path}") | |
else: | |
print("β Quantization failed") | |
return 1 | |
else: | |
# Full quantization and push workflow | |
success = quantizer.quantize_and_push( | |
quant_type=args.quant_type, | |
device=args.device, | |
group_size=args.group_size | |
) | |
if not success: | |
print("β Quantization and push failed") | |
return 1 | |
print("π Quantization completed successfully!") | |
return 0 | |
if __name__ == "__main__": | |
exit(main()) |