Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# J1-7B-RL
|
2 |
+
|
3 |
+
Not yet finished!
|
4 |
+
|
5 |
+
## Model Description
|
6 |
+
|
7 |
+
J1-7B-RL is an LLM-as-a-Judge model trained through a two-stage process of Supervised Fine-Tuning (SFT) followed by Reinforcement Learning (RL). This model is specifically designed to benefit from Simple Test-Time Scaling (STTS) techniques and serves as an improved preference judge for evaluating LLM outputs. It is the implementation of the model described in the paper "J1: Exploring Simple Test-Time Scaling for LLM-as-a-Judge".
|
8 |
+
|
9 |
+
## Key Features
|
10 |
+
|
11 |
+
- **Enhanced Reflective Reasoning**: Trained to utilize reflective reasoning tokens optimally through a novel two-stage paradigm
|
12 |
+
- **STTS Compatibility**: Demonstrates superior scaling behavior under Simple Test-Time Scaling compared to previous LLM-as-a-Judge models
|
13 |
+
- **Performance Improvement**: Achieves 4.8% improvement in overall judgment performance and exhibits a 5.1% stronger scaling trend under STTS
|
14 |
+
|
15 |
+
## Model Details
|
16 |
+
|
17 |
+
- **Base Model**: Qwen2.5-7B-Base
|
18 |
+
- **Training Procedure**:
|
19 |
+
- Stage 1: SFT on J1-SFT-53K dataset (curated from HelpSteer2, OffsetBias, WildGuard, and Magpie)
|
20 |
+
- Stage 2: RL using Reinforce++ algorithm on the English subset of the RISE dataset
|
21 |
+
- **Context Length**: 8192 tokens
|
22 |
+
- **Parameters**: 7 billion
|
23 |
+
- **Training Hardware**: NVIDIA H800 cluster
|
24 |
+
|
25 |
+
## Evaluation Results
|
26 |
+
|
27 |
+
J1-7B-RL was evaluated on four diverse preference datasets and outperforms previous state-of-the-art models:
|
28 |
+
|
29 |
+
| Model | RewardBench | RewardMath | Anthropic Harmless | CodePrefBench | Overall |
|
30 |
+
|-------|-------------|------------|-------------------|---------------|---------|
|
31 |
+
| Llama3.1-8B-Instruct | 70.47 | 61.12 | 46.43 | 67.10 | 61.28 |
|
32 |
+
| Qwen2.5-7B-Instruct | 78.50 | 69.70 | 49.56 | 67.59 | 66.34 |
|
33 |
+
| Skywork-Critic-Llama3.1-8B | 88.86 | 66.51 | 58.61 | 60.57 | 68.64 |
|
34 |
+
| RISE-Judge-Qwen2.5-7B | 87.42 | 81.69 | 56.35 | 59.22 | 71.17 |
|
35 |
+
| J1-7B (SFT Only) | 85.01 | 82.40 | 53.88 | 49.20 | 67.62 |
|
36 |
+
| J1-7B (SFT + RL) | 86.91 | **90.15** | 59.05 | **67.80** | **75.98** |
|
37 |
+
|
38 |
+
This represents a significant improvement over previous state-of-the-art LLM-as-a-Judge models in the same size class.
|
39 |
+
|
40 |
+
## Usage
|
41 |
+
|
42 |
+
J1-7B-RL can be used both conventionally and with STTS for enhanced performance:
|
43 |
+
|
44 |
+
```python
|
45 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
46 |
+
|
47 |
+
# Load model
|
48 |
+
model = AutoModelForCausalLM.from_pretrained("test-time-scaling/J1_7B_RL")
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained("test-time-scaling/J1_7B_RL")
|
50 |
+
|
51 |
+
# Example question and responses
|
52 |
+
query = "What are the advantages and disadvantages of remote work?"
|
53 |
+
response_a = "Remote work offers flexibility and eliminates commuting, but can lead to isolation and blurred work-life boundaries."
|
54 |
+
response_b = "Working remotely is convenient because you can work from anywhere, but sometimes it's hard to communicate with colleagues."
|
55 |
+
|
56 |
+
# Standard usage (without STTS)
|
57 |
+
prompt_template = "
|
58 |
+
Please act as an impartial judge and evaluate the quality of the responses provided
|
59 |
+
by two AI assistants to the user question displayed below. You should choose the assistant that follows the
|
60 |
+
user’s instructions and answers the user’s question better. Your evaluation should consider factors such as
|
61 |
+
the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your
|
62 |
+
evaluation by comparing the two responses and provide a short explanation. Avoid any position biases
|
63 |
+
and ensure that the order in which the responses were presented does not influence your decision. Do not
|
64 |
+
allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants.
|
65 |
+
Be as objective as possible. Please first analysis both of the answer step by step, directly point out the
|
66 |
+
position of error and output why it is an error in detail when finding error in analysis. If the question is
|
67 |
+
open-ended, directly point out why the rejected answer is worse than the chosen one. After providing your
|
68 |
+
explanation, output your final verdict by strictly following this format: ‘[[A]]’ if assistant A is better, ‘[[B]]’
|
69 |
+
if assistant B is better.
|
70 |
+
[User Question]
|
71 |
+
{instruction}
|
72 |
+
{{The Start of Assistant A’s Answer}}
|
73 |
+
{answer_a}
|
74 |
+
{{The End of Assistant A’s Answer}}
|
75 |
+
{{The Start of Assistant B’s Answer}}
|
76 |
+
{answer_b}
|
77 |
+
{{The End of Assistant B’s Answer}}
|
78 |
+
"
|
79 |
+
|
80 |
+
|
81 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
82 |
+
outputs = model.generate(**inputs, max_new_tokens=1024)
|
83 |
+
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
84 |
+
print(result)
|
85 |
+
|
86 |
+
# Usage with STTS
|
87 |
+
def apply_stts(model, tokenizer, query, response_a, response_b, num_waits=2):
|
88 |
+
prompt = f"Question:\n{query}\n\nAnswer A:\n{response_a}\n\nAnswer B:\n{response_b}\n\n<think>"
|
89 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
90 |
+
|
91 |
+
# Initial generation until </think> token
|
92 |
+
outputs = model.generate(
|
93 |
+
**inputs,
|
94 |
+
max_new_tokens=1024,
|
95 |
+
eos_token_id=tokenizer.encode("</think>")[0] # Stop at </think>
|
96 |
+
)
|
97 |
+
|
98 |
+
# Replace </think> with "wait" and continue generation
|
99 |
+
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
100 |
+
for i in range(num_waits):
|
101 |
+
prompt_with_thinking = result + " wait,"
|
102 |
+
inputs = tokenizer(prompt_with_thinking, return_tensors="pt").to(model.device)
|
103 |
+
|
104 |
+
if i == num_waits - 1:
|
105 |
+
|
106 |
+
continued = model.generate(
|
107 |
+
**inputs,
|
108 |
+
max_new_tokens=1024,
|
109 |
+
)
|
110 |
+
|
111 |
+
else:
|
112 |
+
|
113 |
+
continued = model.generate(
|
114 |
+
**inputs,
|
115 |
+
max_new_tokens=1024,
|
116 |
+
eos_token_id=tokenizer.encode("</think>")[0]
|
117 |
+
)
|
118 |
+
|
119 |
+
result = tokenizer.decode(continued[0], skip_special_tokens=True)
|
120 |
+
|
121 |
+
return result
|
122 |
+
|
123 |
+
stts_result = apply_stts(model, tokenizer, query, response_a, response_b, num_waits=2)
|
124 |
+
print(stts_result)
|
125 |
+
|
126 |
+
```
|
127 |
+
|
128 |
+
## License
|
129 |
+
CC-BY-NC-4.0
|
130 |
+
|
131 |
+
## Acknowledgements
|
132 |
+
This work builds upon research in LLM-as-a-Judge, test-time scaling techniques, and reinforcement learning methodologies. We acknowledge the creators of the Qwen2.5 model series, the RISE dataset, the various evaluation benchmarks used to assess model performance and verl, openrlhf framework for performing RL and SFT training.
|