#!/usr/bin/env python3 """ GPT-OSS Training Script Specialized training script for OpenAI's GPT-OSS models Based on the GPT-OSS fine-tuning tutorial """ import os import sys import argparse import inspect import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments from peft import LoraConfig, get_peft_model from trl import SFTTrainer try: from trl import DPOTrainer except Exception: # pragma: no cover - optional import depending on TRL version DPOTrainer = None from datasets import load_dataset from pathlib import Path # Ensure project root and config package are importable for configs that do `from config...` imports project_root = Path(__file__).resolve().parents[2] if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) config_dir = project_root / "config" if str(config_dir) not in sys.path: sys.path.insert(0, str(config_dir)) def load_gpt_oss_model_and_tokenizer(config): """Load GPT-OSS model and tokenizer with proper configuration""" print("Loading GPT-OSS tokenizer...") tokenizer = AutoTokenizer.from_pretrained(config.model_name) print("Loading GPT-OSS model with quantization...") # Import quantization config from transformers import BitsAndBytesConfig # Set up quantization config based on config if config.quantization_config and config.quantization_config.get("load_in_4bit"): # Use BitsAndBytesConfig for 4-bit quantization (memory optimized) quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) elif config.quantization_config and config.quantization_config.get("dequantize"): # Try to use Mxfp4Config if available (as per tutorial) try: from transformers import Mxfp4Config quantization_config = Mxfp4Config(dequantize=True) except ImportError: # Fallback to no quantization if Mxfp4Config not available print("Warning: Mxfp4Config not available, using no quantization") quantization_config = None else: # No quantization quantization_config = None # Model kwargs as per tutorial model_kwargs = { "attn_implementation": "eager", "torch_dtype": torch.bfloat16, "use_cache": False, "device_map": "auto", } # Only add quantization_config if it's not None if quantization_config is not None: model_kwargs["quantization_config"] = quantization_config model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs) return model, tokenizer def setup_lora_for_gpt_oss(model, config): """Setup LoRA for GPT-OSS model""" print("Setting up LoRA for GPT-OSS...") # LoRA configuration as per tutorial lora_config = LoraConfig( r=config.lora_config.get("r", 8) if config.lora_config else 8, lora_alpha=config.lora_config.get("lora_alpha", 16) if config.lora_config else 16, target_modules=config.lora_config.get("target_modules", "all-linear") if config.lora_config else "all-linear", target_parameters=config.lora_config.get("target_parameters", [ "7.mlp.experts.gate_up_proj", "7.mlp.experts.down_proj", "15.mlp.experts.gate_up_proj", "15.mlp.experts.down_proj", "23.mlp.experts.gate_up_proj", "23.mlp.experts.down_proj", ]) if config.lora_config else [ "7.mlp.experts.gate_up_proj", "7.mlp.experts.down_proj", "15.mlp.experts.gate_up_proj", "15.mlp.experts.down_proj", "23.mlp.experts.gate_up_proj", "23.mlp.experts.down_proj", ], ) peft_model = get_peft_model(model, lora_config) peft_model.print_trainable_parameters() return peft_model def load_dataset_from_config(config): """Load dataset based on configuration""" dataset_name = getattr(config, 'dataset_name', 'HuggingFaceH4/Multilingual-Thinking') dataset_split = getattr(config, 'dataset_split', 'train') dataset_config = getattr(config, 'dataset_config', None) print(f"Loading dataset: {dataset_name}") print(f"Dataset split: {dataset_split}") if dataset_config: print(f"Dataset config: {dataset_config}") # Load the dataset if dataset_config: dataset = load_dataset(dataset_name, dataset_config, split=dataset_split) else: dataset = load_dataset(dataset_name, split=dataset_split) print(f"Original dataset size: {len(dataset)} examples") # Apply filtering based on configuration dataset = apply_dataset_filtering(dataset, config) # Apply dataset processing based on format dataset = process_dataset_format(dataset, config) print(f"Final dataset size: {len(dataset)} examples") return dataset def apply_dataset_filtering(dataset, config): """Apply filtering based on configuration""" # Filter bad entries if specified if getattr(config, 'filter_bad_entries', False): bad_entry_field = getattr(config, 'bad_entry_field', 'bad_entry') bad_prompt_field = getattr(config, 'bad_prompt_field', 'bad_prompt_detected') bad_response_field = getattr(config, 'bad_response_field', 'bad_response_detected') original_size = len(dataset) # Filter out bad entries if bad_entry_field in dataset.column_names: dataset = dataset.filter(lambda x: not x.get(bad_entry_field, False)) print(f"Filtered {original_size - len(dataset)} bad entries") # Filter out bad prompts if bad_prompt_field in dataset.column_names: dataset = dataset.filter(lambda x: not x.get(bad_prompt_field, False)) print(f"Filtered bad prompts, remaining: {len(dataset)} examples") # Filter out bad responses if bad_response_field in dataset.column_names: dataset = dataset.filter(lambda x: not x.get(bad_response_field, False)) print(f"Filtered bad responses, remaining: {len(dataset)} examples") # Apply length filtering min_length = getattr(config, 'min_length', 10) max_length = getattr(config, 'max_length', None) input_field = getattr(config, 'input_field', 'prompt') target_field = getattr(config, 'target_field', 'accepted_completion') if min_length > 0 or max_length: def length_filter(example): input_len = len(example.get(input_field, '')) target_len = len(example.get(target_field, '')) total_len = input_len + target_len if total_len < min_length: return False if max_length and total_len > max_length: return False return True original_size = len(dataset) dataset = dataset.filter(length_filter) print(f"Length filtering: {original_size} -> {len(dataset)} examples") # Apply sampling if specified max_samples = getattr(config, 'max_samples', None) if max_samples and len(dataset) > max_samples: dataset = dataset.shuffle(seed=42).select(range(max_samples)) print(f"Sampled {max_samples} examples from dataset") return dataset def format_gpt_oss_harmony(prompt, completion, add_eos_token=True): """ Format data for GPT-OSS Harmony format following the exact template structure. Based on: https://huggingface.co/openai/gpt-oss-20b/raw/main/chat_template.jinja """ # GPT-OSS Harmony format structure (exact template compliance) # User message: <|start|>user<|message|>content<|end|> # Assistant message: <|start|>assistant<|channel|>final<|message|>content<|end|> (inference) # Assistant message: <|start|>assistant<|channel|>final<|message|>content<|return|> (training) harmony_text = f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>{completion}" if add_eos_token: # Use <|return|> for training as per template specification # This indicates the end of generation in training harmony_text += "<|return|>" else: # Use <|end|> for inference harmony_text += "<|end|>" return harmony_text def format_gpt_oss_harmony_prompt(prompt: str) -> str: """Prefix-only Harmony prompt up to assistant content marker for DPO.""" return f"<|start|>user<|message|>{prompt}<|end|><|start|>assistant<|channel|>final<|message|>" def process_dataset_format(dataset, config): """Process dataset based on format configuration with exact GPT-OSS Harmony compliance""" dataset_format = getattr(config, 'dataset_format', 'openhermes_fr') input_field = getattr(config, 'input_field', 'prompt') target_field = getattr(config, 'target_field', 'accepted_completion') concatenate_fields = getattr(config, 'concatenate_fields', True) field_separator = getattr(config, 'field_separator', '\n\n### Response:\n') add_eos_token = getattr(config, 'add_eos_token', True) use_harmony_format = getattr(config, 'use_harmony_format', True) trainer_type = getattr(config, 'trainer_type', 'sft') print(f"Processing dataset format: {dataset_format}") print(f"Input field: {input_field}, Target field: {target_field}") print(f"GPT-OSS Harmony Format: {'Enabled' if use_harmony_format else 'Disabled'}") # Preference-format for DPO training (chosen/rejected pairs) if trainer_type == 'dpo': chosen_field = getattr(config, 'chosen_field', None) rejected_field = getattr(config, 'rejected_field', None) if dataset_format == 'preference': # Expect columns present; optionally reformat to ensure only necessary columns def id_map(example): prompt_val = example.get(input_field, '') chosen_val = example.get('chosen', example.get(chosen_field or 'chosen', '')) rejected_val = example.get('rejected', example.get(rejected_field or 'rejected', '')) if use_harmony_format: prompt_text = format_gpt_oss_harmony_prompt(prompt_val) chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '') rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '') return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text} return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val} keep_cols = [c for c in ['prompt', 'chosen', 'rejected'] if c in dataset.column_names] dataset = dataset.map(id_map, remove_columns=dataset.column_names if keep_cols else dataset.column_names) return dataset # Custom preference mapping via configured field names if chosen_field and rejected_field: def to_pref(example): prompt_val = example.get(input_field, '') chosen_val = example.get(chosen_field, '') rejected_val = example.get(rejected_field, '') if use_harmony_format: prompt_text = format_gpt_oss_harmony_prompt(prompt_val) chosen_text = (chosen_val or '') + ("<|return|>" if add_eos_token else '') rejected_text = (rejected_val or '') + ("<|return|>" if add_eos_token else '') return {"prompt": prompt_text, "chosen": chosen_text, "rejected": rejected_text} return {"prompt": prompt_val, "chosen": chosen_val, "rejected": rejected_val} dataset = dataset.map(to_pref, remove_columns=dataset.column_names) return dataset # If we reach here, we don't have required fields for DPO raise ValueError("DPO training requires preference data. Please set dataset_format='preference' with 'prompt', 'chosen', 'rejected' columns, or specify 'chosen_field' and 'rejected_field' in the config.") if dataset_format == "openhermes_fr": # Process OpenHermes-FR format: prompt + accepted_completion def format_openhermes_fr(example): prompt = example.get(input_field, '') completion = example.get(target_field, '') if concatenate_fields: if use_harmony_format: # Use exact GPT-OSS Harmony format from template text = format_gpt_oss_harmony(prompt, completion, add_eos_token) else: # Fallback to standard format with separator text = prompt + field_separator + completion if add_eos_token: text += "" return {"text": text} else: # Keep separate for more advanced training setups return { "input": prompt, "output": completion } dataset = dataset.map(format_openhermes_fr, remove_columns=dataset.column_names) elif dataset_format == "messages": # Process messages format (like HuggingFaceH4/Multilingual-Thinking) def format_messages(example): messages = example.get(input_field, []) if use_harmony_format and len(messages) >= 2: # Extract user and assistant messages for harmony format user_message = "" assistant_message = "" for message in messages: role = message.get("role", "") content = message.get("content", "") if role == "user": user_message = content elif role == "assistant": assistant_message = content if user_message and assistant_message: # Use GPT-OSS Harmony format text = format_gpt_oss_harmony(user_message, assistant_message, add_eos_token) else: # Fallback to simple concatenation text = "" for message in messages: role = message.get("role", "") content = message.get("content", "") text += f"{role}: {content}\n" if add_eos_token: text += "" else: # Standard format - convert messages to simple text text = "" for message in messages: role = message.get("role", "") content = message.get("content", "") text += f"{role}: {content}\n" if add_eos_token: text += "" return {"text": text} dataset = dataset.map(format_messages, remove_columns=dataset.column_names) elif dataset_format == "text": # Process plain text format text_field = input_field def format_text(example): text = example.get(text_field, '') if add_eos_token: text += "" return {"text": text} dataset = dataset.map(format_text, remove_columns=dataset.column_names) elif dataset_format == "custom": # Custom format - user handles this in their config print("Using custom dataset format - no automatic processing") return dataset def split_dataset(dataset, config): """Create train/validation/test splits from a single dataset. Defaults to 1% eval and 1% test if not specified. """ from datasets import Dataset if not isinstance(dataset, Dataset): # If it's already a DatasetDict, try to use its splits try: train_split = dataset["train"] eval_split = dataset.get("validation") or dataset.get("eval") test_split = dataset.get("test") return train_split, eval_split, test_split except Exception: pass eval_ratio = getattr(config, 'eval_ratio', 0.01) test_ratio = getattr(config, 'test_ratio', 0.01) # Clamp ratios to sane bounds try: eval_ratio = max(0.0, float(eval_ratio)) test_ratio = max(0.0, float(test_ratio)) if eval_ratio + test_ratio >= 0.9: # Avoid extreme splits; cap combined at 0.2 scale = 0.2 / max(1e-9, (eval_ratio + test_ratio)) eval_ratio *= scale test_ratio *= scale except Exception: eval_ratio, test_ratio = 0.01, 0.01 # No eval/test requested if eval_ratio <= 0 and test_ratio <= 0: return dataset, None, None ds_shuffled = dataset.shuffle(seed=42) # First carve out test split if test_ratio > 0: split1 = ds_shuffled.train_test_split(test_size=test_ratio, seed=42) train_part = split1["train"] test_split = split1["test"] else: train_part = ds_shuffled test_split = None # Then carve out eval from remaining train if eval_ratio > 0: remaining_fraction = 1.0 - test_ratio # Convert global eval fraction to fraction of remaining pool relative_eval = eval_ratio / remaining_fraction if remaining_fraction > 0 else eval_ratio split2 = train_part.train_test_split(test_size=relative_eval, seed=42) train_split = split2["train"] eval_split = split2["test"] else: train_split = train_part eval_split = None # Log sizes try: print(f"Created splits -> train: {len(train_split)}, eval: {len(eval_split) if eval_split else 0}, test: {len(test_split) if test_split else 0}") except Exception: pass return train_split, eval_split, test_split def setup_trackio_tracking(config): """Setup Trackio tracking if enabled""" if not config.enable_tracking or not config.trackio_url: print("Trackio tracking disabled or URL not provided") return None print(f"Setting up Trackio tracking: {config.trackio_url}") # Import the correct TrackioAPIClient import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'trackio_tonic')) from trackio_api_client import TrackioAPIClient # Initialize Trackio client using the correct API trackio_client = TrackioAPIClient( space_id=config.trackio_url, hf_token=config.trackio_token ) return trackio_client def create_sft_config(config, output_dir): """Create enhanced SFTConfig for GPT-OSS training""" print("Creating enhanced SFT configuration...") # Helper coercion utilities to guarantee numeric types def _as_int(value, default): if value is None: return int(default) try: return int(value) except Exception: return int(default) def _as_float(value, default): if value is None: return float(default) try: return float(value) except Exception: return float(default) # Extract training parameters from config with enhanced defaults and coercion num_train_epochs = _as_float(getattr(config, 'num_train_epochs', 1.0), 1.0) # Transformers expects max_steps default -1 (disabled). Some code compares > 0 raw_max_steps = getattr(config, 'max_steps', None) max_steps = _as_int(raw_max_steps if raw_max_steps is not None else -1, -1) warmup_ratio = _as_float(getattr(config, 'warmup_ratio', 0.03), 0.03) # Ensure warmup_steps is an int; default 0 to avoid None comparisons in schedulers warmup_steps = _as_int(getattr(config, 'warmup_steps', 0), 0) # Learning rate configuration learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4) lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr') # Batch configuration per_device_train_batch_size = _as_int(getattr(config, 'batch_size', 2), 2) per_device_eval_batch_size = _as_int(getattr(config, 'eval_batch_size', per_device_train_batch_size), per_device_train_batch_size) gradient_accumulation_steps = _as_int(getattr(config, 'gradient_accumulation_steps', 1), 1) # Evaluation and logging eval_strategy = getattr(config, 'eval_strategy', 'steps') eval_steps = _as_int(getattr(config, 'eval_steps', 100), 100) eval_accumulation_steps = _as_int(getattr(config, 'eval_accumulation_steps', 1), 1) logging_steps = _as_int(getattr(config, 'logging_steps', 10), 10) # Saving configuration save_strategy = getattr(config, 'save_strategy', 'steps') save_steps = _as_int(getattr(config, 'save_steps', 500), 500) save_total_limit = _as_int(getattr(config, 'save_total_limit', 3), 3) # Mixed precision fp16 = bool(getattr(config, 'fp16', False)) bf16 = bool(getattr(config, 'bf16', True)) tf32 = bool(getattr(config, 'tf32', False)) # Regularization weight_decay = _as_float(getattr(config, 'weight_decay', 0.01), 0.01) max_grad_norm = _as_float(getattr(config, 'max_grad_norm', 1.0), 1.0) # HuggingFace Hub integration push_to_hub = getattr(config, 'push_to_hub', False) print(f" • Epochs: {num_train_epochs}") print(f" • Learning rate: {learning_rate}") print(f" • Batch size: {per_device_train_batch_size}") print(f" • Gradient accumulation: {gradient_accumulation_steps}") print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}") # Build kwargs dynamically to be compatible across transformers versions ta_kwargs = { # Training duration "num_train_epochs": num_train_epochs, "max_steps": max_steps, # Learning rate "learning_rate": learning_rate, "lr_scheduler_type": lr_scheduler_type, "warmup_ratio": warmup_ratio, "warmup_steps": warmup_steps, # Batch configuration "per_device_train_batch_size": per_device_train_batch_size, "per_device_eval_batch_size": per_device_eval_batch_size, "gradient_accumulation_steps": gradient_accumulation_steps, # Model configuration "gradient_checkpointing": getattr(config, 'use_gradient_checkpointing', True), # Mixed precision "fp16": fp16, "bf16": bf16, # Some versions support tf32 "tf32": tf32 if 'tf32' in TrainingArguments.__init__.__code__.co_varnames else None, # Regularization "weight_decay": weight_decay, "max_grad_norm": max_grad_norm, # Evaluation (name may vary across versions) "evaluation_strategy": eval_strategy, "eval_steps": eval_steps, "eval_accumulation_steps": eval_accumulation_steps, # Logging "logging_steps": logging_steps, # Saving "save_strategy": save_strategy, "save_steps": save_steps, "save_total_limit": save_total_limit, # Output "output_dir": output_dir, # Data loading "dataloader_num_workers": _as_int(getattr(config, 'dataloader_num_workers', 4), 4), "dataloader_pin_memory": getattr(config, 'dataloader_pin_memory', True), # Optional in some versions "dataloader_prefetch_factor": _as_int(getattr(config, 'dataloader_prefetch_factor', 2), 2), # Performance "group_by_length": getattr(config, 'group_by_length', True), "remove_unused_columns": getattr(config, 'remove_unused_columns', True), # HuggingFace Hub "push_to_hub": push_to_hub, # Monitoring "report_to": ("trackio" if getattr(config, 'enable_tracking', False) else None), } # Drop any None-valued kwargs ta_kwargs = {k: v for k, v in ta_kwargs.items() if v is not None} # Adapt to transformers versions where 'evaluation_strategy' was renamed try: ta_sig = inspect.signature(TrainingArguments.__init__) param_names = set(ta_sig.parameters.keys()) except Exception: param_names = set() if "evaluation_strategy" not in param_names and "eval_strategy" in param_names: # Move value to 'eval_strategy' ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy") elif "evaluation_strategy" not in param_names: # If neither is supported, drop it ta_kwargs.pop("evaluation_strategy", None) # Remove any kwargs not supported by current transformers version if param_names: unsupported = [k for k in ta_kwargs.keys() if k not in param_names] for k in unsupported: ta_kwargs.pop(k, None) sft_config = TrainingArguments(**ta_kwargs) return sft_config def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer_type="sft"): """Main training function for GPT-OSS""" print("=== GPT-OSS Training Pipeline ===") print(f"Config: {config_path}") print(f"Experiment: {experiment_name}") print(f"Output: {output_dir}") print(f"Trackio: {trackio_url}") print(f"Trainer: {trainer_type}") # Load configuration if os.path.exists(config_path): import importlib.util spec = importlib.util.spec_from_file_location("config_module", config_path) config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) if hasattr(config_module, 'config'): config = config_module.config else: # Try to find a config class for attr_name in dir(config_module): attr = getattr(config_module, attr_name) if hasattr(attr, 'model_name') and ('gpt_oss' in attr.model_name.lower() or 'GPTOSS' in attr_name): config = attr break else: raise ValueError(f"No GPT-OSS configuration found in {config_path}") else: raise FileNotFoundError(f"Configuration file not found: {config_path}") # Update config with runtime parameters config.experiment_name = experiment_name config.trackio_url = trackio_url config.trainer_type = trainer_type # Load model and tokenizer model, tokenizer = load_gpt_oss_model_and_tokenizer(config) # Setup LoRA peft_model = setup_lora_for_gpt_oss(model, config) # Load dataset dataset = load_dataset_from_config(config) # Split into train/eval/test train_dataset, eval_dataset, test_dataset = split_dataset(dataset, config) # Setup Trackio tracking trackio_client = setup_trackio_tracking(config) # Create SFT configuration sft_config = create_sft_config(config, output_dir) # Create trainer with version-robust kwargs if trainer_type == 'dpo': if DPOTrainer is None: raise RuntimeError("DPOTrainer is not available in this TRL version. Please upgrade 'trl'.") print("Creating DPO trainer...") try: dpo_sig = inspect.signature(DPOTrainer.__init__) dpo_params = set(dpo_sig.parameters.keys()) except Exception: dpo_params = {"model", "args", "train_dataset", "tokenizer", "beta", "prompt_column", "chosen_column", "rejected_column"} dpo_kwargs = { "model": peft_model, "args": sft_config, "train_dataset": train_dataset, "beta": getattr(config, 'dpo_beta', 0.1), } if "tokenizer" in dpo_params: dpo_kwargs["tokenizer"] = tokenizer elif "processing_class" in dpo_params: dpo_kwargs["processing_class"] = tokenizer if "prompt_column" in dpo_params: dpo_kwargs["prompt_column"] = "prompt" if "chosen_column" in dpo_params: dpo_kwargs["chosen_column"] = "chosen" if "rejected_column" in dpo_params: dpo_kwargs["rejected_column"] = "rejected" # Remove Nones dpo_kwargs = {k: v for k, v in dpo_kwargs.items() if v is not None} # Pass eval dataset if supported if "eval_dataset" in dpo_params and eval_dataset is not None: dpo_kwargs["eval_dataset"] = eval_dataset trainer = DPOTrainer(**dpo_kwargs) else: print("Creating SFT trainer...") try: sft_sig = inspect.signature(SFTTrainer.__init__) sft_params = set(sft_sig.parameters.keys()) except Exception: sft_params = {"model", "args", "train_dataset", "tokenizer", "dataset_text_field", "max_seq_length"} sft_kwargs = { "model": peft_model, "args": sft_config, "train_dataset": train_dataset, } # Prefer passing tokenizer if supported; otherwise try processing_class if "tokenizer" in sft_params: sft_kwargs["tokenizer"] = tokenizer elif "processing_class" in sft_params: sft_kwargs["processing_class"] = tokenizer # Pass dataset text field if supported (we produced a 'text' column) if "dataset_text_field" in sft_params: sft_kwargs["dataset_text_field"] = "text" # Pass max sequence length if supported if "max_seq_length" in sft_params: sft_kwargs["max_seq_length"] = getattr(config, 'max_seq_length', 2048) # Remove any None values sft_kwargs = {k: v for k, v in sft_kwargs.items() if v is not None} # Attach eval_dataset if supported if "eval_dataset" in sft_params and eval_dataset is not None: sft_kwargs["eval_dataset"] = eval_dataset trainer = SFTTrainer(**sft_kwargs) # Start training print("Starting GPT-OSS training...") trainer.train() # Save model print("Saving trained model...") trainer.save_model(output_dir) # Push to hub if enabled if sft_config.push_to_hub: print("Pushing model to Hugging Face Hub...") trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking") print("GPT-OSS training completed successfully!") return trainer def main(): parser = argparse.ArgumentParser(description="GPT-OSS Training Script") parser.add_argument("--config", required=True, help="Path to configuration file") parser.add_argument("--experiment-name", required=True, help="Experiment name") parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints") parser.add_argument("--trackio-url", help="Trackio URL for monitoring") parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type") args = parser.parse_args() # Validate arguments if not os.path.exists(args.config): print(f"Error: Configuration file not found: {args.config}") sys.exit(1) # Create output directory os.makedirs(args.output_dir, exist_ok=True) try: train_gpt_oss( config_path=args.config, experiment_name=args.experiment_name, output_dir=args.output_dir, trackio_url=args.trackio_url, trainer_type=args.trainer_type ) except Exception as e: print(f"Error during training: {e}") sys.exit(1) if __name__ == "__main__": main()