firstAI / test_deployment_fallbacks.py
ndc8
upd
cb5d5f8
raw
history blame
5.12 kB
#!/usr/bin/env python3
"""
Test script to verify deployment fallback mechanisms work correctly.
"""
import sys
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_quantization_detection():
"""Test quantization detection logic without actual model loading."""
# Import the function we need
from backend_service import get_quantization_config
test_cases = [
# Standard models - should return None
("microsoft/DialoGPT-medium", None, "Standard model, no quantization"),
("deepseek-ai/DeepSeek-R1-0528-Qwen3-8B", None, "Standard model, no quantization"),
# Quantized models - should return quantization config
("unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit", "quantized", "4-bit quantized model"),
("unsloth/DeepSeek-R1-0528-Qwen3-8B-GGUF", "quantized", "GGUF quantized model"),
("something-4bit-test", "quantized", "Generic 4-bit model"),
("test-bnb-model", "quantized", "BitsAndBytes model"),
]
results = []
logger.info("πŸ§ͺ Testing quantization detection logic...")
logger.info("="*60)
for model_name, expected_type, description in test_cases:
logger.info(f"\nπŸ“ Testing: {model_name}")
logger.info(f" Expected: {description}")
try:
quant_config = get_quantization_config(model_name)
if expected_type is None:
# Should be None for standard models
if quant_config is None:
logger.info(f"βœ… PASS: No quantization detected (as expected)")
results.append((model_name, "PASS", "Correctly detected standard model"))
else:
logger.error(f"❌ FAIL: Unexpected quantization config: {quant_config}")
results.append((model_name, "FAIL", f"Unexpected quantization: {quant_config}"))
else:
# Should have quantization config
if quant_config is not None:
logger.info(f"βœ… PASS: Quantization detected: {quant_config}")
results.append((model_name, "PASS", f"Correctly detected quantization: {quant_config}"))
else:
logger.error(f"❌ FAIL: Expected quantization but got None")
results.append((model_name, "FAIL", "Expected quantization but got None"))
except Exception as e:
logger.error(f"❌ ERROR: Exception during test: {e}")
results.append((model_name, "ERROR", str(e)))
# Print summary
logger.info("\n" + "="*60)
logger.info("πŸ“Š QUANTIZATION DETECTION TEST SUMMARY")
logger.info("="*60)
pass_count = 0
for model_name, status, details in results:
if status == "PASS":
status_emoji = "βœ…"
pass_count += 1
elif status == "FAIL":
status_emoji = "❌"
else:
status_emoji = "⚠️"
logger.info(f"{status_emoji} {model_name}: {status}")
if status != "PASS":
logger.info(f" Details: {details}")
total_count = len(results)
logger.info(f"\nπŸ“ˆ Results: {pass_count}/{total_count} tests passed")
if pass_count == total_count:
logger.info("πŸŽ‰ All quantization detection tests passed!")
return True
else:
logger.warning("⚠️ Some quantization detection tests failed")
return False
def test_imports():
"""Test that we can import required modules."""
logger.info("πŸ§ͺ Testing imports...")
try:
from backend_service import get_quantization_config
logger.info("βœ… Successfully imported get_quantization_config")
# Test that transformers is available
from transformers import AutoTokenizer, AutoModelForCausalLM
logger.info("βœ… Successfully imported transformers")
# Test bitsandbytes import handling
try:
from transformers import BitsAndBytesConfig
logger.info("βœ… BitsAndBytesConfig import successful")
except ImportError as e:
logger.info(f"πŸ“ BitsAndBytesConfig import failed (expected in some environments): {e}")
return True
except Exception as e:
logger.error(f"❌ Import test failed: {e}")
return False
if __name__ == "__main__":
logger.info("πŸš€ Starting deployment fallback mechanism tests...")
# Test imports first
import_success = test_imports()
if not import_success:
logger.error("❌ Import tests failed, cannot continue")
sys.exit(1)
# Test quantization detection
quant_success = test_quantization_detection()
if quant_success:
logger.info("\nπŸŽ‰ All deployment fallback tests passed!")
logger.info("πŸ’‘ Your deployment should handle quantized models gracefully")
sys.exit(0)
else:
logger.error("\n❌ Some tests failed")
sys.exit(1)