Update README.md
Browse files
README.md
CHANGED
@@ -24,15 +24,6 @@ datasets:
|
|
24 |
## Requirements
|
25 |
* `transformers>=4.40.0` for Qwen2.5-Math models. The latest version is recommended.
|
26 |
|
27 |
-
> [!Warning]
|
28 |
-
> <div align="center">
|
29 |
-
> <b>
|
30 |
-
> 🚨 This is a must because `transformers` integrated Qwen2.5 codes since `4.37.0`.
|
31 |
-
> </b>
|
32 |
-
> </div>
|
33 |
-
|
34 |
-
For requirements on GPU memory and the respective throughput, see similar results of Qwen2 [here](https://qwen.readthedocs.io/en/latest/benchmark/speed_benchmark.html).
|
35 |
-
|
36 |
## Quick Start
|
37 |
|
38 |
> [!Important]
|
@@ -45,7 +36,7 @@ For requirements on GPU memory and the respective throughput, see similar result
|
|
45 |
|
46 |
### 🤗 Hugging Face Transformers
|
47 |
|
48 |
-
Here we show a code snippet to show you how to use our PRM with `transformers`:
|
49 |
|
50 |
```python
|
51 |
import torch
|
@@ -129,17 +120,121 @@ print(step_reward) # [[0.796875, 0.185546875, -0.0625, 0.078125]]
|
|
129 |
# torch.tensor(step_reward).sum(dim=-1)
|
130 |
```
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
## Citation
|
133 |
|
134 |
If you find our work useful, we would appreciate it if you could cite our work:
|
135 |
|
136 |
```
|
137 |
-
@
|
138 |
-
title={Stop
|
139 |
-
author={Jie
|
140 |
-
|
141 |
-
howpublished={\url{https://tungsten-ink-510.notion.site/Stop-Gamma-Decay-Min-Form-Credit-Assignment-Is-All-Process-Reward-Model-Needs-for-Reasoning-19fcb6ed0184804eb07fd310b38af155?pvs=4}},
|
142 |
-
note={Notion Blog}
|
143 |
year={2025}
|
144 |
}
|
145 |
```
|
|
|
24 |
## Requirements
|
25 |
* `transformers>=4.40.0` for Qwen2.5-Math models. The latest version is recommended.
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
## Quick Start
|
28 |
|
29 |
> [!Important]
|
|
|
36 |
|
37 |
### 🤗 Hugging Face Transformers
|
38 |
|
39 |
+
1. Here we show a code snippet to show you how to use our PRM with `transformers`:
|
40 |
|
41 |
```python
|
42 |
import torch
|
|
|
120 |
# torch.tensor(step_reward).sum(dim=-1)
|
121 |
```
|
122 |
|
123 |
+
2. Additionally, we share the code for BoN evalution on RLHFlow's data:
|
124 |
+
|
125 |
+
```python
|
126 |
+
import numpy as np
|
127 |
+
import torch
|
128 |
+
from datasets import load_dataset
|
129 |
+
from tqdm import tqdm
|
130 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
131 |
+
|
132 |
+
ds_names = ["GSM8K", "MATH500"]
|
133 |
+
ds = [
|
134 |
+
load_dataset(
|
135 |
+
f"RLHFlow/Deepseek-{ds_name}-Test"
|
136 |
+
)['test'] for ds_name in ds_names
|
137 |
+
]
|
138 |
+
|
139 |
+
def make_step_rewards(logits, token_masks):
|
140 |
+
all_scores_res = []
|
141 |
+
for sample, token_mask in zip(logits, token_masks):
|
142 |
+
# sample: (seq_len, num_labels)
|
143 |
+
probs = sample[token_mask].softmax(dim=-1) # (num_steps, 2)
|
144 |
+
process_reward = probs[:, 1] - probs[:, 0] # (num_steps,)
|
145 |
+
# weighted sum to approx. min, highly recommend when BoN eval and Fine-tuning LLM
|
146 |
+
weight = torch.softmax(
|
147 |
+
-process_reward / 0.1,
|
148 |
+
dim=-1,
|
149 |
+
)
|
150 |
+
process_reward = weight * process_reward
|
151 |
+
all_scores_res.append(process_reward.cpu().tolist())
|
152 |
+
return all_scores_res
|
153 |
+
|
154 |
+
|
155 |
+
model_name = "jinachris/PURE-PRM-7B"
|
156 |
+
device = "auto"
|
157 |
+
|
158 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
159 |
+
model_name,
|
160 |
+
trust_remote_code=True,
|
161 |
+
)
|
162 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
163 |
+
model_name,
|
164 |
+
device_map=device,
|
165 |
+
torch_dtype=torch.bfloat16,
|
166 |
+
trust_remote_code=True,
|
167 |
+
).eval()
|
168 |
+
|
169 |
+
step_separator = "\n"
|
170 |
+
step_separator_token = tokenizer(
|
171 |
+
step_separator,
|
172 |
+
add_special_tokens=False,
|
173 |
+
return_tensors='pt',
|
174 |
+
)['input_ids']
|
175 |
+
|
176 |
+
|
177 |
+
for ds_item, ds_name in zip(ds, ds_names):
|
178 |
+
# sampled_ids = np.random.choice(range(len(ds_item)), size=100, replace=False)
|
179 |
+
correct = 0
|
180 |
+
total = 0
|
181 |
+
for idx in tqdm(range(len(ds_item)), desc=f"Processing questions in {ds_name}"):
|
182 |
+
question = ds_item['prompt'][idx]
|
183 |
+
answers = ds_item['answers'][idx]
|
184 |
+
labels = ds_item['label'][idx]
|
185 |
+
outcome_scores = []
|
186 |
+
|
187 |
+
question_ids = tokenizer(
|
188 |
+
question,
|
189 |
+
add_special_tokens=False,
|
190 |
+
return_tensors='pt',
|
191 |
+
)['input_ids']
|
192 |
+
for answer in tqdm(answers, desc="Processing answers"):
|
193 |
+
steps = [i.rstrip() for i in answer.split("\n\n")]
|
194 |
+
input_ids = question_ids.clone()
|
195 |
+
|
196 |
+
score_ids = []
|
197 |
+
for step in steps:
|
198 |
+
step_ids = tokenizer(
|
199 |
+
step,
|
200 |
+
add_special_tokens=False,
|
201 |
+
return_tensors='pt',
|
202 |
+
)['input_ids']
|
203 |
+
input_ids = torch.cat(
|
204 |
+
[input_ids, step_ids, step_separator_token],
|
205 |
+
dim=-1,
|
206 |
+
)
|
207 |
+
score_ids.append(input_ids.size(-1) - 1)
|
208 |
+
|
209 |
+
input_ids = input_ids.to(model.device, dtype=torch.long)
|
210 |
+
token_masks = torch.zeros_like(input_ids, dtype=torch.bool)
|
211 |
+
token_masks[0, score_ids] = True
|
212 |
+
assert torch.all(input_ids[token_masks].to("cpu") == step_separator_token)
|
213 |
+
|
214 |
+
with torch.no_grad():
|
215 |
+
logits = model(input_ids).logits
|
216 |
+
step_reward = make_step_rewards(logits, token_masks)
|
217 |
+
outcome_reward = torch.tensor(step_reward).sum(dim=-1)
|
218 |
+
|
219 |
+
# TODO: batch input & output
|
220 |
+
outcome_scores.append(outcome_reward.item())
|
221 |
+
|
222 |
+
best_idx = np.argmax(outcome_scores)
|
223 |
+
if labels[best_idx] == 1:
|
224 |
+
correct += 1
|
225 |
+
total += 1
|
226 |
+
print(f"Accuracy on {ds_name}: {correct / total}")
|
227 |
+
```
|
228 |
+
|
229 |
## Citation
|
230 |
|
231 |
If you find our work useful, we would appreciate it if you could cite our work:
|
232 |
|
233 |
```
|
234 |
+
@article{cheng2025stop,
|
235 |
+
title={Stop Summation: Min-Form Credit Assignment Is All Process Reward Model Needs for Reasoning},
|
236 |
+
author={Cheng, Jie and Qiao, Ruixi and Li, Lijun and Guo, Chao and Wang, Junle and Xiong, Gang and Lv, Yisheng and Wang, Fei-Yue},
|
237 |
+
journal={arXiv preprint arXiv:2504.15275},
|
|
|
|
|
238 |
year={2025}
|
239 |
}
|
240 |
```
|