Spaces:
Running
Running
File size: 3,496 Bytes
d9f7e1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
#!/usr/bin/env python3
"""
Test script to verify that training arguments are properly created
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from config.train_smollm3_openhermes_fr_a100_balanced import SmolLM3ConfigOpenHermesFRBalanced
from model import SmolLM3Model
from trainer import SmolLM3Trainer
from data import SmolLM3Dataset
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
def test_training_arguments():
"""Test that training arguments are properly created"""
print("Testing training arguments creation...")
# Create config
config = SmolLM3ConfigOpenHermesFRBalanced()
print(f"Config created: {type(config)}")
# Create model (without actually loading the model)
try:
model = SmolLM3Model(
model_name=config.model_name,
max_seq_length=config.max_seq_length,
config=config
)
print("Model created successfully")
# Test training arguments creation
training_args = model.get_training_arguments("/tmp/test_output")
print(f"Training arguments created: {type(training_args)}")
print(f"Training arguments keys: {list(training_args.__dict__.keys())}")
# Test specific parameters that might cause issues
print(f"report_to: {training_args.report_to}")
print(f"dataloader_pin_memory: {training_args.dataloader_pin_memory}")
print(f"group_by_length: {training_args.group_by_length}")
print(f"prediction_loss_only: {training_args.prediction_loss_only}")
print(f"ignore_data_skip: {training_args.ignore_data_skip}")
print(f"remove_unused_columns: {training_args.remove_unused_columns}")
print(f"fp16: {training_args.fp16}")
print(f"bf16: {training_args.bf16}")
print(f"load_best_model_at_end: {training_args.load_best_model_at_end}")
print(f"greater_is_better: {training_args.greater_is_better}")
print("β
Training arguments test passed!")
return True
except Exception as e:
print(f"β Training arguments test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_callback_creation():
"""Test that callbacks are properly created"""
print("\nTesting callback creation...")
try:
from monitoring import create_monitor_from_config
from config.train_smollm3_openhermes_fr_a100_balanced import SmolLM3ConfigOpenHermesFRBalanced
config = SmolLM3ConfigOpenHermesFRBalanced()
monitor = create_monitor_from_config(config)
# Test callback creation
callback = monitor.create_monitoring_callback()
if callback:
print(f"β
Callback created successfully: {type(callback)}")
return True
else:
print("β Callback creation failed")
return False
except Exception as e:
print(f"β Callback creation test failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("Running training fixes tests...")
test1_passed = test_training_arguments()
test2_passed = test_callback_creation()
if test1_passed and test2_passed:
print("\nβ
All tests passed! The fixes should work.")
else:
print("\nβ Some tests failed. Please check the errors above.") |