Spaetzle-v8-7b / trl /test_orpo_trainer_demo.py
cstr's picture
Upload folder using huggingface_hub
6e81596 verified
raw
history blame
4.76 kB
from dataclasses import dataclass, field
from typing import Optional
import os
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import ORPOConfig, ORPOTrainer, set_seed
from trl.core import LengthSampler
# This code is built on top of the example code from Huggingface TRL Team
tqdm.pandas()
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="microsoft/phi-2", metadata={"help": "the model name"})
optim: Optional[str] = field(default="adamw_torch", metadata={"help": "the model name"})
data_name: Optional[str] = field(default="argilla/dpo-mix-7k", metadata={"help": "the model name"})
cache_dir: Optional[str] = field(default="", metadata={"help": "the model name"})
log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
output_dir: Optional[str] = field(default='', metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
lr_scheduler_type: Optional[str] = field(default='cosine', metadata={"help": "the learning rate scheduler"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "the batch size"})
num_train_epochs: Optional[int] = field(default=5, metadata={"help": "the batch size"})
beta: Optional[float] = field(default=0.25, metadata={"help": "weighting hyperparameter for L_OR"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
config = ORPOConfig(
output_dir=script_args.output_dir,
max_prompt_length=1024,
max_length=2048,
logging_steps=100,
save_strategy='no',
max_completion_length=2048,
per_device_train_batch_size=script_args.per_device_train_batch_size,
remove_unused_columns=False,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=script_args.learning_rate,
optim=script_args.optim,
lr_scheduler_type=script_args.lr_scheduler_type,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={'use_reentrant':True},
evaluation_strategy='epoch',
beta=script_args.beta,
report_to='wandb',
num_train_epochs=script_args.num_train_epochs,
bf16=True,
do_eval=True
)
model = AutoModelForCausalLM.from_pretrained(script_args.model_name,
cache_dir=script_args.cache_dir,
attn_implementation='flash_attention_2',
torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name,
cache_dir=script_args.cache_dir)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
def build_dataset(tokenizer):
ds_train = load_dataset(script_args.data_name, split="train",
cache_dir=script_args.cache_dir)
ds_test = load_dataset(script_args.data_name, split="test",
cache_dir=script_args.cache_dir)
def chat_template_to_text(sample):
sample["chosen"] = [tokenizer.apply_chat_template(item_chosen, tokenize=False) for item_chosen in sample['chosen']]
sample["rejected"] = [tokenizer.apply_chat_template(item_rejected, tokenize=False) for item_rejected in sample['rejected']]
sample['prompt'] = [tokenizer.apply_chat_template([item[0]], tokenize=False, add_generation_prompt=True) for item in sample['chosen']]
return sample
ds_train = ds_train.map(chat_template_to_text, batched=True, num_proc=8)
ds_test = ds_test.map(chat_template_to_text, batched=True, num_proc=8)
return ds_train, ds_test
train, test = build_dataset(tokenizer=tokenizer)
trainer = ORPOTrainer(
model=model,
args=config,
tokenizer=tokenizer,
train_dataset=train,
eval_dataset=test
)
trainer.train()