File size: 7,871 Bytes
c61ed6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
#!/usr/bin/env python3
"""
Test script to verify all training fixes work correctly
"""

import os
import sys
import subprocess
from pathlib import Path

def test_trainer_type_fix():
    """Test that trainer type conversion works correctly"""
    print("πŸ” Testing Trainer Type Fix")
    print("=" * 50)
    
    # Test cases
    test_cases = [
        ("SFT", "sft"),
        ("DPO", "dpo"),
        ("sft", "sft"),
        ("dpo", "dpo")
    ]
    
    all_passed = True
    for input_type, expected_output in test_cases:
        converted = input_type.lower()
        if converted == expected_output:
            print(f"βœ… '{input_type}' -> '{converted}' (expected: '{expected_output}')")
        else:
            print(f"❌ '{input_type}' -> '{converted}' (expected: '{expected_output}')")
            all_passed = False
    
    return all_passed

def test_trackio_conflict_fix():
    """Test that trackio package conflicts are handled"""
    print("\nπŸ” Testing Trackio Conflict Fix")
    print("=" * 50)
    
    try:
        # Test monitoring import
        sys.path.append(str(Path(__file__).parent.parent / "src"))
        from monitoring import SmolLM3Monitor
        
        # Test monitor creation
        monitor = SmolLM3Monitor("test-experiment")
        print("βœ… Monitor created successfully")
        print(f"   Dataset repo: {monitor.dataset_repo}")
        print(f"   Enable tracking: {monitor.enable_tracking}")
        
        # Check that dataset repo is not empty
        if monitor.dataset_repo and monitor.dataset_repo.strip() != '':
            print("βœ… Dataset repository is properly set")
        else:
            print("❌ Dataset repository is empty")
            return False
        
        return True
        
    except Exception as e:
        print(f"❌ Trackio conflict fix failed: {e}")
        return False

def test_dataset_repo_fix():
    """Test that dataset repository is properly set"""
    print("\nπŸ” Testing Dataset Repository Fix")
    print("=" * 50)
    
    # Test environment variable handling
    test_cases = [
        ("user/test-dataset", "user/test-dataset"),
        ("", "tonic/trackio-experiments"),  # Default fallback
        (None, "tonic/trackio-experiments"),  # Default fallback
    ]
    
    all_passed = True
    for input_repo, expected_repo in test_cases:
        # Simulate the monitoring logic
        if input_repo and input_repo.strip() != '':
            actual_repo = input_repo
        else:
            actual_repo = "tonic/trackio-experiments"
        
        if actual_repo == expected_repo:
            print(f"βœ… '{input_repo}' -> '{actual_repo}' (expected: '{expected_repo}')")
        else:
            print(f"❌ '{input_repo}' -> '{actual_repo}' (expected: '{expected_repo}')")
            all_passed = False
    
    return all_passed

def test_launch_script_fixes():
    """Test that launch script fixes are in place"""
    print("\nπŸ” Testing Launch Script Fixes")
    print("=" * 50)
    
    # Check if launch.sh exists
    launch_script = Path("launch.sh")
    if not launch_script.exists():
        print("❌ launch.sh not found")
        return False
    
    # Read launch script and check for fixes
    script_content = launch_script.read_text(encoding='utf-8')
    
    # Check for trainer type conversion
    if 'TRAINER_TYPE_LOWER=$(echo "$TRAINER_TYPE" | tr \'[:upper:]\' \'[:lower:]\')' in script_content:
        print("βœ… Trainer type conversion found")
    else:
        print("❌ Trainer type conversion missing")
        return False
    
    # Check for trainer type usage
    if '--trainer-type "$TRAINER_TYPE_LOWER"' in script_content:
        print("βœ… Trainer type usage updated")
    else:
        print("❌ Trainer type usage not updated")
        return False
    
    # Check for dataset repository default
    if 'TRACKIO_DATASET_REPO="$HF_USERNAME/trackio-experiments"' in script_content:
        print("βœ… Dataset repository default found")
    else:
        print("❌ Dataset repository default missing")
        return False
    
    # Check for dataset repository validation
    if 'if [ -z "$TRACKIO_DATASET_REPO" ]' in script_content:
        print("βœ… Dataset repository validation found")
    else:
        print("❌ Dataset repository validation missing")
        return False
    
    return True

def test_monitoring_fixes():
    """Test that monitoring fixes are in place"""
    print("\nπŸ” Testing Monitoring Fixes")
    print("=" * 50)
    
    # Check if monitoring.py exists
    monitoring_file = Path("src/monitoring.py")
    if not monitoring_file.exists():
        print("❌ monitoring.py not found")
        return False
    
    # Read monitoring file and check for fixes
    script_content = monitoring_file.read_text(encoding='utf-8')
    
    # Check for trackio conflict handling
    if 'import trackio' in script_content:
        print("βœ… Trackio conflict handling found")
    else:
        print("❌ Trackio conflict handling missing")
        return False
    
    # Check for dataset repository validation
    if 'if not self.dataset_repo or self.dataset_repo.strip() == \'\'' in script_content:
        print("βœ… Dataset repository validation found")
    else:
        print("❌ Dataset repository validation missing")
        return False
    
    # Check for improved error handling
    if 'Trackio Space not accessible' in script_content:
        print("βœ… Improved Trackio error handling found")
    else:
        print("❌ Improved Trackio error handling missing")
        return False
    
    return True

def test_training_script_validation():
    """Test that training script accepts correct parameters"""
    print("\nπŸ” Testing Training Script Validation")
    print("=" * 50)
    
    # Check if training script exists
    training_script = Path("scripts/training/train.py")
    if not training_script.exists():
        print("❌ Training script not found")
        return False
    
    # Read training script and check for argument validation
    script_content = training_script.read_text(encoding='utf-8')
    
    # Check for trainer type argument
    if '--trainer-type' in script_content:
        print("βœ… Trainer type argument found")
    else:
        print("❌ Trainer type argument missing")
        return False
    
    # Check for valid choices
    if 'choices=[\'sft\', \'dpo\']' in script_content:
        print("βœ… Valid trainer type choices found")
    else:
        print("❌ Valid trainer type choices missing")
        return False
    
    return True

def main():
    """Run all training fix tests"""
    print("πŸš€ Training Fixes Verification")
    print("=" * 50)
    
    tests = [
        test_trainer_type_fix,
        test_trackio_conflict_fix,
        test_dataset_repo_fix,
        test_launch_script_fixes,
        test_monitoring_fixes,
        test_training_script_validation
    ]
    
    all_passed = True
    for test in tests:
        try:
            if not test():
                all_passed = False
        except Exception as e:
            print(f"❌ Test failed with error: {e}")
            all_passed = False
    
    print("\n" + "=" * 50)
    if all_passed:
        print("πŸŽ‰ ALL TRAINING FIXES PASSED!")
        print("βœ… Trainer type conversion: Working")
        print("βœ… Trackio conflict handling: Working")
        print("βœ… Dataset repository fixes: Working")
        print("βœ… Launch script fixes: Working")
        print("βœ… Monitoring fixes: Working")
        print("βœ… Training script validation: Working")
        print("\nAll training issues have been resolved!")
    else:
        print("❌ SOME TRAINING FIXES FAILED!")
        print("Please check the failed tests above.")
    
    return all_passed

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)