Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Test script to verify all training fixes work correctly | |
""" | |
import os | |
import sys | |
import subprocess | |
from pathlib import Path | |
def test_trainer_type_fix(): | |
"""Test that trainer type conversion works correctly""" | |
print("π Testing Trainer Type Fix") | |
print("=" * 50) | |
# Test cases | |
test_cases = [ | |
("SFT", "sft"), | |
("DPO", "dpo"), | |
("sft", "sft"), | |
("dpo", "dpo") | |
] | |
all_passed = True | |
for input_type, expected_output in test_cases: | |
converted = input_type.lower() | |
if converted == expected_output: | |
print(f"β '{input_type}' -> '{converted}' (expected: '{expected_output}')") | |
else: | |
print(f"β '{input_type}' -> '{converted}' (expected: '{expected_output}')") | |
all_passed = False | |
return all_passed | |
def test_trackio_conflict_fix(): | |
"""Test that trackio package conflicts are handled""" | |
print("\nπ Testing Trackio Conflict Fix") | |
print("=" * 50) | |
try: | |
# Test monitoring import | |
sys.path.append(str(Path(__file__).parent.parent / "src")) | |
from monitoring import SmolLM3Monitor | |
# Test monitor creation | |
monitor = SmolLM3Monitor("test-experiment") | |
print("β Monitor created successfully") | |
print(f" Dataset repo: {monitor.dataset_repo}") | |
print(f" Enable tracking: {monitor.enable_tracking}") | |
# Check that dataset repo is not empty | |
if monitor.dataset_repo and monitor.dataset_repo.strip() != '': | |
print("β Dataset repository is properly set") | |
else: | |
print("β Dataset repository is empty") | |
return False | |
return True | |
except Exception as e: | |
print(f"β Trackio conflict fix failed: {e}") | |
return False | |
def test_dataset_repo_fix(): | |
"""Test that dataset repository is properly set""" | |
print("\nπ Testing Dataset Repository Fix") | |
print("=" * 50) | |
# Test environment variable handling | |
test_cases = [ | |
("user/test-dataset", "user/test-dataset"), | |
("", "tonic/trackio-experiments"), # Default fallback | |
(None, "tonic/trackio-experiments"), # Default fallback | |
] | |
all_passed = True | |
for input_repo, expected_repo in test_cases: | |
# Simulate the monitoring logic | |
if input_repo and input_repo.strip() != '': | |
actual_repo = input_repo | |
else: | |
actual_repo = "tonic/trackio-experiments" | |
if actual_repo == expected_repo: | |
print(f"β '{input_repo}' -> '{actual_repo}' (expected: '{expected_repo}')") | |
else: | |
print(f"β '{input_repo}' -> '{actual_repo}' (expected: '{expected_repo}')") | |
all_passed = False | |
return all_passed | |
def test_launch_script_fixes(): | |
"""Test that launch script fixes are in place""" | |
print("\nπ Testing Launch Script Fixes") | |
print("=" * 50) | |
# Check if launch.sh exists | |
launch_script = Path("launch.sh") | |
if not launch_script.exists(): | |
print("β launch.sh not found") | |
return False | |
# Read launch script and check for fixes | |
script_content = launch_script.read_text(encoding='utf-8') | |
# Check for trainer type conversion | |
if 'TRAINER_TYPE_LOWER=$(echo "$TRAINER_TYPE" | tr \'[:upper:]\' \'[:lower:]\')' in script_content: | |
print("β Trainer type conversion found") | |
else: | |
print("β Trainer type conversion missing") | |
return False | |
# Check for trainer type usage | |
if '--trainer-type "$TRAINER_TYPE_LOWER"' in script_content: | |
print("β Trainer type usage updated") | |
else: | |
print("β Trainer type usage not updated") | |
return False | |
# Check for dataset repository default | |
if 'TRACKIO_DATASET_REPO="$HF_USERNAME/trackio-experiments"' in script_content: | |
print("β Dataset repository default found") | |
else: | |
print("β Dataset repository default missing") | |
return False | |
# Check for dataset repository validation | |
if 'if [ -z "$TRACKIO_DATASET_REPO" ]' in script_content: | |
print("β Dataset repository validation found") | |
else: | |
print("β Dataset repository validation missing") | |
return False | |
return True | |
def test_monitoring_fixes(): | |
"""Test that monitoring fixes are in place""" | |
print("\nπ Testing Monitoring Fixes") | |
print("=" * 50) | |
# Check if monitoring.py exists | |
monitoring_file = Path("src/monitoring.py") | |
if not monitoring_file.exists(): | |
print("β monitoring.py not found") | |
return False | |
# Read monitoring file and check for fixes | |
script_content = monitoring_file.read_text(encoding='utf-8') | |
# Check for trackio conflict handling | |
if 'import trackio' in script_content: | |
print("β Trackio conflict handling found") | |
else: | |
print("β Trackio conflict handling missing") | |
return False | |
# Check for dataset repository validation | |
if 'if not self.dataset_repo or self.dataset_repo.strip() == \'\'' in script_content: | |
print("β Dataset repository validation found") | |
else: | |
print("β Dataset repository validation missing") | |
return False | |
# Check for improved error handling | |
if 'Trackio Space not accessible' in script_content: | |
print("β Improved Trackio error handling found") | |
else: | |
print("β Improved Trackio error handling missing") | |
return False | |
return True | |
def test_training_script_validation(): | |
"""Test that training script accepts correct parameters""" | |
print("\nπ Testing Training Script Validation") | |
print("=" * 50) | |
# Check if training script exists | |
training_script = Path("scripts/training/train.py") | |
if not training_script.exists(): | |
print("β Training script not found") | |
return False | |
# Read training script and check for argument validation | |
script_content = training_script.read_text(encoding='utf-8') | |
# Check for trainer type argument | |
if '--trainer-type' in script_content: | |
print("β Trainer type argument found") | |
else: | |
print("β Trainer type argument missing") | |
return False | |
# Check for valid choices | |
if 'choices=[\'sft\', \'dpo\']' in script_content: | |
print("β Valid trainer type choices found") | |
else: | |
print("β Valid trainer type choices missing") | |
return False | |
return True | |
def main(): | |
"""Run all training fix tests""" | |
print("π Training Fixes Verification") | |
print("=" * 50) | |
tests = [ | |
test_trainer_type_fix, | |
test_trackio_conflict_fix, | |
test_dataset_repo_fix, | |
test_launch_script_fixes, | |
test_monitoring_fixes, | |
test_training_script_validation | |
] | |
all_passed = True | |
for test in tests: | |
try: | |
if not test(): | |
all_passed = False | |
except Exception as e: | |
print(f"β Test failed with error: {e}") | |
all_passed = False | |
print("\n" + "=" * 50) | |
if all_passed: | |
print("π ALL TRAINING FIXES PASSED!") | |
print("β Trainer type conversion: Working") | |
print("β Trackio conflict handling: Working") | |
print("β Dataset repository fixes: Working") | |
print("β Launch script fixes: Working") | |
print("β Monitoring fixes: Working") | |
print("β Training script validation: Working") | |
print("\nAll training issues have been resolved!") | |
else: | |
print("β SOME TRAINING FIXES FAILED!") | |
print("Please check the failed tests above.") | |
return all_passed | |
if __name__ == "__main__": | |
success = main() | |
sys.exit(0 if success else 1) |