Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
GPT-OSS Training Script | |
Specialized training script for OpenAI's GPT-OSS models | |
Based on the GPT-OSS fine-tuning tutorial | |
""" | |
import os | |
import sys | |
import argparse | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import LoraConfig, get_peft_model | |
from trl import SFTTrainer, SFTConfig | |
from datasets import load_dataset | |
def load_gpt_oss_model_and_tokenizer(config): | |
"""Load GPT-OSS model and tokenizer with proper configuration""" | |
print("Loading GPT-OSS tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
print("Loading GPT-OSS model with quantization...") | |
# Import quantization config | |
from transformers import BitsAndBytesConfig | |
# Set up quantization config based on config | |
if config.quantization_config and config.quantization_config.get("load_in_4bit"): | |
# Use BitsAndBytesConfig for 4-bit quantization (memory optimized) | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
elif config.quantization_config and config.quantization_config.get("dequantize"): | |
# Try to use Mxfp4Config if available (as per tutorial) | |
try: | |
from transformers import Mxfp4Config | |
quantization_config = Mxfp4Config(dequantize=True) | |
except ImportError: | |
# Fallback to no quantization if Mxfp4Config not available | |
print("Warning: Mxfp4Config not available, using no quantization") | |
quantization_config = None | |
else: | |
# No quantization | |
quantization_config = None | |
# Model kwargs as per tutorial | |
model_kwargs = { | |
"attn_implementation": "eager", | |
"torch_dtype": torch.bfloat16, | |
"use_cache": False, | |
"device_map": "auto", | |
} | |
# Only add quantization_config if it's not None | |
if quantization_config is not None: | |
model_kwargs["quantization_config"] = quantization_config | |
model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs) | |
return model, tokenizer | |
def setup_lora_for_gpt_oss(model, config): | |
"""Setup LoRA for GPT-OSS model""" | |
print("Setting up LoRA for GPT-OSS...") | |
# LoRA configuration as per tutorial | |
lora_config = LoraConfig( | |
r=config.lora_config.get("r", 8) if config.lora_config else 8, | |
lora_alpha=config.lora_config.get("lora_alpha", 16) if config.lora_config else 16, | |
target_modules=config.lora_config.get("target_modules", "all-linear") if config.lora_config else "all-linear", | |
target_parameters=config.lora_config.get("target_parameters", [ | |
"7.mlp.experts.gate_up_proj", | |
"7.mlp.experts.down_proj", | |
"15.mlp.experts.gate_up_proj", | |
"15.mlp.experts.down_proj", | |
"23.mlp.experts.gate_up_proj", | |
"23.mlp.experts.down_proj", | |
]) if config.lora_config else [ | |
"7.mlp.experts.gate_up_proj", | |
"7.mlp.experts.down_proj", | |
"15.mlp.experts.gate_up_proj", | |
"15.mlp.experts.down_proj", | |
"23.mlp.experts.gate_up_proj", | |
"23.mlp.experts.down_proj", | |
], | |
) | |
peft_model = get_peft_model(model, lora_config) | |
peft_model.print_trainable_parameters() | |
return peft_model | |
def load_multilingual_thinking_dataset(): | |
"""Load the Multilingual-Thinking dataset""" | |
print("Loading Multilingual-Thinking dataset...") | |
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train") | |
print(f"Dataset loaded: {len(dataset)} examples") | |
return dataset | |
def setup_trackio_tracking(config): | |
"""Setup Trackio tracking if enabled""" | |
if not config.enable_tracking or not config.trackio_url: | |
print("Trackio tracking disabled or URL not provided") | |
return None | |
print(f"Setting up Trackio tracking: {config.trackio_url}") | |
# Import the correct TrackioAPIClient | |
import sys | |
import os | |
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'trackio_tonic')) | |
from trackio_api_client import TrackioAPIClient | |
# Initialize Trackio client using the correct API | |
trackio_client = TrackioAPIClient( | |
space_id=config.trackio_url, | |
hf_token=config.trackio_token | |
) | |
return trackio_client | |
def create_sft_config(config): | |
"""Create SFTConfig for GPT-OSS training""" | |
print("Creating SFT configuration...") | |
sft_config = SFTConfig( | |
learning_rate=config.learning_rate, | |
gradient_checkpointing=True, | |
num_train_epochs=1, # Single epoch as per tutorial | |
logging_steps=config.logging_steps, | |
per_device_train_batch_size=config.batch_size, | |
gradient_accumulation_steps=config.gradient_accumulation_steps, | |
max_length=config.max_seq_length, | |
warmup_ratio=0.03, | |
lr_scheduler_type="cosine_with_min_lr", | |
lr_scheduler_kwargs={"min_lr_rate": 0.1}, | |
output_dir="gpt-oss-20b-multilingual-reasoner", | |
report_to="trackio" if config.enable_tracking else None, | |
push_to_hub=True, | |
) | |
return sft_config | |
def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer_type="sft"): | |
"""Main training function for GPT-OSS""" | |
print("=== GPT-OSS Training Pipeline ===") | |
print(f"Config: {config_path}") | |
print(f"Experiment: {experiment_name}") | |
print(f"Output: {output_dir}") | |
print(f"Trackio: {trackio_url}") | |
print(f"Trainer: {trainer_type}") | |
# Load configuration | |
if os.path.exists(config_path): | |
import importlib.util | |
spec = importlib.util.spec_from_file_location("config_module", config_path) | |
config_module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(config_module) | |
if hasattr(config_module, 'config'): | |
config = config_module.config | |
else: | |
# Try to find a config class | |
for attr_name in dir(config_module): | |
attr = getattr(config_module, attr_name) | |
if hasattr(attr, 'model_name') and ('gpt_oss' in attr.model_name.lower() or 'GPTOSS' in attr_name): | |
config = attr | |
break | |
else: | |
raise ValueError(f"No GPT-OSS configuration found in {config_path}") | |
else: | |
raise FileNotFoundError(f"Configuration file not found: {config_path}") | |
# Update config with runtime parameters | |
config.experiment_name = experiment_name | |
config.trackio_url = trackio_url | |
config.trainer_type = trainer_type | |
# Load model and tokenizer | |
model, tokenizer = load_gpt_oss_model_and_tokenizer(config) | |
# Setup LoRA | |
peft_model = setup_lora_for_gpt_oss(model, config) | |
# Load dataset | |
dataset = load_multilingual_thinking_dataset() | |
# Setup Trackio tracking | |
trackio_client = setup_trackio_tracking(config) | |
# Create SFT configuration | |
sft_config = create_sft_config(config) | |
# Create trainer | |
print("Creating SFT trainer...") | |
trainer = SFTTrainer( | |
model=peft_model, | |
args=sft_config, | |
train_dataset=dataset, | |
processing_class=tokenizer, | |
) | |
# Start training | |
print("Starting GPT-OSS training...") | |
trainer.train() | |
# Save model | |
print("Saving trained model...") | |
trainer.save_model(output_dir) | |
# Push to hub if enabled | |
if sft_config.push_to_hub: | |
print("Pushing model to Hugging Face Hub...") | |
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking") | |
print("GPT-OSS training completed successfully!") | |
return trainer | |
def main(): | |
parser = argparse.ArgumentParser(description="GPT-OSS Training Script") | |
parser.add_argument("--config", required=True, help="Path to configuration file") | |
parser.add_argument("--experiment-name", required=True, help="Experiment name") | |
parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints") | |
parser.add_argument("--trackio-url", help="Trackio URL for monitoring") | |
parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type") | |
args = parser.parse_args() | |
# Validate arguments | |
if not os.path.exists(args.config): | |
print(f"Error: Configuration file not found: {args.config}") | |
sys.exit(1) | |
# Create output directory | |
os.makedirs(args.output_dir, exist_ok=True) | |
try: | |
train_gpt_oss( | |
config_path=args.config, | |
experiment_name=args.experiment_name, | |
output_dir=args.output_dir, | |
trackio_url=args.trackio_url, | |
trainer_type=args.trainer_type | |
) | |
except Exception as e: | |
print(f"Error during training: {e}") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |