Model Details

We employ Mistral-Base(7B) as one of the base models to evaluate our proposed Reward-Driven Selective Penalization for Preference Alignment Optimization (RSPO) method. The model is trained for one epoch on the UltraFeedback Binarized dataset using (RSPO) method.

How to use

Transformers AutoModelForCausalLM

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "li11111/Mistral-7B-Base-RSPO"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

terminators = [
    tokenizer.eos_token_id
]

outputs = model.generate(
    input_ids,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))

Experiment Parameters

Parameter Mistral-Base(7B)
GPU 8×Ascend910B
beta 0.01
batch 128
learning_rate 5e-7
max_prompt_length 512
max_length 1024
num_train_epochs 1
torch_dtype bfloat16
warmup_ratio 0.1
β_w 0.01
β_l 0.1
λ 0.1

Training Data

We use the HuggingFaceH4/ultrafeedback_binarized dataset to train the Mistral Base model.

Benchmarks

Method AlpacaEval 2.0
LC WR Avg. Len
RSPO 25.4 23.7 1873
Method GSM8K ARC TQA MMLU IFEval Avg.
SFT 42.61 55.97 28.15 57.17 36.59 44.10
DPO 33.13 59.64 46.14 57.46 50.48 49.37
R-DPO 30.10 56.06 40.64 58.48 53.24 47.70
SimPO 33.59 60.15 43.45 58.25 52.98 49.68
WPO 30.63 57.00 40.51 58.54 55.64 48.46
RSPO 37.45 57.94 47.25 58.58 55.04 51.25
Downloads last month
19
Safetensors
Model size
7.24B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support