""" SmolLM3 Trainer Handles the training loop and integrates with Hugging Face Trainer """ import os import torch import logging from typing import Optional, Dict, Any from transformers import Trainer, TrainingArguments from trl import SFTTrainer import json # Import monitoring from monitoring import create_monitor_from_config logger = logging.getLogger(__name__) class SmolLM3Trainer: """Trainer for SmolLM3 fine-tuning""" def __init__( self, model, dataset, config, output_dir: str, init_from: str = "scratch", use_sft_trainer: bool = True ): self.model = model self.dataset = dataset self.config = config self.output_dir = output_dir self.init_from = init_from self.use_sft_trainer = use_sft_trainer # Initialize monitoring self.monitor = create_monitor_from_config(config) # Setup trainer self.trainer = self._setup_trainer() def _setup_trainer(self): """Setup the trainer""" logger.info("Setting up trainer") # Get training arguments training_args = self.model.get_training_arguments( output_dir=self.output_dir, save_steps=self.config.save_steps, eval_steps=self.config.eval_steps, logging_steps=self.config.logging_steps, max_steps=self.config.max_iters, ) # Debug: Print training arguments logger.info("Training arguments keys: %s", list(training_args.__dict__.keys())) logger.info("Training arguments type: %s", type(training_args)) # Get datasets logger.info("Getting train dataset...") train_dataset = self.dataset.get_train_dataset() logger.info("Train dataset: %s with %d samples", type(train_dataset), len(train_dataset)) logger.info("Getting eval dataset...") eval_dataset = self.dataset.get_eval_dataset() logger.info("Eval dataset: %s with %d samples", type(eval_dataset), len(eval_dataset)) # Get data collator logger.info("Getting data collator...") data_collator = self.dataset.get_data_collator() logger.info("Data collator: %s", type(data_collator)) # Add monitoring callbacks callbacks = [] # Add simple console callback for basic monitoring from transformers import TrainerCallback class SimpleConsoleCallback(TrainerCallback): def on_init_end(self, args, state, control, **kwargs): """Called when training initialization is complete""" print("🔧 Training initialization completed") def on_log(self, args, state, control, logs=None, **kwargs): """Log metrics to console""" if logs and isinstance(logs, dict): step = state.global_step if hasattr(state, 'global_step') else 'unknown' loss = logs.get('loss', 'N/A') lr = logs.get('learning_rate', 'N/A') # Fix format string error by ensuring proper type conversion if isinstance(loss, (int, float)): loss_str = f"{loss:.4f}" else: loss_str = str(loss) if isinstance(lr, (int, float)): lr_str = f"{lr:.2e}" else: lr_str = str(lr) print(f"Step {step}: loss={loss_str}, lr={lr_str}") def on_train_begin(self, args, state, control, **kwargs): print("🚀 Training started!") def on_train_end(self, args, state, control, **kwargs): print("✅ Training completed!") def on_save(self, args, state, control, **kwargs): step = state.global_step if hasattr(state, 'global_step') else 'unknown' print(f"💾 Checkpoint saved at step {step}") def on_evaluate(self, args, state, control, metrics=None, **kwargs): if metrics and isinstance(metrics, dict): step = state.global_step if hasattr(state, 'global_step') else 'unknown' eval_loss = metrics.get('eval_loss', 'N/A') print(f"📊 Evaluation at step {step}: eval_loss={eval_loss}") # Add console callback callbacks.append(SimpleConsoleCallback()) logger.info("Added simple console monitoring callback") # Add Trackio callback if available if self.monitor and self.monitor.enable_tracking: try: trackio_callback = self.monitor.create_monitoring_callback() if trackio_callback: callbacks.append(trackio_callback) logger.info("Added Trackio monitoring callback") else: logger.warning("Failed to create Trackio callback") except Exception as e: logger.error("Error creating Trackio callback: %s", e) logger.info("Continuing with console monitoring only") logger.info("Total callbacks: %d", len(callbacks)) # Initialize trackio for TRL compatibility try: import trackio # Initialize trackio with our configuration and use the same experiment ID if self.monitor and self.monitor.experiment_id: # Use the experiment ID from our monitor experiment_id = self.monitor.experiment_id logger.info(f"Using existing experiment ID: {experiment_id}") else: # Initialize trackio with our configuration experiment_id = trackio.init( project_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'), experiment_name=getattr(self.config, 'experiment_name', 'smollm3_experiment'), trackio_url=getattr(self.config, 'trackio_url', None), trackio_token=getattr(self.config, 'trackio_token', None), hf_token=getattr(self.config, 'hf_token', None), dataset_repo=getattr(self.config, 'dataset_repo', None) ) logger.info(f"Trackio initialized with experiment ID: {experiment_id}") # Update our monitor with the same experiment ID if self.monitor: self.monitor.experiment_id = experiment_id logger.info(f"Updated monitor with experiment ID: {experiment_id}") except Exception as e: logger.warning(f"Failed to initialize trackio: {e}") logger.info("Continuing without trackio integration") # Try SFTTrainer first (better for instruction tuning) logger.info("Creating SFTTrainer with training arguments...") logger.info("Training args type: %s", type(training_args)) try: trainer = SFTTrainer( model=self.model.model, train_dataset=train_dataset, eval_dataset=eval_dataset, args=training_args, data_collator=data_collator, callbacks=callbacks, ) logger.info("Using SFTTrainer (optimized for instruction tuning)") except Exception as e: logger.warning("SFTTrainer failed: %s", e) logger.error("SFTTrainer creation error details: %s: %s", type(e).__name__, str(e)) # Fallback to standard Trainer try: trainer = Trainer( model=self.model.model, tokenizer=self.model.tokenizer, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, callbacks=callbacks, ) logger.info("Using standard Hugging Face Trainer (fallback)") except Exception as e2: logger.error("Standard Trainer also failed: %s", e2) raise e2 return trainer def load_checkpoint(self, checkpoint_path: str): """Load checkpoint for resuming training""" logger.info("Loading checkpoint from %s", checkpoint_path) if self.init_from == "resume": # Load the model from checkpoint self.model.load_checkpoint(checkpoint_path) # Update trainer with loaded model self.trainer.model = self.model.model logger.info("Checkpoint loaded successfully") elif self.init_from == "pretrained": # Model is already loaded from pretrained logger.info("Using pretrained model") else: logger.info("Starting from scratch") def train(self): """Start training""" logger.info("Starting training") # Log configuration to Trackio if self.monitor and self.monitor.enable_tracking: config_dict = {k: v for k, v in self.config.__dict__.items() if not k.startswith('_')} self.monitor.log_config(config_dict) # Log experiment URL experiment_url = self.monitor.get_experiment_url() if experiment_url: logger.info("Trackio experiment URL: %s", experiment_url) # Load checkpoint if resuming if self.init_from == "resume": checkpoint_path = "/input-checkpoint" if os.path.exists(checkpoint_path): self.load_checkpoint(checkpoint_path) else: logger.warning("Checkpoint path %s not found, starting from scratch", checkpoint_path) # Start training try: logger.info("About to start trainer.train()") train_result = self.trainer.train() # Save the final model self.trainer.save_model() # Save training results with open(os.path.join(self.output_dir, "train_results.json"), "w") as f: json.dump(train_result.metrics, f, indent=2) # Log training summary to Trackio if self.monitor and self.monitor.enable_tracking: summary = { 'final_loss': train_result.metrics.get('train_loss', 0), 'total_steps': train_result.metrics.get('train_runtime', 0), 'training_time': train_result.metrics.get('train_runtime', 0), 'output_dir': self.output_dir, 'model_name': getattr(self.config, 'model_name', 'unknown'), } self.monitor.log_training_summary(summary) self.monitor.close() # Finish trackio experiment try: import trackio trackio.finish() logger.info("Trackio experiment finished") except Exception as e: logger.warning(f"Failed to finish trackio experiment: {e}") logger.info("Training completed successfully!") logger.info("Training metrics: %s", train_result.metrics) except Exception as e: logger.error("Training failed: %s", e) # Close monitoring on error if self.monitor and self.monitor.enable_tracking: self.monitor.close() # Finish trackio experiment on error try: import trackio trackio.finish() except Exception as finish_error: logger.warning(f"Failed to finish trackio experiment on error: {finish_error}") raise def evaluate(self): """Evaluate the model""" logger.info("Starting evaluation") try: eval_results = self.trainer.evaluate() # Save evaluation results with open(os.path.join(self.output_dir, "eval_results.json"), "w") as f: json.dump(eval_results, f, indent=2) logger.info("Evaluation completed: %s", eval_results) return eval_results except Exception as e: logger.error("Evaluation failed: %s", e) raise def save_model(self, path: Optional[str] = None): """Save the trained model""" save_path = path or self.output_dir logger.info("Saving model to %s", save_path) try: self.trainer.save_model(save_path) self.model.tokenizer.save_pretrained(save_path) # Save training configuration if self.config: config_dict = {k: v for k, v in self.config.__dict__.items() if not k.startswith('_')} with open(os.path.join(save_path, 'training_config.json'), 'w') as f: json.dump(config_dict, f, indent=2, default=str) logger.info("Model saved successfully!") except Exception as e: logger.error("Failed to save model: %s", e) raise class SmolLM3DPOTrainer: """DPO Trainer for SmolLM3 preference optimization""" def __init__( self, model, dataset, config, output_dir: str, ref_model=None ): self.model = model self.dataset = dataset self.config = config self.output_dir = output_dir self.ref_model = ref_model # Setup DPO trainer self.trainer = self._setup_dpo_trainer() def _setup_dpo_trainer(self): """Setup DPO trainer""" from trl import DPOTrainer # Get training arguments training_args = self.model.get_training_arguments( output_dir=self.output_dir, save_steps=self.config.save_steps, eval_steps=self.config.eval_steps, logging_steps=self.config.logging_steps, max_steps=self.config.max_iters, ) # Get preference dataset train_dataset = self.dataset.get_train_dataset() eval_dataset = self.dataset.get_eval_dataset() # Setup DPO trainer trainer = DPOTrainer( model=self.model.model, ref_model=self.ref_model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=self.model.tokenizer, max_prompt_length=self.config.max_seq_length // 2, max_length=self.config.max_seq_length, ) return trainer def train(self): """Start DPO training""" logger.info("Starting DPO training") try: train_result = self.trainer.train() # Save the final model self.trainer.save_model() # Save training results with open(os.path.join(self.output_dir, "dpo_train_results.json"), "w") as f: json.dump(train_result.metrics, f, indent=2) logger.info("DPO training completed successfully!") logger.info("Training metrics: %s", train_result.metrics) except Exception as e: logger.error("DPO training failed: %s", e) raise