Token Classification
Safetensors
qwen2
jinachris commited on
Commit
b0cce67
·
verified ·
1 Parent(s): 0a96d22

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -105
README.md CHANGED
@@ -120,111 +120,7 @@ print(step_reward) # [[0.796875, 0.185546875, -0.0625, 0.078125]]
120
  # torch.tensor(step_reward).sum(dim=-1)
121
  ```
122
 
123
- 2. Additionally, we share the code for BoN evalution on RLHFlow's data:
124
-
125
- ```python
126
- import numpy as np
127
- import torch
128
- from datasets import load_dataset
129
- from tqdm import tqdm
130
- from transformers import AutoModelForTokenClassification, AutoTokenizer
131
-
132
- ds_names = ["GSM8K", "MATH500"]
133
- ds = [
134
- load_dataset(
135
- f"RLHFlow/Deepseek-{ds_name}-Test"
136
- )['test'] for ds_name in ds_names
137
- ]
138
-
139
- def make_step_rewards(logits, token_masks):
140
- all_scores_res = []
141
- for sample, token_mask in zip(logits, token_masks):
142
- # sample: (seq_len, num_labels)
143
- probs = sample[token_mask].softmax(dim=-1) # (num_steps, 2)
144
- process_reward = probs[:, 1] - probs[:, 0] # (num_steps,)
145
- # weighted sum to approx. min, highly recommend when BoN eval and Fine-tuning LLM
146
- weight = torch.softmax(
147
- -process_reward / 0.1,
148
- dim=-1,
149
- )
150
- process_reward = weight * process_reward
151
- all_scores_res.append(process_reward.cpu().tolist())
152
- return all_scores_res
153
-
154
-
155
- model_name = "jinachris/PURE-PRM-7B"
156
- device = "auto"
157
-
158
- tokenizer = AutoTokenizer.from_pretrained(
159
- model_name,
160
- trust_remote_code=True,
161
- )
162
- model = AutoModelForTokenClassification.from_pretrained(
163
- model_name,
164
- device_map=device,
165
- torch_dtype=torch.bfloat16,
166
- trust_remote_code=True,
167
- ).eval()
168
-
169
- step_separator = "\n"
170
- step_separator_token = tokenizer(
171
- step_separator,
172
- add_special_tokens=False,
173
- return_tensors='pt',
174
- )['input_ids']
175
-
176
-
177
- for ds_item, ds_name in zip(ds, ds_names):
178
- # sampled_ids = np.random.choice(range(len(ds_item)), size=100, replace=False)
179
- correct = 0
180
- total = 0
181
- for idx in tqdm(range(len(ds_item)), desc=f"Processing questions in {ds_name}"):
182
- question = ds_item['prompt'][idx]
183
- answers = ds_item['answers'][idx]
184
- labels = ds_item['label'][idx]
185
- outcome_scores = []
186
-
187
- question_ids = tokenizer(
188
- question,
189
- add_special_tokens=False,
190
- return_tensors='pt',
191
- )['input_ids']
192
- for answer in tqdm(answers, desc="Processing answers"):
193
- steps = [i.rstrip() for i in answer.split("\n\n")]
194
- input_ids = question_ids.clone()
195
-
196
- score_ids = []
197
- for step in steps:
198
- step_ids = tokenizer(
199
- step,
200
- add_special_tokens=False,
201
- return_tensors='pt',
202
- )['input_ids']
203
- input_ids = torch.cat(
204
- [input_ids, step_ids, step_separator_token],
205
- dim=-1,
206
- )
207
- score_ids.append(input_ids.size(-1) - 1)
208
-
209
- input_ids = input_ids.to(model.device, dtype=torch.long)
210
- token_masks = torch.zeros_like(input_ids, dtype=torch.bool)
211
- token_masks[0, score_ids] = True
212
- assert torch.all(input_ids[token_masks].to("cpu") == step_separator_token)
213
-
214
- with torch.no_grad():
215
- logits = model(input_ids).logits
216
- step_reward = make_step_rewards(logits, token_masks)
217
- outcome_reward = torch.tensor(step_reward).sum(dim=-1)
218
-
219
- # TODO: batch input & output
220
- outcome_scores.append(outcome_reward.item())
221
-
222
- best_idx = np.argmax(outcome_scores)
223
- if labels[best_idx] == 1:
224
- correct += 1
225
- total += 1
226
- print(f"Accuracy on {ds_name}: {correct / total}")
227
- ```
228
 
229
  ## Citation
230
 
 
120
  # torch.tensor(step_reward).sum(dim=-1)
121
  ```
122
 
123
+ 2. For evaluation using Best-of-N method or on ProcessBench and PRMBench, refer to [our github repository](https://github.com/CJReinforce/PURE/tree/verl/PRM/eval).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  ## Citation
126