SmolFactory / trainer.py
Tonic's picture
try to resolve the issue with sftt trainer or trackio
d4bee15 verified
raw
history blame
12.1 kB
"""
SmolLM3 Trainer
Handles the training loop and integrates with Hugging Face Trainer
"""
import os
import torch
import logging
from typing import Optional, Dict, Any
from transformers import Trainer, TrainingArguments
from trl import SFTTrainer
import json
# Import monitoring
from monitoring import create_monitor_from_config
logger = logging.getLogger(__name__)
class SmolLM3Trainer:
"""Trainer for SmolLM3 fine-tuning"""
def __init__(
self,
model,
dataset,
config,
output_dir: str,
init_from: str = "scratch",
use_sft_trainer: bool = True
):
self.model = model
self.dataset = dataset
self.config = config
self.output_dir = output_dir
self.init_from = init_from
self.use_sft_trainer = use_sft_trainer
# Initialize monitoring
self.monitor = create_monitor_from_config(config)
# Setup trainer
self.trainer = self._setup_trainer()
def _setup_trainer(self):
"""Setup the trainer"""
logger.info("Setting up trainer")
# Get training arguments
training_args = self.model.get_training_arguments(
output_dir=self.output_dir,
save_steps=self.config.save_steps,
eval_steps=self.config.eval_steps,
logging_steps=self.config.logging_steps,
max_steps=self.config.max_iters,
)
# Get datasets
train_dataset = self.dataset.get_train_dataset()
eval_dataset = self.dataset.get_eval_dataset()
# Get data collator
data_collator = self.dataset.get_data_collator()
# Add monitoring callback - temporarily disabled to debug
callbacks = []
# Simple console callback for basic monitoring
class SimpleConsoleCallback:
def on_init_end(self, args, state, control, **kwargs):
"""Called when training initialization is complete"""
print("🔧 Training initialization completed")
def on_log(self, args, state, control, logs=None, **kwargs):
"""Log metrics to console"""
if logs and isinstance(logs, dict):
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
loss = logs.get('loss', 'N/A')
lr = logs.get('learning_rate', 'N/A')
print(f"Step {step}: loss={loss:.4f}, lr={lr}")
def on_train_begin(self, args, state, control, **kwargs):
print("🚀 Training started!")
def on_train_end(self, args, state, control, **kwargs):
print("✅ Training completed!")
def on_save(self, args, state, control, **kwargs):
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
print(f"💾 Checkpoint saved at step {step}")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics and isinstance(metrics, dict):
step = state.global_step if hasattr(state, 'global_step') else 'unknown'
eval_loss = metrics.get('eval_loss', 'N/A')
print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}")
# Add simple console callback
callbacks.append(SimpleConsoleCallback())
logger.info("Added simple console monitoring callback")
# Try to add Trackio callback if available (temporarily disabled for debugging)
logger.info("Skipping Trackio callback to debug training issue")
# if self.monitor and self.monitor.enable_tracking:
# try:
# trackio_callback = self.monitor.create_monitoring_callback()
# if trackio_callback:
# callbacks.append(trackio_callback)
# logger.info("Added Trackio monitoring callback")
# else:
# logger.warning("Failed to create Trackio callback")
# except Exception as e:
# logger.error(f"Error creating Trackio callback: {e}")
# logger.info("Continuing with console monitoring only")
# Try standard Trainer first (more stable with callbacks)
try:
trainer = Trainer(
model=self.model.model,
tokenizer=self.model.tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
callbacks=callbacks,
)
logger.info("Using standard Hugging Face Trainer")
except Exception as e:
logger.warning(f"Standard Trainer failed: {e}")
# Fallback to SFTTrainer
trainer = SFTTrainer(
model=self.model.model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=data_collator,
callbacks=callbacks,
)
logger.info("Using SFTTrainer")
return trainer
def load_checkpoint(self, checkpoint_path: str):
"""Load checkpoint for resuming training"""
logger.info(f"Loading checkpoint from {checkpoint_path}")
if self.init_from == "resume":
# Load the model from checkpoint
self.model.load_checkpoint(checkpoint_path)
# Update trainer with loaded model
self.trainer.model = self.model.model
logger.info("Checkpoint loaded successfully")
elif self.init_from == "pretrained":
# Model is already loaded from pretrained
logger.info("Using pretrained model")
else:
logger.info("Starting from scratch")
def train(self):
"""Start training"""
logger.info("Starting training")
# Log configuration to Trackio
if self.monitor and self.monitor.enable_tracking:
config_dict = {k: v for k, v in self.config.__dict__.items()
if not k.startswith('_')}
self.monitor.log_config(config_dict)
# Log experiment URL
experiment_url = self.monitor.get_experiment_url()
if experiment_url:
logger.info(f"Trackio experiment URL: {experiment_url}")
# Load checkpoint if resuming
if self.init_from == "resume":
checkpoint_path = "/input-checkpoint"
if os.path.exists(checkpoint_path):
self.load_checkpoint(checkpoint_path)
else:
logger.warning(f"Checkpoint path {checkpoint_path} not found, starting from scratch")
# Start training
try:
train_result = self.trainer.train()
# Save the final model
self.trainer.save_model()
# Save training results
with open(os.path.join(self.output_dir, "train_results.json"), "w") as f:
json.dump(train_result.metrics, f, indent=2)
# Log training summary to Trackio
if self.monitor and self.monitor.enable_tracking:
summary = {
'final_loss': train_result.metrics.get('train_loss', 0),
'total_steps': train_result.metrics.get('train_runtime', 0),
'training_time': train_result.metrics.get('train_runtime', 0),
'output_dir': self.output_dir,
'model_name': getattr(self.config, 'model_name', 'unknown'),
}
self.monitor.log_training_summary(summary)
self.monitor.close()
logger.info("Training completed successfully!")
logger.info(f"Training metrics: {train_result.metrics}")
except Exception as e:
logger.error(f"Training failed: {e}")
# Close monitoring on error
if self.monitor and self.monitor.enable_tracking:
self.monitor.close()
raise
def evaluate(self):
"""Evaluate the model"""
logger.info("Starting evaluation")
try:
eval_results = self.trainer.evaluate()
# Save evaluation results
with open(os.path.join(self.output_dir, "eval_results.json"), "w") as f:
json.dump(eval_results, f, indent=2)
logger.info(f"Evaluation completed: {eval_results}")
return eval_results
except Exception as e:
logger.error(f"Evaluation failed: {e}")
raise
def save_model(self, path: Optional[str] = None):
"""Save the trained model"""
save_path = path or self.output_dir
logger.info(f"Saving model to {save_path}")
try:
self.trainer.save_model(save_path)
self.model.tokenizer.save_pretrained(save_path)
# Save training configuration
if self.config:
config_dict = {k: v for k, v in self.config.__dict__.items()
if not k.startswith('_')}
with open(os.path.join(save_path, 'training_config.json'), 'w') as f:
json.dump(config_dict, f, indent=2, default=str)
logger.info("Model saved successfully!")
except Exception as e:
logger.error(f"Failed to save model: {e}")
raise
class SmolLM3DPOTrainer:
"""DPO Trainer for SmolLM3 preference optimization"""
def __init__(
self,
model,
dataset,
config,
output_dir: str,
ref_model=None
):
self.model = model
self.dataset = dataset
self.config = config
self.output_dir = output_dir
self.ref_model = ref_model
# Setup DPO trainer
self.trainer = self._setup_dpo_trainer()
def _setup_dpo_trainer(self):
"""Setup DPO trainer"""
from trl import DPOTrainer
# Get training arguments
training_args = self.model.get_training_arguments(
output_dir=self.output_dir,
save_steps=self.config.save_steps,
eval_steps=self.config.eval_steps,
logging_steps=self.config.logging_steps,
max_steps=self.config.max_iters,
)
# Get preference dataset
train_dataset = self.dataset.get_train_dataset()
eval_dataset = self.dataset.get_eval_dataset()
# Setup DPO trainer
trainer = DPOTrainer(
model=self.model.model,
ref_model=self.ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=self.model.tokenizer,
max_prompt_length=self.config.max_seq_length // 2,
max_length=self.config.max_seq_length,
)
return trainer
def train(self):
"""Start DPO training"""
logger.info("Starting DPO training")
try:
train_result = self.trainer.train()
# Save the final model
self.trainer.save_model()
# Save training results
with open(os.path.join(self.output_dir, "dpo_train_results.json"), "w") as f:
json.dump(train_result.metrics, f, indent=2)
logger.info("DPO training completed successfully!")
logger.info(f"Training metrics: {train_result.metrics}")
except Exception as e:
logger.error(f"DPO training failed: {e}")
raise