Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
GPT-OSS Training Script | |
Specialized training script for OpenAI's GPT-OSS models | |
Based on the GPT-OSS fine-tuning tutorial | |
""" | |
import os | |
import sys | |
import argparse | |
import inspect | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments | |
from peft import LoraConfig, get_peft_model | |
from trl import SFTTrainer | |
try: | |
from trl import DPOTrainer | |
except Exception: # pragma: no cover - optional import depending on TRL version | |
DPOTrainer = None | |
from datasets import load_dataset | |
from pathlib import Path | |
# Ensure project root and config package are importable for configs that do `from config...` imports | |
project_root = Path(__file__).resolve().parents[2] | |
if str(project_root) not in sys.path: | |
sys.path.insert(0, str(project_root)) | |
config_dir = project_root / "config" | |
if str(config_dir) not in sys.path: | |
sys.path.insert(0, str(config_dir)) | |
def load_gpt_oss_model_and_tokenizer(config): | |
"""Load GPT-OSS model and tokenizer with proper configuration""" | |
print("Loading GPT-OSS tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
print("Loading GPT-OSS model with quantization...") | |
# Import quantization config | |
from transformers import BitsAndBytesConfig | |
# Set up quantization config based on config | |
if config.quantization_config and config.quantization_config.get("load_in_4bit"): | |
# Use BitsAndBytesConfig for 4-bit quantization (memory optimized) | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
elif config.quantization_config and config.quantization_config.get("dequantize"): | |
# Try to use Mxfp4Config if available (as per tutorial) | |
try: | |
from transformers import Mxfp4Config | |
quantization_config = Mxfp4Config(dequantize=True) | |
except ImportError: | |
# Fallback to no quantization if Mxfp4Config not available | |
print("Warning: Mxfp4Config not available, using no quantization") | |
quantization_config = None | |
else: | |
# No quantization | |
quantization_config = None | |
# Model kwargs as per tutorial | |
model_kwargs = { | |
"attn_implementation": "eager", | |
"torch_dtype": torch.bfloat16, | |
"use_cache": False, | |
"device_map": "auto", | |
} | |
# Only add quantization_config if it's not None | |
if quantization_config is not None: | |
model_kwargs["quantization_config"] = quantization_config | |
model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs) | |
return model, tokenizer | |
def setup_lora_for_gpt_oss(model, config): | |
"""Setup LoRA for GPT-OSS model""" | |
print("Setting up LoRA for GPT-OSS...") | |
# LoRA configuration as per tutorial | |
lora_config = LoraConfig( | |
r=config.lora_config.get("r", 8) if config.lora_config else 8, | |
lora_alpha=config.lora_config.get("lora_alpha", 16) if config.lora_config else 16, | |
target_modules=config.lora_config.get("target_modules", "all-linear") if config.lora_config else "all-linear", | |
target_parameters=config.lora_config.get("target_parameters", [ | |
"7.mlp.experts.gate_up_proj", | |
"7.mlp.experts.down_proj", | |
"15.mlp.experts.gate_up_proj", | |
"15.mlp.experts.down_proj", | |
"23.mlp.experts.gate_up_proj", | |
"23.mlp.experts.down_proj", | |
]) if config.lora_config else [ | |
"7.mlp.experts.gate_up_proj", | |
"7.mlp.experts.down_proj", | |
"15.mlp.experts.gate_up_proj", | |
"15.mlp.experts.down_proj", | |
"23.mlp.experts.gate_up_proj", | |
"23.mlp.experts.down_proj", | |
], | |
) | |
peft_model = get_peft_model(model, lora_config) | |
peft_model.print_trainable_parameters() | |
return peft_model | |
def load_dataset_from_config(config): | |
"""Load dataset based on configuration""" | |
dataset_name = getattr(config, 'dataset_name', 'HuggingFaceH4/Multilingual-Thinking') | |
dataset_split = getattr(config, 'dataset_split', 'train') | |
dataset_config = getattr(config, 'dataset_config', None) | |
print(f"Loading dataset: {dataset_name}") | |
print(f"Dataset split: {dataset_split}") | |
if dataset_config: | |
print(f"Dataset config: {dataset_config}") | |
# Load the dataset | |
if dataset_config: | |
dataset = load_dataset(dataset_name, dataset_config, split=dataset_split) | |
else: | |
dataset = load_dataset(dataset_name, split=dataset_split) | |
print(f"Original dataset size: {len(dataset)} examples") | |
# Apply filtering based on configuration | |
dataset = apply_dataset_filtering(dataset, config) | |
# Apply dataset processing based on format | |
dataset = process_dataset_format(dataset, config) | |
print(f"Final dataset size: {len(dataset)} examples") | |
return dataset | |
def apply_dataset_filtering(dataset, config): | |
"""Apply filtering based on configuration""" | |
# Filter bad entries if specified | |
if getattr(config, 'filter_bad_entries', False): | |
bad_entry_field = getattr(config, 'bad_entry_field', 'bad_entry') | |
bad_prompt_field = getattr(config, 'bad_prompt_field', 'bad_prompt_detected') | |
bad_response_field = getattr(config, 'bad_response_field', 'bad_response_detected') | |
original_size = len(dataset) | |
# Filter out bad entries | |
if bad_entry_field in dataset.column_names: | |
dataset = dataset.filter(lambda x: not x.get(bad_entry_field, False)) | |
print(f"Filtered {original_size - len(dataset)} bad entries") | |
# Filter out bad prompts | |
if bad_prompt_field in dataset.column_names: | |
dataset = dataset.filter(lambda x: not x.get(bad_prompt_field, False)) | |
print(f"Filtered bad prompts, remaining: {len(dataset)} examples") | |
# Filter out bad responses | |
if bad_response_field in dataset.column_names: | |
dataset = dataset.filter(lambda x: not x.get(bad_response_field, False)) | |
print(f"Filtered bad responses, remaining: {len(dataset)} examples") | |
# Apply length filtering | |
min_length = getattr(config, 'min_length', 10) | |
max_length = getattr(config, 'max_length', None) | |
input_field = getattr(config, 'input_field', 'prompt') | |
target_field = getattr(config, 'target_field', 'accepted_completion') | |
if min_length > 0 or max_length: | |
def length_filter(example): | |
input_len = len(example.get(input_field, '')) | |
target_len = len(example.get(target_field, '')) | |
total_len = input_len + target_len | |
if total_len < min_length: | |
return False | |
if max_length and total_len > max_length: | |
return False | |
return True | |
original_size = len(dataset) | |
dataset = dataset.filter(length_filter) | |
print(f"Length filtering: {original_size} -> {len(dataset)} examples") | |
# Apply sampling if specified | |
max_samples = getattr(config, 'max_samples', None) | |
if max_samples and len(dataset) > max_samples: | |
dataset = dataset.shuffle(seed=42).select(range(max_samples)) | |
print(f"Sampled {max_samples} examples from dataset") | |
return dataset | |
def format_gpt_oss_harmony(prompt, completion, add_eos_token=True): | |
""" | |
Format data for GPT-OSS Harmony format following the exact template structure. | |
Based on: https://huggingface.co/openai/gpt-oss-20b/raw/main/chat_template.jinja | |
""" | |
# GPT-OSS Harmony format structure (exact template compliance) | |
# User message: <|start|>user<|message|>content<|end|> | |
# Assistant message: <|start|>assistant<|channel|>final<|message|>content<|end|> (inference) | |
# Assistant message: <|start|>assistant<|channel|>final<|message|>content<|return|> (training) | |
harmony_text = f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>{completion}" | |
if add_eos_token: | |
# Use <|return|> for training as per template specification | |
# This indicates the end of generation in training | |
harmony_text += "<|return|>" | |
else: | |
# Use <|end|> for inference | |
harmony_text += "<|end|>" | |
return harmony_text | |
def format_gpt_oss_harmony_prompt(prompt: str) -> str: | |
"""Prefix-only Harmony prompt up to assistant content marker for DPO.""" | |
return f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>" | |
def process_dataset_format(dataset, config): | |
"""Process dataset based on format configuration with exact GPT-OSS Harmony compliance""" | |
dataset_format = getattr(config, 'dataset_format', 'openhermes_fr') | |
input_field = getattr(config, 'input_field', 'prompt') | |
target_field = getattr(config, 'target_field', 'accepted_completion') | |
concatenate_fields = getattr(config, 'concatenate_fields', True) | |
field_separator = getattr(config, 'field_separator', '\n\n### Response:\n') | |
add_eos_token = getattr(config, 'add_eos_token', True) | |
use_harmony_format = getattr(config, 'use_harmony_format', True) | |
trainer_type = getattr(config, 'trainer_type', 'sft') | |
print(f"Processing dataset format: {dataset_format}") | |
print(f"Input field: {input_field}, Target field: {target_field}") | |
print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}") | |
# Preference-format for DPO training (chosen/rejected pairs) | |
if trainer_type == 'dpo': | |
chosen_field = getattr(config, 'chosen_field', None) | |
rejected_field = getattr(config, 'rejected_field', None) | |
if dataset_format == 'preference': | |
# Expect columns present; optionally reformat to ensure only necessary columns | |
def id_map(example): | |
prompt_val = example.get(input_field, '') | |
chosen_val = example.get('chosen', example.get(chosen_field or 'chosen', '')) | |
rejected_val = example.get('rejected', example.get(rejected_field or 'rejected', '')) | |
if use_harmony_format: | |
prompt_text = format_gpt_oss_harmony_prompt(prompt_val) | |
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '') | |
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '') | |
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text} | |
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val} | |
keep_cols = [c for c in ['prompt', 'chosen', 'rejected'] if c in dataset.column_names] | |
dataset = dataset.map(id_map, remove_columns=dataset.column_names if keep_cols else dataset.column_names) | |
return dataset | |
# Custom preference mapping via configured field names | |
if chosen_field and rejected_field: | |
def to_pref(example): | |
prompt_val = example.get(input_field, '') | |
chosen_val = example.get(chosen_field, '') | |
rejected_val = example.get(rejected_field, '') | |
if use_harmony_format: | |
prompt_text = format_gpt_oss_harmony_prompt(prompt_val) | |
chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '') | |
rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '') | |
return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text} | |
return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val} | |
dataset = dataset.map(to_pref, remove_columns=dataset.column_names) | |
return dataset | |
# If we reach here, we don't have required fields for DPO | |
raise ValueError("DPO training requires preference data. Please set dataset_format='preference' with 'prompt', 'chosen', 'rejected' columns, or specify 'chosen_field' and 'rejected_field' in the config.") | |
if dataset_format == "openhermes_fr": | |
# Process OpenHermes-FR format: prompt + accepted_completion | |
def format_openhermes_fr(example): | |
prompt = example.get(input_field, '') | |
completion = example.get(target_field, '') | |
if concatenate_fields: | |
if use_harmony_format: | |
# Use exact GPT-OSS Harmony format from template | |
text = format_gpt_oss_harmony(prompt, completion, add_eos_token) | |
else: | |
# Fallback to standard format with separator | |
text = prompt + field_separator + completion | |
if add_eos_token: | |
text += "</s>" | |
return {"text": text} | |
else: | |
# Keep separate for more advanced training setups | |
return { | |
"input": prompt, | |
"output": completion | |
} | |
dataset = dataset.map(format_openhermes_fr, remove_columns=dataset.column_names) | |
elif dataset_format == "messages": | |
# Process messages format (like HuggingFaceH4/Multilingual-Thinking) | |
def format_messages(example): | |
messages = example.get(input_field, []) | |
if use_harmony_format and len(messages) >= 2: | |
# Extract user and assistant messages for harmony format | |
user_message = "" | |
assistant_message = "" | |
for message in messages: | |
role = message.get("role", "") | |
content = message.get("content", "") | |
if role == "user": | |
user_message = content | |
elif role == "assistant": | |
assistant_message = content | |
if user_message and assistant_message: | |
# Use GPT-OSS Harmony format | |
text = format_gpt_oss_harmony(user_message, assistant_message, add_eos_token) | |
else: | |
# Fallback to simple concatenation | |
text = "" | |
for message in messages: | |
role = message.get("role", "") | |
content = message.get("content", "") | |
text += f"{role}: {content}\n" | |
if add_eos_token: | |
text += "</s>" | |
else: | |
# Standard format - convert messages to simple text | |
text = "" | |
for message in messages: | |
role = message.get("role", "") | |
content = message.get("content", "") | |
text += f"{role}: {content}\n" | |
if add_eos_token: | |
text += "</s>" | |
return {"text": text} | |
dataset = dataset.map(format_messages, remove_columns=dataset.column_names) | |
elif dataset_format == "text": | |
# Process plain text format | |
text_field = input_field | |
def format_text(example): | |
text = example.get(text_field, '') | |
if add_eos_token: | |
text += "</s>" | |
return {"text": text} | |
dataset = dataset.map(format_text, remove_columns=dataset.column_names) | |
elif dataset_format == "custom": | |
# Custom format - user handles this in their config | |
print("Using custom dataset format - no automatic processing") | |
return dataset | |
def split_dataset(dataset, config): | |
"""Create train/validation/test splits from a single dataset. | |
Defaults to 1% eval and 1% test if not specified. | |
""" | |
from datasets import Dataset | |
if not isinstance(dataset, Dataset): | |
# If it's already a DatasetDict, try to use its splits | |
try: | |
train_split = dataset["train"] | |
eval_split = dataset.get("validation") or dataset.get("eval") | |
test_split = dataset.get("test") | |
return train_split, eval_split, test_split | |
except Exception: | |
pass | |
eval_ratio = getattr(config, 'eval_ratio', 0.01) | |
test_ratio = getattr(config, 'test_ratio', 0.01) | |
# Clamp ratios to sane bounds | |
try: | |
eval_ratio = max(0.0, float(eval_ratio)) | |
test_ratio = max(0.0, float(test_ratio)) | |
if eval_ratio + test_ratio >= 0.9: | |
# Avoid extreme splits; cap combined at 0.2 | |
scale = 0.2 / max(1e-9, (eval_ratio + test_ratio)) | |
eval_ratio *= scale | |
test_ratio *= scale | |
except Exception: | |
eval_ratio, test_ratio = 0.01, 0.01 | |
# No eval/test requested | |
if eval_ratio <= 0 and test_ratio <= 0: | |
return dataset, None, None | |
ds_shuffled = dataset.shuffle(seed=42) | |
# First carve out test split | |
if test_ratio > 0: | |
split1 = ds_shuffled.train_test_split(test_size=test_ratio, seed=42) | |
train_part = split1["train"] | |
test_split = split1["test"] | |
else: | |
train_part = ds_shuffled | |
test_split = None | |
# Then carve out eval from remaining train | |
if eval_ratio > 0: | |
remaining_fraction = 1.0 - test_ratio | |
# Convert global eval fraction to fraction of remaining pool | |
relative_eval = eval_ratio / remaining_fraction if remaining_fraction > 0 else eval_ratio | |
split2 = train_part.train_test_split(test_size=relative_eval, seed=42) | |
train_split = split2["train"] | |
eval_split = split2["test"] | |
else: | |
train_split = train_part | |
eval_split = None | |
# Log sizes | |
try: | |
print(f"Created splits -> train: {len(train_split)}, eval: {len(eval_split) if eval_split else 0}, test: {len(test_split) if test_split else 0}") | |
except Exception: | |
pass | |
return train_split, eval_split, test_split | |
def setup_trackio_tracking(config): | |
"""Setup Trackio tracking if enabled""" | |
if not config.enable_tracking or not config.trackio_url: | |
print("Trackio tracking disabled or URL not provided") | |
return None | |
print(f"Setting up Trackio tracking: {config.trackio_url}") | |
# Import the correct TrackioAPIClient | |
import sys | |
import os | |
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'trackio_tonic')) | |
from trackio_api_client import TrackioAPIClient | |
# Initialize Trackio client using the correct API | |
trackio_client = TrackioAPIClient( | |
space_id=config.trackio_url, | |
hf_token=config.trackio_token | |
) | |
return trackio_client | |
def create_sft_config(config, output_dir): | |
"""Create enhanced SFTConfig for GPT-OSS training""" | |
print("Creating enhanced SFT configuration...") | |
# Helper coercion utilities to guarantee numeric types | |
def _as_int(value, default): | |
if value is None: | |
return int(default) | |
try: | |
return int(value) | |
except Exception: | |
return int(default) | |
def _as_float(value, default): | |
if value is None: | |
return float(default) | |
try: | |
return float(value) | |
except Exception: | |
return float(default) | |
# Extract training parameters from config with enhanced defaults and coercion | |
num_train_epochs = _as_float(getattr(config, 'num_train_epochs', 1.0), 1.0) | |
# Transformers expects max_steps default -1 (disabled). Some code compares > 0 | |
raw_max_steps = getattr(config, 'max_steps', None) | |
max_steps = _as_int(raw_max_steps if raw_max_steps is not None else -1, -1) | |
warmup_ratio = _as_float(getattr(config, 'warmup_ratio', 0.03), 0.03) | |
# Ensure warmup_steps is an int; default 0 to avoid None comparisons in schedulers | |
warmup_steps = _as_int(getattr(config, 'warmup_steps', 0), 0) | |
# Learning rate configuration | |
learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4) | |
lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr') | |
# Batch configuration | |
per_device_train_batch_size = _as_int(getattr(config, 'batch_size', 2), 2) | |
per_device_eval_batch_size = _as_int(getattr(config, 'eval_batch_size', per_device_train_batch_size), per_device_train_batch_size) | |
gradient_accumulation_steps = _as_int(getattr(config, 'gradient_accumulation_steps', 1), 1) | |
# Evaluation and logging | |
eval_strategy = getattr(config, 'eval_strategy', 'steps') | |
eval_steps = _as_int(getattr(config, 'eval_steps', 100), 100) | |
eval_accumulation_steps = _as_int(getattr(config, 'eval_accumulation_steps', 1), 1) | |
logging_steps = _as_int(getattr(config, 'logging_steps', 10), 10) | |
# Saving configuration | |
save_strategy = getattr(config, 'save_strategy', 'steps') | |
save_steps = _as_int(getattr(config, 'save_steps', 500), 500) | |
save_total_limit = _as_int(getattr(config, 'save_total_limit', 3), 3) | |
# Mixed precision | |
fp16 = bool(getattr(config, 'fp16', False)) | |
bf16 = bool(getattr(config, 'bf16', True)) | |
tf32 = bool(getattr(config, 'tf32', False)) | |
# Regularization | |
weight_decay = _as_float(getattr(config, 'weight_decay', 0.01), 0.01) | |
max_grad_norm = _as_float(getattr(config, 'max_grad_norm', 1.0), 1.0) | |
# HuggingFace Hub integration | |
push_to_hub = getattr(config, 'push_to_hub', False) | |
print(f" • Epochs: {num_train_epochs}") | |
print(f" • Learning rate: {learning_rate}") | |
print(f" • Batch size: {per_device_train_batch_size}") | |
print(f" • Gradient accumulation: {gradient_accumulation_steps}") | |
print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}") | |
# Build kwargs dynamically to be compatible across transformers versions | |
ta_kwargs = { | |
# Training duration | |
"num_train_epochs": num_train_epochs, | |
"max_steps": max_steps, | |
# Learning rate | |
"learning_rate": learning_rate, | |
"lr_scheduler_type": lr_scheduler_type, | |
"warmup_ratio": warmup_ratio, | |
"warmup_steps": warmup_steps, | |
# Batch configuration | |
"per_device_train_batch_size": per_device_train_batch_size, | |
"per_device_eval_batch_size": per_device_eval_batch_size, | |
"gradient_accumulation_steps": gradient_accumulation_steps, | |
# Model configuration | |
"gradient_checkpointing": getattr(config, 'use_gradient_checkpointing', True), | |
# Mixed precision | |
"fp16": fp16, | |
"bf16": bf16, | |
# Some versions support tf32 | |
"tf32": tf32 if 'tf32' in TrainingArguments.__init__.__code__.co_varnames else None, | |
# Regularization | |
"weight_decay": weight_decay, | |
"max_grad_norm": max_grad_norm, | |
# Evaluation (name may vary across versions) | |
"evaluation_strategy": eval_strategy, | |
"eval_steps": eval_steps, | |
"eval_accumulation_steps": eval_accumulation_steps, | |
# Logging | |
"logging_steps": logging_steps, | |
# Saving | |
"save_strategy": save_strategy, | |
"save_steps": save_steps, | |
"save_total_limit": save_total_limit, | |
# Output | |
"output_dir": output_dir, | |
# Data loading | |
"dataloader_num_workers": _as_int(getattr(config, 'dataloader_num_workers', 4), 4), | |
"dataloader_pin_memory": getattr(config, 'dataloader_pin_memory', True), | |
# Optional in some versions | |
"dataloader_prefetch_factor": _as_int(getattr(config, 'dataloader_prefetch_factor', 2), 2), | |
# Performance | |
"group_by_length": getattr(config, 'group_by_length', True), | |
"remove_unused_columns": getattr(config, 'remove_unused_columns', True), | |
# HuggingFace Hub | |
"push_to_hub": push_to_hub, | |
# Monitoring | |
"report_to": ("trackio" if getattr(config, 'enable_tracking', False) else None), | |
} | |
# Drop any None-valued kwargs | |
ta_kwargs = {k: v for k, v in ta_kwargs.items() if v is not None} | |
# Adapt to transformers versions where 'evaluation_strategy' was renamed | |
try: | |
ta_sig = inspect.signature(TrainingArguments.__init__) | |
param_names = set(ta_sig.parameters.keys()) | |
except Exception: | |
param_names = set() | |
if "evaluation_strategy" not in param_names and "eval_strategy" in param_names: | |
# Move value to 'eval_strategy' | |
ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy") | |
elif "evaluation_strategy" not in param_names: | |
# If neither is supported, drop it | |
ta_kwargs.pop("evaluation_strategy", None) | |
# Remove any kwargs not supported by current transformers version | |
if param_names: | |
unsupported = [k for k in ta_kwargs.keys() if k not in param_names] | |
for k in unsupported: | |
ta_kwargs.pop(k, None) | |
sft_config = TrainingArguments(**ta_kwargs) | |
return sft_config | |
def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer_type="sft"): | |
"""Main training function for GPT-OSS""" | |
print("=== GPT-OSS Training Pipeline ===") | |
print(f"Config: {config_path}") | |
print(f"Experiment: {experiment_name}") | |
print(f"Output: {output_dir}") | |
print(f"Trackio: {trackio_url}") | |
print(f"Trainer: {trainer_type}") | |
# Load configuration | |
if os.path.exists(config_path): | |
import importlib.util | |
spec = importlib.util.spec_from_file_location("config_module", config_path) | |
config_module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(config_module) | |
if hasattr(config_module, 'config'): | |
config = config_module.config | |
else: | |
# Try to find a config class | |
for attr_name in dir(config_module): | |
attr = getattr(config_module, attr_name) | |
if hasattr(attr, 'model_name') and ('gpt_oss' in attr.model_name.lower() or 'GPTOSS' in attr_name): | |
config = attr | |
break | |
else: | |
raise ValueError(f"No GPT-OSS configuration found in {config_path}") | |
else: | |
raise FileNotFoundError(f"Configuration file not found: {config_path}") | |
# Update config with runtime parameters | |
config.experiment_name = experiment_name | |
config.trackio_url = trackio_url | |
config.trainer_type = trainer_type | |
# Load model and tokenizer | |
model, tokenizer = load_gpt_oss_model_and_tokenizer(config) | |
# Setup LoRA | |
peft_model = setup_lora_for_gpt_oss(model, config) | |
# Load dataset | |
dataset = load_dataset_from_config(config) | |
# Split into train/eval/test | |
train_dataset, eval_dataset, test_dataset = split_dataset(dataset, config) | |
# Setup Trackio tracking | |
trackio_client = setup_trackio_tracking(config) | |
# Create SFT configuration | |
sft_config = create_sft_config(config, output_dir) | |
# Create trainer with version-robust kwargs | |
if trainer_type == 'dpo': | |
if DPOTrainer is None: | |
raise RuntimeError("DPOTrainer is not available in this TRL version. Please upgrade 'trl'.") | |
print("Creating DPO trainer...") | |
try: | |
dpo_sig = inspect.signature(DPOTrainer.__init__) | |
dpo_params = set(dpo_sig.parameters.keys()) | |
except Exception: | |
dpo_params = {"model", "args", "train_dataset", "tokenizer", "beta", "prompt_column", "chosen_column", "rejected_column"} | |
dpo_kwargs = { | |
"model": peft_model, | |
"args": sft_config, | |
"train_dataset": train_dataset, | |
"beta": getattr(config, 'dpo_beta', 0.1), | |
} | |
if "tokenizer" in dpo_params: | |
dpo_kwargs["tokenizer"] = tokenizer | |
elif "processing_class" in dpo_params: | |
dpo_kwargs["processing_class"] = tokenizer | |
if "prompt_column" in dpo_params: | |
dpo_kwargs["prompt_column"] = "prompt" | |
if "chosen_column" in dpo_params: | |
dpo_kwargs["chosen_column"] = "chosen" | |
if "rejected_column" in dpo_params: | |
dpo_kwargs["rejected_column"] = "rejected" | |
# Remove Nones | |
dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if v is not None} | |
# Pass eval dataset if supported | |
if "eval_dataset" in dpo_params and eval_dataset is not None: | |
dpo_kwargs["eval_dataset"] = eval_dataset | |
trainer = DPOTrainer(**dpo_kwargs) | |
else: | |
print("Creating SFT trainer...") | |
try: | |
sft_sig = inspect.signature(SFTTrainer.__init__) | |
sft_params = set(sft_sig.parameters.keys()) | |
except Exception: | |
sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"} | |
sft_kwargs = { | |
"model": peft_model, | |
"args": sft_config, | |
"train_dataset": train_dataset, | |
} | |
# Prefer passing tokenizer if supported; otherwise try processing_class | |
if "tokenizer" in sft_params: | |
sft_kwargs["tokenizer"] = tokenizer | |
elif "processing_class" in sft_params: | |
sft_kwargs["processing_class"] = tokenizer | |
# Pass dataset text field if supported (we produced a 'text' column) | |
if "dataset_text_field" in sft_params: | |
sft_kwargs["dataset_text_field"] = "text" | |
# Pass max sequence length if supported | |
if "max_seq_length" in sft_params: | |
sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048) | |
# Remove any None values | |
sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None} | |
# Attach eval_dataset if supported | |
if "eval_dataset" in sft_params and eval_dataset is not None: | |
sft_kwargs["eval_dataset"] = eval_dataset | |
trainer = SFTTrainer(**sft_kwargs) | |
# Start training | |
print("Starting GPT-OSS training...") | |
trainer.train() | |
# Save model | |
print("Saving trained model...") | |
trainer.save_model(output_dir) | |
# Push to hub if enabled | |
if sft_config.push_to_hub: | |
print("Pushing model to Hugging Face Hub...") | |
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking") | |
print("GPT-OSS training completed successfully!") | |
return trainer | |
def main(): | |
parser = argparse.ArgumentParser(description="GPT-OSS Training Script") | |
parser.add_argument("--config", required=True, help="Path to configuration file") | |
parser.add_argument("--experiment-name", required=True, help="Experiment name") | |
parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints") | |
parser.add_argument("--trackio-url", help="Trackio URL for monitoring") | |
parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type") | |
args = parser.parse_args() | |
# Validate arguments | |
if not os.path.exists(args.config): | |
print(f"Error: Configuration file not found: {args.config}") | |
sys.exit(1) | |
# Create output directory | |
os.makedirs(args.output_dir, exist_ok=True) | |
try: | |
train_gpt_oss( | |
config_path=args.config, | |
experiment_name=args.experiment_name, | |
output_dir=args.output_dir, | |
trackio_url=args.trackio_url, | |
trainer_type=args.trainer_type | |
) | |
except Exception as e: | |
print(f"Error during training: {e}") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |