Spaces:
Running
Running
File size: 7,871 Bytes
c61ed6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
#!/usr/bin/env python3
"""
Test script to verify all training fixes work correctly
"""
import os
import sys
import subprocess
from pathlib import Path
def test_trainer_type_fix():
"""Test that trainer type conversion works correctly"""
print("π Testing Trainer Type Fix")
print("=" * 50)
# Test cases
test_cases = [
("SFT", "sft"),
("DPO", "dpo"),
("sft", "sft"),
("dpo", "dpo")
]
all_passed = True
for input_type, expected_output in test_cases:
converted = input_type.lower()
if converted == expected_output:
print(f"β
'{input_type}' -> '{converted}' (expected: '{expected_output}')")
else:
print(f"β '{input_type}' -> '{converted}' (expected: '{expected_output}')")
all_passed = False
return all_passed
def test_trackio_conflict_fix():
"""Test that trackio package conflicts are handled"""
print("\nπ Testing Trackio Conflict Fix")
print("=" * 50)
try:
# Test monitoring import
sys.path.append(str(Path(__file__).parent.parent / "src"))
from monitoring import SmolLM3Monitor
# Test monitor creation
monitor = SmolLM3Monitor("test-experiment")
print("β
Monitor created successfully")
print(f" Dataset repo: {monitor.dataset_repo}")
print(f" Enable tracking: {monitor.enable_tracking}")
# Check that dataset repo is not empty
if monitor.dataset_repo and monitor.dataset_repo.strip() != '':
print("β
Dataset repository is properly set")
else:
print("β Dataset repository is empty")
return False
return True
except Exception as e:
print(f"β Trackio conflict fix failed: {e}")
return False
def test_dataset_repo_fix():
"""Test that dataset repository is properly set"""
print("\nπ Testing Dataset Repository Fix")
print("=" * 50)
# Test environment variable handling
test_cases = [
("user/test-dataset", "user/test-dataset"),
("", "tonic/trackio-experiments"), # Default fallback
(None, "tonic/trackio-experiments"), # Default fallback
]
all_passed = True
for input_repo, expected_repo in test_cases:
# Simulate the monitoring logic
if input_repo and input_repo.strip() != '':
actual_repo = input_repo
else:
actual_repo = "tonic/trackio-experiments"
if actual_repo == expected_repo:
print(f"β
'{input_repo}' -> '{actual_repo}' (expected: '{expected_repo}')")
else:
print(f"β '{input_repo}' -> '{actual_repo}' (expected: '{expected_repo}')")
all_passed = False
return all_passed
def test_launch_script_fixes():
"""Test that launch script fixes are in place"""
print("\nπ Testing Launch Script Fixes")
print("=" * 50)
# Check if launch.sh exists
launch_script = Path("launch.sh")
if not launch_script.exists():
print("β launch.sh not found")
return False
# Read launch script and check for fixes
script_content = launch_script.read_text(encoding='utf-8')
# Check for trainer type conversion
if 'TRAINER_TYPE_LOWER=$(echo "$TRAINER_TYPE" | tr \'[:upper:]\' \'[:lower:]\')' in script_content:
print("β
Trainer type conversion found")
else:
print("β Trainer type conversion missing")
return False
# Check for trainer type usage
if '--trainer-type "$TRAINER_TYPE_LOWER"' in script_content:
print("β
Trainer type usage updated")
else:
print("β Trainer type usage not updated")
return False
# Check for dataset repository default
if 'TRACKIO_DATASET_REPO="$HF_USERNAME/trackio-experiments"' in script_content:
print("β
Dataset repository default found")
else:
print("β Dataset repository default missing")
return False
# Check for dataset repository validation
if 'if [ -z "$TRACKIO_DATASET_REPO" ]' in script_content:
print("β
Dataset repository validation found")
else:
print("β Dataset repository validation missing")
return False
return True
def test_monitoring_fixes():
"""Test that monitoring fixes are in place"""
print("\nπ Testing Monitoring Fixes")
print("=" * 50)
# Check if monitoring.py exists
monitoring_file = Path("src/monitoring.py")
if not monitoring_file.exists():
print("β monitoring.py not found")
return False
# Read monitoring file and check for fixes
script_content = monitoring_file.read_text(encoding='utf-8')
# Check for trackio conflict handling
if 'import trackio' in script_content:
print("β
Trackio conflict handling found")
else:
print("β Trackio conflict handling missing")
return False
# Check for dataset repository validation
if 'if not self.dataset_repo or self.dataset_repo.strip() == \'\'' in script_content:
print("β
Dataset repository validation found")
else:
print("β Dataset repository validation missing")
return False
# Check for improved error handling
if 'Trackio Space not accessible' in script_content:
print("β
Improved Trackio error handling found")
else:
print("β Improved Trackio error handling missing")
return False
return True
def test_training_script_validation():
"""Test that training script accepts correct parameters"""
print("\nπ Testing Training Script Validation")
print("=" * 50)
# Check if training script exists
training_script = Path("scripts/training/train.py")
if not training_script.exists():
print("β Training script not found")
return False
# Read training script and check for argument validation
script_content = training_script.read_text(encoding='utf-8')
# Check for trainer type argument
if '--trainer-type' in script_content:
print("β
Trainer type argument found")
else:
print("β Trainer type argument missing")
return False
# Check for valid choices
if 'choices=[\'sft\', \'dpo\']' in script_content:
print("β
Valid trainer type choices found")
else:
print("β Valid trainer type choices missing")
return False
return True
def main():
"""Run all training fix tests"""
print("π Training Fixes Verification")
print("=" * 50)
tests = [
test_trainer_type_fix,
test_trackio_conflict_fix,
test_dataset_repo_fix,
test_launch_script_fixes,
test_monitoring_fixes,
test_training_script_validation
]
all_passed = True
for test in tests:
try:
if not test():
all_passed = False
except Exception as e:
print(f"β Test failed with error: {e}")
all_passed = False
print("\n" + "=" * 50)
if all_passed:
print("π ALL TRAINING FIXES PASSED!")
print("β
Trainer type conversion: Working")
print("β
Trackio conflict handling: Working")
print("β
Dataset repository fixes: Working")
print("β
Launch script fixes: Working")
print("β
Monitoring fixes: Working")
print("β
Training script validation: Working")
print("\nAll training issues have been resolved!")
else:
print("β SOME TRAINING FIXES FAILED!")
print("Please check the failed tests above.")
return all_passed
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1) |