test-time-scaling commited on
Commit
13fa47e
·
verified ·
1 Parent(s): 3455b0d

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +132 -0
README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # J1-7B-RL
2
+
3
+ Not yet finished!
4
+
5
+ ## Model Description
6
+
7
+ J1-7B-RL is an LLM-as-a-Judge model trained through a two-stage process of Supervised Fine-Tuning (SFT) followed by Reinforcement Learning (RL). This model is specifically designed to benefit from Simple Test-Time Scaling (STTS) techniques and serves as an improved preference judge for evaluating LLM outputs. It is the implementation of the model described in the paper "J1: Exploring Simple Test-Time Scaling for LLM-as-a-Judge".
8
+
9
+ ## Key Features
10
+
11
+ - **Enhanced Reflective Reasoning**: Trained to utilize reflective reasoning tokens optimally through a novel two-stage paradigm
12
+ - **STTS Compatibility**: Demonstrates superior scaling behavior under Simple Test-Time Scaling compared to previous LLM-as-a-Judge models
13
+ - **Performance Improvement**: Achieves 4.8% improvement in overall judgment performance and exhibits a 5.1% stronger scaling trend under STTS
14
+
15
+ ## Model Details
16
+
17
+ - **Base Model**: Qwen2.5-7B-Base
18
+ - **Training Procedure**:
19
+ - Stage 1: SFT on J1-SFT-53K dataset (curated from HelpSteer2, OffsetBias, WildGuard, and Magpie)
20
+ - Stage 2: RL using Reinforce++ algorithm on the English subset of the RISE dataset
21
+ - **Context Length**: 8192 tokens
22
+ - **Parameters**: 7 billion
23
+ - **Training Hardware**: NVIDIA H800 cluster
24
+
25
+ ## Evaluation Results
26
+
27
+ J1-7B-RL was evaluated on four diverse preference datasets and outperforms previous state-of-the-art models:
28
+
29
+ | Model | RewardBench | RewardMath | Anthropic Harmless | CodePrefBench | Overall |
30
+ |-------|-------------|------------|-------------------|---------------|---------|
31
+ | Llama3.1-8B-Instruct | 70.47 | 61.12 | 46.43 | 67.10 | 61.28 |
32
+ | Qwen2.5-7B-Instruct | 78.50 | 69.70 | 49.56 | 67.59 | 66.34 |
33
+ | Skywork-Critic-Llama3.1-8B | 88.86 | 66.51 | 58.61 | 60.57 | 68.64 |
34
+ | RISE-Judge-Qwen2.5-7B | 87.42 | 81.69 | 56.35 | 59.22 | 71.17 |
35
+ | J1-7B (SFT Only) | 85.01 | 82.40 | 53.88 | 49.20 | 67.62 |
36
+ | J1-7B (SFT + RL) | 86.91 | **90.15** | 59.05 | **67.80** | **75.98** |
37
+
38
+ This represents a significant improvement over previous state-of-the-art LLM-as-a-Judge models in the same size class.
39
+
40
+ ## Usage
41
+
42
+ J1-7B-RL can be used both conventionally and with STTS for enhanced performance:
43
+
44
+ ```python
45
+ from transformers import AutoModelForCausalLM, AutoTokenizer
46
+
47
+ # Load model
48
+ model = AutoModelForCausalLM.from_pretrained("test-time-scaling/J1_7B_RL")
49
+ tokenizer = AutoTokenizer.from_pretrained("test-time-scaling/J1_7B_RL")
50
+
51
+ # Example question and responses
52
+ query = "What are the advantages and disadvantages of remote work?"
53
+ response_a = "Remote work offers flexibility and eliminates commuting, but can lead to isolation and blurred work-life boundaries."
54
+ response_b = "Working remotely is convenient because you can work from anywhere, but sometimes it's hard to communicate with colleagues."
55
+
56
+ # Standard usage (without STTS)
57
+ prompt_template = "
58
+ Please act as an impartial judge and evaluate the quality of the responses provided
59
+ by two AI assistants to the user question displayed below. You should choose the assistant that follows the
60
+ user’s instructions and answers the user’s question better. Your evaluation should consider factors such as
61
+ the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your
62
+ evaluation by comparing the two responses and provide a short explanation. Avoid any position biases
63
+ and ensure that the order in which the responses were presented does not influence your decision. Do not
64
+ allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants.
65
+ Be as objective as possible. Please first analysis both of the answer step by step, directly point out the
66
+ position of error and output why it is an error in detail when finding error in analysis. If the question is
67
+ open-ended, directly point out why the rejected answer is worse than the chosen one. After providing your
68
+ explanation, output your final verdict by strictly following this format: ‘[[A]]’ if assistant A is better, ‘[[B]]’
69
+ if assistant B is better.
70
+ [User Question]
71
+ {instruction}
72
+ {{The Start of Assistant A’s Answer}}
73
+ {answer_a}
74
+ {{The End of Assistant A’s Answer}}
75
+ {{The Start of Assistant B’s Answer}}
76
+ {answer_b}
77
+ {{The End of Assistant B’s Answer}}
78
+ "
79
+
80
+
81
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
82
+ outputs = model.generate(**inputs, max_new_tokens=1024)
83
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ print(result)
85
+
86
+ # Usage with STTS
87
+ def apply_stts(model, tokenizer, query, response_a, response_b, num_waits=2):
88
+ prompt = f"Question:\n{query}\n\nAnswer A:\n{response_a}\n\nAnswer B:\n{response_b}\n\n<think>"
89
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
90
+
91
+ # Initial generation until </think> token
92
+ outputs = model.generate(
93
+ **inputs,
94
+ max_new_tokens=1024,
95
+ eos_token_id=tokenizer.encode("</think>")[0] # Stop at </think>
96
+ )
97
+
98
+ # Replace </think> with "wait" and continue generation
99
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
+ for i in range(num_waits):
101
+ prompt_with_thinking = result + " wait,"
102
+ inputs = tokenizer(prompt_with_thinking, return_tensors="pt").to(model.device)
103
+
104
+ if i == num_waits - 1:
105
+
106
+ continued = model.generate(
107
+ **inputs,
108
+ max_new_tokens=1024,
109
+ )
110
+
111
+ else:
112
+
113
+ continued = model.generate(
114
+ **inputs,
115
+ max_new_tokens=1024,
116
+ eos_token_id=tokenizer.encode("</think>")[0]
117
+ )
118
+
119
+ result = tokenizer.decode(continued[0], skip_special_tokens=True)
120
+
121
+ return result
122
+
123
+ stts_result = apply_stts(model, tokenizer, query, response_a, response_b, num_waits=2)
124
+ print(stts_result)
125
+
126
+ ```
127
+
128
+ ## License
129
+ CC-BY-NC-4.0
130
+
131
+ ## Acknowledgements
132
+ This work builds upon research in LLM-as-a-Judge, test-time scaling techniques, and reinforcement learning methodologies. We acknowledge the creators of the Qwen2.5 model series, the RISE dataset, the various evaluation benchmarks used to assess model performance and verl, openrlhf framework for performing RL and SFT training.