Spaces:
Running
Running
adds better launch.sh and eval / test splits auto
Browse files- config/train_gpt_oss_basic.py +10 -0
- config/train_gpt_oss_custom.py +8 -0
- config/train_gpt_oss_h100_optimized.py +7 -0
- config/train_gpt_oss_memory_optimized.py +8 -0
- config/train_gpt_oss_multilingual_reasoning.py +8 -0
- config/train_gpt_oss_openhermes_fr.py +3 -0
- config/train_gpt_oss_openhermes_fr_memory_optimized.py +3 -0
- launch.sh +9 -1
- scripts/training/train_gpt_oss.py +185 -25
config/train_gpt_oss_basic.py
CHANGED
|
@@ -62,6 +62,9 @@ class GPTOSSBasicConfig:
|
|
| 62 |
metric_for_best_model: str = "eval_loss"
|
| 63 |
greater_is_better: bool = False
|
| 64 |
load_best_model_at_end: bool = True
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Data configuration
|
| 67 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
|
@@ -99,6 +102,13 @@ class GPTOSSBasicConfig:
|
|
| 99 |
|
| 100 |
# GPT-OSS specific model kwargs
|
| 101 |
model_kwargs: dict = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def __post_init__(self):
|
| 104 |
if self.chat_template_kwargs is None:
|
|
|
|
| 62 |
metric_for_best_model: str = "eval_loss"
|
| 63 |
greater_is_better: bool = False
|
| 64 |
load_best_model_at_end: bool = True
|
| 65 |
+
eval_accumulation_steps: Optional[int] = None
|
| 66 |
+
eval_ratio: float = 0.01
|
| 67 |
+
test_ratio: float = 0.01
|
| 68 |
|
| 69 |
# Data configuration
|
| 70 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
|
|
|
| 102 |
|
| 103 |
# GPT-OSS specific model kwargs
|
| 104 |
model_kwargs: dict = None
|
| 105 |
+
# Performance and precision extras
|
| 106 |
+
dataloader_prefetch_factor: int = 2
|
| 107 |
+
tf32: Optional[bool] = None
|
| 108 |
+
# DPO preference training fields
|
| 109 |
+
chosen_field: Optional[str] = None
|
| 110 |
+
rejected_field: Optional[str] = None
|
| 111 |
+
dpo_beta: float = 0.1
|
| 112 |
|
| 113 |
def __post_init__(self):
|
| 114 |
if self.chat_template_kwargs is None:
|
config/train_gpt_oss_custom.py
CHANGED
|
@@ -83,6 +83,9 @@ class GPTOSSEnhancedCustomConfig:
|
|
| 83 |
eval_steps: int = 100 # Evaluate every N steps
|
| 84 |
eval_delay: float = 0 # Delay evaluation for N steps/epochs
|
| 85 |
eval_accumulation_steps: Optional[int] = None # Accumulate eval outputs
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# Checkpointing
|
| 88 |
save_strategy: str = "steps" # "no", "steps", "epoch"
|
|
@@ -167,6 +170,11 @@ class GPTOSSEnhancedCustomConfig:
|
|
| 167 |
|
| 168 |
# Generation Configuration (for evaluation/testing)
|
| 169 |
generation_config: Optional[Dict] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# ============================================================================
|
| 172 |
# MULTILINGUAL & DOMAIN SPECIFIC SETTINGS
|
|
|
|
| 83 |
eval_steps: int = 100 # Evaluate every N steps
|
| 84 |
eval_delay: float = 0 # Delay evaluation for N steps/epochs
|
| 85 |
eval_accumulation_steps: Optional[int] = None # Accumulate eval outputs
|
| 86 |
+
# Automatic split ratios when only a single training split is provided
|
| 87 |
+
eval_ratio: float = 0.01 # Fraction of data for validation (0.0-0.5 typical)
|
| 88 |
+
test_ratio: float = 0.01 # Fraction of data for test (0.0-0.5 typical)
|
| 89 |
|
| 90 |
# Checkpointing
|
| 91 |
save_strategy: str = "steps" # "no", "steps", "epoch"
|
|
|
|
| 170 |
|
| 171 |
# Generation Configuration (for evaluation/testing)
|
| 172 |
generation_config: Optional[Dict] = None
|
| 173 |
+
|
| 174 |
+
# Preference-training (DPO) configuration
|
| 175 |
+
chosen_field: Optional[str] = None # Field name for preferred response (for DPO datasets)
|
| 176 |
+
rejected_field: Optional[str] = None # Field name for rejected response (for DPO datasets)
|
| 177 |
+
dpo_beta: float = 0.1 # DPO beta parameter
|
| 178 |
|
| 179 |
# ============================================================================
|
| 180 |
# MULTILINGUAL & DOMAIN SPECIFIC SETTINGS
|
config/train_gpt_oss_h100_optimized.py
CHANGED
|
@@ -62,6 +62,9 @@ class GPTOSSH100OptimizedConfig:
|
|
| 62 |
metric_for_best_model: str = "eval_loss"
|
| 63 |
greater_is_better: bool = False
|
| 64 |
load_best_model_at_end: bool = True
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Data configuration
|
| 67 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
|
@@ -104,6 +107,10 @@ class GPTOSSH100OptimizedConfig:
|
|
| 104 |
dataloader_num_workers: int = 8 # More workers for H100
|
| 105 |
dataloader_pin_memory: bool = True
|
| 106 |
dataloader_prefetch_factor: int = 4 # Increased prefetch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# Memory optimizations for H100
|
| 109 |
max_grad_norm: float = 1.0
|
|
|
|
| 62 |
metric_for_best_model: str = "eval_loss"
|
| 63 |
greater_is_better: bool = False
|
| 64 |
load_best_model_at_end: bool = True
|
| 65 |
+
eval_accumulation_steps: Optional[int] = None
|
| 66 |
+
eval_ratio: float = 0.01
|
| 67 |
+
test_ratio: float = 0.01
|
| 68 |
|
| 69 |
# Data configuration
|
| 70 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
|
|
|
| 107 |
dataloader_num_workers: int = 8 # More workers for H100
|
| 108 |
dataloader_pin_memory: bool = True
|
| 109 |
dataloader_prefetch_factor: int = 4 # Increased prefetch
|
| 110 |
+
tf32: Optional[bool] = None
|
| 111 |
+
chosen_field: Optional[str] = None
|
| 112 |
+
rejected_field: Optional[str] = None
|
| 113 |
+
dpo_beta: float = 0.1
|
| 114 |
|
| 115 |
# Memory optimizations for H100
|
| 116 |
max_grad_norm: float = 1.0
|
config/train_gpt_oss_memory_optimized.py
CHANGED
|
@@ -43,6 +43,9 @@ class GPTOSSMemoryOptimizedConfig:
|
|
| 43 |
metric_for_best_model: str = "eval_loss"
|
| 44 |
greater_is_better: bool = False
|
| 45 |
load_best_model_at_end: bool = True
|
|
|
|
|
|
|
|
|
|
| 46 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
| 47 |
dataset_split: str = "train"
|
| 48 |
input_field: str = "messages"
|
|
@@ -65,6 +68,11 @@ class GPTOSSMemoryOptimizedConfig:
|
|
| 65 |
use_quantization: bool = True
|
| 66 |
quantization_config: dict = None
|
| 67 |
model_kwargs: dict = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
generation_config: dict = None
|
| 69 |
reasoning_languages: list = None
|
| 70 |
|
|
|
|
| 43 |
metric_for_best_model: str = "eval_loss"
|
| 44 |
greater_is_better: bool = False
|
| 45 |
load_best_model_at_end: bool = True
|
| 46 |
+
eval_accumulation_steps: Optional[int] = None
|
| 47 |
+
eval_ratio: float = 0.01
|
| 48 |
+
test_ratio: float = 0.01
|
| 49 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
| 50 |
dataset_split: str = "train"
|
| 51 |
input_field: str = "messages"
|
|
|
|
| 68 |
use_quantization: bool = True
|
| 69 |
quantization_config: dict = None
|
| 70 |
model_kwargs: dict = None
|
| 71 |
+
dataloader_prefetch_factor: int = 2
|
| 72 |
+
tf32: Optional[bool] = None
|
| 73 |
+
chosen_field: Optional[str] = None
|
| 74 |
+
rejected_field: Optional[str] = None
|
| 75 |
+
dpo_beta: float = 0.1
|
| 76 |
generation_config: dict = None
|
| 77 |
reasoning_languages: list = None
|
| 78 |
|
config/train_gpt_oss_multilingual_reasoning.py
CHANGED
|
@@ -62,6 +62,9 @@ class GPTOSSMultilingualReasoningConfig:
|
|
| 62 |
metric_for_best_model: str = "eval_loss"
|
| 63 |
greater_is_better: bool = False
|
| 64 |
load_best_model_at_end: bool = True
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Data configuration - Multilingual-Thinking specific
|
| 67 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
|
@@ -99,6 +102,11 @@ class GPTOSSMultilingualReasoningConfig:
|
|
| 99 |
|
| 100 |
# GPT-OSS specific model kwargs - as per tutorial
|
| 101 |
model_kwargs: dict = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
# Multilingual reasoning specific configurations
|
| 104 |
# Generation parameters for multilingual reasoning
|
|
|
|
| 62 |
metric_for_best_model: str = "eval_loss"
|
| 63 |
greater_is_better: bool = False
|
| 64 |
load_best_model_at_end: bool = True
|
| 65 |
+
eval_accumulation_steps: Optional[int] = None
|
| 66 |
+
eval_ratio: float = 0.01
|
| 67 |
+
test_ratio: float = 0.01
|
| 68 |
|
| 69 |
# Data configuration - Multilingual-Thinking specific
|
| 70 |
dataset_name: str = "HuggingFaceH4/Multilingual-Thinking"
|
|
|
|
| 102 |
|
| 103 |
# GPT-OSS specific model kwargs - as per tutorial
|
| 104 |
model_kwargs: dict = None
|
| 105 |
+
dataloader_prefetch_factor: int = 2
|
| 106 |
+
tf32: Optional[bool] = None
|
| 107 |
+
chosen_field: Optional[str] = None
|
| 108 |
+
rejected_field: Optional[str] = None
|
| 109 |
+
dpo_beta: float = 0.1
|
| 110 |
|
| 111 |
# Multilingual reasoning specific configurations
|
| 112 |
# Generation parameters for multilingual reasoning
|
config/train_gpt_oss_openhermes_fr.py
CHANGED
|
@@ -119,6 +119,9 @@ config = GPTOSSEnhancedCustomConfig(
|
|
| 119 |
metric_for_best_model="eval_loss",
|
| 120 |
greater_is_better=False,
|
| 121 |
load_best_model_at_end=True,
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
# ============================================================================
|
| 124 |
# MULTILINGUAL & FRENCH SPECIFIC SETTINGS
|
|
|
|
| 119 |
metric_for_best_model="eval_loss",
|
| 120 |
greater_is_better=False,
|
| 121 |
load_best_model_at_end=True,
|
| 122 |
+
# Split ratios for automatic validation/test creation
|
| 123 |
+
eval_ratio=0.01,
|
| 124 |
+
test_ratio=0.01,
|
| 125 |
|
| 126 |
# ============================================================================
|
| 127 |
# MULTILINGUAL & FRENCH SPECIFIC SETTINGS
|
config/train_gpt_oss_openhermes_fr_memory_optimized.py
CHANGED
|
@@ -144,6 +144,9 @@ config = GPTOSSEnhancedCustomConfig(
|
|
| 144 |
# Evaluation memory optimization
|
| 145 |
eval_accumulation_steps=4, # Accumulate eval outputs to save memory
|
| 146 |
eval_batch_size=1, # Smaller eval batch size
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
# ============================================================================
|
| 149 |
# GPT-OSS HARMONY FORMAT OPTIMIZATION
|
|
|
|
| 144 |
# Evaluation memory optimization
|
| 145 |
eval_accumulation_steps=4, # Accumulate eval outputs to save memory
|
| 146 |
eval_batch_size=1, # Smaller eval batch size
|
| 147 |
+
# Split ratios for automatic validation/test creation
|
| 148 |
+
eval_ratio=0.001,
|
| 149 |
+
test_ratio=0.0005,
|
| 150 |
|
| 151 |
# ============================================================================
|
| 152 |
# GPT-OSS HARMONY FORMAT OPTIMIZATION
|
launch.sh
CHANGED
|
@@ -827,7 +827,15 @@ fi
|
|
| 827 |
print_step "Step 3: Experiment Details"
|
| 828 |
echo "=============================="
|
| 829 |
|
| 830 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
|
| 832 |
# Configure model repository name (customizable)
|
| 833 |
print_info "Setting up model repository name..."
|
|
|
|
| 827 |
print_step "Step 3: Experiment Details"
|
| 828 |
echo "=============================="
|
| 829 |
|
| 830 |
+
# Derive default experiment name from smolfactory + chosen model family
|
| 831 |
+
if [ "$MODEL_FAMILY" = "GPT-OSS" ]; then
|
| 832 |
+
FAMILY_SLUG="gpt-oss"
|
| 833 |
+
else
|
| 834 |
+
FAMILY_SLUG="smollm3"
|
| 835 |
+
fi
|
| 836 |
+
DEFAULT_EXPERIMENT_NAME="smolfactory-${FAMILY_SLUG}_$(date +%Y%m%d_%H%M%S)"
|
| 837 |
+
|
| 838 |
+
get_input "Experiment name" "$DEFAULT_EXPERIMENT_NAME" EXPERIMENT_NAME
|
| 839 |
|
| 840 |
# Configure model repository name (customizable)
|
| 841 |
print_info "Setting up model repository name..."
|
scripts/training/train_gpt_oss.py
CHANGED
|
@@ -13,6 +13,10 @@ import torch
|
|
| 13 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
| 14 |
from peft import LoraConfig, get_peft_model
|
| 15 |
from trl import SFTTrainer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from datasets import load_dataset
|
| 17 |
from pathlib import Path
|
| 18 |
|
|
@@ -214,6 +218,10 @@ def format_gpt_oss_harmony(prompt, completion, add_eos_token=True):
|
|
| 214 |
|
| 215 |
return harmony_text
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
def process_dataset_format(dataset, config):
|
| 218 |
"""Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
|
| 219 |
|
|
@@ -224,11 +232,53 @@ def process_dataset_format(dataset, config):
|
|
| 224 |
field_separator = getattr(config, 'field_separator', '\n\n### Response:\n')
|
| 225 |
add_eos_token = getattr(config, 'add_eos_token', True)
|
| 226 |
use_harmony_format = getattr(config, 'use_harmony_format', True)
|
|
|
|
| 227 |
|
| 228 |
print(f"Processing dataset format: {dataset_format}")
|
| 229 |
print(f"Input field: {input_field}, Target field: {target_field}")
|
| 230 |
print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}")
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
if dataset_format == "openhermes_fr":
|
| 233 |
# Process OpenHermes-FR format: prompt + accepted_completion
|
| 234 |
def format_openhermes_fr(example):
|
|
@@ -317,6 +367,72 @@ def process_dataset_format(dataset, config):
|
|
| 317 |
|
| 318 |
return dataset
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
def setup_trackio_tracking(config):
|
| 321 |
"""Setup Trackio tracking if enabled"""
|
| 322 |
|
|
@@ -530,6 +646,9 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
| 530 |
|
| 531 |
# Load dataset
|
| 532 |
dataset = load_dataset_from_config(config)
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
# Setup Trackio tracking
|
| 535 |
trackio_client = setup_trackio_tracking(config)
|
|
@@ -538,37 +657,78 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
| 538 |
sft_config = create_sft_config(config, output_dir)
|
| 539 |
|
| 540 |
# Create trainer with version-robust kwargs
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
sft_params = set(sft_sig.parameters.keys())
|
| 545 |
-
except Exception:
|
| 546 |
-
sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"}
|
| 547 |
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
|
| 568 |
-
|
| 569 |
-
|
| 570 |
|
| 571 |
-
|
|
|
|
|
|
|
|
|
|
| 572 |
|
| 573 |
# Start training
|
| 574 |
print("Starting GPT-OSS training...")
|
|
|
|
| 13 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
| 14 |
from peft import LoraConfig, get_peft_model
|
| 15 |
from trl import SFTTrainer
|
| 16 |
+
try:
|
| 17 |
+
from trl import DPOTrainer
|
| 18 |
+
except Exception: # pragma: no cover - optional import depending on TRL version
|
| 19 |
+
DPOTrainer = None
|
| 20 |
from datasets import load_dataset
|
| 21 |
from pathlib import Path
|
| 22 |
|
|
|
|
| 218 |
|
| 219 |
return harmony_text
|
| 220 |
|
| 221 |
+
def format_gpt_oss_harmony_prompt(prompt: str) -> str:
|
| 222 |
+
"""Prefix-only Harmony prompt up to assistant content marker for DPO."""
|
| 223 |
+
return f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>"
|
| 224 |
+
|
| 225 |
def process_dataset_format(dataset, config):
|
| 226 |
"""Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
|
| 227 |
|
|
|
|
| 232 |
field_separator = getattr(config, 'field_separator', '\n\n### Response:\n')
|
| 233 |
add_eos_token = getattr(config, 'add_eos_token', True)
|
| 234 |
use_harmony_format = getattr(config, 'use_harmony_format', True)
|
| 235 |
+
trainer_type = getattr(config, 'trainer_type', 'sft')
|
| 236 |
|
| 237 |
print(f"Processing dataset format: {dataset_format}")
|
| 238 |
print(f"Input field: {input_field}, Target field: {target_field}")
|
| 239 |
print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}")
|
| 240 |
|
| 241 |
+
# Preference-format for DPO training (chosen/rejected pairs)
|
| 242 |
+
if trainer_type == 'dpo':
|
| 243 |
+
chosen_field = getattr(config, 'chosen_field', None)
|
| 244 |
+
rejected_field = getattr(config, 'rejected_field', None)
|
| 245 |
+
|
| 246 |
+
if dataset_format == 'preference':
|
| 247 |
+
# Expect columns present; optionally reformat to ensure only necessary columns
|
| 248 |
+
def id_map(example):
|
| 249 |
+
prompt_val = example.get(input_field, '')
|
| 250 |
+
chosen_val = example.get('chosen', example.get(chosen_field or 'chosen', ''))
|
| 251 |
+
rejected_val = example.get('rejected', example.get(rejected_field or 'rejected', ''))
|
| 252 |
+
if use_harmony_format:
|
| 253 |
+
prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
|
| 254 |
+
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
|
| 255 |
+
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
|
| 256 |
+
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
|
| 257 |
+
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
|
| 258 |
+
|
| 259 |
+
keep_cols = [c for c in ['prompt', 'chosen', 'rejected'] if c in dataset.column_names]
|
| 260 |
+
dataset = dataset.map(id_map, remove_columns=dataset.column_names if keep_cols else dataset.column_names)
|
| 261 |
+
return dataset
|
| 262 |
+
|
| 263 |
+
# Custom preference mapping via configured field names
|
| 264 |
+
if chosen_field and rejected_field:
|
| 265 |
+
def to_pref(example):
|
| 266 |
+
prompt_val = example.get(input_field, '')
|
| 267 |
+
chosen_val = example.get(chosen_field, '')
|
| 268 |
+
rejected_val = example.get(rejected_field, '')
|
| 269 |
+
if use_harmony_format:
|
| 270 |
+
prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
|
| 271 |
+
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
|
| 272 |
+
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
|
| 273 |
+
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
|
| 274 |
+
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
|
| 275 |
+
|
| 276 |
+
dataset = dataset.map(to_pref, remove_columns=dataset.column_names)
|
| 277 |
+
return dataset
|
| 278 |
+
|
| 279 |
+
# If we reach here, we don't have required fields for DPO
|
| 280 |
+
raise ValueError("DPO training requires preference data. Please set dataset_format='preference' with 'prompt', 'chosen', 'rejected' columns, or specify 'chosen_field' and 'rejected_field' in the config.")
|
| 281 |
+
|
| 282 |
if dataset_format == "openhermes_fr":
|
| 283 |
# Process OpenHermes-FR format: prompt + accepted_completion
|
| 284 |
def format_openhermes_fr(example):
|
|
|
|
| 367 |
|
| 368 |
return dataset
|
| 369 |
|
| 370 |
+
def split_dataset(dataset, config):
|
| 371 |
+
"""Create train/validation/test splits from a single dataset.
|
| 372 |
+
Defaults to 1% eval and 1% test if not specified.
|
| 373 |
+
"""
|
| 374 |
+
from datasets import Dataset
|
| 375 |
+
|
| 376 |
+
if not isinstance(dataset, Dataset):
|
| 377 |
+
# If it's already a DatasetDict, try to use its splits
|
| 378 |
+
try:
|
| 379 |
+
train_split = dataset["train"]
|
| 380 |
+
eval_split = dataset.get("validation") or dataset.get("eval")
|
| 381 |
+
test_split = dataset.get("test")
|
| 382 |
+
return train_split, eval_split, test_split
|
| 383 |
+
except Exception:
|
| 384 |
+
pass
|
| 385 |
+
|
| 386 |
+
eval_ratio = getattr(config, 'eval_ratio', 0.01)
|
| 387 |
+
test_ratio = getattr(config, 'test_ratio', 0.01)
|
| 388 |
+
|
| 389 |
+
# Clamp ratios to sane bounds
|
| 390 |
+
try:
|
| 391 |
+
eval_ratio = max(0.0, float(eval_ratio))
|
| 392 |
+
test_ratio = max(0.0, float(test_ratio))
|
| 393 |
+
if eval_ratio + test_ratio >= 0.9:
|
| 394 |
+
# Avoid extreme splits; cap combined at 0.2
|
| 395 |
+
scale = 0.2 / max(1e-9, (eval_ratio + test_ratio))
|
| 396 |
+
eval_ratio *= scale
|
| 397 |
+
test_ratio *= scale
|
| 398 |
+
except Exception:
|
| 399 |
+
eval_ratio, test_ratio = 0.01, 0.01
|
| 400 |
+
|
| 401 |
+
# No eval/test requested
|
| 402 |
+
if eval_ratio <= 0 and test_ratio <= 0:
|
| 403 |
+
return dataset, None, None
|
| 404 |
+
|
| 405 |
+
ds_shuffled = dataset.shuffle(seed=42)
|
| 406 |
+
|
| 407 |
+
# First carve out test split
|
| 408 |
+
if test_ratio > 0:
|
| 409 |
+
split1 = ds_shuffled.train_test_split(test_size=test_ratio, seed=42)
|
| 410 |
+
train_part = split1["train"]
|
| 411 |
+
test_split = split1["test"]
|
| 412 |
+
else:
|
| 413 |
+
train_part = ds_shuffled
|
| 414 |
+
test_split = None
|
| 415 |
+
|
| 416 |
+
# Then carve out eval from remaining train
|
| 417 |
+
if eval_ratio > 0:
|
| 418 |
+
remaining_fraction = 1.0 - test_ratio
|
| 419 |
+
# Convert global eval fraction to fraction of remaining pool
|
| 420 |
+
relative_eval = eval_ratio / remaining_fraction if remaining_fraction > 0 else eval_ratio
|
| 421 |
+
split2 = train_part.train_test_split(test_size=relative_eval, seed=42)
|
| 422 |
+
train_split = split2["train"]
|
| 423 |
+
eval_split = split2["test"]
|
| 424 |
+
else:
|
| 425 |
+
train_split = train_part
|
| 426 |
+
eval_split = None
|
| 427 |
+
|
| 428 |
+
# Log sizes
|
| 429 |
+
try:
|
| 430 |
+
print(f"Created splits -> train: {len(train_split)}, eval: {len(eval_split) if eval_split else 0}, test: {len(test_split) if test_split else 0}")
|
| 431 |
+
except Exception:
|
| 432 |
+
pass
|
| 433 |
+
|
| 434 |
+
return train_split, eval_split, test_split
|
| 435 |
+
|
| 436 |
def setup_trackio_tracking(config):
|
| 437 |
"""Setup Trackio tracking if enabled"""
|
| 438 |
|
|
|
|
| 646 |
|
| 647 |
# Load dataset
|
| 648 |
dataset = load_dataset_from_config(config)
|
| 649 |
+
|
| 650 |
+
# Split into train/eval/test
|
| 651 |
+
train_dataset, eval_dataset, test_dataset = split_dataset(dataset, config)
|
| 652 |
|
| 653 |
# Setup Trackio tracking
|
| 654 |
trackio_client = setup_trackio_tracking(config)
|
|
|
|
| 657 |
sft_config = create_sft_config(config, output_dir)
|
| 658 |
|
| 659 |
# Create trainer with version-robust kwargs
|
| 660 |
+
if trainer_type == 'dpo':
|
| 661 |
+
if DPOTrainer is None:
|
| 662 |
+
raise RuntimeError("DPOTrainer is not available in this TRL version. Please upgrade 'trl'.")
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
+
print("Creating DPO trainer...")
|
| 665 |
+
try:
|
| 666 |
+
dpo_sig = inspect.signature(DPOTrainer.__init__)
|
| 667 |
+
dpo_params = set(dpo_sig.parameters.keys())
|
| 668 |
+
except Exception:
|
| 669 |
+
dpo_params = {"model", "args", "train_dataset", "tokenizer", "beta", "prompt_column", "chosen_column", "rejected_column"}
|
| 670 |
+
|
| 671 |
+
dpo_kwargs = {
|
| 672 |
+
"model": peft_model,
|
| 673 |
+
"args": sft_config,
|
| 674 |
+
"train_dataset": train_dataset,
|
| 675 |
+
"beta": getattr(config, 'dpo_beta', 0.1),
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
if "tokenizer" in dpo_params:
|
| 679 |
+
dpo_kwargs["tokenizer"] = tokenizer
|
| 680 |
+
elif "processing_class" in dpo_params:
|
| 681 |
+
dpo_kwargs["processing_class"] = tokenizer
|
| 682 |
+
|
| 683 |
+
if "prompt_column" in dpo_params:
|
| 684 |
+
dpo_kwargs["prompt_column"] = "prompt"
|
| 685 |
+
if "chosen_column" in dpo_params:
|
| 686 |
+
dpo_kwargs["chosen_column"] = "chosen"
|
| 687 |
+
if "rejected_column" in dpo_params:
|
| 688 |
+
dpo_kwargs["rejected_column"] = "rejected"
|
| 689 |
+
|
| 690 |
+
# Remove Nones
|
| 691 |
+
dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if v is not None}
|
| 692 |
+
|
| 693 |
+
# Pass eval dataset if supported
|
| 694 |
+
if "eval_dataset" in dpo_params and eval_dataset is not None:
|
| 695 |
+
dpo_kwargs["eval_dataset"] = eval_dataset
|
| 696 |
+
trainer = DPOTrainer(**dpo_kwargs)
|
| 697 |
+
else:
|
| 698 |
+
print("Creating SFT trainer...")
|
| 699 |
+
try:
|
| 700 |
+
sft_sig = inspect.signature(SFTTrainer.__init__)
|
| 701 |
+
sft_params = set(sft_sig.parameters.keys())
|
| 702 |
+
except Exception:
|
| 703 |
+
sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"}
|
| 704 |
+
|
| 705 |
+
sft_kwargs = {
|
| 706 |
+
"model": peft_model,
|
| 707 |
+
"args": sft_config,
|
| 708 |
+
"train_dataset": train_dataset,
|
| 709 |
+
}
|
| 710 |
|
| 711 |
+
# Prefer passing tokenizer if supported; otherwise try processing_class
|
| 712 |
+
if "tokenizer" in sft_params:
|
| 713 |
+
sft_kwargs["tokenizer"] = tokenizer
|
| 714 |
+
elif "processing_class" in sft_params:
|
| 715 |
+
sft_kwargs["processing_class"] = tokenizer
|
| 716 |
|
| 717 |
+
# Pass dataset text field if supported (we produced a 'text' column)
|
| 718 |
+
if "dataset_text_field" in sft_params:
|
| 719 |
+
sft_kwargs["dataset_text_field"] = "text"
|
| 720 |
|
| 721 |
+
# Pass max sequence length if supported
|
| 722 |
+
if "max_seq_length" in sft_params:
|
| 723 |
+
sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048)
|
| 724 |
|
| 725 |
+
# Remove any None values
|
| 726 |
+
sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
|
| 727 |
|
| 728 |
+
# Attach eval_dataset if supported
|
| 729 |
+
if "eval_dataset" in sft_params and eval_dataset is not None:
|
| 730 |
+
sft_kwargs["eval_dataset"] = eval_dataset
|
| 731 |
+
trainer = SFTTrainer(**sft_kwargs)
|
| 732 |
|
| 733 |
# Start training
|
| 734 |
print("Starting GPT-OSS training...")
|