Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Test script to verify quantization fixes | |
""" | |
import os | |
import sys | |
import logging | |
from pathlib import Path | |
# Setup logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
def test_quantization_imports(): | |
"""Test that all required imports work""" | |
try: | |
# Test torchao imports | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | |
from torchao.quantization import ( | |
Int8WeightOnlyConfig, | |
Int4WeightOnlyConfig, | |
Int8DynamicActivationInt8WeightConfig | |
) | |
from torchao.dtypes import Int4CPULayout | |
logger.info("β torchao imports successful") | |
# Test bitsandbytes imports | |
try: | |
import bitsandbytes as bnb | |
from transformers import BitsAndBytesConfig | |
logger.info("β bitsandbytes imports successful") | |
except ImportError: | |
logger.warning("β οΈ bitsandbytes not available - alternative quantization disabled") | |
# Test HF imports | |
from huggingface_hub import HfApi | |
logger.info("β huggingface_hub imports successful") | |
return True | |
except ImportError as e: | |
logger.error(f"β Import failed: {e}") | |
return False | |
def test_model_quantizer(): | |
"""Test ModelQuantizer initialization""" | |
try: | |
from scripts.model_tonic.quantize_model import ModelQuantizer | |
# Test with dummy values | |
quantizer = ModelQuantizer( | |
model_path="/output-checkpoint", | |
repo_name="test/test-repo", | |
token="dummy_token" | |
) | |
logger.info("β ModelQuantizer initialization successful") | |
return True | |
except Exception as e: | |
logger.error(f"β ModelQuantizer test failed: {e}") | |
return False | |
def test_quantization_configs(): | |
"""Test quantization config creation""" | |
try: | |
from scripts.model_tonic.quantize_model import ModelQuantizer | |
quantizer = ModelQuantizer( | |
model_path="/output-checkpoint", | |
repo_name="test/test-repo", | |
token="dummy_token" | |
) | |
# Test int8 config | |
config = quantizer.create_quantization_config("int8_weight_only", 128) | |
logger.info("β int8_weight_only config creation successful") | |
# Test int4 config | |
config = quantizer.create_quantization_config("int4_weight_only", 128) | |
logger.info("β int4_weight_only config creation successful") | |
return True | |
except Exception as e: | |
logger.error(f"β Quantization config test failed: {e}") | |
return False | |
def test_device_selection(): | |
"""Test optimal device selection""" | |
try: | |
from scripts.model_tonic.quantize_model import ModelQuantizer | |
quantizer = ModelQuantizer( | |
model_path="/output-checkpoint", | |
repo_name="test/test-repo", | |
token="dummy_token" | |
) | |
# Test device selection | |
device = quantizer.get_optimal_device("int8_weight_only") | |
logger.info(f"β int8 device selection: {device}") | |
device = quantizer.get_optimal_device("int4_weight_only") | |
logger.info(f"β int4 device selection: {device}") | |
return True | |
except Exception as e: | |
logger.error(f"β Device selection test failed: {e}") | |
return False | |
def main(): | |
"""Run all tests""" | |
logger.info("π§ͺ Testing quantization fixes...") | |
tests = [ | |
("Import Test", test_quantization_imports), | |
("ModelQuantizer Test", test_model_quantizer), | |
("Config Creation Test", test_quantization_configs), | |
("Device Selection Test", test_device_selection), | |
] | |
passed = 0 | |
total = len(tests) | |
for test_name, test_func in tests: | |
logger.info(f"\nπ Running {test_name}...") | |
if test_func(): | |
passed += 1 | |
logger.info(f"β {test_name} passed") | |
else: | |
logger.error(f"β {test_name} failed") | |
logger.info(f"\nπ Test Results: {passed}/{total} tests passed") | |
if passed == total: | |
logger.info("π All tests passed! Quantization fixes are working.") | |
return 0 | |
else: | |
logger.error("β Some tests failed. Please check the errors above.") | |
return 1 | |
if __name__ == "__main__": | |
exit(main()) |