Upload folder using huggingface_hub
Browse files
README.md
CHANGED
@@ -3,7 +3,7 @@ license: apache-2.0
|
|
3 |
language:
|
4 |
- en
|
5 |
base_model:
|
6 |
-
- meta-llama/Llama-3.
|
7 |
pipeline_tag: text-generation
|
8 |
tags:
|
9 |
- Reward
|
@@ -14,4 +14,211 @@ tags:
|
|
14 |
- Best-of-N
|
15 |
---
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
language:
|
4 |
- en
|
5 |
base_model:
|
6 |
+
- meta-llama/Llama-3.2-3B-Instruct
|
7 |
pipeline_tag: text-generation
|
8 |
tags:
|
9 |
- Reward
|
|
|
14 |
- Best-of-N
|
15 |
---
|
16 |
|
17 |
+
### Introduction
|
18 |
+
|
19 |
+
This repository contains the released reasoning reward models for the paper [GRAM-R^2: Self-Training Generative Foundation Reward Models for Reward Reasoning 📝]().
|
20 |
+
|
21 |
+
<img src="https://raw.githubusercontent.com/wangclnlp/GRAM/refs/heads/main/gram-rr.png" width="1000px"></img>
|
22 |
+
|
23 |
+
We propose a self-training approach that enables reward models to elicit reward reasoning from both rationale-free labeled data and unlabeled data. This approach avoids the need for costly rationale-based annotations, enabling scalability in building foundation reward models. Specifically, we first train a preference-proving model that, given an input, a response pair, and a preference label, generates a proof explaining why the labeled preference holds. For rationale-free labeled data, this model is used to synthesize rationales for each example. For unlabeled data, the reward model improves its reasoning capability through an iterative self-training loop: (1) predicting preference labels for unlabeled examples, (2) generating corresponding rationales with the preference-proving model, and (3) updating the reward model using the synthesized data. This process scales reward reasoning by leveraging large amounts of unlabeled data. The dataset is available at this [link](https://huggingface.co/datasets/wangclnlp/GRAM-RR-TrainingData).
|
24 |
+
|
25 |
+
|
26 |
+
This reward model is fine-tuned from [LLaMA-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct).
|
27 |
+
|
28 |
+
### Evaluation Results
|
29 |
+
|
30 |
+
We evaluate our model on two challenging reward benchmarks, [RM-Bench](https://github.com/THU-KEG/RM-Bench) and [JudgeBench](https://huggingface.co/datasets/ScalerLab/JudgeBench). We compare its performance against three categories of baselines: (1) LLM-as-a-Judge approaches that prompt large language models to generate preferences, (2) open-source reward models, (3) reasoning reward models, and (4) reward models trained using unlabeled data.
|
31 |
+
|
32 |
+
- Results on the RM-Bench.
|
33 |
+
|
34 |
+
| **Model** | **Params.** | **Chat** | **Math** | **Code** | **Safety** | **Overall** |
|
35 |
+
|:-|-:|:-:|:-:|:-:|:-:|:-:|
|
36 |
+
|**LLM-as-a-Judge**||||||
|
37 |
+
|GPT-4o |- |67.2 | 67.5 | 63.6 | 91.7 | 72.5|
|
38 |
+
|Claude-3.5-Sonnet|- |62.5 | 62.6 | 54.4 | 64.4 | 61.0|
|
39 |
+
|DeepSeek-R1-0528 |671B|76.7 | 74.3 | 51.0 | 89.2 | 72.8|
|
40 |
+
|**Open-Source Reward Models**||||||
|
41 |
+
|Llama-3.1-Nemotron-70B-Reward | 70B | 70.7 | 64.3 | 57.4 | 90.3 | 70.7|
|
42 |
+
|Skywork-Reward-Gemma-2-27B | 27B | 71.8 | 59.2 | 56.6 | 94.3 | 70.5|
|
43 |
+
|Skywork-Reward-Llama-3.1-8B | 27B | 69.5 | 60.6 | 54.5 | 95.7 | 70.1|
|
44 |
+
|Nemotron-Super | 49B | 73.7 | 91.4 | 75.0 | 90.6 | 82.7 |
|
45 |
+
|Nemotron-Super-Multilingual | 49B | **77.2** | **91.9** | 74.7 | 92.9 | 84.2|
|
46 |
+
|**Reasoning Reward Models**||||||
|
47 |
+
|RM-R1-Distilled-Qwen-32B | 32B | 74.2 | 91.8 | 74.1 | 95.4 | 83.9 |
|
48 |
+
|RM-R1-Distilled-Qwen-14B | 14B | 71.8 | 90.5 | 69.5 | 94.1 | 81.5 |
|
49 |
+
|RRM-32B | 32B | 66.6 | 81.4 | 65.2 | 79.4 | 73.1 |
|
50 |
+
|**Training with Unlabeled Preference Data**||||||
|
51 |
+
|GRAM-Qwen3-14B | 14B | 67.4 | 55.2 | 62.8 | 94.3 | 69.9 |
|
52 |
+
|GRAM-Qwen3-8B | 8B | 63.5 | 53.9 | 62.9 | 92.8 | 68.3 |
|
53 |
+
|**Ours**|||||
|
54 |
+
|GRAM-RR-LLaMA-3.2-3B-RewardModel | 3B | 74.4 | 88.8 | 76.6 | 95.5 | 83.8 |
|
55 |
+
|+voting@16 | 3B | 74.8 | 89.4 | 78.4 | 95.7 | 84.6 | 93.5 |
|
56 |
+
|GRAM-RR-LLaMA-3.1-8B-RewardModel | 8B | 76.0 | 89.8 | 80.6 | 96.2 | 85.7 |
|
57 |
+
|+voting@16 | 8B | 76.3 | 90.4 | **81.2** | **96.4** | **86.1** |
|
58 |
+
|
59 |
+
|
60 |
+
- Results on the JudgeBench.
|
61 |
+
|
62 |
+
| **Model** | **Params.** | **Chat** | **Math** | **Code** | **Safety** | **Overall** |
|
63 |
+
|:-|-:|:-:|:-:|:-:|:-:|:-:|
|
64 |
+
|**LLM-as-a-Judge**||||||
|
65 |
+
|GPT-4o |- |50.6 | 54.1 | 75.0 | 59.5 | 59.8 |
|
66 |
+
|Claude-3.5-Sonnet|- |62.3 | 66.3 | 66.1 | 64.3 | 64.8|
|
67 |
+
|DeepSeek-R1-0528 |671B|59.1 | 82.7 | 80.4 | 92.9 | 78.8|
|
68 |
+
|**Open-Source Reward Models**||||||
|
69 |
+
|Llama-3.1-Nemotron-70B-Reward | 70B | 62.3 | 72.5 | 76.8 | 57.1 | 67.2|
|
70 |
+
|Skywork-Reward-Gemma-2-27B | 27B | 59.7 | 66.3 | 83.9 | 50.0 | 65.0|
|
71 |
+
|Skywork-Reward-Llama-3.1-8B | 27B | 59.1 | 64.3 | 76.8 | 50.0 | 62.5|
|
72 |
+
|Nemotron-Super | 49B | 71.4 | 73.5 | 87.5 | 76.2 | 77.2 |
|
73 |
+
|Nemotron-Super-Multilingual | 49B | 64.9 | 74.5 | 87.5 | 73.8 | 75.2|
|
74 |
+
|**Reasoning Reward Models**||||||
|
75 |
+
|RM-R1-Distilled-Qwen-32B | 32B | 76.0 | 80.6 | 88.1 | 70.5 | 78.8 |
|
76 |
+
|RM-R1-Distilled-Qwen-14B | 14B | 68.1 | 72.4 | 87.8 | **84.2** | 78.1 |
|
77 |
+
|RRM-32B | 32B | 79.9 | 70.4 | 87.5 | 65.0 | 75.7 |
|
78 |
+
|**Training with Unlabeled Preference Data**||||||
|
79 |
+
|GRAM-Qwen3-14B | 14B | 63.0 | 64.3 | **89.3** | 69.1 | 71.4 |
|
80 |
+
|GRAM-Qwen3-8B | 8B | 62.3 | 64.3 | 80.4 | 64.3 | 67.8 |
|
81 |
+
|**Ours**|||||
|
82 |
+
|GRAM-RR-LLaMA-3.2-3B-RewardModel | 3B | 93.0 | 78.1 | 81.6 | 68.5 | 80.3 |
|
83 |
+
|+voting@16 | 3B | **93.5** | 78.6 | 82.1 | 69.0 | 80.8 |
|
84 |
+
|GRAM-RR-LLaMA-3.1-8B-RewardModel | 8B | 90.9 | 83.7 | 87.5 | 61.9 | 81.0 |
|
85 |
+
|+voting@16 | 8B | 91.2 | **84.3** | 88.1 | 62.8 | **81.6** |
|
86 |
+
|
87 |
+
### Usage
|
88 |
+
|
89 |
+
You can directly run the GRAM model using the demo provided below. You can also train GRAM-RR using the code available [here](https://github.com/NiuTrans/GRAM).
|
90 |
+
|
91 |
+
```python
|
92 |
+
import torch
|
93 |
+
import accelerate
|
94 |
+
from functools import cmp_to_key
|
95 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
96 |
+
|
97 |
+
|
98 |
+
pairwise_prompt = '''
|
99 |
+
You are a helpful assistant in evaluating the quality of the responses for a given instruction. Your goal is to select the better response for the given user question.
|
100 |
+
Select Response A or Response B, that is better for the given instruction. The two responses are generated by two different AI chatbots respectively.
|
101 |
+
Do NOT say both / neither are good.
|
102 |
+
|
103 |
+
Here are some rules of the evaluation:
|
104 |
+
(1) If the instruction does not contain harmful content, you should prioritize evaluating whether the output honestly/precisely/closely executes the instruction, then consider its helpfulness, accuracy, level of detail, harmlessness, etc.
|
105 |
+
(2) If the instruction contains harmful content, prioritize the harmlessness and safety of the response.
|
106 |
+
(3) Responses should NOT contain more/less than what the instruction asks for, as such responses do NOT precisely execute the instruction.
|
107 |
+
(4) You should avoid any potential bias and your judgment should be as objective as possible. Here are some potential sources of bias:
|
108 |
+
- The order in which the responses were presented should NOT affect your judgment, as Response A and Response B are **equally likely** to be the better.
|
109 |
+
- The length of the responses should NOT affect your judgement, as a longer response does not necessarily correspond to a better response. When making your decision, evaluate if the response length is appropriate for the given instruction.
|
110 |
+
|
111 |
+
Your reply should strictly follow this format:
|
112 |
+
<think>
|
113 |
+
Follow this format:
|
114 |
+
Feedback:
|
115 |
+
<provide free-text feedback on the overall helpfulness of the assistant response>
|
116 |
+
|
117 |
+
Comparision:
|
118 |
+
<give a brief analysis on which is better>
|
119 |
+
|
120 |
+
Conclusion:
|
121 |
+
<make your conclusion>
|
122 |
+
</think>
|
123 |
+
<answer>
|
124 |
+
A or B
|
125 |
+
</answer>
|
126 |
+
|
127 |
+
Here is the data.
|
128 |
+
|
129 |
+
[User Question]
|
130 |
+
{user_input}
|
131 |
+
|
132 |
+
[The Start of Assistant A's Response]
|
133 |
+
{response_1}
|
134 |
+
[The End of Assistant A's Response]
|
135 |
+
|
136 |
+
[The Start of Assistant B's Response]
|
137 |
+
{response_2}
|
138 |
+
[The End of Assistant B's Response]
|
139 |
+
'''.strip()
|
140 |
+
|
141 |
+
# an input example
|
142 |
+
user_input = '10 words to apologize for being late.'
|
143 |
+
responses = [
|
144 |
+
"My sincere apologies for being late today.",
|
145 |
+
"Apologies for making you wait; punctuality isn't my strong suit.",
|
146 |
+
"I'm sorry I couldn’t be on time today; unexpected issues delayed me, and I appreciate your patience."
|
147 |
+
]
|
148 |
+
print('='*25 + '\n' + 'The user input is:\n\n' + user_input + '\n\n' + '='*25 + '\n')
|
149 |
+
for idx, response in enumerate(responses):
|
150 |
+
print('='*25 + '\n' + f'The response {idx} is:\n\n' + response + '\n\n' + '='*25 + '\n')
|
151 |
+
|
152 |
+
# init model
|
153 |
+
model_name = "/path/to/the/model"
|
154 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
155 |
+
model = AutoModelForCausalLM.from_pretrained(
|
156 |
+
model_name,
|
157 |
+
torch_dtype="auto",
|
158 |
+
device_map="auto"
|
159 |
+
)
|
160 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
161 |
+
|
162 |
+
# pairwise ranking
|
163 |
+
# 1 for response_1 is better, -1 for response_2 is better, 0 for no answer
|
164 |
+
def pairwise_ranking(user_input, response_1, response_2):
|
165 |
+
messages = [
|
166 |
+
{
|
167 |
+
"role": "user",
|
168 |
+
"content": pairwise_prompt.format(
|
169 |
+
user_input=user_input,
|
170 |
+
response_1=response_1,
|
171 |
+
response_2=response_2
|
172 |
+
)
|
173 |
+
}
|
174 |
+
]
|
175 |
+
text = tokenizer.apply_chat_template(
|
176 |
+
messages,
|
177 |
+
tokenize=False,
|
178 |
+
add_generation_prompt=True,
|
179 |
+
)
|
180 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
181 |
+
generated_ids = model.generate(
|
182 |
+
**model_inputs,
|
183 |
+
max_new_tokens=16384
|
184 |
+
)
|
185 |
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
186 |
+
|
187 |
+
model_res = tokenizer.decode(output_ids, skip_special_tokens=True)
|
188 |
+
# print(model_res)
|
189 |
+
model_res = model_res.rsplit("answer:")[-1].strip().upper()
|
190 |
+
# print(model_res)
|
191 |
+
if len(model_res) == 0:
|
192 |
+
return -1
|
193 |
+
|
194 |
+
return 1 if model_res.strip().upper().startswith("A") else -1
|
195 |
+
|
196 |
+
# the better one between responses[0] and responses[1]
|
197 |
+
better_response = 0 if pairwise_ranking(user_input, responses[0], responses[1])>0 else 1
|
198 |
+
print(f'Response {better_response} is better between response 0 and response 1.')
|
199 |
+
|
200 |
+
# listwise ranking
|
201 |
+
responses_id = [idx for idx, _ in enumerate(responses)]
|
202 |
+
sorted(
|
203 |
+
responses_id,
|
204 |
+
key=cmp_to_key(lambda response_1, response_2: pairwise_ranking(user_input, response_1, response_2))
|
205 |
+
)
|
206 |
+
print(f"The ranking among responses: {' > '.join([str(i) for i in responses_id])}")
|
207 |
+
|
208 |
+
# best-of-n
|
209 |
+
best = 0
|
210 |
+
for idx in range(1, len(responses)):
|
211 |
+
best = idx if pairwise_ranking(user_input, responses[idx], responses[best])>0 else best
|
212 |
+
|
213 |
+
print(f"The best response is response {best}.")
|
214 |
+
|
215 |
+
# vote in k (take pairwise ranking as an example.)
|
216 |
+
k = 8
|
217 |
+
res = [pairwise_ranking(user_input, responses[0], responses[1]) for i in range(k)]
|
218 |
+
print(f"The better response is response{max(set(res), key=res.count)} in {k} votes.")
|
219 |
+
```
|
220 |
+
|
221 |
+
### Citation
|
222 |
+
```bash
|
223 |
+
coming soon
|
224 |
+
```
|