Update README.md
Browse files
README.md
CHANGED
@@ -120,111 +120,7 @@ print(step_reward) # [[0.796875, 0.185546875, -0.0625, 0.078125]]
|
|
120 |
# torch.tensor(step_reward).sum(dim=-1)
|
121 |
```
|
122 |
|
123 |
-
2.
|
124 |
-
|
125 |
-
```python
|
126 |
-
import numpy as np
|
127 |
-
import torch
|
128 |
-
from datasets import load_dataset
|
129 |
-
from tqdm import tqdm
|
130 |
-
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
131 |
-
|
132 |
-
ds_names = ["GSM8K", "MATH500"]
|
133 |
-
ds = [
|
134 |
-
load_dataset(
|
135 |
-
f"RLHFlow/Deepseek-{ds_name}-Test"
|
136 |
-
)['test'] for ds_name in ds_names
|
137 |
-
]
|
138 |
-
|
139 |
-
def make_step_rewards(logits, token_masks):
|
140 |
-
all_scores_res = []
|
141 |
-
for sample, token_mask in zip(logits, token_masks):
|
142 |
-
# sample: (seq_len, num_labels)
|
143 |
-
probs = sample[token_mask].softmax(dim=-1) # (num_steps, 2)
|
144 |
-
process_reward = probs[:, 1] - probs[:, 0] # (num_steps,)
|
145 |
-
# weighted sum to approx. min, highly recommend when BoN eval and Fine-tuning LLM
|
146 |
-
weight = torch.softmax(
|
147 |
-
-process_reward / 0.1,
|
148 |
-
dim=-1,
|
149 |
-
)
|
150 |
-
process_reward = weight * process_reward
|
151 |
-
all_scores_res.append(process_reward.cpu().tolist())
|
152 |
-
return all_scores_res
|
153 |
-
|
154 |
-
|
155 |
-
model_name = "jinachris/PURE-PRM-7B"
|
156 |
-
device = "auto"
|
157 |
-
|
158 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
159 |
-
model_name,
|
160 |
-
trust_remote_code=True,
|
161 |
-
)
|
162 |
-
model = AutoModelForTokenClassification.from_pretrained(
|
163 |
-
model_name,
|
164 |
-
device_map=device,
|
165 |
-
torch_dtype=torch.bfloat16,
|
166 |
-
trust_remote_code=True,
|
167 |
-
).eval()
|
168 |
-
|
169 |
-
step_separator = "\n"
|
170 |
-
step_separator_token = tokenizer(
|
171 |
-
step_separator,
|
172 |
-
add_special_tokens=False,
|
173 |
-
return_tensors='pt',
|
174 |
-
)['input_ids']
|
175 |
-
|
176 |
-
|
177 |
-
for ds_item, ds_name in zip(ds, ds_names):
|
178 |
-
# sampled_ids = np.random.choice(range(len(ds_item)), size=100, replace=False)
|
179 |
-
correct = 0
|
180 |
-
total = 0
|
181 |
-
for idx in tqdm(range(len(ds_item)), desc=f"Processing questions in {ds_name}"):
|
182 |
-
question = ds_item['prompt'][idx]
|
183 |
-
answers = ds_item['answers'][idx]
|
184 |
-
labels = ds_item['label'][idx]
|
185 |
-
outcome_scores = []
|
186 |
-
|
187 |
-
question_ids = tokenizer(
|
188 |
-
question,
|
189 |
-
add_special_tokens=False,
|
190 |
-
return_tensors='pt',
|
191 |
-
)['input_ids']
|
192 |
-
for answer in tqdm(answers, desc="Processing answers"):
|
193 |
-
steps = [i.rstrip() for i in answer.split("\n\n")]
|
194 |
-
input_ids = question_ids.clone()
|
195 |
-
|
196 |
-
score_ids = []
|
197 |
-
for step in steps:
|
198 |
-
step_ids = tokenizer(
|
199 |
-
step,
|
200 |
-
add_special_tokens=False,
|
201 |
-
return_tensors='pt',
|
202 |
-
)['input_ids']
|
203 |
-
input_ids = torch.cat(
|
204 |
-
[input_ids, step_ids, step_separator_token],
|
205 |
-
dim=-1,
|
206 |
-
)
|
207 |
-
score_ids.append(input_ids.size(-1) - 1)
|
208 |
-
|
209 |
-
input_ids = input_ids.to(model.device, dtype=torch.long)
|
210 |
-
token_masks = torch.zeros_like(input_ids, dtype=torch.bool)
|
211 |
-
token_masks[0, score_ids] = True
|
212 |
-
assert torch.all(input_ids[token_masks].to("cpu") == step_separator_token)
|
213 |
-
|
214 |
-
with torch.no_grad():
|
215 |
-
logits = model(input_ids).logits
|
216 |
-
step_reward = make_step_rewards(logits, token_masks)
|
217 |
-
outcome_reward = torch.tensor(step_reward).sum(dim=-1)
|
218 |
-
|
219 |
-
# TODO: batch input & output
|
220 |
-
outcome_scores.append(outcome_reward.item())
|
221 |
-
|
222 |
-
best_idx = np.argmax(outcome_scores)
|
223 |
-
if labels[best_idx] == 1:
|
224 |
-
correct += 1
|
225 |
-
total += 1
|
226 |
-
print(f"Accuracy on {ds_name}: {correct / total}")
|
227 |
-
```
|
228 |
|
229 |
## Citation
|
230 |
|
|
|
120 |
# torch.tensor(step_reward).sum(dim=-1)
|
121 |
```
|
122 |
|
123 |
+
2. For evaluation using Best-of-N method or on ProcessBench and PRMBench, refer to [our github repository](https://github.com/CJReinforce/PURE/tree/verl/PRM/eval).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
## Citation
|
126 |
|