SmolFactory / src /trainer.py
Tonic's picture
matches experiment id for all metrics
08ed534 verified
raw
history blame
15.9 kB
"""
SmolLM3 Trainer
Handles the training loop and integrates with Hugging Face Trainer
"""
import os
import torch
import logging
from typing import Optional, Dict, Any
from transformers import Trainer, TrainingArguments
from trl import SFTTrainer
import json
# Import monitoring
from monitoring import create_monitor_from_config
logger = logging.getLogger(__name__)
class SmolLM3Trainer:
"""Trainer for SmolLM3 fine-tuning"""
def __init__(
self,
model,
dataset,
config,
output_dir: str,
init_from: str = "scratch",
use_sft_trainer: bool = True
):
self.model = model
self.dataset = dataset
self.config = config
self.output_dir = output_dir
self.init_from = init_from
self.use_sft_trainer = use_sft_trainer
# Initialize monitoring
self.monitor = create_monitor_from_config(config)
# Setup trainer
self.trainer = self._setup_trainer()
def _setup_trainer(self):
"""Setup the trainer"""
logger.info("Setting up trainer")
# Get training arguments
training_args = self.model.get_training_arguments(
output_dir=self.output_dir,
save_steps=self.config.save_steps,
eval_steps=self.config.eval_steps,
logging_steps=self.config.logging_steps,
max_steps=self.config.max_iters,
)
# Debug: Print training arguments
logger.info("Training arguments keys: %s", list(training_args.__dict__.keys()))
logger.info("Training arguments type: %s", type(training_args))
# Get datasets
logger.info("Getting train dataset...")
train_dataset = self.dataset.get_train_dataset()
logger.info("Train dataset: %s with %d samples", type(train_dataset), len(train_dataset))
logger.info("Getting eval dataset...")
eval_dataset = self.dataset.get_eval_dataset()
logger.info("Eval dataset: %s with %d samples", type(eval_dataset), len(eval_dataset))
# Get data collator
logger.info("Getting data collator...")
data_collator = self.dataset.get_data_collator()
logger.info("Data collator: %s", type(data_collator))
# Add monitoring callbacks
callbacks = []
# Add simple console callback for basic monitoring
from transformers import TrainerCallback
class SimpleConsoleCallback(TrainerCallback):
def on_init_end(self, args, state, control, **kwargs):
"""Called when training initialization is complete"""
print("🔧 Training initialization completed")
def on_log(self, args, state, control, logs=None, **kwargs):
"""Log metrics to console"""
if logs and isinstance(logs, dict):
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
loss = logs.get('loss', 'N/A')
lr = logs.get('learning_rate', 'N/A')
# Fix format string error by ensuring proper type conversion
if isinstance(loss, (int, float)):
loss_str = f"{loss:.4f}"
else:
loss_str = str(loss)
if isinstance(lr, (int, float)):
lr_str = f"{lr:.2e}"
else:
lr_str = str(lr)
print(f"Step {step}: loss={loss_str}, lr={lr_str}")
def on_train_begin(self, args, state, control, **kwargs):
print("🚀 Training started!")
def on_train_end(self, args, state, control, **kwargs):
print("✅ Training completed!")
def on_save(self, args, state, control, **kwargs):
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
print(f"💾 Checkpoint saved at step {step}")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics and isinstance(metrics, dict):
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
eval_loss = metrics.get('eval_loss', 'N/A')
print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
# Add console callback
callbacks.append(SimpleConsoleCallback())
logger.info("Added simple console monitoring callback")
# Add Trackio callback if available
if self.monitor and self.monitor.enable_tracking:
try:
trackio_callback = self.monitor.create_monitoring_callback()
if trackio_callback:
callbacks.append(trackio_callback)
logger.info("Added Trackio monitoring callback")
else:
logger.warning("Failed to create Trackio callback")
except Exception as e:
logger.error("Error creating Trackio callback: %s", e)
logger.info("Continuing with console monitoring only")
logger.info("Total callbacks: %d", len(callbacks))
# Initialize trackio for TRL compatibility
try:
import trackio
# Initialize trackio with our configuration and use the same experiment ID
if self.monitor and self.monitor.experiment_id:
# Use the experiment ID from our monitor
experiment_id = self.monitor.experiment_id
logger.info(f"Using existing experiment ID: {experiment_id}")
else:
# Initialize trackio with our configuration
experiment_id = trackio.init(
project_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
experiment_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'),
trackio_url=getattr(self.config, 'trackio_url', None),
trackio_token=getattr(self.config, 'trackio_token', None),
hf_token=getattr(self.config, 'hf_token', None),
dataset_repo=getattr(self.config, 'dataset_repo', None)
)
logger.info(f"Trackio initialized with experiment ID: {experiment_id}")
# Update our monitor with the same experiment ID
if self.monitor:
self.monitor.experiment_id = experiment_id
logger.info(f"Updated monitor with experiment ID: {experiment_id}")
except Exception as e:
logger.warning(f"Failed to initialize trackio: {e}")
logger.info("Continuing without trackio integration")
# Try SFTTrainer first (better for instruction tuning)
logger.info("Creating SFTTrainer with training arguments...")
logger.info("Training args type: %s", type(training_args))
try:
trainer = SFTTrainer(
model=self.model.model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=data_collator,
callbacks=callbacks,
)
logger.info("Using SFTTrainer (optimized for instruction tuning)")
except Exception as e:
logger.warning("SFTTrainer failed: %s", e)
logger.error("SFTTrainer creation error details: %s: %s", type(e).__name__, str(e))
# Fallback to standard Trainer
try:
trainer = Trainer(
model=self.model.model,
tokenizer=self.model.tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
callbacks=callbacks,
)
logger.info("Using standard Hugging Face Trainer (fallback)")
except Exception as e2:
logger.error("Standard Trainer also failed: %s", e2)
raise e2
return trainer
def load_checkpoint(self, checkpoint_path: str):
"""Load checkpoint for resuming training"""
logger.info("Loading checkpoint from %s", checkpoint_path)
if self.init_from == "resume":
# Load the model from checkpoint
self.model.load_checkpoint(checkpoint_path)
# Update trainer with loaded model
self.trainer.model = self.model.model
logger.info("Checkpoint loaded successfully")
elif self.init_from == "pretrained":
# Model is already loaded from pretrained
logger.info("Using pretrained model")
else:
logger.info("Starting from scratch")
def train(self):
"""Start training"""
logger.info("Starting training")
# Log configuration to Trackio
if self.monitor and self.monitor.enable_tracking:
config_dict = {k: v for k, v in self.config.__dict__.items()
if not k.startswith('_')}
self.monitor.log_config(config_dict)
# Log experiment URL
experiment_url = self.monitor.get_experiment_url()
if experiment_url:
logger.info("Trackio experiment URL: %s", experiment_url)
# Load checkpoint if resuming
if self.init_from == "resume":
checkpoint_path = "/input-checkpoint"
if os.path.exists(checkpoint_path):
self.load_checkpoint(checkpoint_path)
else:
logger.warning("Checkpoint path %s not found, starting from scratch", checkpoint_path)
# Start training
try:
logger.info("About to start trainer.train()")
train_result = self.trainer.train()
# Save the final model
self.trainer.save_model()
# Save training results
with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
json.dump(train_result.metrics, f, indent=2)
# Log training summary to Trackio
if self.monitor and self.monitor.enable_tracking:
summary = {
'final_loss': train_result.metrics.get('train_loss', 0),
'total_steps': train_result.metrics.get('train_runtime', 0),
'training_time': train_result.metrics.get('train_runtime', 0),
'output_dir': self.output_dir,
'model_name': getattr(self.config, 'model_name', 'unknown'),
}
self.monitor.log_training_summary(summary)
self.monitor.close()
# Finish trackio experiment
try:
import trackio
trackio.finish()
logger.info("Trackio experiment finished")
except Exception as e:
logger.warning(f"Failed to finish trackio experiment: {e}")
logger.info("Training completed successfully!")
logger.info("Training metrics: %s", train_result.metrics)
except Exception as e:
logger.error("Training failed: %s", e)
# Close monitoring on error
if self.monitor and self.monitor.enable_tracking:
self.monitor.close()
# Finish trackio experiment on error
try:
import trackio
trackio.finish()
except Exception as finish_error:
logger.warning(f"Failed to finish trackio experiment on error: {finish_error}")
raise
def evaluate(self):
"""Evaluate the model"""
logger.info("Starting evaluation")
try:
eval_results = self.trainer.evaluate()
# Save evaluation results
with open(os.path.join(self.output_dir, "eval_results.json"), "w") as f:
json.dump(eval_results, f, indent=2)
logger.info("Evaluation completed: %s", eval_results)
return eval_results
except Exception as e:
logger.error("Evaluation failed: %s", e)
raise
def save_model(self, path: Optional[str] = None):
"""Save the trained model"""
save_path = path or self.output_dir
logger.info("Saving model to %s", save_path)
try:
self.trainer.save_model(save_path)
self.model.tokenizer.save_pretrained(save_path)
# Save training configuration
if self.config:
config_dict = {k: v for k, v in self.config.__dict__.items()
if not k.startswith('_')}
with open(os.path.join(save_path, 'training_config.json'), 'w') as f:
json.dump(config_dict, f, indent=2, default=str)
logger.info("Model saved successfully!")
except Exception as e:
logger.error("Failed to save model: %s", e)
raise
class SmolLM3DPOTrainer:
"""DPO Trainer for SmolLM3 preference optimization"""
def __init__(
self,
model,
dataset,
config,
output_dir: str,
ref_model=None
):
self.model = model
self.dataset = dataset
self.config = config
self.output_dir = output_dir
self.ref_model = ref_model
# Setup DPO trainer
self.trainer = self._setup_dpo_trainer()
def _setup_dpo_trainer(self):
"""Setup DPO trainer"""
from trl import DPOTrainer
# Get training arguments
training_args = self.model.get_training_arguments(
output_dir=self.output_dir,
save_steps=self.config.save_steps,
eval_steps=self.config.eval_steps,
logging_steps=self.config.logging_steps,
max_steps=self.config.max_iters,
)
# Get preference dataset
train_dataset = self.dataset.get_train_dataset()
eval_dataset = self.dataset.get_eval_dataset()
# Setup DPO trainer
trainer = DPOTrainer(
model=self.model.model,
ref_model=self.ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=self.model.tokenizer,
max_prompt_length=self.config.max_seq_length // 2,
max_length=self.config.max_seq_length,
)
return trainer
def train(self):
"""Start DPO training"""
logger.info("Starting DPO training")
try:
train_result = self.trainer.train()
# Save the final model
self.trainer.save_model()
# Save training results
with open(os.path.join(self.output_dir, "dpo_train_results.json"), "w") as f:
json.dump(train_result.metrics, f, indent=2)
logger.info("DPO training completed successfully!")
logger.info("Training metrics: %s", train_result.metrics)
except Exception as e:
logger.error("DPO training failed: %s", e)
raise