SmolFactory / scripts /training /train_gpt_oss.py
Tonic's picture
adds better launch.sh and eval / test splits auto
0fa6045
raw
history blame
31.8 kB
#!/usr/bin/env python3
"""
GPT-OSS Training Script
Specialized training script for OpenAI's GPT-OSS models
Based on the GPT-OSS fine-tuning tutorial
"""
import os
import sys
import argparse
import inspect
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
try:
from trl import DPOTrainer
except Exception: # pragma: no cover - optional import depending on TRL version
DPOTrainer = None
from datasets import load_dataset
from pathlib import Path
# Ensure project root and config package are importable for configs that do `from config...` imports
project_root = Path(__file__).resolve().parents[2]
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
config_dir = project_root / "config"
if str(config_dir) not in sys.path:
sys.path.insert(0, str(config_dir))
def load_gpt_oss_model_and_tokenizer(config):
"""Load GPT-OSS model and tokenizer with proper configuration"""
print("Loading GPT-OSS tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
print("Loading GPT-OSS model with quantization...")
# Import quantization config
from transformers import BitsAndBytesConfig
# Set up quantization config based on config
if config.quantization_config and config.quantization_config.get("load_in_4bit"):
# Use BitsAndBytesConfig for 4-bit quantization (memory optimized)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif config.quantization_config and config.quantization_config.get("dequantize"):
# Try to use Mxfp4Config if available (as per tutorial)
try:
from transformers import Mxfp4Config
quantization_config = Mxfp4Config(dequantize=True)
except ImportError:
# Fallback to no quantization if Mxfp4Config not available
print("Warning: Mxfp4Config not available, using no quantization")
quantization_config = None
else:
# No quantization
quantization_config = None
# Model kwargs as per tutorial
model_kwargs = {
"attn_implementation": "eager",
"torch_dtype": torch.bfloat16,
"use_cache": False,
"device_map": "auto",
}
# Only add quantization_config if it's not None
if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config
model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs)
return model, tokenizer
def setup_lora_for_gpt_oss(model, config):
"""Setup LoRA for GPT-OSS model"""
print("Setting up LoRA for GPT-OSS...")
# LoRA configuration as per tutorial
lora_config = LoraConfig(
r=config.lora_config.get("r", 8) if config.lora_config else 8,
lora_alpha=config.lora_config.get("lora_alpha", 16) if config.lora_config else 16,
target_modules=config.lora_config.get("target_modules", "all-linear") if config.lora_config else "all-linear",
target_parameters=config.lora_config.get("target_parameters", [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
]) if config.lora_config else [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
],
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
return peft_model
def load_dataset_from_config(config):
"""Load dataset based on configuration"""
dataset_name = getattr(config, 'dataset_name', 'HuggingFaceH4/Multilingual-Thinking')
dataset_split = getattr(config, 'dataset_split', 'train')
dataset_config = getattr(config, 'dataset_config', None)
print(f"Loading dataset: {dataset_name}")
print(f"Dataset split: {dataset_split}")
if dataset_config:
print(f"Dataset config: {dataset_config}")
# Load the dataset
if dataset_config:
dataset = load_dataset(dataset_name, dataset_config, split=dataset_split)
else:
dataset = load_dataset(dataset_name, split=dataset_split)
print(f"Original dataset size: {len(dataset)} examples")
# Apply filtering based on configuration
dataset = apply_dataset_filtering(dataset, config)
# Apply dataset processing based on format
dataset = process_dataset_format(dataset, config)
print(f"Final dataset size: {len(dataset)} examples")
return dataset
def apply_dataset_filtering(dataset, config):
"""Apply filtering based on configuration"""
# Filter bad entries if specified
if getattr(config, 'filter_bad_entries', False):
bad_entry_field = getattr(config, 'bad_entry_field', 'bad_entry')
bad_prompt_field = getattr(config, 'bad_prompt_field', 'bad_prompt_detected')
bad_response_field = getattr(config, 'bad_response_field', 'bad_response_detected')
original_size = len(dataset)
# Filter out bad entries
if bad_entry_field in dataset.column_names:
dataset = dataset.filter(lambda x: not x.get(bad_entry_field, False))
print(f"Filtered {original_size - len(dataset)} bad entries")
# Filter out bad prompts
if bad_prompt_field in dataset.column_names:
dataset = dataset.filter(lambda x: not x.get(bad_prompt_field, False))
print(f"Filtered bad prompts, remaining: {len(dataset)} examples")
# Filter out bad responses
if bad_response_field in dataset.column_names:
dataset = dataset.filter(lambda x: not x.get(bad_response_field, False))
print(f"Filtered bad responses, remaining: {len(dataset)} examples")
# Apply length filtering
min_length = getattr(config, 'min_length', 10)
max_length = getattr(config, 'max_length', None)
input_field = getattr(config, 'input_field', 'prompt')
target_field = getattr(config, 'target_field', 'accepted_completion')
if min_length > 0 or max_length:
def length_filter(example):
input_len = len(example.get(input_field, ''))
target_len = len(example.get(target_field, ''))
total_len = input_len + target_len
if total_len < min_length:
return False
if max_length and total_len > max_length:
return False
return True
original_size = len(dataset)
dataset = dataset.filter(length_filter)
print(f"Length filtering: {original_size} -> {len(dataset)} examples")
# Apply sampling if specified
max_samples = getattr(config, 'max_samples', None)
if max_samples and len(dataset) > max_samples:
dataset = dataset.shuffle(seed=42).select(range(max_samples))
print(f"Sampled {max_samples} examples from dataset")
return dataset
def format_gpt_oss_harmony(prompt, completion, add_eos_token=True):
"""
Format data for GPT-OSS Harmony format following the exact template structure.
Based on: https://huggingface.co/openai/gpt-oss-20b/raw/main/chat_template.jinja
"""
# GPT-OSS Harmony format structure (exact template compliance)
# User message: <|start|>user<|message|>content<|end|>
# Assistant message: <|start|>assistant<|channel|>final<|message|>content<|end|> (inference)
# Assistant message: <|start|>assistant<|channel|>final<|message|>content<|return|> (training)
harmony_text = f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>{completion}"
if add_eos_token:
# Use <|return|> for training as per template specification
# This indicates the end of generation in training
harmony_text += "<|return|>"
else:
# Use <|end|> for inference
harmony_text += "<|end|>"
return harmony_text
def format_gpt_oss_harmony_prompt(prompt: str) -> str:
"""Prefix-only Harmony prompt up to assistant content marker for DPO."""
return f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>"
def process_dataset_format(dataset, config):
"""Process dataset based on format configuration with exact GPT-OSS Harmony compliance"""
dataset_format = getattr(config, 'dataset_format', 'openhermes_fr')
input_field = getattr(config, 'input_field', 'prompt')
target_field = getattr(config, 'target_field', 'accepted_completion')
concatenate_fields = getattr(config, 'concatenate_fields', True)
field_separator = getattr(config, 'field_separator', '\n\n### Response:\n')
add_eos_token = getattr(config, 'add_eos_token', True)
use_harmony_format = getattr(config, 'use_harmony_format', True)
trainer_type = getattr(config, 'trainer_type', 'sft')
print(f"Processing dataset format: {dataset_format}")
print(f"Input field: {input_field}, Target field: {target_field}")
print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}")
# Preference-format for DPO training (chosen/rejected pairs)
if trainer_type == 'dpo':
chosen_field = getattr(config, 'chosen_field', None)
rejected_field = getattr(config, 'rejected_field', None)
if dataset_format == 'preference':
# Expect columns present; optionally reformat to ensure only necessary columns
def id_map(example):
prompt_val = example.get(input_field, '')
chosen_val = example.get('chosen', example.get(chosen_field or 'chosen', ''))
rejected_val = example.get('rejected', example.get(rejected_field or 'rejected', ''))
if use_harmony_format:
prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
keep_cols = [c for c in ['prompt', 'chosen', 'rejected'] if c in dataset.column_names]
dataset = dataset.map(id_map, remove_columns=dataset.column_names if keep_cols else dataset.column_names)
return dataset
# Custom preference mapping via configured field names
if chosen_field and rejected_field:
def to_pref(example):
prompt_val = example.get(input_field, '')
chosen_val = example.get(chosen_field, '')
rejected_val = example.get(rejected_field, '')
if use_harmony_format:
prompt_text = format_gpt_oss_harmony_prompt(prompt_val)
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '')
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '')
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text}
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val}
dataset = dataset.map(to_pref, remove_columns=dataset.column_names)
return dataset
# If we reach here, we don't have required fields for DPO
raise ValueError("DPO training requires preference data. Please set dataset_format='preference' with 'prompt', 'chosen', 'rejected' columns, or specify 'chosen_field' and 'rejected_field' in the config.")
if dataset_format == "openhermes_fr":
# Process OpenHermes-FR format: prompt + accepted_completion
def format_openhermes_fr(example):
prompt = example.get(input_field, '')
completion = example.get(target_field, '')
if concatenate_fields:
if use_harmony_format:
# Use exact GPT-OSS Harmony format from template
text = format_gpt_oss_harmony(prompt, completion, add_eos_token)
else:
# Fallback to standard format with separator
text = prompt + field_separator + completion
if add_eos_token:
text += "</s>"
return {"text": text}
else:
# Keep separate for more advanced training setups
return {
"input": prompt,
"output": completion
}
dataset = dataset.map(format_openhermes_fr, remove_columns=dataset.column_names)
elif dataset_format == "messages":
# Process messages format (like HuggingFaceH4/Multilingual-Thinking)
def format_messages(example):
messages = example.get(input_field, [])
if use_harmony_format and len(messages) >= 2:
# Extract user and assistant messages for harmony format
user_message = ""
assistant_message = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
if role == "user":
user_message = content
elif role == "assistant":
assistant_message = content
if user_message and assistant_message:
# Use GPT-OSS Harmony format
text = format_gpt_oss_harmony(user_message, assistant_message, add_eos_token)
else:
# Fallback to simple concatenation
text = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
text += f"{role}: {content}\n"
if add_eos_token:
text += "</s>"
else:
# Standard format - convert messages to simple text
text = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
text += f"{role}: {content}\n"
if add_eos_token:
text += "</s>"
return {"text": text}
dataset = dataset.map(format_messages, remove_columns=dataset.column_names)
elif dataset_format == "text":
# Process plain text format
text_field = input_field
def format_text(example):
text = example.get(text_field, '')
if add_eos_token:
text += "</s>"
return {"text": text}
dataset = dataset.map(format_text, remove_columns=dataset.column_names)
elif dataset_format == "custom":
# Custom format - user handles this in their config
print("Using custom dataset format - no automatic processing")
return dataset
def split_dataset(dataset, config):
"""Create train/validation/test splits from a single dataset.
Defaults to 1% eval and 1% test if not specified.
"""
from datasets import Dataset
if not isinstance(dataset, Dataset):
# If it's already a DatasetDict, try to use its splits
try:
train_split = dataset["train"]
eval_split = dataset.get("validation") or dataset.get("eval")
test_split = dataset.get("test")
return train_split, eval_split, test_split
except Exception:
pass
eval_ratio = getattr(config, 'eval_ratio', 0.01)
test_ratio = getattr(config, 'test_ratio', 0.01)
# Clamp ratios to sane bounds
try:
eval_ratio = max(0.0, float(eval_ratio))
test_ratio = max(0.0, float(test_ratio))
if eval_ratio + test_ratio >= 0.9:
# Avoid extreme splits; cap combined at 0.2
scale = 0.2 / max(1e-9, (eval_ratio + test_ratio))
eval_ratio *= scale
test_ratio *= scale
except Exception:
eval_ratio, test_ratio = 0.01, 0.01
# No eval/test requested
if eval_ratio <= 0 and test_ratio <= 0:
return dataset, None, None
ds_shuffled = dataset.shuffle(seed=42)
# First carve out test split
if test_ratio > 0:
split1 = ds_shuffled.train_test_split(test_size=test_ratio, seed=42)
train_part = split1["train"]
test_split = split1["test"]
else:
train_part = ds_shuffled
test_split = None
# Then carve out eval from remaining train
if eval_ratio > 0:
remaining_fraction = 1.0 - test_ratio
# Convert global eval fraction to fraction of remaining pool
relative_eval = eval_ratio / remaining_fraction if remaining_fraction > 0 else eval_ratio
split2 = train_part.train_test_split(test_size=relative_eval, seed=42)
train_split = split2["train"]
eval_split = split2["test"]
else:
train_split = train_part
eval_split = None
# Log sizes
try:
print(f"Created splits -> train: {len(train_split)}, eval: {len(eval_split) if eval_split else 0}, test: {len(test_split) if test_split else 0}")
except Exception:
pass
return train_split, eval_split, test_split
def setup_trackio_tracking(config):
"""Setup Trackio tracking if enabled"""
if not config.enable_tracking or not config.trackio_url:
print("Trackio tracking disabled or URL not provided")
return None
print(f"Setting up Trackio tracking: {config.trackio_url}")
# Import the correct TrackioAPIClient
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'trackio_tonic'))
from trackio_api_client import TrackioAPIClient
# Initialize Trackio client using the correct API
trackio_client = TrackioAPIClient(
space_id=config.trackio_url,
hf_token=config.trackio_token
)
return trackio_client
def create_sft_config(config, output_dir):
"""Create enhanced SFTConfig for GPT-OSS training"""
print("Creating enhanced SFT configuration...")
# Helper coercion utilities to guarantee numeric types
def _as_int(value, default):
if value is None:
return int(default)
try:
return int(value)
except Exception:
return int(default)
def _as_float(value, default):
if value is None:
return float(default)
try:
return float(value)
except Exception:
return float(default)
# Extract training parameters from config with enhanced defaults and coercion
num_train_epochs = _as_float(getattr(config, 'num_train_epochs', 1.0), 1.0)
# Transformers expects max_steps default -1 (disabled). Some code compares > 0
raw_max_steps = getattr(config, 'max_steps', None)
max_steps = _as_int(raw_max_steps if raw_max_steps is not None else -1, -1)
warmup_ratio = _as_float(getattr(config, 'warmup_ratio', 0.03), 0.03)
# Ensure warmup_steps is an int; default 0 to avoid None comparisons in schedulers
warmup_steps = _as_int(getattr(config, 'warmup_steps', 0), 0)
# Learning rate configuration
learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr')
# Batch configuration
per_device_train_batch_size = _as_int(getattr(config, 'batch_size', 2), 2)
per_device_eval_batch_size = _as_int(getattr(config, 'eval_batch_size', per_device_train_batch_size), per_device_train_batch_size)
gradient_accumulation_steps = _as_int(getattr(config, 'gradient_accumulation_steps', 1), 1)
# Evaluation and logging
eval_strategy = getattr(config, 'eval_strategy', 'steps')
eval_steps = _as_int(getattr(config, 'eval_steps', 100), 100)
eval_accumulation_steps = _as_int(getattr(config, 'eval_accumulation_steps', 1), 1)
logging_steps = _as_int(getattr(config, 'logging_steps', 10), 10)
# Saving configuration
save_strategy = getattr(config, 'save_strategy', 'steps')
save_steps = _as_int(getattr(config, 'save_steps', 500), 500)
save_total_limit = _as_int(getattr(config, 'save_total_limit', 3), 3)
# Mixed precision
fp16 = bool(getattr(config, 'fp16', False))
bf16 = bool(getattr(config, 'bf16', True))
tf32 = bool(getattr(config, 'tf32', False))
# Regularization
weight_decay = _as_float(getattr(config, 'weight_decay', 0.01), 0.01)
max_grad_norm = _as_float(getattr(config, 'max_grad_norm', 1.0), 1.0)
# HuggingFace Hub integration
push_to_hub = getattr(config, 'push_to_hub', False)
print(f" • Epochs: {num_train_epochs}")
print(f" • Learning rate: {learning_rate}")
print(f" • Batch size: {per_device_train_batch_size}")
print(f" • Gradient accumulation: {gradient_accumulation_steps}")
print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}")
# Build kwargs dynamically to be compatible across transformers versions
ta_kwargs = {
# Training duration
"num_train_epochs": num_train_epochs,
"max_steps": max_steps,
# Learning rate
"learning_rate": learning_rate,
"lr_scheduler_type": lr_scheduler_type,
"warmup_ratio": warmup_ratio,
"warmup_steps": warmup_steps,
# Batch configuration
"per_device_train_batch_size": per_device_train_batch_size,
"per_device_eval_batch_size": per_device_eval_batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
# Model configuration
"gradient_checkpointing": getattr(config, 'use_gradient_checkpointing', True),
# Mixed precision
"fp16": fp16,
"bf16": bf16,
# Some versions support tf32
"tf32": tf32 if 'tf32' in TrainingArguments.__init__.__code__.co_varnames else None,
# Regularization
"weight_decay": weight_decay,
"max_grad_norm": max_grad_norm,
# Evaluation (name may vary across versions)
"evaluation_strategy": eval_strategy,
"eval_steps": eval_steps,
"eval_accumulation_steps": eval_accumulation_steps,
# Logging
"logging_steps": logging_steps,
# Saving
"save_strategy": save_strategy,
"save_steps": save_steps,
"save_total_limit": save_total_limit,
# Output
"output_dir": output_dir,
# Data loading
"dataloader_num_workers": _as_int(getattr(config, 'dataloader_num_workers', 4), 4),
"dataloader_pin_memory": getattr(config, 'dataloader_pin_memory', True),
# Optional in some versions
"dataloader_prefetch_factor": _as_int(getattr(config, 'dataloader_prefetch_factor', 2), 2),
# Performance
"group_by_length": getattr(config, 'group_by_length', True),
"remove_unused_columns": getattr(config, 'remove_unused_columns', True),
# HuggingFace Hub
"push_to_hub": push_to_hub,
# Monitoring
"report_to": ("trackio" if getattr(config, 'enable_tracking', False) else None),
}
# Drop any None-valued kwargs
ta_kwargs = {k: v for k, v in ta_kwargs.items() if v is not None}
# Adapt to transformers versions where 'evaluation_strategy' was renamed
try:
ta_sig = inspect.signature(TrainingArguments.__init__)
param_names = set(ta_sig.parameters.keys())
except Exception:
param_names = set()
if "evaluation_strategy" not in param_names and "eval_strategy" in param_names:
# Move value to 'eval_strategy'
ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy")
elif "evaluation_strategy" not in param_names:
# If neither is supported, drop it
ta_kwargs.pop("evaluation_strategy", None)
# Remove any kwargs not supported by current transformers version
if param_names:
unsupported = [k for k in ta_kwargs.keys() if k not in param_names]
for k in unsupported:
ta_kwargs.pop(k, None)
sft_config = TrainingArguments(**ta_kwargs)
return sft_config
def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer_type="sft"):
"""Main training function for GPT-OSS"""
print("=== GPT-OSS Training Pipeline ===")
print(f"Config: {config_path}")
print(f"Experiment: {experiment_name}")
print(f"Output: {output_dir}")
print(f"Trackio: {trackio_url}")
print(f"Trainer: {trainer_type}")
# Load configuration
if os.path.exists(config_path):
import importlib.util
spec = importlib.util.spec_from_file_location("config_module", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
if hasattr(config_module, 'config'):
config = config_module.config
else:
# Try to find a config class
for attr_name in dir(config_module):
attr = getattr(config_module, attr_name)
if hasattr(attr, 'model_name') and ('gpt_oss' in attr.model_name.lower() or 'GPTOSS' in attr_name):
config = attr
break
else:
raise ValueError(f"No GPT-OSS configuration found in {config_path}")
else:
raise FileNotFoundError(f"Configuration file not found: {config_path}")
# Update config with runtime parameters
config.experiment_name = experiment_name
config.trackio_url = trackio_url
config.trainer_type = trainer_type
# Load model and tokenizer
model, tokenizer = load_gpt_oss_model_and_tokenizer(config)
# Setup LoRA
peft_model = setup_lora_for_gpt_oss(model, config)
# Load dataset
dataset = load_dataset_from_config(config)
# Split into train/eval/test
train_dataset, eval_dataset, test_dataset = split_dataset(dataset, config)
# Setup Trackio tracking
trackio_client = setup_trackio_tracking(config)
# Create SFT configuration
sft_config = create_sft_config(config, output_dir)
# Create trainer with version-robust kwargs
if trainer_type == 'dpo':
if DPOTrainer is None:
raise RuntimeError("DPOTrainer is not available in this TRL version. Please upgrade 'trl'.")
print("Creating DPO trainer...")
try:
dpo_sig = inspect.signature(DPOTrainer.__init__)
dpo_params = set(dpo_sig.parameters.keys())
except Exception:
dpo_params = {"model", "args", "train_dataset", "tokenizer", "beta", "prompt_column", "chosen_column", "rejected_column"}
dpo_kwargs = {
"model": peft_model,
"args": sft_config,
"train_dataset": train_dataset,
"beta": getattr(config, 'dpo_beta', 0.1),
}
if "tokenizer" in dpo_params:
dpo_kwargs["tokenizer"] = tokenizer
elif "processing_class" in dpo_params:
dpo_kwargs["processing_class"] = tokenizer
if "prompt_column" in dpo_params:
dpo_kwargs["prompt_column"] = "prompt"
if "chosen_column" in dpo_params:
dpo_kwargs["chosen_column"] = "chosen"
if "rejected_column" in dpo_params:
dpo_kwargs["rejected_column"] = "rejected"
# Remove Nones
dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if v is not None}
# Pass eval dataset if supported
if "eval_dataset" in dpo_params and eval_dataset is not None:
dpo_kwargs["eval_dataset"] = eval_dataset
trainer = DPOTrainer(**dpo_kwargs)
else:
print("Creating SFT trainer...")
try:
sft_sig = inspect.signature(SFTTrainer.__init__)
sft_params = set(sft_sig.parameters.keys())
except Exception:
sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"}
sft_kwargs = {
"model": peft_model,
"args": sft_config,
"train_dataset": train_dataset,
}
# Prefer passing tokenizer if supported; otherwise try processing_class
if "tokenizer" in sft_params:
sft_kwargs["tokenizer"] = tokenizer
elif "processing_class" in sft_params:
sft_kwargs["processing_class"] = tokenizer
# Pass dataset text field if supported (we produced a 'text' column)
if "dataset_text_field" in sft_params:
sft_kwargs["dataset_text_field"] = "text"
# Pass max sequence length if supported
if "max_seq_length" in sft_params:
sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048)
# Remove any None values
sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None}
# Attach eval_dataset if supported
if "eval_dataset" in sft_params and eval_dataset is not None:
sft_kwargs["eval_dataset"] = eval_dataset
trainer = SFTTrainer(**sft_kwargs)
# Start training
print("Starting GPT-OSS training...")
trainer.train()
# Save model
print("Saving trained model...")
trainer.save_model(output_dir)
# Push to hub if enabled
if sft_config.push_to_hub:
print("Pushing model to Hugging Face Hub...")
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
print("GPT-OSS training completed successfully!")
return trainer
def main():
parser = argparse.ArgumentParser(description="GPT-OSS Training Script")
parser.add_argument("--config", required=True, help="Path to configuration file")
parser.add_argument("--experiment-name", required=True, help="Experiment name")
parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints")
parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type")
args = parser.parse_args()
# Validate arguments
if not os.path.exists(args.config):
print(f"Error: Configuration file not found: {args.config}")
sys.exit(1)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
try:
train_gpt_oss(
config_path=args.config,
experiment_name=args.experiment_name,
output_dir=args.output_dir,
trackio_url=args.trackio_url,
trainer_type=args.trainer_type
)
except Exception as e:
print(f"Error during training: {e}")
sys.exit(1)
if __name__ == "__main__":
main()