Spaces:
Running
Running
File size: 9,100 Bytes
fcf2981 7181190 fa7de39 7181190 fa7de39 7181190 fa7de39 7181190 fcf2981 fa7de39 fcf2981 c7cffbb fcf2981 c7cffbb fcf2981 dfcb060 fcf2981 7181190 fcf2981 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
#!/usr/bin/env python3
"""
GPT-OSS Training Script
Specialized training script for OpenAI's GPT-OSS models
Based on the GPT-OSS fine-tuning tutorial
"""
import os
import sys
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
def load_gpt_oss_model_and_tokenizer(config):
"""Load GPT-OSS model and tokenizer with proper configuration"""
print("Loading GPT-OSS tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
print("Loading GPT-OSS model with quantization...")
# Import quantization config
from transformers import BitsAndBytesConfig
# Set up quantization config based on config
if config.quantization_config and config.quantization_config.get("load_in_4bit"):
# Use BitsAndBytesConfig for 4-bit quantization (memory optimized)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif config.quantization_config and config.quantization_config.get("dequantize"):
# Try to use Mxfp4Config if available (as per tutorial)
try:
from transformers import Mxfp4Config
quantization_config = Mxfp4Config(dequantize=True)
except ImportError:
# Fallback to no quantization if Mxfp4Config not available
print("Warning: Mxfp4Config not available, using no quantization")
quantization_config = None
else:
# No quantization
quantization_config = None
# Model kwargs as per tutorial
model_kwargs = {
"attn_implementation": "eager",
"torch_dtype": torch.bfloat16,
"use_cache": False,
"device_map": "auto",
}
# Only add quantization_config if it's not None
if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config
model = AutoModelForCausalLM.from_pretrained(config.model_name, **model_kwargs)
return model, tokenizer
def setup_lora_for_gpt_oss(model, config):
"""Setup LoRA for GPT-OSS model"""
print("Setting up LoRA for GPT-OSS...")
# LoRA configuration as per tutorial
lora_config = LoraConfig(
r=config.lora_config.get("r", 8) if config.lora_config else 8,
lora_alpha=config.lora_config.get("lora_alpha", 16) if config.lora_config else 16,
target_modules=config.lora_config.get("target_modules", "all-linear") if config.lora_config else "all-linear",
target_parameters=config.lora_config.get("target_parameters", [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
]) if config.lora_config else [
"7.mlp.experts.gate_up_proj",
"7.mlp.experts.down_proj",
"15.mlp.experts.gate_up_proj",
"15.mlp.experts.down_proj",
"23.mlp.experts.gate_up_proj",
"23.mlp.experts.down_proj",
],
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
return peft_model
def load_multilingual_thinking_dataset():
"""Load the Multilingual-Thinking dataset"""
print("Loading Multilingual-Thinking dataset...")
dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")
print(f"Dataset loaded: {len(dataset)} examples")
return dataset
def setup_trackio_tracking(config):
"""Setup Trackio tracking if enabled"""
if not config.enable_tracking or not config.trackio_url:
print("Trackio tracking disabled or URL not provided")
return None
print(f"Setting up Trackio tracking: {config.trackio_url}")
# Import the correct TrackioAPIClient
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'trackio_tonic'))
from trackio_api_client import TrackioAPIClient
# Initialize Trackio client using the correct API
trackio_client = TrackioAPIClient(
space_id=config.trackio_url,
hf_token=config.trackio_token
)
return trackio_client
def create_sft_config(config):
"""Create SFTConfig for GPT-OSS training"""
print("Creating SFT configuration...")
sft_config = SFTConfig(
learning_rate=config.learning_rate,
gradient_checkpointing=True,
num_train_epochs=1, # Single epoch as per tutorial
logging_steps=config.logging_steps,
per_device_train_batch_size=config.batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
max_length=config.max_seq_length,
warmup_ratio=0.03,
lr_scheduler_type="cosine_with_min_lr",
lr_scheduler_kwargs={"min_lr_rate": 0.1},
output_dir="gpt-oss-20b-multilingual-reasoner",
report_to="trackio" if config.enable_tracking else None,
push_to_hub=True,
)
return sft_config
def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer_type="sft"):
"""Main training function for GPT-OSS"""
print("=== GPT-OSS Training Pipeline ===")
print(f"Config: {config_path}")
print(f"Experiment: {experiment_name}")
print(f"Output: {output_dir}")
print(f"Trackio: {trackio_url}")
print(f"Trainer: {trainer_type}")
# Load configuration
if os.path.exists(config_path):
import importlib.util
spec = importlib.util.spec_from_file_location("config_module", config_path)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
if hasattr(config_module, 'config'):
config = config_module.config
else:
# Try to find a config class
for attr_name in dir(config_module):
attr = getattr(config_module, attr_name)
if hasattr(attr, 'model_name') and ('gpt_oss' in attr.model_name.lower() or 'GPTOSS' in attr_name):
config = attr
break
else:
raise ValueError(f"No GPT-OSS configuration found in {config_path}")
else:
raise FileNotFoundError(f"Configuration file not found: {config_path}")
# Update config with runtime parameters
config.experiment_name = experiment_name
config.trackio_url = trackio_url
config.trainer_type = trainer_type
# Load model and tokenizer
model, tokenizer = load_gpt_oss_model_and_tokenizer(config)
# Setup LoRA
peft_model = setup_lora_for_gpt_oss(model, config)
# Load dataset
dataset = load_multilingual_thinking_dataset()
# Setup Trackio tracking
trackio_client = setup_trackio_tracking(config)
# Create SFT configuration
sft_config = create_sft_config(config)
# Create trainer
print("Creating SFT trainer...")
trainer = SFTTrainer(
model=peft_model,
args=sft_config,
train_dataset=dataset,
processing_class=tokenizer,
)
# Start training
print("Starting GPT-OSS training...")
trainer.train()
# Save model
print("Saving trained model...")
trainer.save_model(output_dir)
# Push to hub if enabled
if sft_config.push_to_hub:
print("Pushing model to Hugging Face Hub...")
trainer.push_to_hub(dataset_name="HuggingFaceH4/Multilingual-Thinking")
print("GPT-OSS training completed successfully!")
return trainer
def main():
parser = argparse.ArgumentParser(description="GPT-OSS Training Script")
parser.add_argument("--config", required=True, help="Path to configuration file")
parser.add_argument("--experiment-name", required=True, help="Experiment name")
parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints")
parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type")
args = parser.parse_args()
# Validate arguments
if not os.path.exists(args.config):
print(f"Error: Configuration file not found: {args.config}")
sys.exit(1)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
try:
train_gpt_oss(
config_path=args.config,
experiment_name=args.experiment_name,
output_dir=args.output_dir,
trackio_url=args.trackio_url,
trainer_type=args.trainer_type
)
except Exception as e:
print(f"Error during training: {e}")
sys.exit(1)
if __name__ == "__main__":
main() |