haijunlv commited on
Commit
d018696
·
verified ·
1 Parent(s): 4564b11

upload model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ misc/intro.jpeg filter=lfs diff=lfs merge=lfs -text
37
+ misc/result.png filter=lfs diff=lfs merge=lfs -text
38
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023-2025 Shanghai AI Laboratory
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
202
+
203
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
204
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
205
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
206
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
207
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
208
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
209
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
210
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
211
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
212
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,3 +1,372 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img src="./misc/logo.png" width="400"/><br>
4
+
5
+
6
+ [![license](https://img.shields.io/github/license/InternLM/xtuner.svg)](./LICENSE)
7
+ [![xtuner](https://img.shields.io/badge/support-xtuner-blue)](https://github.com/InternLM/xtuner/)
8
+ [![lmdeploy](https://img.shields.io/badge/lmdeploy-blue)](https://github.com/InternLM/lmdeploy/)
9
+ [![sglang](https://img.shields.io/badge/sglang-blue)](https://github.com/sgl-project/sglang/)
10
+ [![vllm](https://img.shields.io/badge/vllm-blue)](https://github.com/vllm-project/vllm/)
11
+
12
+
13
+ [💻 Github](https://github.com/InternLM/POLAR) |
14
+ [📜 Paper](https://arxiv.org/abs/xxxxxx)<br>
15
+
16
+ [English](./README.md) |
17
+ [简体中文](./README_zh-CN.md)
18
+
19
+ </div>
20
+
21
+ # Introduction
22
+
23
+ POLAR represents a significant breakthrough in scalar-based reward models achieved through large-scale pre-training. It leverages the innovative **POL**icy Discrimin**A**tive Lea**R**ning (**POLAR**) paradigm——a scalable, high-level optimization objective——to effectively discriminate between policies using a large-scale synthetic corpora. Following pre-training, POLAR RMs are fine-tuned with minimal preference data, rapidly aligning with human preferences. Key features of POLAR include:
24
+
25
+ * **Innovative Pre-training Paradigm:** POLAR trains a reward model to discern identical policies and discriminate different ones. Unlike traditional reward modeling methods relying on absolute preferences, POLAR captures the relative difference between two policies, which is a scalable, high-level optimization objective suitable for modeling generic ranking relationships.
26
+
27
+ * **Tailored for Reinforcement Fine-tuning:** POLAR assigns rewards to LLM trajectories based on given references, perfectly aligning with the Reinforcement Fine-tuning (RFT) framework. POLAR provides a promising solution for applying RFT in generic scenarios.
28
+
29
+ * **Superior Performance and Generalization:** POLAR achieves state-of-the-art results on downstream reinforcement learning tasks, consistently delivering accurate and reliable reward signals that generalize effectively to unseen scenarios and significantly reducing reward hacking.
30
+
31
+ * **Easy to Customize:** Pre-trained checkpoints of POLAR are available, enabling researchers to conveniently fine-tune the RM for various customized scenarios, thus facilitating straightforward adaptation and expansion tailored to specific applications and experimental requirements.
32
+
33
+ <img src="./misc/intro.jpeg"/><br>
34
+
35
+
36
+ # POLAR-1.8B-Base
37
+
38
+ **POLAR-1.8B-Base** refers to the pre-trained-only checkpoint, ideal for customized fine-tuning according to specific preferences. The "ready-to-use" checkpoint **POLAR-1.8B** has been already fine-tuned on general preference data, making it suitable for immediate use in most scenarios.
39
+
40
+ We conducted a comprehensive evaluation of POLAR-1.8B via the Proximal Policy Optimization (PPO) algorithm. We evaluate the downstream RL performances of four different policy models using [OpenCompass](https://github.com/internLM/OpenCompass/). More details are available in our [Paper](https://arxiv.org/abs/xxxxxx).
41
+
42
+ <img src="./misc/result.png"/><br>
43
+
44
+ # Quick Start
45
+
46
+ ## Installation
47
+
48
+ You could employ the latest [xtuner](https://github.com/InternLM/xtuner) to fine-tune and use POLAR. Xtuner is an efficient, flexible and full-featured toolkit for fine-tuning LLMs.
49
+
50
+ - It is recommended to build a Python-3.10 virtual environment using conda
51
+
52
+ ```bash
53
+ conda create --name xtuner-env python=3.10 -y
54
+ conda activate xtuner-env
55
+ ```
56
+
57
+ - Install xtuner via pip
58
+
59
+ ```shell
60
+ pip install 'git+https://github.com/InternLM/xtuner.git@main#egg=xtuner[deepspeed]'
61
+ ```
62
+
63
+ ## Inference
64
+
65
+ We support reward inference through [lmdeploy](https://github.com/InternLM/lmdeploy/), [sglang](https://github.com/sgl-project/sglang/), and [vllm](https://github.com/vllm-project/vllm/). We recommend setting up a virtual environment with conda when using these inference engines to prevent potential dependency conflicts.
66
+
67
+ ### Data format
68
+
69
+ Unlike traditional reward models, POLAR requires an additional reference trajectory as a demonstration and evaluate candidate trajectories by measuring their consistency with the provided reference.
70
+
71
+ ```python
72
+ data = [
73
+ {
74
+ "prompt": [{"role": "user", "content": "What is the capital of China?"}],
75
+ "reference": [{"role": "assistant", "content": "Beijing."}],
76
+ "output": [{"role": "assistant", "content": "Beijing."}]
77
+ },
78
+ {
79
+ "prompt": [{"role": "user", "content": "What is the capital of China?"}],
80
+ "reference": [{"role": "assistant", "content": "Beijing."}],
81
+ "output": [{"role": "assistant", "content": "Shanghai."}]
82
+ }
83
+ ]
84
+ ```
85
+
86
+ ### Inference with transformers
87
+
88
+ #### Reward request
89
+ To load the POLAR model using transformers, use the following code to get rewards:
90
+
91
+ ```python
92
+ from transformers import AutoModel, AutoTokenizer
93
+ from xtuner.utils import RewardModelClient
94
+
95
+ model_name = 'internlm/POLAR-1_8B'
96
+
97
+ model = AutoModel.from_pretrained(
98
+ model_name,
99
+ device_map="cuda",
100
+ trust_remote_code=True
101
+ )
102
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
103
+
104
+ client = RewardModelClient(model_name)
105
+ encoded_data = client.encode(data)
106
+ batch = tokenizer(encoded_data, return_tensors='pt', padding=True).to('cuda')
107
+ outputs = model(**batch)
108
+ rewards = outputs[0].squeeze(-1).cpu().tolist()
109
+ print(rewards)
110
+ ```
111
+
112
+ ### Inference with lmdeploy
113
+
114
+ [LMDeploy](https://github.com/InternLM/lmdeploy) is a toolkit for compressing, deploying, and serving LLMs.
115
+
116
+ #### Requirements
117
+
118
+ - lmdeploy >= 0.9.1
119
+
120
+ #### Server Launch
121
+
122
+ ```bash
123
+ lmdeploy serve api_server internlm/POLAR-1_8B --backend pytorch --server-port 30000
124
+ ```
125
+ #### Client Request
126
+
127
+ ```python
128
+ from xtuner.utils import RewardModelClient
129
+
130
+ client = RewardModelClient("internlm/POLAR-1_8B",
131
+ server_type="lmdeploy",
132
+ server_address="127.0.0.1:30000")
133
+
134
+ # Request rewards directly
135
+ rewards = client(data)
136
+ print(rewards)
137
+
138
+ # First encode data and then get rewards via the request function.
139
+ encoded_data = client.encode(data)
140
+ rewards = client.lmdeploy_request_reward(encoded_data)
141
+ print(rewards)
142
+ ```
143
+
144
+ ### Inference with sglang
145
+
146
+ #### Requirements
147
+
148
+ - sglang >= 0.4.3.post4
149
+
150
+ #### Server Launch
151
+
152
+ ```bash
153
+ python3 -m sglang.launch_server --model internlm/POLAR-1_8B --trust-remote-code --is-embedding --dp 4 --tp 2 --mem-fraction-static 0.9 --port 30000
154
+ ```
155
+
156
+ #### Client Request
157
+
158
+ ```python
159
+ from xtuner.utils import RewardModelClient
160
+
161
+ client = RewardModelClient("internlm/POLAR-1_8B",
162
+ server_type="sglang",
163
+ server_address="127.0.0.1:30000")
164
+
165
+ # Request rewards directly
166
+ rewards = client(data)
167
+ print(rewards)
168
+
169
+ # First encode data and then get rewards via the request function.
170
+ encoded_data = client.encode(data)
171
+ rewards = client.sglang_request_reward(encoded_data)
172
+ print(rewards)
173
+ ```
174
+
175
+ ### Inference with vllm
176
+
177
+ #### Requirements
178
+
179
+ - vllm >= 0.8.0
180
+
181
+ #### Server Launch
182
+
183
+ ```bash
184
+ vllm serve internlm/POLAR-1_8B --task=reward --trust-remote-code --tensor-parallel-size=2 --port 30000
185
+ ```
186
+
187
+ #### Client Request
188
+
189
+ ```python
190
+ from xtuner.utils import RewardModelClient
191
+
192
+ client = RewardModelClient("internlm/POLAR-1_8B",
193
+ server_type="vllm",
194
+ server_address="127.0.0.1:30000")
195
+
196
+ # Request rewards directly
197
+ rewards = client(data)
198
+ print(rewards)
199
+
200
+ # First encode data and then get rewards via the request function.
201
+ encoded_data = client.encode(data)
202
+ rewards = client.vllm_request_reward(encoded_data)
203
+ print(rewards)
204
+ ```
205
+
206
+ ## Fine-tune
207
+
208
+ ### Requirements
209
+
210
+ - flash_attn
211
+ - tensorboard
212
+
213
+ ### Data format
214
+
215
+ Unlike traditional reward models, POLAR requires an additional reference trajectory as a demonstration during fine-tuning, along with a chosen trajectory and a rejected trajectory. You can construct your fine-tuning data in a `train.jsonl` file, formatted as follows:
216
+
217
+ ```json
218
+ {
219
+ "prompt": [{"role": "user", "content": "What is the capital of China?"}],
220
+ "reference": [{"role": "assistant", "content": "Beijing."}],
221
+ "chosen": [{"role": "assistant", "content": "Beijing."}],
222
+ "rejected": [{"role": "assistant", "content": "Shanghai."}]
223
+ }
224
+ ```
225
+
226
+ ### Training steps
227
+
228
+ - **Step 0:** Prepare the config. We provide examplar ready-to-use configs [here](https://github.com/InternLM/POLAR/blob/main/examples/xtuner_configs/POLAR_1_8B_full_varlenattn_custom_dataset.py). If the provided configs cannot meet the requirements, please copy the provided config and do modification following the [xtuner guideline](https://github.com/InternLM/xtuner/blob/main/docs/en/get_started/quickstart.md). For more details of reward model training settings, please see the xtuner [reward model guideline](https://github.com/InternLM/xtuner/blob/main/docs/en/reward_model/modify_settings.md).
229
+
230
+ - **Step 1:** Start fine-tuning.
231
+
232
+ ```shell
233
+ xtuner train ${CONFIG_FILE_PATH}
234
+ ```
235
+
236
+ For example, you can start the fine-tuning of POLAR-1_8B-Base by
237
+
238
+ ```shell
239
+ # On a single GPU
240
+ xtuner train /path/to/POLAR_1_8B_full_varlenattn_custom_dataset.py --deepspeed deepspeed_zero2
241
+
242
+ # On multiple GPUs
243
+ NPROC_PER_NODE=${GPU_NUM} xtuner train /path/to/POLAR_1_8B_full_varlenattn_custom_dataset.py --deepspeed deepspeed_zero2
244
+ ```
245
+
246
+ Here, `--deepspeed` means using [DeepSpeed](https://github.com/microsoft/DeepSpeed) to optimize the training. Xtuner comes with several integrated strategies including ZeRO-1, ZeRO-2, and ZeRO-3. If you wish to disable this feature, simply remove this argument.
247
+
248
+ - **Step 2:** Convert the saved PTH model (if using DeepSpeed, it will be a directory) to Hugging Face model, by
249
+
250
+ ```shell
251
+ xtuner convert pth_to_hf ${CONFIG_FILE_PATH} ${PTH} ${SAVE_PATH}
252
+ ```
253
+
254
+ # Examples
255
+
256
+ ## Closed-ended questions
257
+
258
+ ```python
259
+ from xtuner.utils import RewardModelClient
260
+
261
+ prompt = "How many 'r's are there in the word 'strawberry'?"
262
+ reference = "There are 3 'r's in the word 'strawberry'. Here's how we can count them: 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. So, the answer is 3."
263
+ outputs = [
264
+ # Same as the reference response.
265
+ "There are 3 'r's in the word 'strawberry'. Here's how we can count them: 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. So, the answer is 3.",
266
+ # Correct answer with correct thoughts.
267
+ "Let's count the 'r's in 'strawberry': 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. There are three 'r's, so the answer is three.",
268
+ # Wrong answer with wrong thoughts.
269
+ "Let's count the 'r's in 'strawberry': 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. There are two 'r's, so the answer is two.",
270
+ # Wrong answer with correct thoughts.
271
+ "Let's count the 'r's in 'strawberry': 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. There are three 'r's, so the answer is two.",
272
+ # Correct answer with wrong thoughts.
273
+ "Let's count the 'r's in 'strawberry': 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. There are two 'r's, so the answer is three.",
274
+ # Correct answer without thoughts.
275
+ "There are 3 'r's in the word 'strawberry'.",
276
+ # Wrong answer without thoughts.
277
+ "There are 2 'r's in the word 'strawberry'.",
278
+ ]
279
+ data = [{"prompt": prompt, "reference": reference, "output": output} for output in outputs]
280
+
281
+ client = RewardModelClient("internlm/POLAR-7B", server_type="sglang", server_address="127.0.0.1:30000")
282
+ rewards = client(data)
283
+
284
+ sorted_res = sorted(zip(outputs, rewards), key=lambda x: x[1], reverse=True)
285
+
286
+ for output, reward in sorted_res:
287
+ print(f"Output: {output}\nReward: {reward}\n")
288
+ ```
289
+
290
+ ```txt
291
+ Output: There are 3 'r's in the word 'strawberry'. Here's how we can count them: 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. So, the answer is 3.
292
+ Reward: 0.054595947265625
293
+
294
+ Output: Let's count the 'r's in 'strawberry': 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. There are three 'r's, so the answer is three.
295
+ Reward: -2.005859375
296
+
297
+ Output: There are 3 'r's in the word 'strawberry'.
298
+ Reward: -6.70703125
299
+
300
+ Output: Let's count the 'r's in 'strawberry': 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. There are two 'r's, so the answer is three.
301
+ Reward: -7.10546875
302
+
303
+ Output: Let's count the 'r's in 'strawberry': 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. There are three 'r's, so the answer is two.
304
+ Reward: -7.1328125
305
+
306
+ Output: Let's count the 'r's in 'strawberry': 's', 't', 'r', 'a', 'w', 'b', 'e', 'r', 'r', 'y'. There are two 'r's, so the answer is two.
307
+ Reward: -8.46875
308
+
309
+ Output: There are 2 'r's in the word 'strawberry'.
310
+ Reward: -10.8203125
311
+ ```
312
+
313
+ ## Open-ended questions
314
+ ```python
315
+ from xtuner.utils import RewardModelClient
316
+
317
+ prompt = "Summarize the first book of Frank Herbert’s Dune in one witty short sentence."
318
+ reference = "Royal teen discovers that life’s a beach—minus the ocean, plus spice, giant sandworms and deadly politics."
319
+ outputs = [
320
+ # Same as the reference response.
321
+ "Royal teen discovers that life’s a beach—minus the ocean, plus spice, giant sandworms and deadly politics.",
322
+ # Closely resembles the reference response but includes factual errors.
323
+ "Royal teen discovers that life’s a beach—minus the ocean, plus magic, dark wizards and deadly politics.",
324
+ # A distinct yet concise and witty summary that draws analogies from other dramas—markedly different from the reference response.
325
+ "Young noble’s move to desert planet turns into galactic Game of Thrones with fewer dragons, more worms.",
326
+ # A concise summary, but lacking wit—fails to meet the requirement.
327
+ "A noble family’s fall sparks a young heir’s rise as a leader on a harsh desert planet governed by prophecy and survival.",
328
+ # A witty summary, but overly long—fails to meet the requirement.
329
+ "Paul Atreides loses his father, gains prophetic powers, learns to ride a sandworm, leads a holy war, and discovers that being the chosen one comes with a lot of blood, sand, and questionable decisions.",
330
+ # A concise and witty summary that draws from multiple Dune books rather than just the first—fails to follow the instruction.
331
+ "Boy gets planet, becomes god, loses soul — family drama ensues across galaxies."
332
+ ]
333
+ data = [{"prompt": prompt, "reference": reference, "output": output} for output in outputs]
334
+
335
+ client = RewardModelClient("internlm/POLAR-7B", server_type="sglang", server_address="127.0.0.1:30000")
336
+ rewards = client(data)
337
+
338
+ sorted_res = sorted(zip(outputs, rewards), key=lambda x: x[1], reverse=True)
339
+
340
+ for output, reward in sorted_res:
341
+ print(f"Output: {output}\nReward: {reward}\n")
342
+ ```
343
+
344
+ ```txt
345
+ Output: Royal teen discovers that life’s a beach—minus the ocean, plus spice, giant sandworms and deadly politics.
346
+ Reward: 0.466552734375
347
+
348
+ Output: Young noble’s move to desert planet turns into galactic Game of Thrones with fewer dragons, more worms.
349
+ Reward: -6.91796875
350
+
351
+ Output: Royal teen discovers that life’s a beach—minus the ocean, plus magic, dark wizards and deadly politics.
352
+ Reward: -7.70703125
353
+
354
+ Output: Paul Atreides loses his father, gains prophetic powers, learns to ride a sandworm, leads a holy war, and discovers that being the chosen one comes with a lot of blood, sand, and questionable decisions.
355
+ Reward: -8.4296875
356
+
357
+ Output: A noble family’s fall sparks a young heir’s rise as a leader on a harsh desert planet governed by prophecy and survival.
358
+ Reward: -8.6484375
359
+
360
+ Output: Boy gets planet, becomes god, loses soul — family drama ensues across galaxies.
361
+ Reward: -10.359375
362
+ ```
363
+
364
+ # License
365
+
366
+ Code and model weights are licensed under Apache-2.0.
367
+
368
+ # Citation
369
+
370
+ ```
371
+ TBC
372
+ ```
README_zh-CN.md ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img src="./misc/logo.png" width="400"/><br>
4
+
5
+
6
+ [![license](https://img.shields.io/github/license/InternLM/xtuner.svg)](./LICENSE)
7
+ [![xtuner](https://img.shields.io/badge/support-xtuner-blue)](https://github.com/InternLM/xtuner/)
8
+ [![lmdeploy](https://img.shields.io/badge/lmdeploy-blue)](https://github.com/InternLM/lmdeploy/)
9
+ [![sglang](https://img.shields.io/badge/sglang-blue)](https://github.com/sgl-project/sglang/)
10
+ [![vllm](https://img.shields.io/badge/vllm-blue)](https://github.com/vllm-project/vllm/)
11
+
12
+
13
+ [💻 Github](https://github.com/InternLM/POLAR) |
14
+ [📜 论文](https://arxiv.org/abs/xxxxxx)<br>
15
+
16
+ [English](./README.md) |
17
+ [简体中文](./README_zh-CN.md)
18
+
19
+ </div>
20
+
21
+ # 简介
22
+
23
+ POLAR 是一个经过大规模预训练的奖励模型,在训练范式和模型性能上取得了重大突破。我们利用全新的策略判别学习方法(Policy Discriminative Learning,POLAR),使用大规模合成语料进行高效扩展预训练,使奖励模型能够有效区分不同的语言模型和策略分布。经过预训练的 POLAR 可通过少量的偏好数据进行微调,以快速对齐人类偏好。POLAR 的主要特点包括:
24
+
25
+ * **全新的预训练范式**:POLAR 让奖励模型学会识别相同的策略并区分不同的策略。与传统的依赖绝对偏好的奖励建模方法不同,POLAR 能够学习两个策略之间的相对差异,是一种可扩展的、高层次的优化目标。
26
+
27
+ * **专为强化学习微调(RFT)设计:** POLAR 根据给定的参考答案为语言模型的输出打分,完美契合强化学习微调(RFT)框架,为强化学习微调在通用场景的应用提供了一种有效解决方案。
28
+
29
+ * **卓越的性能与泛化能力:** POLAR 在下游强化学习任务中展现出领先的水平,可稳定地提供准确可靠的奖励信号。POLAR 具有极强的泛化能力,可有效泛化到分布外场景,并显著减少奖励黑客(Reward Hacking)的现象。
30
+
31
+ * **易于定制化:** 我们提供了 POLAR 的预训练权重(POLAR-Base)。研究人员可以根据自身需求,便捷地对其进行微调以适配各种定制化场景。
32
+
33
+ <br><img src="./misc/intro.jpeg"/><br>
34
+
35
+
36
+ # POLAR-1.8B-Base
37
+
38
+ **POLAR-1.8B-Base** 是仅经过预训练阶段的权重,适合根据特定需求进行微调。**POLAR-1.8B** 是经过偏好微调的奖励模型,可开箱即用,适用于大部分通用场景。
39
+
40
+ 我们通过 Proximal Policy Optimization(PPO)算法对 POLAR 的使用效果进行了验证,评测了四种语言模型的下游强化学习性能,评测工具是 [OpenCompass](https://github.com/internLM/OpenCompass/) 。详细信息请参阅[论文](https://arxiv.org/abs/xxxxxx)。
41
+
42
+ <img src="./misc/result.png"/><br>
43
+
44
+ # 快速开始
45
+
46
+ ## 安装
47
+
48
+ 推荐使用最新的 [xtuner](https://github.com/InternLM/xtuner) 来微调和使用 POLAR。xtuner 是一个高效、灵活、具有多种使用特性的语言模型微调工具。
49
+
50
+ - 建议使用 conda 创建 Python-3.10 虚拟环境:
51
+
52
+ ```bash
53
+ conda create --name xtuner-env python=3.10 -y
54
+ conda activate xtuner-env
55
+ ```
56
+
57
+ - 通过 pip 安装 xtuner:
58
+
59
+ ```shell
60
+ pip install 'git+https://github.com/InternLM/xtuner.git@main#egg=xtuner[deepspeed]'
61
+ ```
62
+
63
+ ## 推理
64
+
65
+ 我们支持通过 [lmdeploy](https://github.com/InternLM/lmdeploy/)、[sglang](https://github.com/sgl-project/sglang/)、[vllm](https://github.com/vllm-project/vllm/) 对 POLAR 进行推理并获取奖励信号。建议在使用这些推理引擎时,创建 conda 虚拟环境,以避免可能出现的依赖冲突问题。
66
+
67
+ ### 数据格式
68
+
69
+ 与传统奖励模型不同,POLAR 需要额外的参考答案。POLAR 对模型输出轨迹与参考答案的一致性进行评估,并给出奖励分数。
70
+
71
+ ```python
72
+ data = [
73
+ {
74
+ "prompt": [{"role": "user", "content": "What is the capital of China?"}],
75
+ "reference": [{"role": "assistant", "content": "Beijing."}],
76
+ "output": [{"role": "assistant", "content": "Beijing."}]
77
+ },
78
+ {
79
+ "prompt": [{"role": "user", "content": "What is the capital of China?"}],
80
+ "reference": [{"role": "assistant", "content": "Beijing."}],
81
+ "output": [{"role": "assistant", "content": "Shanghai."}]
82
+ }
83
+ ]
84
+ ```
85
+
86
+ ### 使用 transformers 进行推理
87
+
88
+ #### 示例代码
89
+
90
+ ```python
91
+ from transformers import AutoModel, AutoTokenizer
92
+ from xtuner.utils import RewardModelClient
93
+
94
+ model_name = 'internlm/POLAR-1_8B'
95
+
96
+ model = AutoModel.from_pretrained(
97
+ model_name,
98
+ device_map="cuda",
99
+ trust_remote_code=True
100
+ )
101
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
102
+
103
+ client = RewardModelClient(model_name)
104
+ encoded_data = client.encode(data)
105
+ batch = tokenizer(encoded_data, return_tensors='pt', padding=True).to('cuda')
106
+ outputs = model(**batch)
107
+ rewards = outputs[0].squeeze(-1).cpu().tolist()
108
+ print(rewards)
109
+ ```
110
+
111
+ ### 使用 lmdeploy 进行推理
112
+
113
+ [LMDeploy](https://github.com/InternLM/lmdeploy) 是一个高效压缩��部署语言模型的工具。
114
+
115
+ #### 环境依赖
116
+
117
+ - lmdeploy >= 0.9.1
118
+
119
+ #### 启动服务端
120
+
121
+ ```bash
122
+ lmdeploy serve api_server internlm/POLAR-1_8B --backend pytorch --server-port 30000
123
+ ```
124
+ #### 客户端请求示例
125
+
126
+ ```python
127
+ from xtuner.utils import RewardModelClient
128
+
129
+ client = RewardModelClient("internlm/POLAR-1_8B",
130
+ server_type="lmdeploy",
131
+ server_address="127.0.0.1:30000")
132
+
133
+ # Request rewards directly
134
+ rewards = client(data)
135
+ print(rewards)
136
+
137
+ # First encode data and then get rewards via the request function.
138
+ encoded_data = client.encode(data)
139
+ rewards = client.lmdeploy_request_reward(encoded_data)
140
+ print(rewards)
141
+ ```
142
+
143
+ ### 使用 sglang 进行推理
144
+
145
+ #### 环境依赖
146
+
147
+ - sglang >= 0.4.3.post4
148
+
149
+ #### 启动服务端
150
+
151
+ ```bash
152
+ python3 -m sglang.launch_server --model internlm/POLAR-1_8B --trust-remote-code --is-embedding --dp 4 --tp 2 --mem-fraction-static 0.9 --port 30000
153
+ ```
154
+
155
+ #### 客户端请求示例
156
+
157
+ ```python
158
+ from xtuner.utils import RewardModelClient
159
+
160
+ client = RewardModelClient("internlm/POLAR-1_8B",
161
+ server_type="sglang",
162
+ server_address="127.0.0.1:30000")
163
+
164
+ # Request rewards directly
165
+ rewards = client(data)
166
+ print(rewards)
167
+
168
+ # First encode data and then get rewards via the request function.
169
+ encoded_data = client.encode(data)
170
+ rewards = client.sglang_request_reward(encoded_data)
171
+ print(rewards)
172
+ ```
173
+
174
+ ### 使用 vllm 进行推理
175
+
176
+ #### 环境依赖
177
+
178
+ - vllm >= 0.8.0
179
+
180
+ #### 启动服务端
181
+
182
+ ```bash
183
+ vllm serve internlm/POLAR-1_8B --task=reward --trust-remote-code --tensor-parallel-size=2 --port 30000
184
+ ```
185
+
186
+ #### 客户端请求示例
187
+
188
+ ```python
189
+ from xtuner.utils import RewardModelClient
190
+
191
+ client = RewardModelClient("internlm/POLAR-1_8B",
192
+ server_type="vllm",
193
+ server_address="127.0.0.1:30000")
194
+
195
+ # Request rewards directly
196
+ rewards = client(data)
197
+ print(rewards)
198
+
199
+ # First encode data and then get rewards via the request function.
200
+ encoded_data = client.encode(data)
201
+ rewards = client.vllm_request_reward(encoded_data)
202
+ print(rewards)
203
+ ```
204
+
205
+ ## 偏好微调
206
+
207
+ ### 环境依赖
208
+
209
+ - flash_attn
210
+ - tensorboard
211
+
212
+ ### 数据格式
213
+
214
+ 与传统的奖励模型不同,除了 chosen 轨迹和 rejected 轨迹,POLAR 在微调过程中还需要一个额外的参考答案作为示范。你可以通过构建一个 `train.jsonl` 的文件来准备微调数据,格式如下:
215
+
216
+ ```json
217
+ {
218
+ "prompt": [{"role": "user", "content": "What is the capital of China?"}],
219
+ "reference": [{"role": "assistant", "content": "Beijing."}],
220
+ "chosen": [{"role": "assistant", "content": "Beijing."}],
221
+ "rejected": [{"role": "assistant", "content": "Shanghai."}]
222
+ }
223
+ ```
224
+
225
+ ### 训练步骤
226
+
227
+ - **第一步:** 准备配置文件。我们提供了可直接使用的[示例配置](https://github.com/InternLM/POLAR/blob/main/examples/xtuner_configs/POLAR_1_8B_full_varlenattn_custom_dataset.py)。如果需要进一步对超参进行修改,请复制一份示例配置文件,并根据 [xtuner 使用指南](https://github.com/InternLM/xtuner/blob/main/docs/en/get_started/quickstart.md) 进行修改。有关奖励模型训练设置的更多信息,请参考 [xtuner 奖励模型](https://github.com/InternLM/xtuner/blob/main/docs/en/reward_model/modify_settings.md)。
228
+
229
+ - **第二步:** 启动微调。
230
+
231
+ ```shell
232
+ xtuner train ${CONFIG_FILE_PATH}
233
+ ```
234
+
235
+ 例如,你可以按照如下的方式微调 POLAR-1_8B-Base:
236
+ ```shell
237
+ # On a single GPU
238
+ xtuner train /path/to/POLAR_1_8B_full_varlenattn_custom_dataset.py --deepspeed deepspeed_zero2
239
+
240
+ # On multiple GPUs
241
+ NPROC_PER_NODE=${GPU_NUM} xtuner train /path/to/POLAR_1_8B_full_varlenattn_custom_dataset.py --deepspeed deepspeed_zero2
242
+ ```
243
+
244
+ 这里,`--deepspeed` 表示使用 [DeepSpeed](https://github.com/microsoft/DeepSpeed) 来加速训练。xtuner 内置了多种 DeepSpeed 策略,包括 ZeRO-1、ZeRO-2 和 ZeRO-3。如果您想禁用此功能,只需移除此参数即可。
245
+
246
+ - **第三步:** 将保存的 PTH 模型(若使用 DeepSpeed,则保存结果会是一个目录)转换为 HuggingFace 模型,命令如下:
247
+
248
+ ```shell
249
+ xtuner convert pth_to_hf ${CONFIG_FILE_PATH} ${PTH} ${SAVE_PATH}
250
+ ```
251
+ <br>
252
+
253
+ # 效果示例
254
+
255
+ ## 客观问答
256
+
257
+ ```python
258
+ from xtuner.utils import RewardModelClient
259
+
260
+ prompt = "单词“strawberry”中有几个“r”?"
261
+ reference = "单词“strawberry”中包含3个字母“r”。我们可以逐字母数一下:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。因此,答案是3。"
262
+ outputs = [
263
+ # 与参考完全一致
264
+ "单词“strawberry”中包含3个字母“r”。我们可以逐字母数一下:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。因此,答案是3。",
265
+ # 思路正确,答案正确
266
+ "我们来数一数单词“strawberry”中有几个“r”:“s��、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。这里一共有三个“r”,因此答案是三。",
267
+ # 思路错误,答案错误
268
+ "我们来数一数单词“strawberry”中有几个“r”:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。这里一共有两个“r”,因此答案是二。",
269
+ # 思路错误,答案正确
270
+ "我们来数一数单词“strawberry”中有几个“r”:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。这里一共有两个“r”,因此答案是三。",
271
+ # 思路正确,答案错误
272
+ "我们来数一数单词“strawberry”中有几个“r”:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。这里一共有三个“r”,因此答案是二。",
273
+ # 答案正确
274
+ "单词“strawberry”中有3个“r”",
275
+ # 答案错误
276
+ "单词“strawberry”中有2个“r”"
277
+ ]
278
+ data = [{"prompt": prompt, "reference": reference, "output": output} for output in outputs]
279
+
280
+ client = RewardModelClient("internlm/POLAR-7B", server_type="sglang", server_address="127.0.0.1:30000")
281
+ rewards = client(data)
282
+
283
+ sorted_res = sorted(zip(outputs, rewards), key=lambda x: x[1], reverse=True)
284
+
285
+ for output, reward in sorted_res:
286
+ print(f"Output: {output}\nReward: {reward}\n")
287
+ ```
288
+
289
+ ```txt
290
+ Output: 单词“strawberry”中包含3个字母“r”。我们可以逐字母数一下:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。因此,答案是3。
291
+ Reward: -1.5380859375
292
+
293
+ Output: 我们来数一数单词“strawberry”中有几个“r”:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。这里一共有三个“r”,因此答案是三。
294
+ Reward: -2.767578125
295
+
296
+ Output: 单词“strawberry”中有3个“r”
297
+ Reward: -7.45703125
298
+
299
+ Output: 我们来数一数单词“strawberry”中有几个“r”:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。这里一共有三个“r”,因此答案是二。
300
+ Reward: -7.6328125
301
+
302
+ Output: 我们来数一数单词“strawberry”中有几个“r”:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。这里一共有两个“r”,因此答案是三。
303
+ Reward: -8.65625
304
+
305
+ Output: 我们来数一数单词“strawberry”中有几个“r”:“s”、“t”、“r”、“a”、“w”、“b”、“e”、“r”、“r”、“y”。这里一共有两个“r”,因此答案是二。
306
+ Reward: -9.2890625
307
+
308
+ Output: 单词“strawberry”中有2个“r”
309
+ Reward: -11.921875
310
+ ```
311
+
312
+ ## 主观问答
313
+ ```python
314
+ from xtuner.utils import RewardModelClient
315
+
316
+ prompt = "帮我想3个形容雨很大的成语,要求不能重复。"
317
+ reference = "1. 倾盆大雨 2. 暴雨如注 3. 瓢泼大雨"
318
+ outputs = [
319
+ # 与参考相同
320
+ "1. 倾盆大雨 2. 暴雨如注 3. 瓢泼大雨",
321
+ # 正确回答
322
+ "1. 大雨滂沱 2. 狂风骤雨 3. 大雨如注",
323
+ # 非成语
324
+ "1. 急雨如瀑 2. 豪雨倾天 3. 雨势磅礴",
325
+ # 与参考类似,多一个。
326
+ "1. 倾盆大雨 2. 暴雨如注 3. 瓢泼大雨 4. 大雨滂沱",
327
+ # 与参考类似,重复一个。
328
+ "1. 倾盆大雨 2. 暴雨如注 3. 暴雨如注",
329
+ # 与参考类似,少一个。
330
+ "1. 倾盆大雨 2. 暴雨如注",
331
+ # 成语正确,多一个。
332
+ "1. 大雨滂沱 2. 狂风骤雨 3. 大雨如注 4. 倾盆大雨",
333
+ # 成语正确,重复一个
334
+ "1. 大雨滂沱 2. 狂风骤雨 3. 狂风骤雨",
335
+ # 成语正确,少一个
336
+ "1. 大雨滂沱 2. 狂风骤雨"
337
+ ]
338
+ data = [{"prompt": prompt, "reference": reference, "output": output} for output in outputs]
339
+
340
+ client = RewardModelClient("internlm/POLAR-7B", server_type="sglang", server_address="127.0.0.1:30000")
341
+ rewards = client(data)
342
+
343
+ sorted_res = sorted(zip(outputs, rewards), key=lambda x: x[1], reverse=True)
344
+
345
+ for output, reward in sorted_res:
346
+ print(f"Output: {output}\nReward: {reward}\n")
347
+ ```
348
+
349
+ ```txt
350
+ Output: 1. 倾盆大雨 2. 暴雨如注 3. 瓢泼大雨
351
+ Reward: -1.42578125
352
+
353
+ Output: 1. 大雨滂沱 2. 狂风骤雨 3. 大雨如注
354
+ Reward: -5.234375
355
+
356
+ Output: 1. 倾盆大雨 2. 暴雨如注 3. 瓢泼大雨 4. 大雨滂沱
357
+ Reward: -5.62890625
358
+
359
+ Output: 1. 急雨如瀑 2. 豪雨倾天 3. 雨势磅礴
360
+ Reward: -5.7109375
361
+
362
+ Output: 1. 倾盆大雨 2. 暴雨如注
363
+ Reward: -6.61328125
364
+
365
+ Output: 1. 倾盆大雨 2. 暴雨如注 3. 暴雨如注
366
+ Reward: -6.65234375
367
+
368
+ Output: 1. 大雨滂沱 2. 狂风骤雨
369
+ Reward: -6.828125
370
+
371
+ Output: 1. 大雨滂沱 2. 狂风骤雨 3. 大雨如注 4. 倾盆大雨
372
+ Reward: -7.0234375
373
+
374
+ Output: 1. 大雨滂沱 2. 狂风骤雨 3. 狂风骤雨
375
+ Reward: -7.23046875
376
+ ```
377
+
378
+ # 许可证
379
+
380
+ 代码和模型权重均采用 Apache-2.0 许可证。
381
+
382
+ # 引用
383
+
384
+ ```
385
+ TBC
386
+ ```
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/cpfs01/shared/llm_ddd/zouyicheng/xtuner/model/internlm2_5-1_8b",
3
+ "architectures": [
4
+ "InternLM2ForRewardModel"
5
+ ],
6
+ "attn_implementation": "flash_attention_2",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_internlm2.InternLM2Config",
9
+ "AutoModel": "modeling_internlm2.InternLM2ForRewardModel"
10
+ },
11
+ "bias": false,
12
+ "bos_token_id": 1,
13
+ "eos_token_id": 2,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 2048,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 8192,
18
+ "max_position_embeddings": 32768,
19
+ "model_type": "internlm2",
20
+ "num_attention_heads": 16,
21
+ "num_hidden_layers": 24,
22
+ "num_key_value_heads": 8,
23
+ "pad_token_id": 2,
24
+ "pretraining_tp": 1,
25
+ "reward_token_id": 92527,
26
+ "rms_norm_eps": 1e-05,
27
+ "rope_scaling": {
28
+ "factor": 2.0,
29
+ "type": "dynamic"
30
+ },
31
+ "rope_theta": 1000000,
32
+ "tie_word_embeddings": false,
33
+ "torch_dtype": "float16",
34
+ "transformers_version": "4.49.0",
35
+ "use_cache": true,
36
+ "vocab_size": 92544
37
+ }
configuration_internlm2.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ InternLM2 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
28
+ class InternLM2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
31
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`InternLM2Model`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer decoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer decoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
61
+ The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens.
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ pad_token_id (`int`, *optional*):
70
+ Padding token id.
71
+ bos_token_id (`int`, *optional*, defaults to 1):
72
+ Beginning of stream token id.
73
+ eos_token_id (`int`, *optional*, defaults to 2):
74
+ End of stream token id.
75
+ pretraining_tp (`int`, *optional*, defaults to 1):
76
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
77
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism)
78
+ to understand more about it. This value is necessary to ensure exact reproducibility
79
+ of the pretraining results. Please refer to [this
80
+ issue](https://github.com/pytorch/pytorch/issues/76232).
81
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82
+ Whether to tie weight embeddings
83
+ rope_theta (`float`, *optional*, defaults to 10000.0):
84
+ The base period of the RoPE embeddings.
85
+ rope_scaling (`Dict`, *optional*):
86
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
87
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
88
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
89
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
90
+ these scaling strategies behave:
91
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
92
+ experimental feature, subject to breaking API changes in future versions.
93
+ """
94
+ _auto_class = "AutoConfig"
95
+ model_type = "internlm2"
96
+ keys_to_ignore_at_inference = ["past_key_values"]
97
+
98
+ def __init__( # pylint: disable=W0102
99
+ self,
100
+ vocab_size=103168,
101
+ hidden_size=4096,
102
+ intermediate_size=11008,
103
+ num_hidden_layers=32,
104
+ num_attention_heads=32,
105
+ num_key_value_heads=None,
106
+ hidden_act="silu",
107
+ max_position_embeddings=2048,
108
+ initializer_range=0.02,
109
+ rms_norm_eps=1e-6,
110
+ use_cache=True,
111
+ pad_token_id=0,
112
+ bos_token_id=1,
113
+ eos_token_id=2,
114
+ pretraining_tp=1,
115
+ tie_word_embeddings=False,
116
+ bias=True,
117
+ rope_theta=10000,
118
+ rope_scaling=None,
119
+ attn_implementation=None,
120
+ **kwargs,
121
+ ):
122
+ self.vocab_size = vocab_size
123
+ self.max_position_embeddings = max_position_embeddings
124
+ self.hidden_size = hidden_size
125
+ self.intermediate_size = intermediate_size
126
+ self.num_hidden_layers = num_hidden_layers
127
+ self.num_attention_heads = num_attention_heads
128
+ self.bias = bias
129
+
130
+ if num_key_value_heads is None:
131
+ num_key_value_heads = num_attention_heads
132
+ self.num_key_value_heads = num_key_value_heads
133
+
134
+ self.hidden_act = hidden_act
135
+ self.initializer_range = initializer_range
136
+ self.rms_norm_eps = rms_norm_eps
137
+ self.pretraining_tp = pretraining_tp
138
+ self.use_cache = use_cache
139
+ self.rope_theta = rope_theta
140
+ self.rope_scaling = rope_scaling
141
+ self._rope_scaling_validation()
142
+ self.attn_implementation = attn_implementation
143
+ if self.attn_implementation is None:
144
+ self.attn_implementation = "eager"
145
+
146
+ super().__init__(
147
+ pad_token_id=pad_token_id,
148
+ bos_token_id=bos_token_id,
149
+ eos_token_id=eos_token_id,
150
+ tie_word_embeddings=tie_word_embeddings,
151
+ **kwargs,
152
+ )
153
+
154
+ def _rope_scaling_validation(self):
155
+ """
156
+ Validate the `rope_scaling` configuration.
157
+ """
158
+ if self.rope_scaling is None:
159
+ return
160
+
161
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
162
+ raise ValueError(
163
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
164
+ f"got {self.rope_scaling}"
165
+ )
166
+ rope_scaling_type = self.rope_scaling.get("type", None)
167
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
168
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
169
+ raise ValueError(
170
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
171
+ )
172
+ if (
173
+ rope_scaling_factor is None
174
+ or not isinstance(rope_scaling_factor, (float, int))
175
+ or rope_scaling_factor < 1.0
176
+ ):
177
+ raise ValueError(
178
+ f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
179
+ f"of type {type(rope_scaling_factor)}"
180
+ )
misc/intro.jpeg ADDED

Git LFS Details

  • SHA256: d81fdc6c6f832d3859279549a3443eed5cf819a868d9d6c821c10928bb0ba022
  • Pointer size: 132 Bytes
  • Size of remote file: 2.3 MB
misc/logo.png ADDED
misc/result.png ADDED

Git LFS Details

  • SHA256: d9aa175f84cea71e11b7c03538a0909a3f733bb6140a0e5fbd3b3bc255413191
  • Pointer size: 131 Bytes
  • Size of remote file: 785 kB
modeling_internlm2.py ADDED
@@ -0,0 +1,1995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch InternLM2 model."""
17
+ import math
18
+ import queue
19
+ import threading
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from einops import rearrange
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPast,
33
+ CausalLMOutputWithPast,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutputWithPast,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
40
+ from transformers.utils import (
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ is_flash_attn_greater_or_equal_2_10,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+
48
+ try:
49
+ from transformers.generation.streamers import BaseStreamer
50
+ except Exception:
51
+ BaseStreamer = None
52
+
53
+ from .configuration_internlm2 import InternLM2Config
54
+
55
+
56
+ try:
57
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
58
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
59
+ except:
60
+ pass
61
+
62
+
63
+ logger = logging.get_logger(__name__)
64
+
65
+ _CONFIG_FOR_DOC = "InternLM2Config"
66
+
67
+
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # pylint: disable=E1102
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
+
80
+ class InternLM2RMSNorm(nn.Module):
81
+ """InternLM2RMSNorm is equivalent to T5LayerNorm."""
82
+
83
+ def __init__(self, hidden_size, eps=1e-6):
84
+ super().__init__()
85
+ self.weight = nn.Parameter(torch.ones(hidden_size))
86
+ self.variance_epsilon = eps
87
+
88
+ def forward(self, hidden_states):
89
+ input_dtype = hidden_states.dtype
90
+ hidden_states = hidden_states.to(torch.float32)
91
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
92
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
93
+ return self.weight * hidden_states.to(input_dtype)
94
+
95
+
96
+ ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm)
97
+
98
+
99
+ class InternLM2RotaryEmbedding(nn.Module):
100
+ """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains."""
101
+
102
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
103
+ super().__init__()
104
+ self.scaling_factor = scaling_factor
105
+ self.dim = dim
106
+ self.max_position_embeddings = max_position_embeddings
107
+ self.base = base
108
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
109
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
110
+ # For BC we register cos and sin cached
111
+ self.max_seq_len_cached = max_position_embeddings
112
+
113
+ @torch.no_grad()
114
+ def forward(self, x, position_ids):
115
+ # x: [bs, num_attention_heads, seq_len, head_size]
116
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
117
+ position_ids_expanded = position_ids[:, None, :].float()
118
+ # Force float32 since bfloat16 loses precision on long contexts
119
+ # See https://github.com/huggingface/transformers/pull/29285
120
+ device_type = x.device.type
121
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
122
+ with torch.autocast(device_type=device_type, enabled=False):
123
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
124
+ emb = torch.cat((freqs, freqs), dim=-1)
125
+ cos = emb.cos()
126
+ sin = emb.sin()
127
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
128
+
129
+
130
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
131
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
132
+
133
+ def forward(self, x, position_ids):
134
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
135
+ position_ids = position_ids.float() / self.scaling_factor
136
+ cos, sin = super().forward(x, position_ids)
137
+ return cos, sin
138
+
139
+
140
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
141
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
142
+ Credits to the Reddit users /u/bloc97 and /u/emozilla"""
143
+
144
+ def forward(self, x, position_ids):
145
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
146
+ seq_len = torch.max(position_ids) + 1
147
+ if seq_len > self.max_position_embeddings:
148
+ base = self.base * (
149
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
150
+ ) ** (self.dim / (self.dim - 2))
151
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim))
152
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
153
+
154
+ cos, sin = super().forward(x, position_ids)
155
+ return cos, sin
156
+
157
+
158
+ def rotate_half(x):
159
+ """Rotates half the hidden dims of the input."""
160
+ x1 = x[..., : x.shape[-1] // 2]
161
+ x2 = x[..., x.shape[-1] // 2 :]
162
+ return torch.cat((-x2, x1), dim=-1)
163
+
164
+
165
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument
166
+ """Applies Rotary Position Embedding to the query and key tensors.
167
+ Args:
168
+ q (`torch.Tensor`): The query tensor.
169
+ k (`torch.Tensor`): The key tensor.
170
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
171
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
172
+ position_ids (`torch.Tensor`, *optional*):
173
+ Deprecated and unused.
174
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
175
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
176
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
177
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
178
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
179
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
180
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
181
+ Returns:
182
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
183
+ """
184
+ cos = cos.unsqueeze(unsqueeze_dim)
185
+ sin = sin.unsqueeze(unsqueeze_dim)
186
+ q_embed = (q * cos) + (rotate_half(q) * sin)
187
+ k_embed = (k * cos) + (rotate_half(k) * sin)
188
+ return q_embed, k_embed
189
+
190
+
191
+ class InternLM2MLP(nn.Module):
192
+ """MLP for InternLM2 model."""
193
+
194
+ def __init__(self, config):
195
+ super().__init__()
196
+ self.config = config
197
+ self.hidden_size = config.hidden_size
198
+ self.intermediate_size = config.intermediate_size
199
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
200
+ self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
201
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
202
+ self.act_fn = ACT2FN[config.hidden_act]
203
+
204
+ def forward(self, x):
205
+ down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
206
+
207
+ return down_proj
208
+
209
+
210
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
211
+ """
212
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
213
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
214
+ """
215
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
216
+ if n_rep == 1:
217
+ return hidden_states
218
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
219
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
220
+
221
+
222
+ class InternLM2Attention(nn.Module):
223
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
224
+
225
+ def __init__(self, config: InternLM2Config, layer_idx: Optional[int] = None):
226
+ super().__init__()
227
+ self.config = config
228
+ self.layer_idx = layer_idx
229
+ if layer_idx is None:
230
+ logger.warning_once(
231
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
232
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
233
+ "when creating this class."
234
+ )
235
+
236
+ self.hidden_size = config.hidden_size
237
+ self.num_heads = config.num_attention_heads
238
+ self.head_dim = self.hidden_size // self.num_heads
239
+ self.num_key_value_heads = config.num_key_value_heads
240
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
241
+ self.max_position_embeddings = config.max_position_embeddings
242
+ self.rope_theta = config.rope_theta
243
+ self.is_causal = True
244
+
245
+ if (self.head_dim * self.num_heads) != self.hidden_size:
246
+ raise ValueError(
247
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
248
+ f" and `num_heads`: {self.num_heads})."
249
+ )
250
+
251
+ self.wqkv = nn.Linear(
252
+ self.hidden_size,
253
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
254
+ bias=config.bias,
255
+ )
256
+ self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
257
+
258
+ self._init_rope()
259
+
260
+ def _init_rope(self):
261
+ if self.config.rope_scaling is None:
262
+ self.rotary_emb = InternLM2RotaryEmbedding(
263
+ self.head_dim,
264
+ max_position_embeddings=self.max_position_embeddings,
265
+ base=self.rope_theta,
266
+ )
267
+ else:
268
+ scaling_type = self.config.rope_scaling["type"]
269
+ scaling_factor = self.config.rope_scaling["factor"]
270
+ if scaling_type == "linear":
271
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
272
+ self.head_dim,
273
+ max_position_embeddings=self.max_position_embeddings,
274
+ scaling_factor=scaling_factor,
275
+ base=self.rope_theta,
276
+ )
277
+ elif scaling_type == "dynamic":
278
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
279
+ self.head_dim,
280
+ max_position_embeddings=self.max_position_embeddings,
281
+ scaling_factor=scaling_factor,
282
+ base=self.rope_theta,
283
+ )
284
+ else:
285
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
286
+
287
+ def forward(
288
+ self,
289
+ hidden_states: torch.Tensor,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ position_ids: Optional[torch.LongTensor] = None,
292
+ past_key_value: Optional[Cache] = None,
293
+ output_attentions: bool = False,
294
+ use_cache: bool = False, # pylint: disable=unused-argument
295
+ cache_position: Optional[torch.LongTensor] = None,
296
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
297
+ bsz, q_len, _ = hidden_states.size()
298
+
299
+ if self.config.pretraining_tp > 1:
300
+ # split qkv_states by tp size
301
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
302
+ qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0)
303
+ qkv_states = torch.cat(
304
+ [F.linear(hidden_states, qkv_slice) for qkv_slice in qkv_slices], dim=-1 # pylint: disable=E1102
305
+ )
306
+ else:
307
+ qkv_states = self.wqkv(hidden_states)
308
+
309
+ qkv_states = rearrange(
310
+ qkv_states,
311
+ "b q (h gs d) -> b q h gs d",
312
+ gs=2 + self.num_key_value_groups,
313
+ d=self.head_dim,
314
+ )
315
+
316
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
317
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d").transpose(1, 2)
318
+ key_states = qkv_states[..., -2, :].transpose(1, 2)
319
+ value_states = qkv_states[..., -1, :].transpose(1, 2)
320
+
321
+ cos, sin = self.rotary_emb(value_states, position_ids)
322
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
323
+
324
+ if past_key_value is not None:
325
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
326
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
327
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
328
+
329
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
330
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
331
+
332
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
333
+
334
+ if attention_mask is not None: # no matter the length, we just slice it
335
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
336
+ attn_weights = attn_weights + causal_mask
337
+
338
+ # upcast attention to fp32
339
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
340
+ attn_output = torch.matmul(attn_weights, value_states)
341
+
342
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
343
+ raise ValueError(
344
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
345
+ f" {attn_output.size()}"
346
+ )
347
+
348
+ attn_output = attn_output.transpose(1, 2).contiguous()
349
+
350
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
351
+
352
+ if self.config.pretraining_tp > 1:
353
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
354
+ o_proj_slices = self.wo.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
355
+ attn_output = sum(
356
+ [
357
+ F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102
358
+ for i in range(self.config.pretraining_tp)
359
+ ]
360
+ )
361
+ else:
362
+ attn_output = self.wo(attn_output)
363
+
364
+ if not output_attentions:
365
+ attn_weights = None
366
+
367
+ return attn_output, attn_weights, past_key_value
368
+
369
+
370
+ class InternLM2FlashAttention2(InternLM2Attention):
371
+ """
372
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
373
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
374
+ flash attention and deal with padding tokens in case the input contains any of them.
375
+ """
376
+
377
+ def __init__(self, *args, **kwargs):
378
+ super().__init__(*args, **kwargs)
379
+
380
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
381
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement,
382
+ # that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
383
+ # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
384
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1)
385
+ # produces a wrong mask (top-left).
386
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
387
+
388
+ def forward(
389
+ self,
390
+ hidden_states: torch.Tensor,
391
+ attention_mask: Optional[torch.LongTensor] = None,
392
+ position_ids: Optional[torch.LongTensor] = None,
393
+ past_key_value: Optional[Cache] = None,
394
+ output_attentions: bool = False,
395
+ use_cache: bool = False,
396
+ cache_position: Optional[torch.LongTensor] = None,
397
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
398
+ if isinstance(past_key_value, StaticCache):
399
+ raise ValueError(
400
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
401
+ "make sure to use `sdpa` in the mean time, and open an issue at "
402
+ "https://github.com/huggingface/transformers"
403
+ )
404
+
405
+ output_attentions = False
406
+
407
+ bsz, q_len, _ = hidden_states.size()
408
+
409
+ qkv_states = self.wqkv(hidden_states)
410
+
411
+ qkv_states = rearrange(
412
+ qkv_states,
413
+ "b q (h gs d) -> b q h gs d",
414
+ gs=2 + self.num_key_value_groups,
415
+ d=self.head_dim,
416
+ )
417
+
418
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
419
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
420
+ key_states = qkv_states[..., -2, :]
421
+ value_states = qkv_states[..., -1, :]
422
+
423
+ query_states = query_states.transpose(1, 2)
424
+ key_states = key_states.transpose(1, 2)
425
+ value_states = value_states.transpose(1, 2)
426
+
427
+ cos, sin = self.rotary_emb(value_states, position_ids)
428
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
429
+
430
+ if past_key_value is not None:
431
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
432
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
433
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
434
+
435
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
436
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
437
+ # to be able to avoid many of these transpose/reshape/view.
438
+ query_states = query_states.transpose(1, 2)
439
+ key_states = key_states.transpose(1, 2)
440
+ value_states = value_states.transpose(1, 2)
441
+
442
+ # dropout_rate = self.attention_dropout if self.training else 0.0
443
+ dropout_rate = 0.0
444
+
445
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
446
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
447
+ # cast them back in the correct dtype just to be sure everything works as expected.
448
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
449
+ # in fp32. (InternLM2RMSNorm handles it correctly)
450
+
451
+ input_dtype = query_states.dtype
452
+ if input_dtype == torch.float32:
453
+ if torch.is_autocast_enabled():
454
+ target_dtype = torch.get_autocast_gpu_dtype()
455
+ # Handle the case where the model is quantized
456
+ elif hasattr(self.config, "_pre_quantization_dtype"):
457
+ target_dtype = self.config._pre_quantization_dtype
458
+ else:
459
+ target_dtype = self.wqkv.weight.dtype
460
+
461
+ logger.warning_once(
462
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
463
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
464
+ f" {target_dtype}."
465
+ )
466
+
467
+ query_states = query_states.to(target_dtype)
468
+ key_states = key_states.to(target_dtype)
469
+ value_states = value_states.to(target_dtype)
470
+
471
+ attn_output = self._flash_attention_forward(
472
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
473
+ )
474
+
475
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
476
+ attn_output = self.wo(attn_output)
477
+
478
+ if not output_attentions:
479
+ attn_weights = None
480
+
481
+ return attn_output, attn_weights, past_key_value # pylint: disable=E0606
482
+
483
+ def _flash_attention_forward(
484
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
485
+ ):
486
+ """
487
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
488
+ first unpad the input, then computes the attention scores and pad the final attention scores.
489
+ Args:
490
+ query_states (`torch.Tensor`):
491
+ Input query states to be passed to Flash Attention API
492
+ key_states (`torch.Tensor`):
493
+ Input key states to be passed to Flash Attention API
494
+ value_states (`torch.Tensor`):
495
+ Input value states to be passed to Flash Attention API
496
+ attention_mask (`torch.Tensor`):
497
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
498
+ position of padding tokens and 1 for the position of non-padding tokens.
499
+ dropout (`float`):
500
+ Attention dropout
501
+ softmax_scale (`float`, *optional*):
502
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
503
+ """
504
+ if not self._flash_attn_uses_top_left_mask:
505
+ causal = self.is_causal
506
+ else:
507
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
508
+ # For details, please see the comment in InternLM2FlashAttention2 __init__.
509
+ causal = self.is_causal and query_length != 1
510
+
511
+ # Contains at least one padding token in the sequence
512
+ if attention_mask is not None:
513
+ batch_size = query_states.shape[0]
514
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
515
+ query_states, key_states, value_states, attention_mask, query_length
516
+ )
517
+
518
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
519
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
520
+
521
+ attn_output_unpad = flash_attn_varlen_func( # pylint: disable=E0606
522
+ query_states,
523
+ key_states,
524
+ value_states,
525
+ cu_seqlens_q=cu_seqlens_q,
526
+ cu_seqlens_k=cu_seqlens_k,
527
+ max_seqlen_q=max_seqlen_in_batch_q,
528
+ max_seqlen_k=max_seqlen_in_batch_k,
529
+ dropout_p=dropout,
530
+ softmax_scale=softmax_scale,
531
+ causal=causal,
532
+ )
533
+
534
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) # pylint: disable=E0606
535
+ else:
536
+ attn_output = flash_attn_func( # pylint: disable=E0606
537
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
538
+ )
539
+
540
+ return attn_output
541
+
542
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
543
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
544
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
545
+
546
+ key_layer = index_first_axis( # pylint: disable=E0606
547
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
548
+ )
549
+ value_layer = index_first_axis( # pylint: disable=E0606
550
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
551
+ )
552
+ if query_length == kv_seq_len:
553
+ query_layer = index_first_axis( # pylint: disable=E0606
554
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
555
+ )
556
+ cu_seqlens_q = cu_seqlens_k
557
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
558
+ indices_q = indices_k
559
+ elif query_length == 1:
560
+ max_seqlen_in_batch_q = 1
561
+ cu_seqlens_q = torch.arange(
562
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
563
+ ) # There is a memcpy here, that is very bad.
564
+ indices_q = cu_seqlens_q[:-1]
565
+ query_layer = query_layer.squeeze(1)
566
+ else:
567
+ # The -q_len: slice assumes left padding.
568
+ attention_mask = attention_mask[:, -query_length:]
569
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( # pylint: disable=E0606
570
+ query_layer, attention_mask
571
+ )
572
+
573
+ return (
574
+ query_layer,
575
+ key_layer,
576
+ value_layer,
577
+ indices_q,
578
+ (cu_seqlens_q, cu_seqlens_k),
579
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
580
+ )
581
+
582
+
583
+ # Copied from transformers.models.llama.modeling_llama.LllamaSdpaAttention with Llama->InternLM2
584
+ class InternLM2SdpaAttention(InternLM2Attention):
585
+ """
586
+ InternLM2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
587
+ `InternLM2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
588
+ to adapt to SDPA API.
589
+ """
590
+
591
+ # Adapted from InternLM2Attention.forward
592
+ def forward(
593
+ self,
594
+ hidden_states: torch.Tensor,
595
+ attention_mask: Optional[torch.Tensor] = None,
596
+ position_ids: Optional[torch.LongTensor] = None,
597
+ past_key_value: Optional[Cache] = None,
598
+ output_attentions: bool = False,
599
+ use_cache: bool = False,
600
+ cache_position: Optional[torch.LongTensor] = None,
601
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
602
+ if output_attentions:
603
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"`
604
+ # once this is implemented.
605
+ logger.warning_once(
606
+ "InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
607
+ "does not support `output_attentions=True`. Falling back to the manual attention implementation, "
608
+ "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
609
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
610
+ )
611
+ return super().forward(
612
+ hidden_states=hidden_states,
613
+ attention_mask=attention_mask,
614
+ position_ids=position_ids,
615
+ past_key_value=past_key_value,
616
+ output_attentions=output_attentions,
617
+ use_cache=use_cache,
618
+ cache_position=cache_position,
619
+ )
620
+
621
+ bsz, q_len, _ = hidden_states.size()
622
+
623
+ qkv_states = self.wqkv(hidden_states)
624
+
625
+ qkv_states = rearrange(
626
+ qkv_states,
627
+ "b q (h gs d) -> b q h gs d",
628
+ gs=2 + self.num_key_value_groups,
629
+ d=self.head_dim,
630
+ )
631
+
632
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
633
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
634
+ key_states = qkv_states[..., -2, :]
635
+ value_states = qkv_states[..., -1, :]
636
+
637
+ query_states = query_states.transpose(1, 2)
638
+ key_states = key_states.transpose(1, 2)
639
+ value_states = value_states.transpose(1, 2)
640
+
641
+ cos, sin = self.rotary_emb(value_states, position_ids)
642
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
643
+
644
+ if past_key_value is not None:
645
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
646
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
647
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
648
+
649
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
650
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
651
+
652
+ causal_mask = attention_mask
653
+ if attention_mask is not None:
654
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
655
+
656
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
657
+ # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577.
658
+ if query_states.device.type == "cuda" and causal_mask is not None:
659
+ query_states = query_states.contiguous()
660
+ key_states = key_states.contiguous()
661
+ value_states = value_states.contiguous()
662
+
663
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
664
+ # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
665
+ # options. An inline conditional prevents dynamic shapes from compiling.
666
+ is_causal = bool(causal_mask is None and q_len > 1)
667
+
668
+ attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
669
+ query_states,
670
+ key_states,
671
+ value_states,
672
+ attn_mask=causal_mask,
673
+ dropout_p=0.0,
674
+ is_causal=is_causal,
675
+ )
676
+
677
+ attn_output = attn_output.transpose(1, 2).contiguous()
678
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
679
+
680
+ attn_output = self.wo(attn_output)
681
+
682
+ return attn_output, None, past_key_value
683
+
684
+
685
+ INTERNLM2_ATTENTION_CLASSES = {
686
+ "eager": InternLM2Attention,
687
+ "flash_attention_2": InternLM2FlashAttention2,
688
+ "sdpa": InternLM2SdpaAttention,
689
+ }
690
+
691
+
692
+ # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2
693
+ class InternLM2DecoderLayer(nn.Module):
694
+ """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model."""
695
+
696
+ def __init__(self, config: InternLM2Config, layer_idx: int):
697
+ super().__init__()
698
+ self.hidden_size = config.hidden_size
699
+ self.layer_idx = layer_idx
700
+
701
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config, layer_idx=layer_idx)
702
+
703
+ self.feed_forward = InternLM2MLP(config)
704
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
705
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
706
+
707
+ def forward(
708
+ self,
709
+ hidden_states: torch.Tensor,
710
+ attention_mask: Optional[torch.Tensor] = None,
711
+ position_ids: Optional[torch.LongTensor] = None,
712
+ past_key_value: Optional[Cache] = None,
713
+ output_attentions: Optional[bool] = False,
714
+ use_cache: Optional[bool] = False,
715
+ cache_position: Optional[torch.LongTensor] = None,
716
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
717
+ """
718
+ Args:
719
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
720
+ attention_mask (`torch.FloatTensor`, *optional*):
721
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
722
+ query_sequence_length, key_sequence_length)` if default attention is used.
723
+ output_attentions (`bool`, *optional*):
724
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
725
+ returned tensors for more detail.
726
+ use_cache (`bool`, *optional*):
727
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
728
+ (see `past_key_values`).
729
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
730
+ """
731
+ residual = hidden_states
732
+
733
+ hidden_states = self.attention_norm(hidden_states)
734
+
735
+ # Self Attention
736
+ hidden_states, self_attn_weights, present_key_value = self.attention(
737
+ hidden_states=hidden_states,
738
+ attention_mask=attention_mask,
739
+ position_ids=position_ids,
740
+ past_key_value=past_key_value,
741
+ output_attentions=output_attentions,
742
+ use_cache=use_cache,
743
+ cache_position=cache_position,
744
+ )
745
+ hidden_states = residual + hidden_states
746
+
747
+ # Fully Connected
748
+ residual = hidden_states
749
+ hidden_states = self.ffn_norm(hidden_states)
750
+ hidden_states = self.feed_forward(hidden_states)
751
+ hidden_states = residual + hidden_states
752
+
753
+ outputs = (hidden_states,)
754
+
755
+ if output_attentions:
756
+ outputs += (self_attn_weights,)
757
+
758
+ if use_cache:
759
+ outputs += (present_key_value,)
760
+
761
+ return outputs
762
+
763
+
764
+ InternLM2_START_DOCSTRING = r"""
765
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
766
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
767
+ etc.)
768
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
769
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
770
+ and behavior.
771
+ Parameters:
772
+ config ([`InternLM2Config`]):
773
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
774
+ load the weights associated with the model, only the configuration. Check out the
775
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
776
+ """
777
+
778
+
779
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
780
+ @add_start_docstrings(
781
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
782
+ InternLM2_START_DOCSTRING,
783
+ )
784
+ class InternLM2PreTrainedModel(PreTrainedModel):
785
+ """
786
+ InternLM2 pretraiend model's base class.
787
+ """
788
+
789
+ config_class = InternLM2Config
790
+ base_model_prefix = "model"
791
+ supports_gradient_checkpointing = True
792
+ _no_split_modules = ["InternLM2DecoderLayer"]
793
+ _skip_keys_device_placement = ["past_key_values"]
794
+ _supports_flash_attn_2 = True
795
+ _supports_sdpa = True
796
+ _supports_cache_class = True
797
+ _supports_quantized_cache = True
798
+ _supports_static_cache = True
799
+
800
+ def _init_weights(self, module):
801
+ std = self.config.initializer_range
802
+ if isinstance(module, nn.Linear):
803
+ module.weight.data.normal_(mean=0.0, std=std)
804
+ if module.bias is not None:
805
+ module.bias.data.zero_()
806
+ elif isinstance(module, nn.Embedding):
807
+ module.weight.data.normal_(mean=0.0, std=std)
808
+ if module.padding_idx is not None:
809
+ module.weight.data[module.padding_idx].zero_()
810
+
811
+
812
+ InternLM2_INPUTS_DOCSTRING = r"""
813
+ Args:
814
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
815
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
816
+ it.
817
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
818
+ [`PreTrainedTokenizer.__call__`] for details.
819
+ [What are input IDs?](../glossary#input-ids)
820
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
821
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
822
+ - 1 for tokens that are **not masked**,
823
+ - 0 for tokens that are **masked**.
824
+ [What are attention masks?](../glossary#attention-mask)
825
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
826
+ [`PreTrainedTokenizer.__call__`] for details.
827
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
828
+ `past_key_values`).
829
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
830
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
831
+ information on the default strategy.
832
+ - 1 indicates the head is **not masked**,
833
+ - 0 indicates the head is **masked**.
834
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
835
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
836
+ config.n_positions - 1]`.
837
+ [What are position IDs?](../glossary#position-ids)
838
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
839
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
840
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
841
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
842
+ Two formats are allowed:
843
+ - a [`~cache_utils.Cache`] instance;
844
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
845
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
846
+ cache format.
847
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
848
+ legacy cache format will be returned.
849
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
850
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
851
+ of shape `(batch_size, sequence_length)`.
852
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
853
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
854
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
855
+ model's internal embedding lookup matrix.
856
+ use_cache (`bool`, *optional*):
857
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
858
+ `past_key_values`).
859
+ output_attentions (`bool`, *optional*):
860
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
861
+ tensors for more detail.
862
+ output_hidden_states (`bool`, *optional*):
863
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
864
+ more detail.
865
+ return_dict (`bool`, *optional*):
866
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
867
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
868
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
869
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
870
+ the complete sequence length.
871
+ """
872
+
873
+
874
+ # Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2
875
+ @add_start_docstrings(
876
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
877
+ InternLM2_START_DOCSTRING,
878
+ )
879
+ class InternLM2Model(InternLM2PreTrainedModel):
880
+ """
881
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
882
+ Args:
883
+ config: InternLM2Config
884
+ """
885
+
886
+ _auto_class = "AutoModel"
887
+
888
+ def __init__(self, config: InternLM2Config):
889
+ super().__init__(config)
890
+ self.padding_idx = config.pad_token_id
891
+ self.vocab_size = config.vocab_size
892
+ self.config = config
893
+
894
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
895
+
896
+ self.layers = nn.ModuleList(
897
+ [InternLM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
898
+ )
899
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
900
+
901
+ self.gradient_checkpointing = False
902
+ # Initialize weights and apply final processing
903
+ self.post_init()
904
+
905
+ def get_input_embeddings(self):
906
+ return self.tok_embeddings
907
+
908
+ def set_input_embeddings(self, value):
909
+ self.tok_embeddings = value
910
+
911
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
912
+ def forward(
913
+ self,
914
+ input_ids: torch.LongTensor = None,
915
+ attention_mask: Optional[torch.Tensor] = None,
916
+ position_ids: Optional[torch.LongTensor] = None,
917
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
918
+ inputs_embeds: Optional[torch.FloatTensor] = None,
919
+ use_cache: Optional[bool] = None,
920
+ output_attentions: Optional[bool] = None,
921
+ output_hidden_states: Optional[bool] = None,
922
+ return_dict: Optional[bool] = None,
923
+ cache_position: Optional[torch.LongTensor] = None,
924
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
925
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
926
+ output_hidden_states = (
927
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
928
+ )
929
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
930
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
931
+
932
+ if (input_ids is None) ^ (inputs_embeds is not None):
933
+ raise ValueError(
934
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
935
+ )
936
+
937
+ if self.gradient_checkpointing and self.training and use_cache:
938
+ logger.warning_once(
939
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
940
+ )
941
+ use_cache = False
942
+
943
+ if inputs_embeds is None:
944
+ inputs_embeds = self.tok_embeddings(input_ids)
945
+
946
+ return_legacy_cache = False
947
+ if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
948
+ return_legacy_cache = True
949
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
950
+
951
+ if cache_position is None:
952
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
953
+ cache_position = torch.arange(
954
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
955
+ )
956
+ if position_ids is None:
957
+ position_ids = cache_position.unsqueeze(0)
958
+
959
+ causal_mask = self._update_causal_mask(
960
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
961
+ )
962
+
963
+ # embed positions
964
+ hidden_states = inputs_embeds
965
+
966
+ # decoder layers
967
+ all_hidden_states = () if output_hidden_states else None
968
+ all_self_attns = () if output_attentions else None
969
+ next_decoder_cache = None
970
+
971
+ for decoder_layer in self.layers:
972
+ if output_hidden_states:
973
+ all_hidden_states += (hidden_states,)
974
+
975
+ if self.gradient_checkpointing and self.training:
976
+ layer_outputs = self._gradient_checkpointing_func(
977
+ decoder_layer.__call__,
978
+ hidden_states,
979
+ causal_mask,
980
+ position_ids,
981
+ past_key_values,
982
+ output_attentions,
983
+ use_cache,
984
+ cache_position,
985
+ )
986
+ else:
987
+ layer_outputs = decoder_layer(
988
+ hidden_states,
989
+ attention_mask=causal_mask,
990
+ position_ids=position_ids,
991
+ past_key_value=past_key_values,
992
+ output_attentions=output_attentions,
993
+ use_cache=use_cache,
994
+ cache_position=cache_position,
995
+ )
996
+
997
+ hidden_states = layer_outputs[0]
998
+
999
+ if use_cache:
1000
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1001
+
1002
+ if output_attentions:
1003
+ all_self_attns += (layer_outputs[1],)
1004
+
1005
+ hidden_states = self.norm(hidden_states)
1006
+
1007
+ # add hidden states from the last decoder layer
1008
+ if output_hidden_states:
1009
+ all_hidden_states += (hidden_states,)
1010
+
1011
+ next_cache = next_decoder_cache if use_cache else None
1012
+ if return_legacy_cache:
1013
+ next_cache = next_cache.to_legacy_cache()
1014
+
1015
+ if not return_dict:
1016
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1017
+ return BaseModelOutputWithPast(
1018
+ last_hidden_state=hidden_states,
1019
+ past_key_values=next_cache,
1020
+ hidden_states=all_hidden_states,
1021
+ attentions=all_self_attns,
1022
+ )
1023
+
1024
+ def _update_causal_mask(
1025
+ self,
1026
+ attention_mask: torch.Tensor,
1027
+ input_tensor: torch.Tensor,
1028
+ cache_position: torch.Tensor,
1029
+ past_key_values: Cache,
1030
+ output_attentions: bool,
1031
+ ):
1032
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length
1033
+ # even when the static KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at
1034
+ # each decode steps due to the dynamic shapes. (`recording cudagraph tree for symint key 13`, etc.), which is
1035
+ # VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`.
1036
+ # See more context in https://github.com/huggingface/transformers/pull/29114
1037
+
1038
+ if self.config.attn_implementation == "flash_attention_2":
1039
+ if attention_mask is not None and 0.0 in attention_mask:
1040
+ return attention_mask
1041
+ return None
1042
+
1043
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1044
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1045
+ # to infer the attention mask.
1046
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1047
+ using_static_cache = isinstance(past_key_values, StaticCache)
1048
+
1049
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1050
+ if self.config.attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1051
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1052
+ attention_mask,
1053
+ inputs_embeds=input_tensor,
1054
+ past_key_values_length=past_seen_tokens,
1055
+ is_training=self.training,
1056
+ ):
1057
+ return None
1058
+
1059
+ dtype, device = input_tensor.dtype, input_tensor.device
1060
+ min_dtype = torch.finfo(dtype).min
1061
+ sequence_length = input_tensor.shape[1]
1062
+ if using_static_cache:
1063
+ target_length = past_key_values.get_max_cache_shape()
1064
+ else:
1065
+ target_length = (
1066
+ attention_mask.shape[-1]
1067
+ if isinstance(attention_mask, torch.Tensor)
1068
+ else past_seen_tokens + sequence_length + 1
1069
+ )
1070
+
1071
+ if attention_mask is not None and attention_mask.dim() == 4:
1072
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1073
+ if attention_mask.max() != 0:
1074
+ raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
1075
+ causal_mask = attention_mask
1076
+ else:
1077
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1078
+ if sequence_length != 1:
1079
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1080
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1081
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1082
+ if attention_mask is not None:
1083
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1084
+ mask_length = attention_mask.shape[-1]
1085
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1086
+ padding_mask = padding_mask == 0
1087
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1088
+ padding_mask, min_dtype
1089
+ )
1090
+ if (
1091
+ self.config.attn_implementation == "sdpa"
1092
+ and attention_mask is not None
1093
+ and attention_mask.device.type == "cuda"
1094
+ and not output_attentions
1095
+ ):
1096
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1097
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1098
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1099
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) # pylint: disable=E1120
1100
+
1101
+ return causal_mask
1102
+
1103
+
1104
+ # Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM
1105
+ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1106
+ """Causal language model (CLM) for InternLM2."""
1107
+
1108
+ _auto_class = "AutoModelForCausalLM"
1109
+ _tied_weights_keys = ["output.weight"]
1110
+
1111
+ def __init__(self, config):
1112
+ super().__init__(config)
1113
+ self.model = InternLM2Model(config)
1114
+ self.vocab_size = config.vocab_size
1115
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1116
+
1117
+ # Initialize weights and apply final processing
1118
+ self.post_init()
1119
+
1120
+ def get_input_embeddings(self):
1121
+ return self.model.tok_embeddings
1122
+
1123
+ def set_input_embeddings(self, value):
1124
+ self.model.tok_embeddings = value
1125
+
1126
+ def get_output_embeddings(self):
1127
+ return self.output
1128
+
1129
+ def set_output_embeddings(self, new_embeddings):
1130
+ self.output = new_embeddings
1131
+
1132
+ def set_decoder(self, decoder):
1133
+ self.model = decoder
1134
+
1135
+ def get_decoder(self):
1136
+ return self.model
1137
+
1138
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1139
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1140
+ def forward(
1141
+ self,
1142
+ input_ids: torch.LongTensor = None,
1143
+ attention_mask: Optional[torch.Tensor] = None,
1144
+ position_ids: Optional[torch.LongTensor] = None,
1145
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1146
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1147
+ labels: Optional[torch.LongTensor] = None,
1148
+ use_cache: Optional[bool] = None,
1149
+ output_attentions: Optional[bool] = None,
1150
+ output_hidden_states: Optional[bool] = None,
1151
+ return_dict: Optional[bool] = None,
1152
+ cache_position: Optional[torch.LongTensor] = None,
1153
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1154
+ r"""
1155
+ Args:
1156
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1157
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1158
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1159
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1160
+ Returns:
1161
+ Example:
1162
+ ```python
1163
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1164
+ >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
1165
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
1166
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1167
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1168
+ >>> # Generate
1169
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1170
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1171
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1172
+ ```"""
1173
+
1174
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1175
+ output_hidden_states = (
1176
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1177
+ )
1178
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1179
+
1180
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1181
+ outputs = self.model(
1182
+ input_ids=input_ids,
1183
+ attention_mask=attention_mask,
1184
+ position_ids=position_ids,
1185
+ past_key_values=past_key_values,
1186
+ inputs_embeds=inputs_embeds,
1187
+ use_cache=use_cache,
1188
+ output_attentions=output_attentions,
1189
+ output_hidden_states=output_hidden_states,
1190
+ return_dict=return_dict,
1191
+ cache_position=cache_position,
1192
+ )
1193
+
1194
+ hidden_states = outputs[0]
1195
+ if self.config.pretraining_tp > 1:
1196
+ output_slices = self.output.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1197
+ logits = [
1198
+ F.linear(hidden_states, output_slices[i]) # pylint: disable=not-callable
1199
+ for i in range(self.config.pretraining_tp)
1200
+ ]
1201
+ logits = torch.cat(logits, dim=-1)
1202
+ else:
1203
+ logits = self.output(hidden_states)
1204
+ logits = logits.float()
1205
+
1206
+ loss = None
1207
+ if labels is not None:
1208
+ # Shift so that tokens < n predict n
1209
+ shift_logits = logits[..., :-1, :].contiguous()
1210
+ shift_labels = labels[..., 1:].contiguous()
1211
+ # Flatten the tokens
1212
+ loss_fct = CrossEntropyLoss()
1213
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1214
+ shift_labels = shift_labels.view(-1)
1215
+ # Enable model parallelism
1216
+ shift_labels = shift_labels.to(shift_logits.device)
1217
+ loss = loss_fct(shift_logits, shift_labels)
1218
+
1219
+ if not return_dict:
1220
+ output = (logits,) + outputs[1:]
1221
+ return (loss,) + output if loss is not None else output
1222
+
1223
+ return CausalLMOutputWithPast(
1224
+ loss=loss,
1225
+ logits=logits,
1226
+ past_key_values=outputs.past_key_values,
1227
+ hidden_states=outputs.hidden_states,
1228
+ attentions=outputs.attentions,
1229
+ )
1230
+
1231
+ def prepare_inputs_for_generation(
1232
+ self,
1233
+ input_ids,
1234
+ past_key_values=None,
1235
+ attention_mask=None,
1236
+ inputs_embeds=None,
1237
+ cache_position=None,
1238
+ use_cache=True,
1239
+ **kwargs,
1240
+ ):
1241
+ past_length = 0
1242
+ if past_key_values is not None:
1243
+ if isinstance(past_key_values, Cache):
1244
+ past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
1245
+ max_cache_length = (
1246
+ torch.tensor(past_key_values.get_max_cache_shape(), device=input_ids.device)
1247
+ if past_key_values.get_max_cache_shape() is not None
1248
+ else None
1249
+ )
1250
+ cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
1251
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1252
+ else:
1253
+ cache_length = past_length = past_key_values[0][0].shape[2]
1254
+ max_cache_length = None
1255
+
1256
+ # Keep only the unprocessed tokens:
1257
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1258
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
1259
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1260
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1261
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1262
+ # input_ids based on the past_length.
1263
+ elif past_length < input_ids.shape[1]:
1264
+ input_ids = input_ids[:, past_length:]
1265
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1266
+
1267
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1268
+ if (
1269
+ max_cache_length is not None
1270
+ and attention_mask is not None
1271
+ and cache_length + input_ids.shape[1] > max_cache_length
1272
+ ):
1273
+ attention_mask = attention_mask[:, -max_cache_length:] # pylint: disable=E1130
1274
+
1275
+ position_ids = kwargs.get("position_ids", None)
1276
+ if attention_mask is not None and position_ids is None:
1277
+ # create position_ids on the fly for batch generation
1278
+ position_ids = attention_mask.long().cumsum(-1) - 1
1279
+ position_ids.masked_fill_(attention_mask == 0, 1)
1280
+ if past_key_values:
1281
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1282
+
1283
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1284
+ if inputs_embeds is not None and past_key_values is None:
1285
+ model_inputs = {"inputs_embeds": inputs_embeds}
1286
+ else:
1287
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1288
+ # recompiles graphs as the stride of the inputs is a guard.
1289
+ # Ref: https://github.com/huggingface/transformers/pull/29114
1290
+ # TODO: use `next_tokens` directly instead.
1291
+ model_inputs = {"input_ids": input_ids.contiguous()}
1292
+
1293
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1294
+ if cache_position is None:
1295
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
1296
+ elif use_cache:
1297
+ cache_position = cache_position[-input_length:]
1298
+
1299
+ model_inputs.update(
1300
+ {
1301
+ "position_ids": position_ids,
1302
+ "cache_position": cache_position,
1303
+ "past_key_values": past_key_values,
1304
+ "use_cache": use_cache,
1305
+ "attention_mask": attention_mask,
1306
+ }
1307
+ )
1308
+ return model_inputs
1309
+
1310
+ @staticmethod
1311
+ def _reorder_cache(past_key_values, beam_idx):
1312
+ reordered_past = ()
1313
+ for layer_past in past_key_values:
1314
+ reordered_past += (
1315
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1316
+ )
1317
+ return reordered_past
1318
+
1319
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, meta_instruction=""):
1320
+ if history is None:
1321
+ history = []
1322
+ if tokenizer.add_bos_token:
1323
+ prompt = ""
1324
+ else:
1325
+ prompt = tokenizer.bos_token
1326
+ if meta_instruction:
1327
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1328
+ for record in history:
1329
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1330
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1331
+ return tokenizer([prompt], return_tensors="pt")
1332
+
1333
+ @torch.no_grad()
1334
+ def chat(
1335
+ self,
1336
+ tokenizer,
1337
+ query: str,
1338
+ history: Optional[List[Tuple[str, str]]] = None,
1339
+ streamer: Optional[BaseStreamer] = None,
1340
+ max_new_tokens: int = 1024,
1341
+ do_sample: bool = True,
1342
+ temperature: float = 0.8,
1343
+ top_p: float = 0.8,
1344
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1345
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
1346
+ "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1347
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such "
1348
+ "as English and 中文.",
1349
+ **kwargs,
1350
+ ):
1351
+ if history is None:
1352
+ history = []
1353
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1354
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1355
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1356
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
1357
+ outputs = self.generate(
1358
+ **inputs,
1359
+ streamer=streamer,
1360
+ max_new_tokens=max_new_tokens,
1361
+ do_sample=do_sample,
1362
+ temperature=temperature,
1363
+ top_p=top_p,
1364
+ eos_token_id=eos_token_id,
1365
+ **kwargs,
1366
+ )
1367
+ outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1368
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
1369
+ response = response.split("<|im_end|>")[0]
1370
+ history = history + [(query, response)]
1371
+ return response, history
1372
+
1373
+ @torch.no_grad()
1374
+ def stream_chat(
1375
+ self,
1376
+ tokenizer,
1377
+ query: str,
1378
+ history: List[Tuple[str, str]] = None,
1379
+ max_new_tokens: int = 1024,
1380
+ do_sample: bool = True,
1381
+ temperature: float = 0.8,
1382
+ top_p: float = 0.8,
1383
+ **kwargs,
1384
+ ):
1385
+ if history is None:
1386
+ history = []
1387
+ """
1388
+ Return a generator in format: (response, history)
1389
+ Eg.
1390
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1391
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1392
+ """
1393
+ if BaseStreamer is None:
1394
+ raise ModuleNotFoundError(
1395
+ "The version of `transformers` is too low. Please make sure "
1396
+ "that you have installed `transformers>=4.28.0`."
1397
+ )
1398
+
1399
+ response_queue = queue.Queue(maxsize=20)
1400
+
1401
+ class ChatStreamer(BaseStreamer):
1402
+ """
1403
+ Streamer used in generate to print words one by one.
1404
+ """
1405
+
1406
+ def __init__(self, tokenizer) -> None:
1407
+ super().__init__()
1408
+ self.tokenizer = tokenizer
1409
+ self.queue = response_queue
1410
+ self.query = query
1411
+ self.history = history
1412
+ self.response = ""
1413
+ self.cache = []
1414
+ self.received_inputs = False
1415
+ self.queue.put((self.response, history + [(self.query, self.response)]))
1416
+
1417
+ def put(self, value):
1418
+ if len(value.shape) > 1 and value.shape[0] > 1:
1419
+ raise ValueError("ChatStreamer only supports batch size 1")
1420
+ elif len(value.shape) > 1:
1421
+ value = value[0]
1422
+
1423
+ if not self.received_inputs:
1424
+ # The first received value is input_ids, ignore here
1425
+ self.received_inputs = True
1426
+ return
1427
+
1428
+ self.cache.extend(value.tolist())
1429
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
1430
+ if token.strip() != "<|im_end|>":
1431
+ self.response = self.response + token
1432
+ history = self.history + [(self.query, self.response)]
1433
+ self.queue.put((self.response, history))
1434
+ self.cache = []
1435
+ else:
1436
+ self.end()
1437
+
1438
+ def end(self):
1439
+ self.queue.put(None)
1440
+
1441
+ def stream_producer():
1442
+ return self.chat(
1443
+ tokenizer=tokenizer,
1444
+ query=query,
1445
+ streamer=ChatStreamer(tokenizer=tokenizer),
1446
+ history=history,
1447
+ max_new_tokens=max_new_tokens,
1448
+ do_sample=do_sample,
1449
+ temperature=temperature,
1450
+ top_p=top_p,
1451
+ **kwargs,
1452
+ )
1453
+
1454
+ def consumer():
1455
+ producer = threading.Thread(target=stream_producer)
1456
+ producer.start()
1457
+ while True:
1458
+ res = response_queue.get()
1459
+ if res is None:
1460
+ return
1461
+ yield res
1462
+
1463
+ return consumer()
1464
+
1465
+
1466
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1467
+ @add_start_docstrings(
1468
+ """
1469
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1470
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1471
+ (e.g. GPT-2) do.
1472
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1473
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1474
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1475
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1476
+ each row of the batch).
1477
+ """,
1478
+ InternLM2_START_DOCSTRING,
1479
+ )
1480
+ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1481
+ """Sequence Classification Head for InternLM2 Model."""
1482
+
1483
+ def __init__(self, config):
1484
+ super().__init__(config)
1485
+ self.num_labels = config.num_labels
1486
+ self.model = InternLM2Model(config)
1487
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1488
+
1489
+ # Initialize weights and apply final processing
1490
+ self.post_init()
1491
+
1492
+ def get_input_embeddings(self):
1493
+ return self.model.tok_embeddings
1494
+
1495
+ def set_input_embeddings(self, value):
1496
+ self.model.tok_embeddings = value
1497
+
1498
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1499
+ def forward(
1500
+ self,
1501
+ input_ids: torch.LongTensor = None,
1502
+ attention_mask: Optional[torch.Tensor] = None,
1503
+ position_ids: Optional[torch.LongTensor] = None,
1504
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1505
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1506
+ labels: Optional[torch.LongTensor] = None,
1507
+ use_cache: Optional[bool] = None,
1508
+ output_attentions: Optional[bool] = None,
1509
+ output_hidden_states: Optional[bool] = None,
1510
+ return_dict: Optional[bool] = None,
1511
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1512
+ r"""
1513
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1514
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1515
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1516
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1517
+ """
1518
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1519
+
1520
+ transformer_outputs = self.model(
1521
+ input_ids,
1522
+ attention_mask=attention_mask,
1523
+ position_ids=position_ids,
1524
+ past_key_values=past_key_values,
1525
+ inputs_embeds=inputs_embeds,
1526
+ use_cache=use_cache,
1527
+ output_attentions=output_attentions,
1528
+ output_hidden_states=output_hidden_states,
1529
+ return_dict=return_dict,
1530
+ )
1531
+ hidden_states = transformer_outputs[0]
1532
+ logits = self.score(hidden_states)
1533
+
1534
+ if input_ids is not None:
1535
+ batch_size = input_ids.shape[0]
1536
+ else:
1537
+ batch_size = inputs_embeds.shape[0]
1538
+
1539
+ if self.config.pad_token_id is None and batch_size != 1:
1540
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1541
+ if self.config.pad_token_id is None:
1542
+ sequence_lengths = -1
1543
+ else:
1544
+ if input_ids is not None:
1545
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1546
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1547
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1548
+ sequence_lengths = sequence_lengths.to(logits.device)
1549
+ else:
1550
+ sequence_lengths = -1
1551
+
1552
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1553
+
1554
+ loss = None
1555
+ if labels is not None:
1556
+ labels = labels.to(logits.device)
1557
+ if self.config.problem_type is None:
1558
+ if self.num_labels == 1:
1559
+ self.config.problem_type = "regression"
1560
+ elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
1561
+ self.config.problem_type = "single_label_classification"
1562
+ else:
1563
+ self.config.problem_type = "multi_label_classification"
1564
+
1565
+ if self.config.problem_type == "regression":
1566
+ loss_fct = MSELoss()
1567
+ if self.num_labels == 1:
1568
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1569
+ else:
1570
+ loss = loss_fct(pooled_logits, labels)
1571
+ elif self.config.problem_type == "single_label_classification":
1572
+ loss_fct = CrossEntropyLoss()
1573
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1574
+ elif self.config.problem_type == "multi_label_classification":
1575
+ loss_fct = BCEWithLogitsLoss()
1576
+ loss = loss_fct(pooled_logits, labels)
1577
+ if not return_dict:
1578
+ output = (pooled_logits,) + transformer_outputs[1:]
1579
+ return ((loss,) + output) if loss is not None else output
1580
+
1581
+ return SequenceClassifierOutputWithPast(
1582
+ loss=loss,
1583
+ logits=pooled_logits,
1584
+ past_key_values=transformer_outputs.past_key_values,
1585
+ hidden_states=transformer_outputs.hidden_states,
1586
+ attentions=transformer_outputs.attentions,
1587
+ )
1588
+
1589
+
1590
+ # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->InternLM2
1591
+ @add_start_docstrings(
1592
+ """
1593
+ The InternLM2 Model transformer with a span classification head on top for extractive question-answering tasks like
1594
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1595
+ """,
1596
+ InternLM2_START_DOCSTRING,
1597
+ )
1598
+ class InternLM2ForQuestionAnswering(InternLM2PreTrainedModel):
1599
+ """Question Answering model for InternLM2."""
1600
+
1601
+ base_model_prefix = "transformer"
1602
+
1603
+ def __init__(self, config):
1604
+ super().__init__(config)
1605
+ self.transformer = InternLM2Model(config)
1606
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1607
+
1608
+ # Initialize weights and apply final processing
1609
+ self.post_init()
1610
+
1611
+ def get_input_embeddings(self):
1612
+ return self.transformer.tok_embeddings
1613
+
1614
+ def set_input_embeddings(self, value):
1615
+ self.transformer.tok_embeddings = value
1616
+
1617
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1618
+ def forward(
1619
+ self,
1620
+ input_ids: Optional[torch.LongTensor] = None,
1621
+ attention_mask: Optional[torch.FloatTensor] = None,
1622
+ position_ids: Optional[torch.LongTensor] = None,
1623
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1624
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1625
+ start_positions: Optional[torch.LongTensor] = None,
1626
+ end_positions: Optional[torch.LongTensor] = None,
1627
+ output_attentions: Optional[bool] = None,
1628
+ output_hidden_states: Optional[bool] = None,
1629
+ return_dict: Optional[bool] = None,
1630
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1631
+ r"""
1632
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1633
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1634
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1635
+ are not taken into account for computing the loss.
1636
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1637
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1638
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1639
+ are not taken into account for computing the loss.
1640
+ """
1641
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1642
+
1643
+ outputs = self.transformer(
1644
+ input_ids,
1645
+ attention_mask=attention_mask,
1646
+ position_ids=position_ids,
1647
+ past_key_values=past_key_values,
1648
+ inputs_embeds=inputs_embeds,
1649
+ output_attentions=output_attentions,
1650
+ output_hidden_states=output_hidden_states,
1651
+ return_dict=return_dict,
1652
+ )
1653
+
1654
+ sequence_output = outputs[0]
1655
+
1656
+ logits = self.qa_outputs(sequence_output)
1657
+ start_logits, end_logits = logits.split(1, dim=-1)
1658
+ start_logits = start_logits.squeeze(-1).contiguous()
1659
+ end_logits = end_logits.squeeze(-1).contiguous()
1660
+
1661
+ total_loss = None
1662
+ if start_positions is not None and end_positions is not None:
1663
+ # If we are on multi-GPU, split add a dimension
1664
+ if len(start_positions.size()) > 1:
1665
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1666
+ if len(end_positions.size()) > 1:
1667
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1668
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1669
+ ignored_index = start_logits.size(1)
1670
+ start_positions = start_positions.clamp(0, ignored_index)
1671
+ end_positions = end_positions.clamp(0, ignored_index)
1672
+
1673
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1674
+ start_loss = loss_fct(start_logits, start_positions)
1675
+ end_loss = loss_fct(end_logits, end_positions)
1676
+ total_loss = (start_loss + end_loss) / 2
1677
+
1678
+ if not return_dict:
1679
+ output = (start_logits, end_logits) + outputs[2:]
1680
+ return ((total_loss,) + output) if total_loss is not None else output
1681
+
1682
+ return QuestionAnsweringModelOutput(
1683
+ loss=total_loss,
1684
+ start_logits=start_logits,
1685
+ end_logits=end_logits,
1686
+ hidden_states=outputs.hidden_states,
1687
+ attentions=outputs.attentions,
1688
+ )
1689
+
1690
+
1691
+ # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->InternLM2
1692
+ @add_start_docstrings(
1693
+ """
1694
+ The InternLM2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1695
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1696
+ """,
1697
+ InternLM2_START_DOCSTRING,
1698
+ )
1699
+ class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
1700
+ """Token classification model for InternLM2."""
1701
+
1702
+ def __init__(self, config):
1703
+ super().__init__(config)
1704
+ self.num_labels = config.num_labels
1705
+ self.model = InternLM2Model(config)
1706
+ if getattr(config, "classifier_dropout", None) is not None:
1707
+ classifier_dropout = config.classifier_dropout
1708
+ elif getattr(config, "hidden_dropout", None) is not None:
1709
+ classifier_dropout = config.hidden_dropout
1710
+ else:
1711
+ classifier_dropout = 0.1
1712
+ self.dropout = nn.Dropout(classifier_dropout)
1713
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1714
+
1715
+ # Initialize weights and apply final processing
1716
+ self.post_init()
1717
+
1718
+ def get_input_embeddings(self):
1719
+ return self.model.tok_embeddings
1720
+
1721
+ def set_input_embeddings(self, value):
1722
+ self.model.tok_embeddings = value
1723
+
1724
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1725
+ def forward(
1726
+ self,
1727
+ input_ids: torch.LongTensor = None,
1728
+ attention_mask: Optional[torch.Tensor] = None,
1729
+ position_ids: Optional[torch.LongTensor] = None,
1730
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1731
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1732
+ labels: Optional[torch.LongTensor] = None,
1733
+ use_cache: Optional[bool] = None,
1734
+ output_attentions: Optional[bool] = None,
1735
+ output_hidden_states: Optional[bool] = None,
1736
+ return_dict: Optional[bool] = None,
1737
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1738
+ r"""
1739
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1740
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1741
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1742
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1743
+ """
1744
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1745
+
1746
+ outputs = self.model(
1747
+ input_ids,
1748
+ attention_mask=attention_mask,
1749
+ position_ids=position_ids,
1750
+ past_key_values=past_key_values,
1751
+ inputs_embeds=inputs_embeds,
1752
+ use_cache=use_cache,
1753
+ output_attentions=output_attentions,
1754
+ output_hidden_states=output_hidden_states,
1755
+ return_dict=return_dict,
1756
+ )
1757
+ sequence_output = outputs[0]
1758
+ sequence_output = self.dropout(sequence_output)
1759
+ logits = self.score(sequence_output)
1760
+
1761
+ loss = None
1762
+ if labels is not None:
1763
+ loss_fct = CrossEntropyLoss()
1764
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1765
+
1766
+ if not return_dict:
1767
+ output = (logits,) + outputs[2:]
1768
+ return ((loss,) + output) if loss is not None else output
1769
+
1770
+ return TokenClassifierOutput(
1771
+ loss=loss,
1772
+ logits=logits,
1773
+ hidden_states=outputs.hidden_states,
1774
+ attentions=outputs.attentions,
1775
+ )
1776
+
1777
+
1778
+ # Modified from transformers.models.llama.modeling_llama.LlamaForTokenClassification
1779
+ class InternLM2ForRewardModel(InternLM2PreTrainedModel):
1780
+
1781
+ _auto_class = "AutoModel"
1782
+ _tied_weights_keys = ["v_head.weight"]
1783
+
1784
+ def __init__(self, config):
1785
+ super().__init__(config)
1786
+ self.model = InternLM2Model(config)
1787
+ self.vocab_size = config.vocab_size
1788
+ self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
1789
+ self.reward_token_id = config.reward_token_id
1790
+
1791
+ # Initialize weights and apply final processing
1792
+ self.post_init()
1793
+
1794
+ def get_input_embeddings(self):
1795
+ return self.model.tok_embeddings
1796
+
1797
+ def set_input_embeddings(self, value):
1798
+ self.model.tok_embeddings = value
1799
+
1800
+ def get_output_embeddings(self):
1801
+ return self.v_head
1802
+
1803
+ def set_output_embeddings(self, new_embeddings):
1804
+ self.v_head = new_embeddings
1805
+
1806
+ def set_decoder(self, decoder):
1807
+ self.model = decoder
1808
+
1809
+ def get_decoder(self):
1810
+ return self.model
1811
+
1812
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1813
+ @replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
1814
+ def forward(
1815
+ self,
1816
+ input_ids: torch.LongTensor = None,
1817
+ attention_mask: Optional[torch.Tensor] = None,
1818
+ position_ids: Optional[torch.LongTensor] = None,
1819
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1820
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1821
+ labels: Optional[torch.LongTensor] = None,
1822
+ use_cache: Optional[bool] = None,
1823
+ output_attentions: Optional[bool] = None,
1824
+ output_hidden_states: Optional[bool] = None,
1825
+ return_dict: Optional[bool] = None,
1826
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1827
+ """
1828
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1829
+ Labels for computing the sequence classification/regression loss.
1830
+
1831
+ Returns:
1832
+ """
1833
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1834
+ output_hidden_states = (
1835
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1836
+ )
1837
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1838
+
1839
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1840
+ outputs = self.model(
1841
+ input_ids=input_ids,
1842
+ attention_mask=attention_mask,
1843
+ position_ids=position_ids,
1844
+ past_key_values=past_key_values,
1845
+ inputs_embeds=inputs_embeds,
1846
+ use_cache=use_cache,
1847
+ output_attentions=output_attentions,
1848
+ output_hidden_states=output_hidden_states,
1849
+ return_dict=return_dict,
1850
+ )
1851
+
1852
+ hidden_states = outputs[0]
1853
+ hidden_states = self.v_head(hidden_states)
1854
+ # get end reward token's score
1855
+ ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1, 1)
1856
+
1857
+ reward_scores = torch.gather(hidden_states.squeeze(-1), 1, ends)
1858
+
1859
+ loss = None
1860
+
1861
+ if not return_dict:
1862
+ output = (reward_scores,) + outputs[1:]
1863
+ return (loss,) + output if loss is not None else output
1864
+
1865
+ return SequenceClassifierOutputWithPast(
1866
+ loss=loss,
1867
+ logits=reward_scores,
1868
+ past_key_values=outputs.past_key_values,
1869
+ hidden_states=outputs.hidden_states,
1870
+ attentions=outputs.attentions,
1871
+ )
1872
+
1873
+ @torch.no_grad()
1874
+ def get_score(
1875
+ self,
1876
+ tokenizer,
1877
+ conversation: List[dict],
1878
+ **kwargs,
1879
+ ):
1880
+ """
1881
+ Computes the reward score for a given conversation.
1882
+ This function takes a conversation represented as a list of dictionaries, formats it into a string using the chat
1883
+ template from the tokenizer, and passes it through the model to compute the score. A special token representing
1884
+ the reward score is appended to the input sequence. The reward score is then extracted from the model's output.
1885
+ Args:
1886
+ tokenizer: The tokenizer to be used for formatting and tokenizing the conversation.
1887
+ conversation (List[dict]): A list of dictionaries where each dictionary represents a message in the conversation.
1888
+ Returns:
1889
+ float: The computed reward score from the model.
1890
+ """
1891
+ conversation_str = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
1892
+ input_ids = tokenizer.encode(conversation_str, return_tensors="pt", add_special_tokens=False)
1893
+ # add reward score token at the end of the input_ids if it is not already there
1894
+ if input_ids[0, -1] != self.reward_token_id:
1895
+ input_ids = torch.cat([input_ids, torch.tensor([[self.reward_token_id]], dtype=torch.long)], dim=1)
1896
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1897
+
1898
+ outputs = self.forward(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device), **kwargs)
1899
+ score = outputs[0].cpu().item()
1900
+ return score
1901
+
1902
+ @torch.no_grad()
1903
+ def get_scores(
1904
+ self,
1905
+ tokenizer,
1906
+ conversations: List[List[dict]],
1907
+ **kwargs,
1908
+ ):
1909
+ """
1910
+ Computes the reward scores for multiple conversations in a batched manner.
1911
+ This function takes multiple conversations, each represented as a list of dictionaries, formats them into strings using the chat
1912
+ template from the tokenizer, and passes these formatted strings through the model to compute scores for each conversation.
1913
+ Each input sequence has a special token representing the reward score appended before passing to the model.
1914
+ The reward scores are then extracted from the model's output.
1915
+ Args:
1916
+ tokenizer: The tokenizer to be used for formatting and tokenizing the conversation.
1917
+ conversations (List[List[dict]]): A list of conversations, with each conversation represented as a list of dictionaries where each dictionary contains a message.
1918
+ Returns:
1919
+ List[float]: A list of computed reward scores for each conversation in the input batch.
1920
+ """
1921
+ conversation_strs = [tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) for conversation in conversations]
1922
+ batch_input_ids = []
1923
+ attention_masks = []
1924
+
1925
+ for conversation_str in conversation_strs:
1926
+ input_ids = tokenizer.encode(conversation_str, return_tensors="pt", add_special_tokens=False)
1927
+ # add reward score token at the end of the input_ids if it is not already there
1928
+ if input_ids[0, -1] != self.reward_token_id:
1929
+ input_ids = torch.cat([input_ids, torch.tensor([[self.reward_token_id]], dtype=torch.long)], dim=1)
1930
+ input_ids = input_ids.squeeze(0)
1931
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.bool)
1932
+ batch_input_ids.append(input_ids)
1933
+ attention_masks.append(attention_mask)
1934
+
1935
+ r_pad_batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
1936
+ r_pad_attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=False)
1937
+
1938
+ outputs = self.forward(input_ids=r_pad_batch_input_ids.to(self.device), attention_mask=r_pad_attention_masks.to(self.device), **kwargs)
1939
+ scores = outputs[0].squeeze().cpu().tolist()
1940
+ return scores
1941
+
1942
+ @torch.no_grad()
1943
+ def compare(
1944
+ self,
1945
+ tokenizer,
1946
+ conversation1: List[dict],
1947
+ conversation2: List[dict],
1948
+ return_logits: bool = False,
1949
+ **kwargs,
1950
+ ):
1951
+ """
1952
+ Compares the reward scores of two conversations and determines which conversation has a higher score.
1953
+ This function computes reward scores for two given conversations using the `get_score` method and compares the scores to determine which conversation has a higher score.
1954
+ The function can optionally return the actual scores (logits) along with the comparison result.
1955
+ Parameters:
1956
+ tokenizer: The tokenizer used for formatting and tokenizing the conversation.
1957
+ conversation1 (List[dict]): The first conversation to compare, represented as a list of dictionaries where each dictionary contains a message.
1958
+ conversation2 (List[dict]): The second conversation to compare, similarly represented.
1959
+ return_logits (bool, optional): If True, the function returns both the comparison result and the actual scores of the two conversations. Defaults to False.
1960
+ Returns:
1961
+ bool: True if the score of the first conversation is greater than the second, otherwise False.
1962
+ List[float] (optional): A list containing the scores of the first and second conversations respectively.
1963
+ Note:
1964
+ - This function is designed for inference, with `@torch.no_grad()` used to disable gradient calculations to optimize performance.
1965
+ """
1966
+ score1 = self.get_score(tokenizer, conversation1, **kwargs)
1967
+ score2 = self.get_score(tokenizer, conversation2, **kwargs)
1968
+ if return_logits:
1969
+ return score1 > score2, [score1, score2]
1970
+ else:
1971
+ return score1 > score2
1972
+
1973
+ @torch.no_grad()
1974
+ def rank(
1975
+ self,
1976
+ tokenizer,
1977
+ conversations: List[List[dict]],
1978
+ return_logits: bool = False,
1979
+ **kwargs,
1980
+ ):
1981
+ """
1982
+ Ranks the conversations based on their scores.
1983
+ Args:
1984
+ tokenizer: The tokenizer to be used for formatting and tokenizing the conversation.
1985
+ conversations: A list of conversations, where each conversation is represented as a list of dictionaries. Each dictionary contains the necessary information for the conversation.
1986
+ return_logits: If True, returns the conversation indices along with their logits. Defaults to False.
1987
+ Returns:
1988
+ list: A list of conversation rank indices based on their scores. Smaller index means higher score.
1989
+ List[float] (optional): If return_logits is True, a list of conversation indices and their corresponding logits.
1990
+ """
1991
+ scores = self.get_scores(tokenizer, conversations, **kwargs)
1992
+ if return_logits:
1993
+ return sorted(range(len(scores)), key=lambda i: scores[i], reverse=True), scores
1994
+ else:
1995
+ return sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eda59ff79a6112fb9e16a8c7de4b1a7a30b37755a88f7485958855f2e84d0255
3
+ size 1981412716
pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbe5758a4303a9fe665ce10d6870d3ee44d7705663dd330846f178aed52d19c4
3
+ size 1417809006
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 3399163904
4
+ },
5
+ "weight_map": {
6
+ "model.layers.0.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
7
+ "model.layers.0.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
8
+ "model.layers.0.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
9
+ "model.layers.0.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
10
+ "model.layers.0.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
11
+ "model.layers.0.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
12
+ "model.layers.0.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
13
+ "model.layers.1.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
14
+ "model.layers.1.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
15
+ "model.layers.1.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
16
+ "model.layers.1.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
17
+ "model.layers.1.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
18
+ "model.layers.1.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
19
+ "model.layers.1.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
20
+ "model.layers.10.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
21
+ "model.layers.10.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
22
+ "model.layers.10.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
23
+ "model.layers.10.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
24
+ "model.layers.10.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
25
+ "model.layers.10.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
26
+ "model.layers.10.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
27
+ "model.layers.11.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
28
+ "model.layers.11.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
29
+ "model.layers.11.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
30
+ "model.layers.11.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
31
+ "model.layers.11.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
32
+ "model.layers.11.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
33
+ "model.layers.11.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
34
+ "model.layers.12.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
35
+ "model.layers.12.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
36
+ "model.layers.12.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
37
+ "model.layers.12.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
38
+ "model.layers.12.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
39
+ "model.layers.12.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
40
+ "model.layers.12.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
41
+ "model.layers.13.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
42
+ "model.layers.13.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
43
+ "model.layers.13.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
44
+ "model.layers.13.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
45
+ "model.layers.13.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
46
+ "model.layers.13.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
47
+ "model.layers.13.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
48
+ "model.layers.14.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
49
+ "model.layers.14.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
50
+ "model.layers.14.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
51
+ "model.layers.14.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
52
+ "model.layers.14.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
53
+ "model.layers.14.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
54
+ "model.layers.14.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
55
+ "model.layers.15.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
56
+ "model.layers.15.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
57
+ "model.layers.15.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
58
+ "model.layers.15.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
59
+ "model.layers.15.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
60
+ "model.layers.15.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
61
+ "model.layers.15.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
62
+ "model.layers.16.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
63
+ "model.layers.16.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
64
+ "model.layers.16.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
65
+ "model.layers.16.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
66
+ "model.layers.16.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
67
+ "model.layers.16.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
68
+ "model.layers.16.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
69
+ "model.layers.17.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
70
+ "model.layers.17.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
71
+ "model.layers.17.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
72
+ "model.layers.17.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
73
+ "model.layers.17.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
74
+ "model.layers.17.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
75
+ "model.layers.17.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
76
+ "model.layers.18.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
77
+ "model.layers.18.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
78
+ "model.layers.18.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
79
+ "model.layers.18.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
80
+ "model.layers.18.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
81
+ "model.layers.18.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
82
+ "model.layers.18.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
83
+ "model.layers.19.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
84
+ "model.layers.19.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
85
+ "model.layers.19.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
86
+ "model.layers.19.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
87
+ "model.layers.19.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
88
+ "model.layers.19.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
89
+ "model.layers.19.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
90
+ "model.layers.2.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
91
+ "model.layers.2.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
92
+ "model.layers.2.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
93
+ "model.layers.2.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
94
+ "model.layers.2.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
95
+ "model.layers.2.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
96
+ "model.layers.2.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
97
+ "model.layers.20.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
98
+ "model.layers.20.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
99
+ "model.layers.20.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
100
+ "model.layers.20.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
101
+ "model.layers.20.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
102
+ "model.layers.20.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
103
+ "model.layers.20.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
104
+ "model.layers.21.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
105
+ "model.layers.21.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
106
+ "model.layers.21.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
107
+ "model.layers.21.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
108
+ "model.layers.21.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
109
+ "model.layers.21.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
110
+ "model.layers.21.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
111
+ "model.layers.22.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
112
+ "model.layers.22.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
113
+ "model.layers.22.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
114
+ "model.layers.22.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
115
+ "model.layers.22.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
116
+ "model.layers.22.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
117
+ "model.layers.22.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
118
+ "model.layers.23.attention.wo.weight": "pytorch_model-00002-of-00002.bin",
119
+ "model.layers.23.attention.wqkv.weight": "pytorch_model-00002-of-00002.bin",
120
+ "model.layers.23.attention_norm.weight": "pytorch_model-00002-of-00002.bin",
121
+ "model.layers.23.feed_forward.w1.weight": "pytorch_model-00002-of-00002.bin",
122
+ "model.layers.23.feed_forward.w2.weight": "pytorch_model-00002-of-00002.bin",
123
+ "model.layers.23.feed_forward.w3.weight": "pytorch_model-00002-of-00002.bin",
124
+ "model.layers.23.ffn_norm.weight": "pytorch_model-00002-of-00002.bin",
125
+ "model.layers.3.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
126
+ "model.layers.3.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
127
+ "model.layers.3.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
128
+ "model.layers.3.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
129
+ "model.layers.3.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
130
+ "model.layers.3.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
131
+ "model.layers.3.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
132
+ "model.layers.4.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
133
+ "model.layers.4.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
134
+ "model.layers.4.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
135
+ "model.layers.4.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
136
+ "model.layers.4.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
137
+ "model.layers.4.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
138
+ "model.layers.4.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
139
+ "model.layers.5.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
140
+ "model.layers.5.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
141
+ "model.layers.5.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
142
+ "model.layers.5.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
143
+ "model.layers.5.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
144
+ "model.layers.5.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
145
+ "model.layers.5.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
146
+ "model.layers.6.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
147
+ "model.layers.6.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
148
+ "model.layers.6.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
149
+ "model.layers.6.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
150
+ "model.layers.6.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
151
+ "model.layers.6.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
152
+ "model.layers.6.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
153
+ "model.layers.7.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
154
+ "model.layers.7.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
155
+ "model.layers.7.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
156
+ "model.layers.7.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
157
+ "model.layers.7.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
158
+ "model.layers.7.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
159
+ "model.layers.7.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
160
+ "model.layers.8.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
161
+ "model.layers.8.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
162
+ "model.layers.8.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
163
+ "model.layers.8.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
164
+ "model.layers.8.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
165
+ "model.layers.8.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
166
+ "model.layers.8.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
167
+ "model.layers.9.attention.wo.weight": "pytorch_model-00001-of-00002.bin",
168
+ "model.layers.9.attention.wqkv.weight": "pytorch_model-00001-of-00002.bin",
169
+ "model.layers.9.attention_norm.weight": "pytorch_model-00001-of-00002.bin",
170
+ "model.layers.9.feed_forward.w1.weight": "pytorch_model-00001-of-00002.bin",
171
+ "model.layers.9.feed_forward.w2.weight": "pytorch_model-00001-of-00002.bin",
172
+ "model.layers.9.feed_forward.w3.weight": "pytorch_model-00001-of-00002.bin",
173
+ "model.layers.9.ffn_norm.weight": "pytorch_model-00001-of-00002.bin",
174
+ "model.norm.weight": "pytorch_model-00002-of-00002.bin",
175
+ "model.tok_embeddings.weight": "pytorch_model-00001-of-00002.bin",
176
+ "v_head.weight": "pytorch_model-00002-of-00002.bin"
177
+ }
178
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_internlm2.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization classes for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, List, Optional, Tuple
22
+
23
+ import sentencepiece as spm
24
+ from transformers.tokenization_utils import PreTrainedTokenizer
25
+ from transformers.utils import logging
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {}
32
+
33
+
34
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
35
+ class InternLM2Tokenizer(PreTrainedTokenizer):
36
+ """
37
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
38
+
39
+ Args:
40
+ vocab_file (`str`):
41
+ Path to the vocabulary file.
42
+ """
43
+
44
+ vocab_files_names = VOCAB_FILES_NAMES
45
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
46
+ model_input_names = ["input_ids", "attention_mask"]
47
+ _auto_class = "AutoTokenizer"
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_file,
52
+ unk_token="<unk>",
53
+ bos_token="<s>",
54
+ eos_token="</s>",
55
+ pad_token="</s>",
56
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
57
+ add_bos_token=True,
58
+ add_eos_token=False,
59
+ decode_with_prefix_space=False,
60
+ clean_up_tokenization_spaces=False,
61
+ **kwargs,
62
+ ):
63
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
64
+ self.vocab_file = vocab_file
65
+ self.add_bos_token = add_bos_token
66
+ self.add_eos_token = add_eos_token
67
+ self.decode_with_prefix_space = decode_with_prefix_space
68
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
69
+ self.sp_model.Load(vocab_file)
70
+ self._no_prefix_space_tokens = None
71
+ super().__init__(
72
+ bos_token=bos_token,
73
+ eos_token=eos_token,
74
+ unk_token=unk_token,
75
+ pad_token=pad_token,
76
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
77
+ **kwargs,
78
+ )
79
+
80
+ @property
81
+ def no_prefix_space_tokens(self):
82
+ if self._no_prefix_space_tokens is None:
83
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
84
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
85
+ return self._no_prefix_space_tokens
86
+
87
+ @property
88
+ def vocab_size(self):
89
+ """Returns vocab size"""
90
+ return self.sp_model.get_piece_size()
91
+
92
+ @property
93
+ def bos_token_id(self) -> Optional[int]:
94
+ return self.sp_model.bos_id()
95
+
96
+ @property
97
+ def eos_token_id(self) -> Optional[int]:
98
+ return self.sp_model.eos_id()
99
+
100
+ def get_vocab(self):
101
+ """Returns vocab as a dict"""
102
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
103
+ vocab.update(self.added_tokens_encoder)
104
+ return vocab
105
+
106
+ def _tokenize(self, text):
107
+ """Returns a tokenized string."""
108
+ return self.sp_model.encode(text, out_type=str)
109
+
110
+ def _convert_token_to_id(self, token):
111
+ """Converts a token (str) in an id using the vocab."""
112
+ return self.sp_model.piece_to_id(token)
113
+
114
+ def _convert_id_to_token(self, index):
115
+ """Converts an index (integer) in a token (str) using the vocab."""
116
+ token = self.sp_model.IdToPiece(index)
117
+ return token
118
+
119
+ def _maybe_add_prefix_space(self, tokens, decoded):
120
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
121
+ return " " + decoded
122
+ else:
123
+ return decoded
124
+
125
+ def convert_tokens_to_string(self, tokens):
126
+ """Converts a sequence of tokens (string) in a single string."""
127
+ current_sub_tokens = []
128
+ out_string = ""
129
+ prev_is_special = False
130
+ for token in tokens:
131
+ # make sure that special tokens are not decoded using sentencepiece model
132
+ if token in self.all_special_tokens:
133
+ if not prev_is_special:
134
+ out_string += " "
135
+ out_string += self.sp_model.decode(current_sub_tokens) + token
136
+ prev_is_special = True
137
+ current_sub_tokens = []
138
+ else:
139
+ current_sub_tokens.append(token)
140
+ prev_is_special = False
141
+ out_string += self.sp_model.decode(current_sub_tokens)
142
+ out_string = self.clean_up_tokenization(out_string)
143
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
144
+ return out_string[1:]
145
+
146
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
147
+ """
148
+ Save the vocabulary and special tokens file to a directory.
149
+
150
+ Args:
151
+ save_directory (`str`):
152
+ The directory in which to save the vocabulary.
153
+
154
+ Returns:
155
+ `Tuple(str)`: Paths to the files saved.
156
+ """
157
+ if not os.path.isdir(save_directory):
158
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
159
+ return
160
+ out_vocab_file = os.path.join(
161
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
162
+ )
163
+
164
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
165
+ copyfile(self.vocab_file, out_vocab_file)
166
+ elif not os.path.isfile(self.vocab_file):
167
+ with open(out_vocab_file, "wb") as fi:
168
+ content_spiece_model = self.sp_model.serialized_model_proto()
169
+ fi.write(content_spiece_model)
170
+
171
+ return (out_vocab_file,)
172
+
173
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
174
+ if self.add_bos_token:
175
+ bos_token_ids = [self.bos_token_id]
176
+ else:
177
+ bos_token_ids = []
178
+
179
+ output = bos_token_ids + token_ids_0
180
+
181
+ if token_ids_1 is not None:
182
+ output = output + token_ids_1
183
+
184
+ if self.add_eos_token:
185
+ output = output + [self.eos_token_id]
186
+
187
+ return output
188
+
189
+ def get_special_tokens_mask(
190
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
191
+ ) -> List[int]:
192
+ """
193
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
194
+ special tokens using the tokenizer `prepare_for_model` method.
195
+
196
+ Args:
197
+ token_ids_0 (`List[int]`):
198
+ List of IDs.
199
+ token_ids_1 (`List[int]`, *optional*):
200
+ Optional second list of IDs for sequence pairs.
201
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
202
+ Whether or not the token list is already formatted with special tokens for the model.
203
+
204
+ Returns:
205
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
206
+ """
207
+ if already_has_special_tokens:
208
+ return super().get_special_tokens_mask(
209
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
210
+ )
211
+
212
+ if token_ids_1 is None:
213
+ return [1] + ([0] * len(token_ids_0)) + [1]
214
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
215
+
216
+ def create_token_type_ids_from_sequences(
217
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
218
+ ) -> List[int]:
219
+ """
220
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
221
+ use of token type ids, therefore a list of zeros is returned.
222
+
223
+ Args:
224
+ token_ids_0 (`List[int]`):
225
+ List of IDs.
226
+ token_ids_1 (`List[int]`, *optional*):
227
+ Optional second list of IDs for sequence pairs.
228
+
229
+ Returns:
230
+ `List[int]`: List of zeros.
231
+ """
232
+ eos = [self.eos_token_id]
233
+
234
+ if token_ids_1 is None:
235
+ return len(token_ids_0 + eos) * [0]
236
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
tokenization_internlm2_fast.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """Tokenization Fast class for InternLM."""
19
+ import os
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, Optional, Tuple
22
+
23
+ from tokenizers import processors, decoders, Tokenizer, normalizers
24
+ from tokenizers.models import BPE
25
+
26
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
27
+ from transformers.utils import logging
28
+
29
+ from transformers.convert_slow_tokenizer import (
30
+ SLOW_TO_FAST_CONVERTERS,
31
+ SpmConverter,
32
+ SentencePieceExtractor,
33
+ )
34
+
35
+ from .tokenization_internlm2 import InternLM2Tokenizer
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
40
+
41
+ # Modified from transformers.convert_slow_tokenizer.LlamaConverter
42
+ class InternLM2Converter(SpmConverter):
43
+ handle_byte_fallback = True
44
+
45
+ def vocab(self, proto):
46
+ vocab = [
47
+ ("<unk>", 0.0),
48
+ ("<s>", 0.0),
49
+ ("</s>", 0.0),
50
+ ]
51
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
52
+ return vocab
53
+
54
+ def unk_id(self, proto):
55
+ unk_id = 0
56
+ return unk_id
57
+
58
+ def decoder(self, replacement, add_prefix_space):
59
+ decoders_sequence = [
60
+ decoders.Replace("▁", " "),
61
+ decoders.ByteFallback(),
62
+ decoders.Fuse(),
63
+ ]
64
+ if self.proto.normalizer_spec.add_dummy_prefix:
65
+ decoders_sequence.append(decoders.Strip(content=" ", left=1))
66
+ return decoders.Sequence(decoders_sequence)
67
+
68
+ def tokenizer(self, proto):
69
+ model_type = proto.trainer_spec.model_type
70
+ vocab_scores = self.vocab(proto)
71
+ # special tokens
72
+ added_tokens = self.original_tokenizer.added_tokens_decoder
73
+ for i in range(len(vocab_scores)):
74
+ piece, score = vocab_scores[i]
75
+ if i in added_tokens:
76
+ vocab_scores[i] = (added_tokens[i].content, score)
77
+ if model_type == 1:
78
+ raise RuntimeError("InternLM2 is supposed to be a BPE model!")
79
+
80
+ elif model_type == 2:
81
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
82
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
83
+ tokenizer = Tokenizer(
84
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
85
+ )
86
+ tokenizer.add_special_tokens(
87
+ [ added_token for index, added_token in added_tokens.items()]
88
+ )
89
+ else:
90
+ raise Exception(
91
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
92
+ )
93
+
94
+ return tokenizer
95
+
96
+ def normalizer(self, proto):
97
+ normalizers_list = []
98
+ if proto.normalizer_spec.add_dummy_prefix:
99
+ normalizers_list.append(normalizers.Prepend(prepend="▁"))
100
+ normalizers_list.append(normalizers.Replace(pattern=" ", content="▁"))
101
+ return normalizers.Sequence(normalizers_list)
102
+
103
+ def pre_tokenizer(self, replacement, add_prefix_space):
104
+ return None
105
+
106
+ SLOW_TO_FAST_CONVERTERS["InternLM2Tokenizer"] = InternLM2Converter
107
+
108
+
109
+ # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
110
+ class InternLM2TokenizerFast(PreTrainedTokenizerFast):
111
+ vocab_files_names = VOCAB_FILES_NAMES
112
+ slow_tokenizer_class = InternLM2Tokenizer
113
+ padding_side = "left"
114
+ model_input_names = ["input_ids", "attention_mask"]
115
+ _auto_class = "AutoTokenizer"
116
+
117
+ def __init__(
118
+ self,
119
+ vocab_file,
120
+ unk_token="<unk>",
121
+ bos_token="<s>",
122
+ eos_token="</s>",
123
+ pad_token="</s>",
124
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
125
+ add_bos_token=True,
126
+ add_eos_token=False,
127
+ decode_with_prefix_space=False,
128
+ clean_up_tokenization_spaces=False,
129
+ **kwargs,
130
+ ):
131
+ super().__init__(
132
+ vocab_file=vocab_file,
133
+ unk_token=unk_token,
134
+ bos_token=bos_token,
135
+ eos_token=eos_token,
136
+ pad_token=pad_token,
137
+ sp_model_kwargs=sp_model_kwargs,
138
+ add_bos_token=add_bos_token,
139
+ add_eos_token=add_eos_token,
140
+ decode_with_prefix_space=decode_with_prefix_space,
141
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
142
+ **kwargs,
143
+ )
144
+ self._add_bos_token = add_bos_token
145
+ self._add_eos_token = add_eos_token
146
+ self.update_post_processor()
147
+ self.vocab_file = vocab_file
148
+
149
+ @property
150
+ def can_save_slow_tokenizer(self) -> bool:
151
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
152
+
153
+ def update_post_processor(self):
154
+ """
155
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
156
+ """
157
+ bos = self.bos_token
158
+ bos_token_id = self.bos_token_id
159
+ if bos is None and self.add_bos_token:
160
+ raise ValueError("add_bos_token = True but bos_token = None")
161
+
162
+ eos = self.eos_token
163
+ eos_token_id = self.eos_token_id
164
+ if eos is None and self.add_eos_token:
165
+ raise ValueError("add_eos_token = True but eos_token = None")
166
+
167
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
168
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
169
+
170
+ special_tokens = []
171
+ if self.add_bos_token:
172
+ special_tokens.append((bos, bos_token_id))
173
+ if self.add_eos_token:
174
+ special_tokens.append((eos, eos_token_id))
175
+ self._tokenizer.post_processor = processors.TemplateProcessing(
176
+ single=single, pair=pair, special_tokens=special_tokens
177
+ )
178
+
179
+ @property
180
+ def add_eos_token(self):
181
+ return self._add_eos_token
182
+
183
+ @property
184
+ def add_bos_token(self):
185
+ return self._add_bos_token
186
+
187
+ @add_eos_token.setter
188
+ def add_eos_token(self, value):
189
+ self._add_eos_token = value
190
+ self.update_post_processor()
191
+
192
+ @add_bos_token.setter
193
+ def add_bos_token(self, value):
194
+ self._add_bos_token = value
195
+ self.update_post_processor()
196
+
197
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
198
+ if not self.can_save_slow_tokenizer:
199
+ raise ValueError(
200
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
201
+ "tokenizer."
202
+ )
203
+
204
+ if not os.path.isdir(save_directory):
205
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
206
+ return
207
+ out_vocab_file = os.path.join(
208
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
209
+ )
210
+
211
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
212
+ copyfile(self.vocab_file, out_vocab_file)
213
+
214
+ return (out_vocab_file,)
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20286a7a382e842fe9fbbee3113a916fb407c69320f97fc1dc07ab3a7985db5f
3
+ size 10540451
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f868398fc4e05ee1e8aeba95ddf18ddcc45b8bce55d5093bead5bbf80429b48b
3
+ size 1477754
tokenizer_config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "92397": {
30
+ "content": "<|reward|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "92527": {
38
+ "content": "[UNUSED_TOKEN_130]",
39
+ "single_word": false,
40
+ "lstrip": false,
41
+ "rstrip": false,
42
+ "normalized": false,
43
+ "special": true
44
+ }
45
+ },
46
+ "additional_special_tokens": [
47
+ "[UNUSED_TOKEN_130]"
48
+ ],
49
+ "auto_map": {
50
+ "AutoTokenizer": [
51
+ "tokenization_internlm2.InternLM2Tokenizer",
52
+ "tokenization_internlm2_fast.InternLM2TokenizerFast"
53
+ ]
54
+ },
55
+ "bos_token": "<s>",
56
+ "clean_up_tokenization_spaces": false,
57
+ "decode_with_prefix_space": false,
58
+ "eos_token": "</s>",
59
+ "extra_special_tokens": {},
60
+ "model_max_length": 1000000000000000019884624838656,
61
+ "pad_token": "</s>",
62
+ "padding_side": "left",
63
+ "sp_model_kwargs": null,
64
+ "tokenizer_class": "InternLM2Tokenizer",
65
+ "unk_token": "<unk>"
66
+ }
xtuner_config.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accumulative_counts = 1
2
+ avg_num_per_pack = 5
3
+ batch_size = 1
4
+ betas = (
5
+ 0.9,
6
+ 0.95,
7
+ )
8
+ custom_hooks = [
9
+ dict(type='xtuner.engine.hooks.VarlenAttnArgsToMessageHubHook'),
10
+ ]
11
+ data_num = 142595996
12
+ data_path = '/cpfs01/shared/llm_ddd/zouyicheng/rm_pretrain/data/train'
13
+ dataloader_num_workers = 0
14
+ default_hooks = dict(
15
+ checkpoint=dict(
16
+ by_epoch=False,
17
+ interval=2000,
18
+ max_keep_ckpts=-1,
19
+ type='mmengine.hooks.CheckpointHook'),
20
+ logger=dict(
21
+ interval=10,
22
+ log_metric_by_epoch=False,
23
+ type='mmengine.hooks.LoggerHook'),
24
+ param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'),
25
+ sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'),
26
+ timer=dict(type='mmengine.hooks.IterTimerHook'))
27
+ env_cfg = dict(
28
+ cudnn_benchmark=False,
29
+ dist_cfg=dict(backend='nccl'),
30
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
31
+ launcher = 'pytorch'
32
+ load_from = None
33
+ log_level = 'INFO'
34
+ log_processor = dict(by_epoch=False)
35
+ loss_type = 'ranking'
36
+ lr = 1e-05
37
+ max_epochs = 1
38
+ max_length = 16384
39
+ max_norm = 1
40
+ max_packed_length = 32768
41
+ max_response_length = 5120
42
+ model = dict(
43
+ llm=dict(
44
+ pretrained_model_name_or_path=
45
+ '/cpfs01/shared/llm_ddd/zouyicheng/xtuner/model/internlm2_5-1_8b',
46
+ trust_remote_code=True,
47
+ type='transformers.AutoModelForCausalLM.from_pretrained'),
48
+ loss_type='ranking',
49
+ penalty_type='none',
50
+ type='xtuner.model.reward.RewardModel',
51
+ use_varlen_attn=True)
52
+ optim_type = 'torch.optim.AdamW'
53
+ optim_wrapper = dict(
54
+ optimizer=dict(
55
+ betas=(
56
+ 0.9,
57
+ 0.95,
58
+ ),
59
+ lr=1e-05,
60
+ type='torch.optim.AdamW',
61
+ weight_decay=0),
62
+ type='DeepSpeedOptimWrapper')
63
+ param_scheduler = [
64
+ dict(
65
+ begin=0,
66
+ by_epoch=True,
67
+ convert_to_iter_based=True,
68
+ end=0.03,
69
+ start_factor=1.0000000000000002e-06,
70
+ type='mmengine.optim.LinearLR'),
71
+ dict(
72
+ begin=0.03,
73
+ by_epoch=True,
74
+ convert_to_iter_based=True,
75
+ end=1,
76
+ eta_min=1.0000000000000002e-06,
77
+ type='mmengine.optim.CosineAnnealingLR'),
78
+ ]
79
+ penalty_type = 'none'
80
+ pretrained_model_name_or_path = '/cpfs01/shared/llm_ddd/zouyicheng/xtuner/model/internlm2_5-1_8b'
81
+ randomness = dict(deterministic=False, seed=None)
82
+ resume = False
83
+ reward_token_id = 92527
84
+ runner_type = 'FlexibleRunner'
85
+ sampler = 'mmengine.dataset.DefaultSampler'
86
+ save_steps = 2000
87
+ save_total_limit = -1
88
+ sequence_parallel_size = 1
89
+ strategy = dict(
90
+ config=dict(
91
+ bf16=dict(enabled=True),
92
+ fp16=dict(enabled=False, initial_scale_power=16),
93
+ gradient_accumulation_steps='auto',
94
+ gradient_clipping='auto',
95
+ train_micro_batch_size_per_gpu='auto',
96
+ zero_allow_untested_optimizer=True,
97
+ zero_force_ds_cpu_optimizer=False,
98
+ zero_optimization=dict(overlap_comm=True, stage=1)),
99
+ exclude_frozen_parameters=True,
100
+ gradient_accumulation_steps=1,
101
+ gradient_clipping=1,
102
+ sequence_parallel_size=1,
103
+ train_micro_batch_size_per_gpu=1,
104
+ type='xtuner.engine.DeepSpeedStrategy')
105
+ tokenizer = dict(
106
+ padding_side='left',
107
+ pretrained_model_name_or_path=
108
+ '/cpfs01/shared/llm_ddd/zouyicheng/xtuner/model/internlm2_5-1_8b',
109
+ trust_remote_code=True,
110
+ type='transformers.AutoTokenizer.from_pretrained')
111
+ train_cfg = dict(max_epochs=1, type='xtuner.engine.runner.TrainLoop')
112
+ train_dataloader = dict(
113
+ batch_size=1,
114
+ collate_fn=dict(
115
+ type=
116
+ 'xtuner.dataset.collate_fns.preference_collate_fn.preference_collate_fn',
117
+ use_varlen_attn=True),
118
+ dataset=dict(
119
+ avg_num_per_pack=5,
120
+ data_num=142595996,
121
+ dataset=dict(
122
+ path='/cpfs01/shared/llm_ddd/zouyicheng/rm_pretrain/data/train',
123
+ streaming=True,
124
+ type='datasets.load_dataset'),
125
+ dataset_map_fn=None,
126
+ is_dpo=False,
127
+ is_reward=True,
128
+ max_length=16384,
129
+ max_packed_length=32768,
130
+ max_response_length=5120,
131
+ num_proc=32,
132
+ reward_token_id=92527,
133
+ shuffle_before_pack=True,
134
+ tokenizer=dict(
135
+ padding_side='left',
136
+ pretrained_model_name_or_path=
137
+ '/cpfs01/shared/llm_ddd/zouyicheng/xtuner/model/internlm2_5-1_8b',
138
+ trust_remote_code=True,
139
+ type='transformers.AutoTokenizer.from_pretrained'),
140
+ type=
141
+ 'xtuner.dataset.preference_dataset.build_preference_dataset_stream',
142
+ use_varlen_attn=True),
143
+ drop_last=True,
144
+ num_workers=0)
145
+ train_dataset = dict(
146
+ avg_num_per_pack=5,
147
+ data_num=142595996,
148
+ dataset=dict(
149
+ path='/cpfs01/shared/llm_ddd/zouyicheng/rm_pretrain/data/train',
150
+ streaming=True,
151
+ type='datasets.load_dataset'),
152
+ dataset_map_fn=None,
153
+ is_dpo=False,
154
+ is_reward=True,
155
+ max_length=16384,
156
+ max_packed_length=32768,
157
+ max_response_length=5120,
158
+ num_proc=32,
159
+ reward_token_id=92527,
160
+ shuffle_before_pack=True,
161
+ tokenizer=dict(
162
+ padding_side='left',
163
+ pretrained_model_name_or_path=
164
+ '/cpfs01/shared/llm_ddd/zouyicheng/xtuner/model/internlm2_5-1_8b',
165
+ trust_remote_code=True,
166
+ type='transformers.AutoTokenizer.from_pretrained'),
167
+ type='xtuner.dataset.preference_dataset.build_preference_dataset_stream',
168
+ use_varlen_attn=True)
169
+ use_varlen_attn = True
170
+ visualizer = dict(
171
+ type='mmengine.visualization.Visualizer',
172
+ vis_backends=[
173
+ dict(type='mmengine.visualization.TensorboardVisBackend'),
174
+ ])
175
+ warmup_ratio = 0.03
176
+ weight_decay = 0
177
+ work_dir = './work_dirs/RM_PT_internlm2_5_1_8b_DATA_140m_single_mix_Node_40_LR_1e_5'