Token Classification
Safetensors
qwen2
jinachris commited on
Commit
0a96d22
·
verified ·
1 Parent(s): df929c4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +111 -16
README.md CHANGED
@@ -24,15 +24,6 @@ datasets:
24
  ## Requirements
25
  * `transformers>=4.40.0` for Qwen2.5-Math models. The latest version is recommended.
26
 
27
- > [!Warning]
28
- > <div align="center">
29
- > <b>
30
- > 🚨 This is a must because `transformers` integrated Qwen2.5 codes since `4.37.0`.
31
- > </b>
32
- > </div>
33
-
34
- For requirements on GPU memory and the respective throughput, see similar results of Qwen2 [here](https://qwen.readthedocs.io/en/latest/benchmark/speed_benchmark.html).
35
-
36
  ## Quick Start
37
 
38
  > [!Important]
@@ -45,7 +36,7 @@ For requirements on GPU memory and the respective throughput, see similar result
45
 
46
  ### 🤗 Hugging Face Transformers
47
 
48
- Here we show a code snippet to show you how to use our PRM with `transformers`:
49
 
50
  ```python
51
  import torch
@@ -129,17 +120,121 @@ print(step_reward) # [[0.796875, 0.185546875, -0.0625, 0.078125]]
129
  # torch.tensor(step_reward).sum(dim=-1)
130
  ```
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  ## Citation
133
 
134
  If you find our work useful, we would appreciate it if you could cite our work:
135
 
136
  ```
137
- @misc{cheng2025pure,
138
- title={Stop Gamma Decay: Min-Form Credit Assignment Is All Process Reward Model Needs for Reasoning},
139
- author={Jie Cheng and Lijun Li and Gang Xiong and Jing Shao and Yisheng Lv and Fei-Yue Wang},
140
- year={2025},
141
- howpublished={\url{https://tungsten-ink-510.notion.site/Stop-Gamma-Decay-Min-Form-Credit-Assignment-Is-All-Process-Reward-Model-Needs-for-Reasoning-19fcb6ed0184804eb07fd310b38af155?pvs=4}},
142
- note={Notion Blog}
143
  year={2025}
144
  }
145
  ```
 
24
  ## Requirements
25
  * `transformers>=4.40.0` for Qwen2.5-Math models. The latest version is recommended.
26
 
 
 
 
 
 
 
 
 
 
27
  ## Quick Start
28
 
29
  > [!Important]
 
36
 
37
  ### 🤗 Hugging Face Transformers
38
 
39
+ 1. Here we show a code snippet to show you how to use our PRM with `transformers`:
40
 
41
  ```python
42
  import torch
 
120
  # torch.tensor(step_reward).sum(dim=-1)
121
  ```
122
 
123
+ 2. Additionally, we share the code for BoN evalution on RLHFlow's data:
124
+
125
+ ```python
126
+ import numpy as np
127
+ import torch
128
+ from datasets import load_dataset
129
+ from tqdm import tqdm
130
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
131
+
132
+ ds_names = ["GSM8K", "MATH500"]
133
+ ds = [
134
+ load_dataset(
135
+ f"RLHFlow/Deepseek-{ds_name}-Test"
136
+ )['test'] for ds_name in ds_names
137
+ ]
138
+
139
+ def make_step_rewards(logits, token_masks):
140
+ all_scores_res = []
141
+ for sample, token_mask in zip(logits, token_masks):
142
+ # sample: (seq_len, num_labels)
143
+ probs = sample[token_mask].softmax(dim=-1) # (num_steps, 2)
144
+ process_reward = probs[:, 1] - probs[:, 0] # (num_steps,)
145
+ # weighted sum to approx. min, highly recommend when BoN eval and Fine-tuning LLM
146
+ weight = torch.softmax(
147
+ -process_reward / 0.1,
148
+ dim=-1,
149
+ )
150
+ process_reward = weight * process_reward
151
+ all_scores_res.append(process_reward.cpu().tolist())
152
+ return all_scores_res
153
+
154
+
155
+ model_name = "jinachris/PURE-PRM-7B"
156
+ device = "auto"
157
+
158
+ tokenizer = AutoTokenizer.from_pretrained(
159
+ model_name,
160
+ trust_remote_code=True,
161
+ )
162
+ model = AutoModelForTokenClassification.from_pretrained(
163
+ model_name,
164
+ device_map=device,
165
+ torch_dtype=torch.bfloat16,
166
+ trust_remote_code=True,
167
+ ).eval()
168
+
169
+ step_separator = "\n"
170
+ step_separator_token = tokenizer(
171
+ step_separator,
172
+ add_special_tokens=False,
173
+ return_tensors='pt',
174
+ )['input_ids']
175
+
176
+
177
+ for ds_item, ds_name in zip(ds, ds_names):
178
+ # sampled_ids = np.random.choice(range(len(ds_item)), size=100, replace=False)
179
+ correct = 0
180
+ total = 0
181
+ for idx in tqdm(range(len(ds_item)), desc=f"Processing questions in {ds_name}"):
182
+ question = ds_item['prompt'][idx]
183
+ answers = ds_item['answers'][idx]
184
+ labels = ds_item['label'][idx]
185
+ outcome_scores = []
186
+
187
+ question_ids = tokenizer(
188
+ question,
189
+ add_special_tokens=False,
190
+ return_tensors='pt',
191
+ )['input_ids']
192
+ for answer in tqdm(answers, desc="Processing answers"):
193
+ steps = [i.rstrip() for i in answer.split("\n\n")]
194
+ input_ids = question_ids.clone()
195
+
196
+ score_ids = []
197
+ for step in steps:
198
+ step_ids = tokenizer(
199
+ step,
200
+ add_special_tokens=False,
201
+ return_tensors='pt',
202
+ )['input_ids']
203
+ input_ids = torch.cat(
204
+ [input_ids, step_ids, step_separator_token],
205
+ dim=-1,
206
+ )
207
+ score_ids.append(input_ids.size(-1) - 1)
208
+
209
+ input_ids = input_ids.to(model.device, dtype=torch.long)
210
+ token_masks = torch.zeros_like(input_ids, dtype=torch.bool)
211
+ token_masks[0, score_ids] = True
212
+ assert torch.all(input_ids[token_masks].to("cpu") == step_separator_token)
213
+
214
+ with torch.no_grad():
215
+ logits = model(input_ids).logits
216
+ step_reward = make_step_rewards(logits, token_masks)
217
+ outcome_reward = torch.tensor(step_reward).sum(dim=-1)
218
+
219
+ # TODO: batch input & output
220
+ outcome_scores.append(outcome_reward.item())
221
+
222
+ best_idx = np.argmax(outcome_scores)
223
+ if labels[best_idx] == 1:
224
+ correct += 1
225
+ total += 1
226
+ print(f"Accuracy on {ds_name}: {correct / total}")
227
+ ```
228
+
229
  ## Citation
230
 
231
  If you find our work useful, we would appreciate it if you could cite our work:
232
 
233
  ```
234
+ @article{cheng2025stop,
235
+ title={Stop Summation: Min-Form Credit Assignment Is All Process Reward Model Needs for Reasoning},
236
+ author={Cheng, Jie and Qiao, Ruixi and Li, Lijun and Guo, Chao and Wang, Junle and Xiong, Gang and Lv, Yisheng and Wang, Fei-Yue},
237
+ journal={arXiv preprint arXiv:2504.15275},
 
 
238
  year={2025}
239
  }
240
  ```