Update README.md
Browse files
README.md
CHANGED
@@ -69,8 +69,16 @@ We started the second-stage training on top of [EurusPRM-Stage1](https://hugging
|
|
69 |
We show an example leveraging **EurusPRM-Stage2** below:
|
70 |
|
71 |
```python
|
|
|
|
|
72 |
coef=0.001
|
73 |
-
d = {'query':'
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
}
|
75 |
model = AutoModelForCausalLM.from_pretrained('PRIME-RL/EurusPRM-Stage2')
|
76 |
tokenizer = AutoTokenizer.from_pretrained('PRIME-RL/EurusPRM-Stage2')
|
@@ -78,7 +86,7 @@ ref_model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-Math-7B-Instruct'
|
|
78 |
input_ids = tokenizer.apply_chat_template([
|
79 |
{"role": "user", "content": d["query"]},
|
80 |
{"role": "assistant", "content": "\n\n".join(d["answer"])},
|
81 |
-
], tokenize=True, add_generation_prompt=False)
|
82 |
attention_mask = input_ids!=tokenizer.pad_token_id
|
83 |
step_last_tokens = []
|
84 |
for step_num in range(0, len(d["answer"])+1):
|
@@ -92,8 +100,10 @@ for step_num in range(0, len(d["answer"])+1):
|
|
92 |
currect_ids = tokenizer.encode(conv,add_special_tokens=False)
|
93 |
step_last_tokens.append(len(currect_ids) - 2)
|
94 |
|
|
|
95 |
inputs = {'input_ids':input_ids,'attention_mask':attention_mask,'labels':input_ids}
|
96 |
-
|
|
|
97 |
|
98 |
def get_logps(model,inputs):
|
99 |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
@@ -105,12 +115,12 @@ def get_logps(model,inputs):
|
|
105 |
|
106 |
with torch.no_grad():
|
107 |
per_token_logps = get_logps(model, inputs)
|
108 |
-
|
109 |
|
110 |
raw_reward = per_token_logps - ref_per_token_logps
|
111 |
-
beta_reward = coef * raw_reward
|
112 |
beta_reward = beta_reward.cumsum(-1)
|
113 |
-
beta_reward = beta_reward.gather(dim=-1, index=step_last_tokens[1:])
|
114 |
print(beta_reward)
|
115 |
```
|
116 |
|
|
|
69 |
We show an example leveraging **EurusPRM-Stage2** below:
|
70 |
|
71 |
```python
|
72 |
+
import torch
|
73 |
+
from transformers import AutoTokenizer,AutoModelForCausalLM
|
74 |
coef=0.001
|
75 |
+
d = {'query':'Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$',
|
76 |
+
'answer':[
|
77 |
+
"Step 1: To convert the point (0,3) from rectangular coordinates to polar coordinates, we need to find the radius (r) and the angle theta (\u03b8).",
|
78 |
+
"Step 1: Find the radius (r). The radius is the distance from the origin (0,0) to the point (0,3). Since the x-coordinate is 0, the distance is simply the absolute value of the y-coordinate. So, r = |3| = 3.",
|
79 |
+
"Step 2: Find the angle theta (\u03b8). The angle theta is measured counterclockwise from the positive x-axis. Since the point (0,3) lies on the positive y-axis, the angle theta is 90 degrees or \u03c0\/2 radians.",
|
80 |
+
"Step 3: Write the polar coordinates. The polar coordinates are (r, \u03b8), where r > 0 and 0 \u2264 \u03b8 < 2\u03c0. In this case, r = 3 and \u03b8 = \u03c0\/2.\n\nTherefore, the polar coordinates of the point (0,3) are (3, \u03c0\/2).\n\n\n\\boxed{(3,\\frac{\\pi}{2})}"
|
81 |
+
]
|
82 |
}
|
83 |
model = AutoModelForCausalLM.from_pretrained('PRIME-RL/EurusPRM-Stage2')
|
84 |
tokenizer = AutoTokenizer.from_pretrained('PRIME-RL/EurusPRM-Stage2')
|
|
|
86 |
input_ids = tokenizer.apply_chat_template([
|
87 |
{"role": "user", "content": d["query"]},
|
88 |
{"role": "assistant", "content": "\n\n".join(d["answer"])},
|
89 |
+
], tokenize=True, add_generation_prompt=False,return_tensors='pt')
|
90 |
attention_mask = input_ids!=tokenizer.pad_token_id
|
91 |
step_last_tokens = []
|
92 |
for step_num in range(0, len(d["answer"])+1):
|
|
|
100 |
currect_ids = tokenizer.encode(conv,add_special_tokens=False)
|
101 |
step_last_tokens.append(len(currect_ids) - 2)
|
102 |
|
103 |
+
|
104 |
inputs = {'input_ids':input_ids,'attention_mask':attention_mask,'labels':input_ids}
|
105 |
+
label_mask = torch.tensor([[0]*step_last_tokens[0]+[1]*(input_ids.shape[-1]-step_last_tokens[0])])
|
106 |
+
step_last_tokens = torch.tensor([step_last_tokens])
|
107 |
|
108 |
def get_logps(model,inputs):
|
109 |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
|
|
|
115 |
|
116 |
with torch.no_grad():
|
117 |
per_token_logps = get_logps(model, inputs)
|
118 |
+
ref_per_token_logps = get_logps(ref_model,inputs)
|
119 |
|
120 |
raw_reward = per_token_logps - ref_per_token_logps
|
121 |
+
beta_reward = coef * raw_reward * label_mask[:,1:]
|
122 |
beta_reward = beta_reward.cumsum(-1)
|
123 |
+
beta_reward = beta_reward.gather(dim=-1, index=step_last_tokens[:,1:])
|
124 |
print(beta_reward)
|
125 |
```
|
126 |
|