#!/usr/bin/env python3 """ Test script to verify the training pipeline fixes """ import os import sys import logging from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) def test_imports(): """Test that all imports work correctly""" print("šŸ” Testing imports...") try: from src.config import get_config print("āœ… config.py imported successfully") except Exception as e: print(f"āŒ config.py import failed: {e}") return False try: from src.model import SmolLM3Model print("āœ… model.py imported successfully") except Exception as e: print(f"āŒ model.py import failed: {e}") return False try: from src.data import SmolLM3Dataset print("āœ… data.py imported successfully") except Exception as e: print(f"āŒ data.py import failed: {e}") return False try: from src.trainer import SmolLM3Trainer print("āœ… trainer.py imported successfully") except Exception as e: print(f"āŒ trainer.py import failed: {e}") return False try: from src.monitoring import create_monitor_from_config print("āœ… monitoring.py imported successfully") except Exception as e: print(f"āŒ monitoring.py import failed: {e}") return False return True def test_config_loading(): """Test configuration loading""" print("\nšŸ” Testing configuration loading...") try: from src.config import get_config # Test loading the H100 lightweight config config = get_config("config/train_smollm3_h100_lightweight.py") print("āœ… Configuration loaded successfully") print(f" Model: {config.model_name}") print(f" Dataset: {config.dataset_name}") print(f" Batch size: {config.batch_size}") print(f" Learning rate: {config.learning_rate}") return True except Exception as e: print(f"āŒ Configuration loading failed: {e}") return False def test_monitoring_setup(): """Test monitoring setup without Trackio Space""" print("\nšŸ” Testing monitoring setup...") try: from src.monitoring import create_monitor_from_config from src.config import get_config # Load config config = get_config("config/train_smollm3_h100_lightweight.py") # Set Trackio URL to a non-existent one to test fallback config.trackio_url = "https://non-existent-space.hf.space" config.experiment_name = "test_experiment" # Create monitor monitor = create_monitor_from_config(config) print("āœ… Monitoring setup successful") print(f" Experiment: {monitor.experiment_name}") print(f" Tracking enabled: {monitor.enable_tracking}") print(f" HF Dataset: {monitor.dataset_repo}") return True except Exception as e: print(f"āŒ Monitoring setup failed: {e}") return False def test_trainer_creation(): """Test trainer creation""" print("\nšŸ” Testing trainer creation...") try: from src.config import get_config from src.model import SmolLM3Model from src.data import SmolLM3Dataset from src.trainer import SmolLM3Trainer # Load config config = get_config("config/train_smollm3_h100_lightweight.py") # Create model (without loading the actual model) model = SmolLM3Model( model_name=config.model_name, max_seq_length=config.max_seq_length, config=config ) print("āœ… Model created successfully") # Create dataset (without loading actual data) dataset = SmolLM3Dataset( data_path=config.dataset_name, tokenizer=model.tokenizer, max_seq_length=config.max_seq_length, config=config ) print("āœ… Dataset created successfully") # Create trainer trainer = SmolLM3Trainer( model=model, dataset=dataset, config=config, output_dir="/tmp/test_output", init_from="scratch" ) print("āœ… Trainer created successfully") return True except Exception as e: print(f"āŒ Trainer creation failed: {e}") return False def test_format_string_fix(): """Test that the format string fix works""" print("\nšŸ” Testing format string fix...") try: from src.trainer import SmolLM3Trainer # Test the SimpleConsoleCallback format string handling from transformers import TrainerCallback class TestCallback(TrainerCallback): def on_log(self, args, state, control, logs=None, **kwargs): if logs and isinstance(logs, dict): step = getattr(state, 'global_step', 'unknown') loss = logs.get('loss', 'N/A') lr = logs.get('learning_rate', 'N/A') # Test the fixed format string logic if isinstance(loss, (int, float)): loss_str = f"{loss:.4f}" else: loss_str = str(loss) if isinstance(lr, (int, float)): lr_str = f"{lr:.2e}" else: lr_str = str(lr) print(f"Step {step}: loss={loss_str}, lr={lr_str}") print("āœ… Format string fix works correctly") return True except Exception as e: print(f"āŒ Format string fix test failed: {e}") return False def main(): """Run all tests""" print("šŸš€ Testing SmolLM3 Training Pipeline Fixes") print("=" * 50) tests = [ test_imports, test_config_loading, test_monitoring_setup, test_trainer_creation, test_format_string_fix ] passed = 0 total = len(tests) for test in tests: try: if test(): passed += 1 except Exception as e: print(f"āŒ Test {test.__name__} crashed: {e}") print(f"\nšŸ“Š Test Results: {passed}/{total} tests passed") if passed == total: print("āœ… All tests passed! The training pipeline should work correctly.") return True else: print("āŒ Some tests failed. Please check the errors above.") return False if __name__ == "__main__": success = main() sys.exit(0 if success else 1)