Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Test script to verify that training arguments are properly created | |
""" | |
import sys | |
import os | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
from config.train_smollm3_openhermes_fr_a100_balanced import SmolLM3ConfigOpenHermesFRBalanced | |
from model import SmolLM3Model | |
from trainer import SmolLM3Trainer | |
from data import SmolLM3Dataset | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
def test_training_arguments(): | |
"""Test that training arguments are properly created""" | |
print("Testing training arguments creation...") | |
# Create config | |
config = SmolLM3ConfigOpenHermesFRBalanced() | |
print(f"Config created: {type(config)}") | |
# Create model (without actually loading the model) | |
try: | |
model = SmolLM3Model( | |
model_name=config.model_name, | |
max_seq_length=config.max_seq_length, | |
config=config | |
) | |
print("Model created successfully") | |
# Test training arguments creation | |
training_args = model.get_training_arguments("/tmp/test_output") | |
print(f"Training arguments created: {type(training_args)}") | |
print(f"Training arguments keys: {list(training_args.__dict__.keys())}") | |
# Test specific parameters that might cause issues | |
print(f"report_to: {training_args.report_to}") | |
print(f"dataloader_pin_memory: {training_args.dataloader_pin_memory}") | |
print(f"group_by_length: {training_args.group_by_length}") | |
print(f"prediction_loss_only: {training_args.prediction_loss_only}") | |
print(f"ignore_data_skip: {training_args.ignore_data_skip}") | |
print(f"remove_unused_columns: {training_args.remove_unused_columns}") | |
print(f"fp16: {training_args.fp16}") | |
print(f"bf16: {training_args.bf16}") | |
print(f"load_best_model_at_end: {training_args.load_best_model_at_end}") | |
print(f"greater_is_better: {training_args.greater_is_better}") | |
print("β Training arguments test passed!") | |
return True | |
except Exception as e: | |
print(f"β Training arguments test failed: {e}") | |
import traceback | |
traceback.print_exc() | |
return False | |
def test_callback_creation(): | |
"""Test that callbacks are properly created""" | |
print("\nTesting callback creation...") | |
try: | |
from monitoring import create_monitor_from_config | |
from config.train_smollm3_openhermes_fr_a100_balanced import SmolLM3ConfigOpenHermesFRBalanced | |
config = SmolLM3ConfigOpenHermesFRBalanced() | |
monitor = create_monitor_from_config(config) | |
# Test callback creation | |
callback = monitor.create_monitoring_callback() | |
if callback: | |
print(f"β Callback created successfully: {type(callback)}") | |
return True | |
else: | |
print("β Callback creation failed") | |
return False | |
except Exception as e: | |
print(f"β Callback creation test failed: {e}") | |
import traceback | |
traceback.print_exc() | |
return False | |
if __name__ == "__main__": | |
print("Running training fixes tests...") | |
test1_passed = test_training_arguments() | |
test2_passed = test_callback_creation() | |
if test1_passed and test2_passed: | |
print("\nβ All tests passed! The fixes should work.") | |
else: | |
print("\nβ Some tests failed. Please check the errors above.") |