Spaces:
Running
Running
File size: 2,901 Bytes
e6ad96a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
#!/usr/bin/env python3
"""
Quantize and Push Script
Quantizes the uploaded model and pushes quantized versions to the same repository
"""
import os
import sys
import logging
from pathlib import Path
# Add src to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def main():
"""Quantize and push the model"""
# Configuration
MODEL_PATH = "/output-checkpoint"
REPO_NAME = "Tonic/smollm3-finetuned"
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
logger.error("β HF_TOKEN not set")
return 1
if not Path(MODEL_PATH).exists():
logger.error(f"β Model path not found: {MODEL_PATH}")
return 1
logger.info("β
Model files validated")
# Import and run quantization
try:
from scripts.model_tonic.quantize_model import ModelQuantizer
# Quantization types to process
quant_types = ["int8_weight_only", "int4_weight_only"]
success_count = 0
total_count = len(quant_types)
for quant_type in quant_types:
logger.info(f"π Processing quantization type: {quant_type}")
# Initialize quantizer
quantizer = ModelQuantizer(
model_path=MODEL_PATH,
repo_name=REPO_NAME,
token=HF_TOKEN,
private=False,
hf_token=HF_TOKEN
)
# Perform quantization and push
success = quantizer.quantize_and_push(
quant_type=quant_type,
device="auto",
group_size=128
)
if success:
logger.info(f"β
{quant_type} quantization and push completed")
success_count += 1
else:
logger.error(f"β {quant_type} quantization and push failed")
logger.info(f"π Quantization summary: {success_count}/{total_count} successful")
if success_count > 0:
logger.info("β
Quantization completed successfully!")
logger.info(f"π View your models at: https://huggingface.co/{REPO_NAME}")
logger.info("π Quantized models available at:")
logger.info(f" - https://huggingface.co/{REPO_NAME}/int8 (GPU optimized)")
logger.info(f" - https://huggingface.co/{REPO_NAME}/int4 (CPU optimized)")
return 0
else:
logger.error("β All quantization attempts failed!")
return 1
except Exception as e:
logger.error(f"β Error during quantization: {e}")
return 1
if __name__ == "__main__":
exit(main()) |