SmolFactory / docs /SFT_TRAINER_CONFIG_USAGE.md
Tonic's picture
adds sft , quantization, better readmes
40fd629 verified
|
raw
history blame
7.36 kB

SFT Trainer Configuration Usage Guide

Overview

This guide describes how the SFT (Supervised Fine-tuning) trainer uses the premade configuration files and how the trainer_type field is passed through the system.

How SFT Trainer Uses Premade Configs

1. Configuration Loading Process

The SFT trainer uses premade configs through the following process:

  1. Config File Selection: Users specify a config file via command line or launch script
  2. Config Loading: The system loads the config using get_config() function
  3. Config Inheritance: All configs inherit from SmolLM3Config base class
  4. Trainer Type Detection: The system checks for trainer_type field in the config
  5. Training Arguments Creation: Config parameters are used to create TrainingArguments

2. Configuration Parameters Used by SFT Trainer

The SFT trainer uses the following config parameters:

Model Configuration

  • model_name: Model to load (e.g., "HuggingFaceTB/SmolLM3-3B")
  • max_seq_length: Maximum sequence length for tokenization
  • use_flash_attention: Whether to use flash attention
  • use_gradient_checkpointing: Whether to use gradient checkpointing

Training Configuration

  • batch_size: Per-device batch size
  • gradient_accumulation_steps: Gradient accumulation steps
  • learning_rate: Learning rate for optimization
  • weight_decay: Weight decay for optimizer
  • warmup_steps: Number of warmup steps
  • max_iters: Maximum training iterations
  • save_steps: Save checkpoint every N steps
  • eval_steps: Evaluate every N steps
  • logging_steps: Log every N steps

Optimizer Configuration

  • optimizer: Optimizer type (e.g., "adamw_torch")
  • beta1, beta2, eps: Optimizer parameters

Scheduler Configuration

  • scheduler: Learning rate scheduler type
  • min_lr: Minimum learning rate

Mixed Precision

  • fp16: Whether to use fp16 precision
  • bf16: Whether to use bf16 precision

Data Configuration

  • dataset_name: Hugging Face dataset name
  • data_dir: Local dataset directory
  • train_file: Training file name
  • validation_file: Validation file name

Monitoring Configuration

  • enable_tracking: Whether to enable Trackio tracking
  • trackio_url: Trackio server URL
  • experiment_name: Experiment name for tracking

3. Training Arguments Creation

The SFT trainer creates TrainingArguments from config parameters:

def get_training_arguments(self, output_dir: str, **kwargs) -> TrainingArguments:
    training_args = {
        "output_dir": output_dir,
        "per_device_train_batch_size": self.config.batch_size,
        "per_device_eval_batch_size": self.config.batch_size,
        "gradient_accumulation_steps": self.config.gradient_accumulation_steps,
        "learning_rate": self.config.learning_rate,
        "weight_decay": self.config.weight_decay,
        "warmup_steps": self.config.warmup_steps,
        "max_steps": self.config.max_iters,
        "save_steps": self.config.save_steps,
        "eval_steps": self.config.eval_steps,
        "logging_steps": self.config.logging_steps,
        "fp16": self.config.fp16,
        "bf16": self.config.bf16,
        # ... additional parameters
    }
    return TrainingArguments(**training_args)

4. Trainer Selection Logic

The system determines which trainer to use based on the trainer_type field:

# Determine trainer type (command line overrides config)
trainer_type = args.trainer_type or getattr(config, 'trainer_type', 'sft')

# Initialize trainer based on type
if trainer_type.lower() == 'dpo':
    trainer = SmolLM3DPOTrainer(...)
else:
    trainer = SmolLM3Trainer(...)  # SFT trainer

Configuration Files Structure

Base Config (config/train_smollm3.py)

@dataclass
class SmolLM3Config:
    # Trainer type selection
    trainer_type: str = "sft"  # "sft" or "dpo"
    
    # Model configuration
    model_name: str = "HuggingFaceTB/SmolLM3-3B"
    max_seq_length: int = 4096
    # ... other fields

DPO Config (config/train_smollm3_dpo.py)

@dataclass
class SmolLM3DPOConfig(SmolLM3Config):
    # Trainer type selection
    trainer_type: str = "dpo"  # Override default to use DPO trainer
    
    # DPO-specific configuration
    beta: float = 0.1
    # ... DPO-specific fields

Specialized Configs (e.g., config/train_smollm3_openhermes_fr_a100_multiple_passes.py)

@dataclass
class SmolLM3ConfigOpenHermesFRMultiplePasses(SmolLM3Config):
    # Inherits trainer_type = "sft" from base config
    
    # Specialized configuration for multiple passes
    batch_size: int = 6
    gradient_accumulation_steps: int = 20
    learning_rate: float = 3e-6
    max_iters: int = 25000
    # ... other specialized fields

Trainer Type Priority

The trainer type is determined in the following order of priority:

  1. Command line argument (--trainer_type) - Highest priority
  2. Config file (trainer_type field) - Medium priority
  3. Default value ("sft") - Lowest priority

Usage Examples

Using SFT Trainer with Different Configs

# Basic SFT training (uses base config)
python src/train.py config/train_smollm3.py

# SFT training with specialized config
python src/train.py config/train_smollm3_openhermes_fr_a100_multiple_passes.py

# SFT training with override
python src/train.py config/train_smollm3.py --trainer_type sft

# DPO training (uses DPO config)
python src/train.py config/train_smollm3_dpo.py

# Override config's trainer type
python src/train.py config/train_smollm3.py --trainer_type dpo

Launch Script Usage

./launch.sh
# Select "SFT" when prompted for trainer type
# The system will use the appropriate config based on selection

Configuration Inheritance

All specialized configs inherit from SmolLM3Config and automatically get:

  • trainer_type = "sft" (default)
  • All base training parameters
  • All monitoring configuration
  • All data configuration

Specialized configs can override any of these parameters for their specific use case.

SFT Trainer Features

The SFT trainer provides:

  1. SFTTrainer Backend: Uses Hugging Face's SFTTrainer for instruction tuning
  2. Fallback Support: Falls back to standard Trainer if SFTTrainer fails
  3. Config Integration: Uses all config parameters for training setup
  4. Monitoring: Integrates with Trackio for experiment tracking
  5. Checkpointing: Supports model checkpointing and resuming
  6. Mixed Precision: Supports fp16 and bf16 training

Troubleshooting

Common Issues

  1. Missing trainer_type field: Ensure all configs have the trainer_type field
  2. Config inheritance issues: Check that specialized configs properly inherit from base
  3. Parameter conflicts: Ensure command line arguments don't conflict with config values

Debugging

Enable verbose logging to see config usage:

python src/train.py config/train_smollm3.py --trainer_type sft

Look for these log messages:

Using trainer type: sft
Initializing SFT trainer...
Creating SFTTrainer with training arguments...

Related Documentation