#!/usr/bin/env python3 """ GPT-OSS Training Script Specialized training script for OpenAI's GPT-OSS models Based on the GPT-OSS fine-tuning tutorial """ import os import sys import argparse import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import LoraConfig, get_peft_model from trl import SFTTrainer, SFTConfig from datasets import load_dataset from pathlib import Path # Ensure project root and config package are importable for configs that do `from config...` imports project_root = Path(__file__).resolve().parents[2] if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) config_dir = project_root / "config" if str(config_dir) not in sys.path: sys.path.insert(0, str(config_dir)) def load_gpt_oss_model_and_tokenizer(config): """Load GPT-OSS model and tokenizer with proper configuration""" print("Loading GPT-OSS tokenizer...") tokenizer = AutoTokenizer.from_pretrained(config.model_name) print("Loading GPT-OSS model with quantization...") # Import quantization config from transformers import BitsAndBytesConfig # Set up quantization config based on config if config.quantization_config and config.quantization_config.get("load_in_4bit"): # Use BitsAndBytesConfig for 4-bit quantization (memory optimized) quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) elif config.quantization_config and config.quantization_config.get("dequantize"): # Try to use Mxfp4Config if available (as per tutorial) try: from transformers import Mxfp4Config quantization_config = Mxfp4Config(dequantize=True) except ImportError: # Fallback to no quantization if Mxfp4Config not available print("Warning: Mxfp4Config not available, using no quantization") quantization_config = None else: # No quantization quantization_config = None # Model kwargs as per tutorial model_kwargs = { "attn_implementation": "eager", "torch_dtype": torch.bfloat16, "use_cache": False, "device_map": "auto", } # Only add quantization_config if it's not None if quantization_config is not None: model_kwargs["quantization_config"] = quantization_config model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs) return model, tokenizer def setup_lora_for_gpt_oss(model, config): """Setup LoRA for GPT-OSS model""" print("Setting up LoRA for GPT-OSS...") # LoRA configuration as per tutorial lora_config = LoraConfig( r=config.lora_config.get("r", 8) if config.lora_config else 8, lora_alpha=config.lora_config.get("lora_alpha", 16) if config.lora_config else 16, target_modules=config.lora_config.get("target_modules", "all-linear") if config.lora_config else "all-linear", target_parameters=config.lora_config.get("target_parameters", [ "7.mlp.experts.gate_up_proj", "7.mlp.experts.down_proj", "15.mlp.experts.gate_up_proj", "15.mlp.experts.down_proj", "23.mlp.experts.gate_up_proj", "23.mlp.experts.down_proj", ]) if config.lora_config else [ "7.mlp.experts.gate_up_proj", "7.mlp.experts.down_proj", "15.mlp.experts.gate_up_proj", "15.mlp.experts.down_proj", "23.mlp.experts.gate_up_proj", "23.mlp.experts.down_proj", ], ) peft_model = get_peft_model(model, lora_config) peft_model.print_trainable_parameters() return peft_model def load_dataset_from_config(config): """Load dataset based on configuration""" dataset_name = getattr(config, 'dataset_name', 'HuggingFaceH4/Multilingual-Thinking') dataset_split = getattr(config, 'dataset_split', 'train') dataset_config = getattr(config, 'dataset_config', None) print(f"Loading dataset: {dataset_name}") print(f"Dataset split: {dataset_split}") if dataset_config: print(f"Dataset config: {dataset_config}") # Load the dataset if dataset_config: dataset = load_dataset(dataset_name, dataset_config, split=dataset_split) else: dataset = load_dataset(dataset_name, split=dataset_split) print(f"Original dataset size: {len(dataset)} examples") # Apply filtering based on configuration dataset = apply_dataset_filtering(dataset, config) # Apply dataset processing based on format dataset = process_dataset_format(dataset, config) print(f"Final dataset size: {len(dataset)} examples") return dataset def apply_dataset_filtering(dataset, config): """Apply filtering based on configuration""" # Filter bad entries if specified if getattr(config, 'filter_bad_entries', False): bad_entry_field = getattr(config, 'bad_entry_field', 'bad_entry') bad_prompt_field = getattr(config, 'bad_prompt_field', 'bad_prompt_detected') bad_response_field = getattr(config, 'bad_response_field', 'bad_response_detected') original_size = len(dataset) # Filter out bad entries if bad_entry_field in dataset.column_names: dataset = dataset.filter(lambda x: not x.get(bad_entry_field, False)) print(f"Filtered {original_size - len(dataset)} bad entries") # Filter out bad prompts if bad_prompt_field in dataset.column_names: dataset = dataset.filter(lambda x: not x.get(bad_prompt_field, False)) print(f"Filtered bad prompts, remaining: {len(dataset)} examples") # Filter out bad responses if bad_response_field in dataset.column_names: dataset = dataset.filter(lambda x: not x.get(bad_response_field, False)) print(f"Filtered bad responses, remaining: {len(dataset)} examples") # Apply length filtering min_length = getattr(config, 'min_length', 10) max_length = getattr(config, 'max_length', None) input_field = getattr(config, 'input_field', 'prompt') target_field = getattr(config, 'target_field', 'accepted_completion') if min_length > 0 or max_length: def length_filter(example): input_len = len(example.get(input_field, '')) target_len = len(example.get(target_field, '')) total_len = input_len + target_len if total_len < min_length: return False if max_length and total_len > max_length: return False return True original_size = len(dataset) dataset = dataset.filter(length_filter) print(f"Length filtering: {original_size} -> {len(dataset)} examples") # Apply sampling if specified max_samples = getattr(config, 'max_samples', None) if max_samples and len(dataset) > max_samples: dataset = dataset.shuffle(seed=42).select(range(max_samples)) print(f"Sampled {max_samples} examples from dataset") return dataset def format_gpt_oss_harmony(prompt, completion, add_eos_token=True): """ Format data for GPT-OSS Harmony format following the exact template structure. Based on: https://huggingface.co/openai/gpt-oss-20b/raw/main/chat_template.jinja """ # GPT-OSS Harmony format structure (exact template compliance) # User message: <|start|>user<|message|>content<|end|> # Assistant message: <|start|>assistant<|channel|>final<|message|>content<|end|> (inference) # Assistant message: <|start|>assistant<|channel|>final<|message|>content<|return|> (training) harmony_text = f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>{completion}" if add_eos_token: # Use <|return|> for training as per template specification # This indicates the end of generation in training harmony_text += "<|return|>" else: # Use <|end|> for inference harmony_text += "<|end|>" return harmony_text def process_dataset_format(dataset, config): """Process dataset based on format configuration with exact GPT-OSS Harmony compliance""" dataset_format = getattr(config, 'dataset_format', 'openhermes_fr') input_field = getattr(config, 'input_field', 'prompt') target_field = getattr(config, 'target_field', 'accepted_completion') concatenate_fields = getattr(config, 'concatenate_fields', True) field_separator = getattr(config, 'field_separator', '\n\n### Response:\n') add_eos_token = getattr(config, 'add_eos_token', True) use_harmony_format = getattr(config, 'use_harmony_format', True) print(f"Processing dataset format: {dataset_format}") print(f"Input field: {input_field}, Target field: {target_field}") print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}") if dataset_format == "openhermes_fr": # Process OpenHermes-FR format: prompt + accepted_completion def format_openhermes_fr(example): prompt = example.get(input_field, '') completion = example.get(target_field, '') if concatenate_fields: if use_harmony_format: # Use exact GPT-OSS Harmony format from template text = format_gpt_oss_harmony(prompt, completion, add_eos_token) else: # Fallback to standard format with separator text = prompt + field_separator + completion if add_eos_token: text += "" return {"text": text} else: # Keep separate for more advanced training setups return { "input": prompt, "output": completion } dataset = dataset.map(format_openhermes_fr, remove_columns=dataset.column_names) elif dataset_format == "messages": # Process messages format (like HuggingFaceH4/Multilingual-Thinking) def format_messages(example): messages = example.get(input_field, []) if use_harmony_format and len(messages) >= 2: # Extract user and assistant messages for harmony format user_message = "" assistant_message = "" for message in messages: role = message.get("role", "") content = message.get("content", "") if role == "user": user_message = content elif role == "assistant": assistant_message = content if user_message and assistant_message: # Use GPT-OSS Harmony format text = format_gpt_oss_harmony(user_message, assistant_message, add_eos_token) else: # Fallback to simple concatenation text = "" for message in messages: role = message.get("role", "") content = message.get("content", "") text += f"{role}: {content}\n" if add_eos_token: text += "" else: # Standard format - convert messages to simple text text = "" for message in messages: role = message.get("role", "") content = message.get("content", "") text += f"{role}: {content}\n" if add_eos_token: text += "" return {"text": text} dataset = dataset.map(format_messages, remove_columns=dataset.column_names) elif dataset_format == "text": # Process plain text format text_field = input_field def format_text(example): text = example.get(text_field, '') if add_eos_token: text += "" return {"text": text} dataset = dataset.map(format_text, remove_columns=dataset.column_names) elif dataset_format == "custom": # Custom format - user handles this in their config print("Using custom dataset format - no automatic processing") return dataset def setup_trackio_tracking(config): """Setup Trackio tracking if enabled""" if not config.enable_tracking or not config.trackio_url: print("Trackio tracking disabled or URL not provided") return None print(f"Setting up Trackio tracking: {config.trackio_url}") # Import the correct TrackioAPIClient import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'trackio_tonic')) from trackio_api_client import TrackioAPIClient # Initialize Trackio client using the correct API trackio_client = TrackioAPIClient( space_id=config.trackio_url, hf_token=config.trackio_token ) return trackio_client def create_sft_config(config, output_dir): """Create enhanced SFTConfig for GPT-OSS training""" print("Creating enhanced SFT configuration...") # Extract training parameters from config with enhanced defaults num_train_epochs = getattr(config, 'num_train_epochs', 1.0) max_steps = getattr(config, 'max_steps', None) warmup_ratio = getattr(config, 'warmup_ratio', 0.03) warmup_steps = getattr(config, 'warmup_steps', None) # Learning rate configuration learning_rate = config.learning_rate lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr') lr_scheduler_kwargs = getattr(config, 'lr_scheduler_kwargs', {"min_lr_rate": 0.1}) # Batch configuration per_device_train_batch_size = config.batch_size per_device_eval_batch_size = getattr(config, 'eval_batch_size', config.batch_size) gradient_accumulation_steps = config.gradient_accumulation_steps # Evaluation and logging eval_strategy = getattr(config, 'eval_strategy', 'steps') eval_steps = getattr(config, 'eval_steps', 100) logging_steps = getattr(config, 'logging_steps', 10) # Saving configuration save_strategy = getattr(config, 'save_strategy', 'steps') save_steps = getattr(config, 'save_steps', 500) save_total_limit = getattr(config, 'save_total_limit', 3) # Mixed precision fp16 = getattr(config, 'fp16', False) bf16 = getattr(config, 'bf16', True) # Regularization weight_decay = getattr(config, 'weight_decay', 0.01) max_grad_norm = getattr(config, 'max_grad_norm', 1.0) # HuggingFace Hub integration push_to_hub = getattr(config, 'push_to_hub', False) print(f" • Epochs: {num_train_epochs}") print(f" • Learning rate: {learning_rate}") print(f" • Batch size: {per_device_train_batch_size}") print(f" • Gradient accumulation: {gradient_accumulation_steps}") print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}") sft_config = SFTConfig( # Training duration num_train_epochs=num_train_epochs, max_steps=max_steps, # Learning rate learning_rate=learning_rate, lr_scheduler_type=lr_scheduler_type, lr_scheduler_kwargs=lr_scheduler_kwargs, warmup_ratio=warmup_ratio, warmup_steps=warmup_steps, # Batch configuration per_device_train_batch_size=per_device_train_batch_size, per_device_eval_batch_size=per_device_eval_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, # Model configuration gradient_checkpointing=getattr(config, 'use_gradient_checkpointing', True), # Mixed precision fp16=fp16, bf16=bf16, # Regularization weight_decay=weight_decay, max_grad_norm=max_grad_norm, # Evaluation evaluation_strategy=eval_strategy, eval_steps=eval_steps, # Logging logging_steps=logging_steps, # Saving save_strategy=save_strategy, save_steps=save_steps, save_total_limit=save_total_limit, # Output output_dir=output_dir, # Data loading dataloader_num_workers=getattr(config, 'dataloader_num_workers', 4), dataloader_pin_memory=getattr(config, 'dataloader_pin_memory', True), # Performance group_by_length=getattr(config, 'group_by_length', True), remove_unused_columns=getattr(config, 'remove_unused_columns', True), # HuggingFace Hub push_to_hub=push_to_hub, # Monitoring report_to="trackio" if getattr(config, 'enable_tracking', False) else None, ) return sft_config def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer_type="sft"): """Main training function for GPT-OSS""" print("=== GPT-OSS Training Pipeline ===") print(f"Config: {config_path}") print(f"Experiment: {experiment_name}") print(f"Output: {output_dir}") print(f"Trackio: {trackio_url}") print(f"Trainer: {trainer_type}") # Load configuration if os.path.exists(config_path): import importlib.util spec = importlib.util.spec_from_file_location("config_module", config_path) config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) if hasattr(config_module, 'config'): config = config_module.config else: # Try to find a config class for attr_name in dir(config_module): attr = getattr(config_module, attr_name) if hasattr(attr, 'model_name') and ('gpt_oss' in attr.model_name.lower() or 'GPTOSS' in attr_name): config = attr break else: raise ValueError(f"No GPT-OSS configuration found in {config_path}") else: raise FileNotFoundError(f"Configuration file not found: {config_path}") # Update config with runtime parameters config.experiment_name = experiment_name config.trackio_url = trackio_url config.trainer_type = trainer_type # Load model and tokenizer model, tokenizer = load_gpt_oss_model_and_tokenizer(config) # Setup LoRA peft_model = setup_lora_for_gpt_oss(model, config) # Load dataset dataset = load_dataset_from_config(config) # Setup Trackio tracking trackio_client = setup_trackio_tracking(config) # Create SFT configuration sft_config = create_sft_config(config, output_dir) # Create trainer print("Creating SFT trainer...") trainer = SFTTrainer( model=peft_model, args=sft_config, train_dataset=dataset, processing_class=tokenizer, dataset_text_field="text", max_seq_length=getattr(config, 'max_seq_length', 2048), ) # Start training print("Starting GPT-OSS training...") trainer.train() # Save model print("Saving trained model...") trainer.save_model(output_dir) # Push to hub if enabled if sft_config.push_to_hub: print("Pushing model to Hugging Face Hub...") trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking") print("GPT-OSS training completed successfully!") return trainer def main(): parser = argparse.ArgumentParser(description="GPT-OSS Training Script") parser.add_argument("--config", required=True, help="Path to configuration file") parser.add_argument("--experiment-name", required=True, help="Experiment name") parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints") parser.add_argument("--trackio-url", help="Trackio URL for monitoring") parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type") args = parser.parse_args() # Validate arguments if not os.path.exists(args.config): print(f"Error: Configuration file not found: {args.config}") sys.exit(1) # Create output directory os.makedirs(args.output_dir, exist_ok=True) try: train_gpt_oss( config_path=args.config, experiment_name=args.experiment_name, output_dir=args.output_dir, trackio_url=args.trackio_url, trainer_type=args.trainer_type ) except Exception as e: print(f"Error during training: {e}") sys.exit(1) if __name__ == "__main__": main()