SmolFactory / tests /test_quantization_fix.py
testtest123's picture
cleanup a bit the files
ad3b15d unverified
raw
history blame
4.58 kB
#!/usr/bin/env python3
"""
Test script to verify quantization fixes
"""
import os
import sys
import logging
from pathlib import Path
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def test_quantization_imports():
"""Test that all required imports work"""
try:
# Test torchao imports
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from torchao.quantization import (
Int8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig
)
from torchao.dtypes import Int4CPULayout
logger.info("βœ… torchao imports successful")
# Test bitsandbytes imports
try:
import bitsandbytes as bnb
from transformers import BitsAndBytesConfig
logger.info("βœ… bitsandbytes imports successful")
except ImportError:
logger.warning("⚠️ bitsandbytes not available - alternative quantization disabled")
# Test HF imports
from huggingface_hub import HfApi
logger.info("βœ… huggingface_hub imports successful")
return True
except ImportError as e:
logger.error(f"❌ Import failed: {e}")
return False
def test_model_quantizer():
"""Test ModelQuantizer initialization"""
try:
from scripts.model_tonic.quantize_model import ModelQuantizer
# Test with dummy values
quantizer = ModelQuantizer(
model_path="/output-checkpoint",
repo_name="test/test-repo",
token="dummy_token"
)
logger.info("βœ… ModelQuantizer initialization successful")
return True
except Exception as e:
logger.error(f"❌ ModelQuantizer test failed: {e}")
return False
def test_quantization_configs():
"""Test quantization config creation"""
try:
from scripts.model_tonic.quantize_model import ModelQuantizer
quantizer = ModelQuantizer(
model_path="/output-checkpoint",
repo_name="test/test-repo",
token="dummy_token"
)
# Test int8 config
config = quantizer.create_quantization_config("int8_weight_only", 128)
logger.info("βœ… int8_weight_only config creation successful")
# Test int4 config
config = quantizer.create_quantization_config("int4_weight_only", 128)
logger.info("βœ… int4_weight_only config creation successful")
return True
except Exception as e:
logger.error(f"❌ Quantization config test failed: {e}")
return False
def test_device_selection():
"""Test optimal device selection"""
try:
from scripts.model_tonic.quantize_model import ModelQuantizer
quantizer = ModelQuantizer(
model_path="/output-checkpoint",
repo_name="test/test-repo",
token="dummy_token"
)
# Test device selection
device = quantizer.get_optimal_device("int8_weight_only")
logger.info(f"βœ… int8 device selection: {device}")
device = quantizer.get_optimal_device("int4_weight_only")
logger.info(f"βœ… int4 device selection: {device}")
return True
except Exception as e:
logger.error(f"❌ Device selection test failed: {e}")
return False
def main():
"""Run all tests"""
logger.info("πŸ§ͺ Testing quantization fixes...")
tests = [
("Import Test", test_quantization_imports),
("ModelQuantizer Test", test_model_quantizer),
("Config Creation Test", test_quantization_configs),
("Device Selection Test", test_device_selection),
]
passed = 0
total = len(tests)
for test_name, test_func in tests:
logger.info(f"\nπŸ” Running {test_name}...")
if test_func():
passed += 1
logger.info(f"βœ… {test_name} passed")
else:
logger.error(f"❌ {test_name} failed")
logger.info(f"\nπŸ“Š Test Results: {passed}/{total} tests passed")
if passed == total:
logger.info("πŸŽ‰ All tests passed! Quantization fixes are working.")
return 0
else:
logger.error("❌ Some tests failed. Please check the errors above.")
return 1
if __name__ == "__main__":
exit(main())