Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Test script for quantization functionality | |
| """ | |
| import os | |
| import sys | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| import logging | |
| # Add the project root to the path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.append(str(project_root)) | |
| from scripts.model_tonic.quantize_model import ModelQuantizer | |
| def test_quantization_imports(): | |
| """Test that all required imports are available""" | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | |
| from torchao.quantization import ( | |
| Int8WeightOnlyConfig, | |
| Int4WeightOnlyConfig, | |
| Int8DynamicActivationInt8WeightConfig | |
| ) | |
| from torchao.dtypes import Int4CPULayout | |
| print("β All quantization imports successful") | |
| return True | |
| except ImportError as e: | |
| print(f"β Import error: {e}") | |
| return False | |
| def test_quantizer_initialization(): | |
| """Test quantizer initialization""" | |
| try: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Create a dummy model directory | |
| model_dir = Path(temp_dir) / "dummy_model" | |
| model_dir.mkdir() | |
| # Create minimal model files | |
| (model_dir / "config.json").write_text('{"model_type": "test"}') | |
| (model_dir / "pytorch_model.bin").write_text('dummy') | |
| quantizer = ModelQuantizer( | |
| model_path=str(model_dir), | |
| repo_name="test/test-quantized", | |
| token="dummy_token" | |
| ) | |
| print("β Quantizer initialization successful") | |
| return True | |
| except Exception as e: | |
| print(f"β Quantizer initialization failed: {e}") | |
| return False | |
| def test_quantization_config_creation(): | |
| """Test quantization configuration creation""" | |
| try: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| model_dir = Path(temp_dir) / "dummy_model" | |
| model_dir.mkdir() | |
| (model_dir / "config.json").write_text('{"model_type": "test"}') | |
| (model_dir / "pytorch_model.bin").write_text('dummy') | |
| quantizer = ModelQuantizer( | |
| model_path=str(model_dir), | |
| repo_name="test/test-quantized", | |
| token="dummy_token" | |
| ) | |
| # Test int8 config | |
| config_int8 = quantizer.create_quantization_config("int8_weight_only", 128) | |
| print("β int8 config creation successful") | |
| # Test int4 config | |
| config_int4 = quantizer.create_quantization_config("int4_weight_only", 128) | |
| print("β int4 config creation successful") | |
| return True | |
| except Exception as e: | |
| print(f"β Config creation failed: {e}") | |
| return False | |
| def test_model_validation(): | |
| """Test model path validation""" | |
| try: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Test with valid model | |
| model_dir = Path(temp_dir) / "valid_model" | |
| model_dir.mkdir() | |
| (model_dir / "config.json").write_text('{"model_type": "test"}') | |
| (model_dir / "pytorch_model.bin").write_text('dummy') | |
| quantizer = ModelQuantizer( | |
| model_path=str(model_dir), | |
| repo_name="test/test-quantized", | |
| token="dummy_token" | |
| ) | |
| if quantizer.validate_model_path(): | |
| print("β Valid model validation successful") | |
| else: | |
| print("β Valid model validation failed") | |
| return False | |
| # Test with invalid model | |
| invalid_dir = Path(temp_dir) / "invalid_model" | |
| invalid_dir.mkdir() | |
| # Missing required files | |
| quantizer_invalid = ModelQuantizer( | |
| model_path=str(invalid_dir), | |
| repo_name="test/test-quantized", | |
| token="dummy_token" | |
| ) | |
| if not quantizer_invalid.validate_model_path(): | |
| print("β Invalid model validation successful") | |
| else: | |
| print("β Invalid model validation failed") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β Model validation test failed: {e}") | |
| return False | |
| def test_quantized_model_card_creation(): | |
| """Test quantized model card creation""" | |
| try: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| model_dir = Path(temp_dir) / "dummy_model" | |
| model_dir.mkdir() | |
| (model_dir / "config.json").write_text('{"model_type": "test"}') | |
| (model_dir / "pytorch_model.bin").write_text('dummy') | |
| quantizer = ModelQuantizer( | |
| model_path=str(model_dir), | |
| repo_name="test/test-quantized", | |
| token="dummy_token" | |
| ) | |
| # Test int8 model card | |
| card_int8 = quantizer.create_quantized_model_card("int8_weight_only", "test/model") | |
| if "int8_weight_only" in card_int8 and "GPU" in card_int8: | |
| print("β int8 model card creation successful") | |
| else: | |
| print("β int8 model card creation failed") | |
| return False | |
| # Test int4 model card | |
| card_int4 = quantizer.create_quantized_model_card("int4_weight_only", "test/model") | |
| if "int4_weight_only" in card_int4 and "CPU" in card_int4: | |
| print("β int4 model card creation successful") | |
| else: | |
| print("β int4 model card creation failed") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β Model card creation test failed: {e}") | |
| return False | |
| def test_quantized_readme_creation(): | |
| """Test quantized README creation""" | |
| try: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| model_dir = Path(temp_dir) / "dummy_model" | |
| model_dir.mkdir() | |
| (model_dir / "config.json").write_text('{"model_type": "test"}') | |
| (model_dir / "pytorch_model.bin").write_text('dummy') | |
| quantizer = ModelQuantizer( | |
| model_path=str(model_dir), | |
| repo_name="test/test-quantized", | |
| token="dummy_token" | |
| ) | |
| # Test int8 README | |
| readme_int8 = quantizer.create_quantized_readme("int8_weight_only", "test/model") | |
| if "int8_weight_only" in readme_int8 and "GPU optimized" in readme_int8: | |
| print("β int8 README creation successful") | |
| else: | |
| print("β int8 README creation failed") | |
| return False | |
| # Test int4 README | |
| readme_int4 = quantizer.create_quantized_readme("int4_weight_only", "test/model") | |
| if "int4_weight_only" in readme_int4 and "CPU optimized" in readme_int4: | |
| print("β int4 README creation successful") | |
| else: | |
| print("β int4 README creation failed") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"β README creation test failed: {e}") | |
| return False | |
| def main(): | |
| """Run all quantization tests""" | |
| print("π§ͺ Running Quantization Tests") | |
| print("=" * 40) | |
| tests = [ | |
| ("Import Test", test_quantization_imports), | |
| ("Initialization Test", test_quantizer_initialization), | |
| ("Config Creation Test", test_quantization_config_creation), | |
| ("Model Validation Test", test_model_validation), | |
| ("Model Card Test", test_quantized_model_card_creation), | |
| ("README Test", test_quantized_readme_creation), | |
| ] | |
| passed = 0 | |
| total = len(tests) | |
| for test_name, test_func in tests: | |
| print(f"\nπ Running {test_name}...") | |
| try: | |
| if test_func(): | |
| passed += 1 | |
| print(f"β {test_name} passed") | |
| else: | |
| print(f"β {test_name} failed") | |
| except Exception as e: | |
| print(f"β {test_name} failed with exception: {e}") | |
| print("\n" + "=" * 40) | |
| print(f"π Test Results: {passed}/{total} tests passed") | |
| if passed == total: | |
| print("π All quantization tests passed!") | |
| return 0 | |
| else: | |
| print("β οΈ Some tests failed. Check the output above.") | |
| return 1 | |
| if __name__ == "__main__": | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| exit(main()) |