danielhanchen commited on
Commit
238a784
·
verified ·
1 Parent(s): e235f83

Add files using upload-large-folder tool

Browse files
Files changed (48) hide show
  1. .gitattributes +2 -0
  2. README.md +384 -0
  3. accuracy_plot.png +0 -0
  4. bias.md +4 -0
  5. block_config.py +118 -0
  6. chat_template.jinja +1 -0
  7. config.json +1486 -0
  8. configuration_decilm.py +65 -0
  9. explainability.md +12 -0
  10. flow.png +3 -0
  11. generation_config.json +11 -0
  12. model-00001-of-00021.safetensors +3 -0
  13. model-00002-of-00021.safetensors +3 -0
  14. model-00003-of-00021.safetensors +3 -0
  15. model-00004-of-00021.safetensors +3 -0
  16. model-00005-of-00021.safetensors +3 -0
  17. model-00006-of-00021.safetensors +3 -0
  18. model-00007-of-00021.safetensors +3 -0
  19. model-00008-of-00021.safetensors +3 -0
  20. model-00009-of-00021.safetensors +3 -0
  21. model-00010-of-00021.safetensors +3 -0
  22. model-00011-of-00021.safetensors +3 -0
  23. model-00012-of-00021.safetensors +3 -0
  24. model-00013-of-00021.safetensors +3 -0
  25. model-00014-of-00021.safetensors +3 -0
  26. model-00015-of-00021.safetensors +3 -0
  27. model-00016-of-00021.safetensors +3 -0
  28. model-00017-of-00021.safetensors +3 -0
  29. model-00018-of-00021.safetensors +3 -0
  30. model-00019-of-00021.safetensors +3 -0
  31. model-00020-of-00021.safetensors +3 -0
  32. model-00021-of-00021.safetensors +3 -0
  33. model.safetensors.index.json +575 -0
  34. modeling_decilm.py +1681 -0
  35. privacy.md +9 -0
  36. safety.md +6 -0
  37. special_tokens_map.json +23 -0
  38. tokenizer.json +3 -0
  39. tokenizer_config.json +2067 -0
  40. transformers_4_44_2__activations.py +239 -0
  41. transformers_4_44_2__cache_utils.py +1347 -0
  42. transformers_4_44_2__configuration_llama.py +203 -0
  43. transformers_4_44_2__modeling_attn_mask_utils.py +482 -0
  44. transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py +348 -0
  45. transformers_4_44_2__modeling_outputs.py +0 -0
  46. transformers_4_44_2__modeling_rope_utils.py +559 -0
  47. transformers_4_44_2__pytorch_utils.py +17 -0
  48. variable_cache.py +139 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ flow.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model:
3
+ - nvidia/Llama-3_3-Nemotron-Super-49B-v1
4
+ library_name: transformers
5
+ license: other
6
+ license_name: nvidia-open-model-license
7
+ license_link: >-
8
+ https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/
9
+
10
+ pipeline_tag: text-generation
11
+ language:
12
+ - en
13
+ tags:
14
+ - nvidia
15
+ - unsloth
16
+ - llama-3
17
+ - pytorch
18
+ ---
19
+ <div>
20
+ <p style="margin-top: 0;margin-bottom: 0;">
21
+ <em><a href="https://docs.unsloth.ai/basics/unsloth-dynamic-v2.0-gguf">Unsloth Dynamic 2.0</a> achieves superior accuracy & outperforms other leading quants.</em>
22
+ </p>
23
+ <div style="display: flex; gap: 5px; align-items: center; ">
24
+ <a href="https://github.com/unslothai/unsloth/">
25
+ <img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="133">
26
+ </a>
27
+ <a href="https://discord.gg/unsloth">
28
+ <img src="https://github.com/unslothai/unsloth/raw/main/images/Discord%20button.png" width="173">
29
+ </a>
30
+ <a href="https://docs.unsloth.ai/basics/qwen3-how-to-run-and-fine-tune">
31
+ <img src="https://raw.githubusercontent.com/unslothai/unsloth/refs/heads/main/images/documentation%20green%20button.png" width="143">
32
+ </a>
33
+ </div>
34
+ </div>
35
+
36
+
37
+ # Llama-3.3-Nemotron-Super-49B-v1
38
+
39
+ ## Model Overview
40
+
41
+ ![Accuracy Comparison Plot](./accuracy_plot.png)
42
+
43
+ Llama-3.3-Nemotron-Super-49B-v1 is a large language model (LLM) which is a derivative of [Meta Llama-3.3-70B-Instruct](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct) (AKA the *reference model*). It is a reasoning model that is post trained for reasoning, human chat preferences, and tasks, such as RAG and tool calling. The model supports a context length of 128K tokens.
44
+
45
+ Llama-3.3-Nemotron-Super-49B-v1 is a model which offers a great tradeoff between model accuracy and efficiency. Efficiency (throughput) directly translates to savings. Using a novel Neural Architecture Search (NAS) approach, we greatly reduce the model’s memory footprint, enabling larger workloads, as well as fitting the model on a single GPU at high workloads (H200). This NAS approach enables the selection of a desired point in the accuracy-efficiency tradeoff. For more information on the NAS approach, please refer to [this paper](https://arxiv.org/abs/2411.19146).
46
+
47
+ The model underwent a multi-phase post-training process to enhance both its reasoning and non-reasoning capabilities. This includes a supervised fine-tuning stage for Math, Code, Reasoning, and Tool Calling as well as multiple reinforcement learning (RL) stages using REINFORCE (RLOO) and Online Reward-aware Preference Optimization (RPO) algorithms for both chat and instruction-following. The final model checkpoint is obtained after merging the final SFT and Online RPO checkpoints. For more details on how the model was trained, please see our [technical report](https://arxiv.org/abs/2505.00949) and [blog](https://developer.nvidia.com/blog/build-enterprise-ai-agents-with-advanced-open-nvidia-llama-nemotron-reasoning-models/).
48
+ ![Training Process](flow.png)
49
+
50
+ This model is part of the Llama Nemotron Collection. You can find the other model(s) in this family here:
51
+ - [Llama-3.1-Nemotron-Nano-8B-v1](https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-8B-v1)
52
+ - [Llama-3.1-Nemotron-Ultra-253B-v1](https://huggingface.co/nvidia/Llama-3_1-Nemotron-Ultra-253B-v1)
53
+
54
+ This model is ready for commercial use.
55
+
56
+ ## License/Terms of Use
57
+
58
+ GOVERNING TERMS: Your use of this model is governed by the [NVIDIA Open Model License.](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) \
59
+ Additional Information: [Llama 3.3 Community License Agreement](https://www.llama.com/llama3_3/license/). Built with Llama.
60
+
61
+ **Model Developer:** NVIDIA
62
+
63
+ **Model Dates:** Trained between November 2024 and February 2025
64
+
65
+ **Data Freshness:** The pretraining data has a cutoff of 2023 per Meta Llama 3.3 70B
66
+
67
+ ### Use Case: <br>
68
+ Developers designing AI Agent systems, chatbots, RAG systems, and other AI-powered applications. Also suitable for typical instruction-following tasks. <br>
69
+
70
+ ### Release Date: <br>
71
+ 3/18/2025 <br>
72
+
73
+ ## References
74
+
75
+ * [\[2505.00949\] Llama-Nemotron: Efficient Reasoning Models](https://arxiv.org/abs/2505.00949)
76
+ * [[2411.19146] Puzzle: Distillation-Based NAS for Inference-Optimized LLMs](https://arxiv.org/abs/2411.19146)
77
+ * [[2502.00203] Reward-aware Preference Optimization: A Unified Mathematical Framework for Model Alignment](https://arxiv.org/abs/2502.00203)
78
+
79
+ ## Model Architecture
80
+ **Architecture Type:** Dense decoder-only Transformer model \
81
+ **Network Architecture:** Llama 3.3 70B Instruct, customized through Neural Architecture Search (NAS)
82
+
83
+ The model is a derivative of Meta’s Llama-3.3-70B-Instruct, using Neural Architecture Search (NAS). The NAS algorithm results in non-standard and non-repetitive blocks. This includes the following:
84
+ * Skip attention: In some blocks, the attention is skipped entirely, or replaced with a single linear layer.
85
+ * Variable FFN: The expansion/compression ratio in the FFN layer is different between blocks.
86
+
87
+ We utilize a block-wise distillation of the reference model, where for each block we create multiple variants providing different tradeoffs of quality vs. computational complexity, discussed in more depth below. We then search over the blocks to create a model which meets the required throughput and memory (optimized for a single H100-80GB GPU) while minimizing the quality degradation. The model then undergoes knowledge distillation (KD), with a focus on English single and multi-turn chat use-cases. The KD step included 40 billion tokens consisting of a mixture of 3 datasets - FineWeb, Buzz-V1.2 and Dolma.
88
+
89
+ ## Intended use
90
+
91
+ Llama-3.3-Nemotron-Super-49B-v1 is a general purpose reasoning and chat model intended to be used in English and coding languages. Other non-English languages (German, French, Italian, Portuguese, Hindi, Spanish, and Thai) are also supported.
92
+
93
+ ## Input
94
+ - **Input Type:** Text
95
+ - **Input Format:** String
96
+ - **Input Parameters:** One-Dimensional (1D)
97
+ - **Other Properties Related to Input:** Context length up to 131,072 tokens
98
+
99
+ ## Output
100
+ - **Output Type:** Text
101
+ - **Output Format:** String
102
+ - **Output Parameters:** One-Dimensional (1D)
103
+ - **Other Properties Related to Output:** Context length up to 131,072 tokens
104
+
105
+ ## Model Version
106
+ 1.0 (3/18/2025)
107
+
108
+ ## Software Integration
109
+ - **Runtime Engine:** Transformers
110
+ - **Recommended Hardware Microarchitecture Compatibility:**
111
+ - NVIDIA Hopper
112
+ - NVIDIA Ampere
113
+
114
+ ## Quick Start and Usage Recommendations:
115
+
116
+ 1. Reasoning mode (ON/OFF) is controlled via the system prompt, which must be set as shown in the example below. All instructions should be contained within the user prompt
117
+ 2. We recommend setting temperature to `0.6`, and Top P to `0.95` for Reasoning ON mode
118
+ 3. We recommend using greedy decoding for Reasoning OFF mode
119
+ 4. We have provided a list of prompts to use for evaluation for each benchmark where a specific template is required
120
+ 5. The model will include `<think></think>` if no reasoning was necessary in Reasoning ON model, this is expected behaviour
121
+
122
+ You can try this model out through the preview API, using this link: [Llama-3_3-Nemotron-Super-49B-v1](https://build.nvidia.com/nvidia/llama-3_3-nemotron-super-49b-v1).
123
+
124
+ ### Use It with Transformers
125
+ See the snippet below for usage with [Hugging Face Transformers](https://huggingface.co/docs/transformers/main/en/index) library. Reasoning mode (ON/OFF) is controlled via system prompt. Please see the example below
126
+
127
+ We recommend using the *transformers* package with version 4.48.3.
128
+ Example of reasoning on:
129
+
130
+ ```py
131
+ import torch
132
+ import transformers
133
+
134
+ model_id = "nvidia/Llama-3_3-Nemotron-Super-49B-v1"
135
+ model_kwargs = {"torch_dtype": torch.bfloat16, "trust_remote_code": True, "device_map": "auto"}
136
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
137
+ tokenizer.pad_token_id = tokenizer.eos_token_id
138
+
139
+ pipeline = transformers.pipeline(
140
+ "text-generation",
141
+ model=model_id,
142
+ tokenizer=tokenizer,
143
+ max_new_tokens=32768,
144
+ temperature=0.6,
145
+ top_p=0.95,
146
+ **model_kwargs
147
+ )
148
+
149
+ thinking = "on"
150
+
151
+ print(pipeline([{"role": "system", "content": f"detailed thinking {thinking}"},{"role": "user", "content": "Solve x*(sin(x)+2)=0"}]))
152
+ ```
153
+
154
+ Example of reasoning off:
155
+
156
+ ```py
157
+ import torch
158
+ import transformers
159
+
160
+ model_id = "nvidia/Llama-3_3-Nemotron-Super-49B-v1"
161
+ model_kwargs = {"torch_dtype": torch.bfloat16, "trust_remote_code": True, "device_map": "auto"}
162
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
163
+ tokenizer.pad_token_id = tokenizer.eos_token_id
164
+
165
+ pipeline = transformers.pipeline(
166
+ "text-generation",
167
+ model=model_id,
168
+ tokenizer=tokenizer,
169
+ max_new_tokens=32768,
170
+ do_sample=False,
171
+ **model_kwargs
172
+ )
173
+
174
+ # Thinking can be "on" or "off"
175
+ thinking = "off"
176
+
177
+ print(pipeline([{"role": "system", "content": f"detailed thinking {thinking}"},{"role": "user", "content": "Solve x*(sin(x)+2)=0"}]))
178
+ ```
179
+
180
+ ### Use It with vLLM
181
+
182
+ ```
183
+ pip install vllm==0.8.3
184
+ ```
185
+ An example on how to serve with vLLM:
186
+ ```
187
+ python3 -m vllm.entrypoints.openai.api_server \
188
+ --model "nvidia/Llama-3_3-Nemotron-Super-49B-v1" \
189
+ --trust-remote-code \
190
+ --seed=1 \
191
+ --host="0.0.0.0" \
192
+ --port=5000 \
193
+ --served-model-name "nvidia/Llama-3_3-Nemotron-Super-49B-v1" \
194
+ --tensor-parallel-size=8 \
195
+ --max-model-len=32768 \
196
+ --gpu-memory-utilization 0.95 \
197
+ --enforce-eager
198
+ ```
199
+
200
+ ## Inference:
201
+
202
+ **Engine:**
203
+ - Transformers
204
+
205
+ **Test Hardware:**
206
+ - FP8: 1x NVIDIA H100-80GB GPU (Coming Soon!)
207
+ - BF16:
208
+ - 2x NVIDIA H100-80GB
209
+ - 2x NVIDIA A100-80GB GPUs
210
+
211
+ **[Preferred/Supported] Operating System(s):** Linux <br>
212
+
213
+ ## Training Datasets
214
+
215
+ A large variety of training data was used for the knowledge distillation phase before post-training pipeline, 3 of which included: FineWeb, Buzz-V1.2, and Dolma.
216
+
217
+ The data for the multi-stage post-training phases for improvements in Code, Math, and Reasoning is a compilation of SFT and RL data that supports improvements of math, code, general reasoning, and instruction following capabilities of the original Llama instruct model.
218
+
219
+ In conjunction with this model release, NVIDIA has released 30M samples of post-training data, as public and permissive. Please see [Llama-Nemotron-Postraining-Dataset-v1](https://huggingface.co/datasets/nvidia/Llama-Nemotron-Post-Training-Dataset-v1).
220
+
221
+ Distribution of the domains is as follows:
222
+
223
+ | Category | Value |
224
+ |----------|-----------|
225
+ | math | 19,840,970|
226
+ | code | 9,612,677 |
227
+ | science | 708,920 |
228
+ | instruction following | 56,339 |
229
+ | chat | 39,792 |
230
+ | safety | 31,426 |
231
+
232
+ Prompts have been sourced from either public and open corpus or synthetically generated. Responses were synthetically generated by a variety of models, with some prompts containing responses for both reasoning on and off modes, to train the model to distinguish between two modes.
233
+
234
+
235
+ **Data Collection for Training Datasets:**
236
+
237
+ - Hybrid: Automated, Human, Synthetic
238
+
239
+ **Data Labeling for Training Datasets:**
240
+
241
+ - Hybrid: Automated, Human, Synthetic
242
+
243
+ ## Evaluation Datasets
244
+
245
+ We used the datasets listed below to evaluate Llama-3.3-Nemotron-Super-49B-v1.
246
+
247
+ Data Collection for Evaluation Datasets:
248
+
249
+ - Hybrid: Human/Synthetic
250
+
251
+ Data Labeling for Evaluation Datasets:
252
+
253
+ - Hybrid: Human/Synthetic/Automatic
254
+
255
+ ## Evaluation Results
256
+ These results contain both “Reasoning On”, and “Reasoning Off”. We recommend using temperature=`0.6`, top_p=`0.95` for “Reasoning On” mode, and greedy decoding for “Reasoning Off” mode. All evaluations are done with 32k sequence length. We run the benchmarks up to 16 times and average the scores to be more accurate.
257
+
258
+ > NOTE: Where applicable, a Prompt Template will be provided. While completing benchmarks, please ensure that you are parsing for the correct output format as per the provided prompt in order to reproduce the benchmarks seen below.
259
+
260
+ ### Arena-Hard
261
+
262
+ | Reasoning Mode | Score |
263
+ |--------------|------------|
264
+ | Reasoning Off | 88.3 |
265
+
266
+ ### MATH500
267
+
268
+ | Reasoning Mode | pass@1 |
269
+ |--------------|------------|
270
+ | Reasoning Off | 74.0 |
271
+ | Reasoning On | 96.6 |
272
+
273
+ User Prompt Template:
274
+
275
+ ```
276
+ "Below is a math question. I want you to reason through the steps and then give a final answer. Your final answer should be in \boxed{}.\nQuestion: {question}"
277
+ ```
278
+ ### AIME25
279
+
280
+ | Reasoning Mode | pass@1 |
281
+ |--------------|------------|
282
+ | Reasoning Off | 13.33 |
283
+ | Reasoning On | 58.4 |
284
+
285
+ User Prompt Template:
286
+
287
+ ```
288
+ "Below is a math question. I want you to reason through the steps and then give a final answer. Your final answer should be in \boxed{}.\nQuestion: {question}"
289
+ ```
290
+
291
+ ### GPQA
292
+
293
+ | Reasoning Mode | pass@1 |
294
+ |--------------|------------|
295
+ | Reasoning Off | 50 |
296
+ | Reasoning On | 66.67 |
297
+
298
+ User Prompt Template:
299
+
300
+ ```
301
+ "What is the correct answer to this question: {question}\nChoices:\nA. {option_A}\nB. {option_B}\nC. {option_C}\nD. {option_D}\nLet's think step by step, and put the final answer (should be a single letter A, B, C, or D) into a \boxed{}"
302
+ ```
303
+
304
+ ### IFEval
305
+
306
+ | Reasoning Mode | Strict:Instruction |
307
+ |--------------|------------|
308
+ | Reasoning Off | 89.21 |
309
+
310
+ ### BFCL V2 Live
311
+
312
+ | Reasoning Mode | Score |
313
+ |--------------|------------|
314
+ | Reasoning Off | 73.7 |
315
+
316
+ User Prompt Template:
317
+
318
+ ```
319
+ You are an expert in composing functions. You are given a question and a set of possible functions.
320
+ Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
321
+ If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
322
+ also point it out. You should only return the function call in tools call sections.
323
+
324
+ If you decide to invoke any of the function(s), you MUST put it in the format of <TOOLCALL>[func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]</TOOLCALL>
325
+
326
+ You SHOULD NOT include any other text in the response.
327
+ Here is a list of functions in JSON format that you can invoke.
328
+
329
+ <AVAILABLE_TOOLS>{functions}</AVAILABLE_TOOLS>
330
+
331
+ {user_prompt}
332
+ ```
333
+
334
+ ### MBPP 0-shot
335
+
336
+ | Reasoning Mode | pass@1 |
337
+ |--------------|------------|
338
+ | Reasoning Off | 84.9|
339
+ | Reasoning On | 91.3 |
340
+
341
+ User Prompt Template:
342
+
343
+ ````
344
+ You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.
345
+
346
+ @@ Instruction
347
+ Here is the given problem and test examples:
348
+ {prompt}
349
+ Please use the python programming language to solve this problem.
350
+ Please make sure that your code includes the functions from the test samples and that the input and output formats of these functions match the test samples.
351
+ Please return all completed codes in one code block.
352
+ This code block should be in the following format:
353
+ ```python
354
+ # Your codes here
355
+ ```
356
+ ````
357
+
358
+ ### MT-Bench
359
+
360
+ | Reasoning Mode | Score |
361
+ |--------------|------------|
362
+ | Reasoning Off | 9.17 |
363
+
364
+ ## Ethical Considerations:
365
+
366
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
367
+
368
+ For more detailed information on ethical considerations for this model, please see the Model Card++ [Explainability](explainability.md), [Bias](bias.md), [Safety & Security](safety.md), and [Privacy](privacy.md) Subcards.
369
+
370
+ Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
371
+
372
+
373
+ ## Citation
374
+ ```
375
+ @misc{bercovich2025llamanemotronefficientreasoningmodels,
376
+ title={Llama-Nemotron: Efficient Reasoning Models},
377
+ author={Akhiad Bercovich and Itay Levy and Izik Golan and Mohammad Dabbah and Ran El-Yaniv and Omri Puny and Ido Galil and Zach Moshe and Tomer Ronen and Najeeb Nabwani and Ido Shahaf and Oren Tropp and Ehud Karpas and Ran Zilberstein and Jiaqi Zeng and Soumye Singhal and Alexander Bukharin and Yian Zhang and Tugrul Konuk and Gerald Shen and Ameya Sunil Mahabaleshwarkar and Bilal Kartal and Yoshi Suhara and Olivier Delalleau and Zijia Chen and Zhilin Wang and David Mosallanezhad and Adi Renduchintala and Haifeng Qian and Dima Rekesh and Fei Jia and Somshubra Majumdar and Vahid Noroozi and Wasi Uddin Ahmad and Sean Narenthiran and Aleksander Ficek and Mehrzad Samadi and Jocelyn Huang and Siddhartha Jain and Igor Gitman and Ivan Moshkov and Wei Du and Shubham Toshniwal and George Armstrong and Branislav Kisacanin and Matvei Novikov and Daria Gitman and Evelina Bakhturina and Jane Polak Scowcroft and John Kamalu and Dan Su and Kezhi Kong and Markus Kliegl and Rabeeh Karimi and Ying Lin and Sanjeev Satheesh and Jupinder Parmar and Pritam Gundecha and Brandon Norick and Joseph Jennings and Shrimai Prabhumoye and Syeda Nahida Akter and Mostofa Patwary and Abhinav Khattar and Deepak Narayanan and Roger Waleffe and Jimmy Zhang and Bor-Yiing Su and Guyue Huang and Terry Kong and Parth Chadha and Sahil Jain and Christine Harvey and Elad Segal and Jining Huang and Sergey Kashirsky and Robert McQueen and Izzy Putterman and George Lam and Arun Venkatesan and Sherry Wu and Vinh Nguyen and Manoj Kilaru and Andrew Wang and Anna Warno and Abhilash Somasamudramath and Sandip Bhaskar and Maka Dong and Nave Assaf and Shahar Mor and Omer Ullman Argov and Scot Junkin and Oleksandr Romanenko and Pedro Larroy and Monika Katariya and Marco Rovinelli and Viji Balas and Nicholas Edelman and Anahita Bhiwandiwalla and Muthu Subramaniam and Smita Ithape and Karthik Ramamoorthy and Yuting Wu and Suguna Varshini Velury and Omri Almog and Joyjit Daw and Denys Fridman and Erick Galinkin and Michael Evans and Katherine Luna and Leon Derczynski and Nikki Pope and Eileen Long and Seth Schneider and Guillermo Siman and Tomasz Grzegorzek and Pablo Ribalta and Monika Katariya and Joey Conway and Trisha Saar and Ann Guan and Krzysztof Pawelec and Shyamala Prayaga and Oleksii Kuchaiev and Boris Ginsburg and Oluwatobi Olabiyi and Kari Briski and Jonathan Cohen and Bryan Catanzaro and Jonah Alben and Yonatan Geifman and Eric Chung and Chris Alexiuk},
378
+ year={2025},
379
+ eprint={2505.00949},
380
+ archivePrefix={arXiv},
381
+ primaryClass={cs.CL},
382
+ url={https://arxiv.org/abs/2505.00949},
383
+ }
384
+ ```
accuracy_plot.png ADDED
bias.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ | Field: | Response: |
2
+ | :---- | :---- |
3
+ | Participation considerations from adversely impacted groups [(protected classes)](https://www.senate.ca.gov/content/protected-classes) in model design and testing: | None |
4
+ | Measures taken to mitigate against unwanted bias: | None |
block_config.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import json
3
+ import warnings
4
+ from dataclasses import dataclass, MISSING
5
+ from functools import partial
6
+ from typing import Optional, Any
7
+
8
+
9
+ @partial(dataclass, frozen=True, kw_only=True)
10
+ class JsonComparable:
11
+ def to_json(self) -> str:
12
+ return json.dumps(dataclasses.asdict(self))
13
+
14
+ def __eq__(self, other: "JsonComparable") -> bool:
15
+ return self.to_json() == other.to_json()
16
+
17
+ def __hash__(self) -> int:
18
+ return hash(self.to_json())
19
+
20
+ def __lt__(self, other: "JsonComparable") -> bool:
21
+ return self.to_json() < other.to_json()
22
+
23
+
24
+ @partial(dataclass, frozen=True, kw_only=True)
25
+ class SubblockConfig(JsonComparable):
26
+ no_op: bool = False
27
+ replace_with_linear: bool = False
28
+ sparsify: Optional[list[str]] = None
29
+
30
+ def __post_init__(self):
31
+ assert not (self.no_op and self.replace_with_linear)
32
+
33
+ def _force_setattr(self, name: str, value: Any) -> None:
34
+ """
35
+ Set an attribute even in frozen dataclasses.
36
+ Use only inside __post_init__!
37
+ """
38
+ object.__setattr__(self, name, value)
39
+
40
+
41
+ @partial(dataclass, frozen=True, kw_only=True)
42
+ class AttentionConfig(SubblockConfig):
43
+ n_heads_in_group: Optional[int] = None
44
+ window_length: Optional[int] = None
45
+ num_sink_tokens: Optional[int] = None
46
+ use_prefill_window_in_sink_attention: bool = False
47
+ unshifted_sink: bool = False
48
+
49
+ def __post_init__(self):
50
+ super().__post_init__()
51
+ assert not (self.no_op and self.replace_with_linear)
52
+
53
+ if self.no_op or self.replace_with_linear:
54
+ for irrelevant_att in ["n_heads_in_group", "window_length", "num_sink_tokens"]:
55
+ self._force_setattr(irrelevant_att, None)
56
+ else:
57
+ assert self.n_heads_in_group is not None
58
+
59
+ if self.is_sink:
60
+ assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), \
61
+ ("Unshifted sink uses its own kind of explicit masking, not standard window. "
62
+ "Set use_prefill_window_in_sink_attention to False.")
63
+ assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), \
64
+ "Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True"
65
+
66
+ @property
67
+ def prefill_sliding_window(self) -> Optional[int]:
68
+ if self.window_length is not None:
69
+ if not self.is_sink or self.use_prefill_window_in_sink_attention:
70
+ return self.window_length
71
+ return None
72
+
73
+ @property
74
+ def is_sliding(self) -> bool:
75
+ return self.prefill_sliding_window is not None
76
+
77
+ @property
78
+ def is_sink(self) -> bool:
79
+ return (
80
+ (self.window_length is not None)
81
+ and
82
+ (self.num_sink_tokens is not None)
83
+ )
84
+
85
+
86
+ @partial(dataclass, frozen=True, kw_only=True)
87
+ class FFNConfig(SubblockConfig):
88
+ ffn_mult: Optional[float] = None
89
+
90
+ def __post_init__(self):
91
+ super().__post_init__()
92
+ if self.no_op or self.replace_with_linear:
93
+ self._force_setattr("ffn_mult", None)
94
+ else:
95
+ assert self.ffn_mult is not None
96
+ self._force_setattr("ffn_mult", round(self.ffn_mult, 6))
97
+
98
+
99
+ @partial(dataclass, frozen=True, kw_only=True)
100
+ class BlockConfig(JsonComparable):
101
+ attention: AttentionConfig = MISSING
102
+ ffn: FFNConfig = MISSING
103
+
104
+ def __post_init__(self):
105
+ """
106
+ Init subblock dataclasses from dicts
107
+ """
108
+ for subblock_name in dataclasses.fields(self):
109
+ subblock_config = getattr(self, subblock_name.name)
110
+ if isinstance(subblock_config, dict):
111
+ subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)]
112
+ unsupported_fields = [field_name for field_name in subblock_config.keys()
113
+ if field_name not in subblock_fields]
114
+ if len(unsupported_fields) > 0:
115
+ warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}")
116
+ subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields}
117
+ object.__setattr__(self, subblock_name.name,
118
+ subblock_name.type(**subblock_config)) # __setattr__ to overcome frozen=True
chat_template.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {{- bos_token }}{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content']|trim %}{%- set messages = messages[1:] %}{%- else %}{%- set system_message = "" %}{%- endif %}{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}{{- system_message }}{{- "<|eot_id|>" }}{%- for message in messages %}{%- if message['role'] == 'assistant' and '</think>' in message['content'] %}{%- set content = message['content'].split('</think>')[-1].lstrip() %}{%- else %}{%- set content = message['content'] %}{%- endif %}{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + content | trim + '<|eot_id|>' }}{%- endfor %}{%- if add_generation_prompt %}{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{%- endif %}
config.json ADDED
@@ -0,0 +1,1486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DeciLMForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_decilm.DeciLMConfig",
9
+ "AutoModelForCausalLM": "modeling_decilm.DeciLMForCausalLM"
10
+ },
11
+ "block_configs": [
12
+ {
13
+ "attention": {
14
+ "n_heads_in_group": 8,
15
+ "no_op": false,
16
+ "num_sink_tokens": null,
17
+ "replace_with_linear": false,
18
+ "sparsify": null,
19
+ "unshifted_sink": false,
20
+ "use_prefill_window_in_sink_attention": false,
21
+ "window_length": null
22
+ },
23
+ "ffn": {
24
+ "ffn_mult": 2.625,
25
+ "no_op": false,
26
+ "replace_with_linear": false,
27
+ "sparsify": null
28
+ }
29
+ },
30
+ {
31
+ "attention": {
32
+ "n_heads_in_group": 8,
33
+ "no_op": false,
34
+ "num_sink_tokens": null,
35
+ "replace_with_linear": false,
36
+ "sparsify": null,
37
+ "unshifted_sink": false,
38
+ "use_prefill_window_in_sink_attention": false,
39
+ "window_length": null
40
+ },
41
+ "ffn": {
42
+ "ffn_mult": 5.25,
43
+ "no_op": false,
44
+ "replace_with_linear": false,
45
+ "sparsify": null
46
+ }
47
+ },
48
+ {
49
+ "attention": {
50
+ "n_heads_in_group": 8,
51
+ "no_op": false,
52
+ "num_sink_tokens": null,
53
+ "replace_with_linear": false,
54
+ "sparsify": null,
55
+ "unshifted_sink": false,
56
+ "use_prefill_window_in_sink_attention": false,
57
+ "window_length": null
58
+ },
59
+ "ffn": {
60
+ "ffn_mult": 5.25,
61
+ "no_op": false,
62
+ "replace_with_linear": false,
63
+ "sparsify": null
64
+ }
65
+ },
66
+ {
67
+ "attention": {
68
+ "n_heads_in_group": 8,
69
+ "no_op": false,
70
+ "num_sink_tokens": null,
71
+ "replace_with_linear": false,
72
+ "sparsify": null,
73
+ "unshifted_sink": false,
74
+ "use_prefill_window_in_sink_attention": false,
75
+ "window_length": null
76
+ },
77
+ "ffn": {
78
+ "ffn_mult": 5.25,
79
+ "no_op": false,
80
+ "replace_with_linear": false,
81
+ "sparsify": null
82
+ }
83
+ },
84
+ {
85
+ "attention": {
86
+ "n_heads_in_group": 8,
87
+ "no_op": false,
88
+ "num_sink_tokens": null,
89
+ "replace_with_linear": false,
90
+ "sparsify": null,
91
+ "unshifted_sink": false,
92
+ "use_prefill_window_in_sink_attention": false,
93
+ "window_length": null
94
+ },
95
+ "ffn": {
96
+ "ffn_mult": 5.25,
97
+ "no_op": false,
98
+ "replace_with_linear": false,
99
+ "sparsify": null
100
+ }
101
+ },
102
+ {
103
+ "attention": {
104
+ "n_heads_in_group": 8,
105
+ "no_op": false,
106
+ "num_sink_tokens": null,
107
+ "replace_with_linear": false,
108
+ "sparsify": null,
109
+ "unshifted_sink": false,
110
+ "use_prefill_window_in_sink_attention": false,
111
+ "window_length": null
112
+ },
113
+ "ffn": {
114
+ "ffn_mult": 5.25,
115
+ "no_op": false,
116
+ "replace_with_linear": false,
117
+ "sparsify": null
118
+ }
119
+ },
120
+ {
121
+ "attention": {
122
+ "n_heads_in_group": null,
123
+ "no_op": true,
124
+ "num_sink_tokens": null,
125
+ "replace_with_linear": false,
126
+ "sparsify": null,
127
+ "unshifted_sink": false,
128
+ "use_prefill_window_in_sink_attention": false,
129
+ "window_length": null
130
+ },
131
+ "ffn": {
132
+ "ffn_mult": 2.625,
133
+ "no_op": false,
134
+ "replace_with_linear": false,
135
+ "sparsify": null
136
+ }
137
+ },
138
+ {
139
+ "attention": {
140
+ "n_heads_in_group": null,
141
+ "no_op": true,
142
+ "num_sink_tokens": null,
143
+ "replace_with_linear": false,
144
+ "sparsify": null,
145
+ "unshifted_sink": false,
146
+ "use_prefill_window_in_sink_attention": false,
147
+ "window_length": null
148
+ },
149
+ "ffn": {
150
+ "ffn_mult": 2.625,
151
+ "no_op": false,
152
+ "replace_with_linear": false,
153
+ "sparsify": null
154
+ }
155
+ },
156
+ {
157
+ "attention": {
158
+ "n_heads_in_group": 8,
159
+ "no_op": false,
160
+ "num_sink_tokens": null,
161
+ "replace_with_linear": false,
162
+ "sparsify": null,
163
+ "unshifted_sink": false,
164
+ "use_prefill_window_in_sink_attention": false,
165
+ "window_length": null
166
+ },
167
+ "ffn": {
168
+ "ffn_mult": 5.25,
169
+ "no_op": false,
170
+ "replace_with_linear": false,
171
+ "sparsify": null
172
+ }
173
+ },
174
+ {
175
+ "attention": {
176
+ "n_heads_in_group": 8,
177
+ "no_op": false,
178
+ "num_sink_tokens": null,
179
+ "replace_with_linear": false,
180
+ "sparsify": null,
181
+ "unshifted_sink": false,
182
+ "use_prefill_window_in_sink_attention": false,
183
+ "window_length": null
184
+ },
185
+ "ffn": {
186
+ "ffn_mult": 5.25,
187
+ "no_op": false,
188
+ "replace_with_linear": false,
189
+ "sparsify": null
190
+ }
191
+ },
192
+ {
193
+ "attention": {
194
+ "n_heads_in_group": 8,
195
+ "no_op": false,
196
+ "num_sink_tokens": null,
197
+ "replace_with_linear": false,
198
+ "sparsify": null,
199
+ "unshifted_sink": false,
200
+ "use_prefill_window_in_sink_attention": false,
201
+ "window_length": null
202
+ },
203
+ "ffn": {
204
+ "ffn_mult": 5.25,
205
+ "no_op": false,
206
+ "replace_with_linear": false,
207
+ "sparsify": null
208
+ }
209
+ },
210
+ {
211
+ "attention": {
212
+ "n_heads_in_group": null,
213
+ "no_op": true,
214
+ "num_sink_tokens": null,
215
+ "replace_with_linear": false,
216
+ "sparsify": null,
217
+ "unshifted_sink": false,
218
+ "use_prefill_window_in_sink_attention": false,
219
+ "window_length": null
220
+ },
221
+ "ffn": {
222
+ "ffn_mult": 3.28125,
223
+ "no_op": false,
224
+ "replace_with_linear": false,
225
+ "sparsify": null
226
+ }
227
+ },
228
+ {
229
+ "attention": {
230
+ "n_heads_in_group": 8,
231
+ "no_op": false,
232
+ "num_sink_tokens": null,
233
+ "replace_with_linear": false,
234
+ "sparsify": null,
235
+ "unshifted_sink": false,
236
+ "use_prefill_window_in_sink_attention": false,
237
+ "window_length": null
238
+ },
239
+ "ffn": {
240
+ "ffn_mult": 5.25,
241
+ "no_op": false,
242
+ "replace_with_linear": false,
243
+ "sparsify": null
244
+ }
245
+ },
246
+ {
247
+ "attention": {
248
+ "n_heads_in_group": 8,
249
+ "no_op": false,
250
+ "num_sink_tokens": null,
251
+ "replace_with_linear": false,
252
+ "sparsify": null,
253
+ "unshifted_sink": false,
254
+ "use_prefill_window_in_sink_attention": false,
255
+ "window_length": null
256
+ },
257
+ "ffn": {
258
+ "ffn_mult": 5.25,
259
+ "no_op": false,
260
+ "replace_with_linear": false,
261
+ "sparsify": null
262
+ }
263
+ },
264
+ {
265
+ "attention": {
266
+ "n_heads_in_group": 8,
267
+ "no_op": false,
268
+ "num_sink_tokens": null,
269
+ "replace_with_linear": false,
270
+ "sparsify": null,
271
+ "unshifted_sink": false,
272
+ "use_prefill_window_in_sink_attention": false,
273
+ "window_length": null
274
+ },
275
+ "ffn": {
276
+ "ffn_mult": 5.25,
277
+ "no_op": false,
278
+ "replace_with_linear": false,
279
+ "sparsify": null
280
+ }
281
+ },
282
+ {
283
+ "attention": {
284
+ "n_heads_in_group": 8,
285
+ "no_op": false,
286
+ "num_sink_tokens": null,
287
+ "replace_with_linear": false,
288
+ "sparsify": null,
289
+ "unshifted_sink": false,
290
+ "use_prefill_window_in_sink_attention": false,
291
+ "window_length": null
292
+ },
293
+ "ffn": {
294
+ "ffn_mult": 5.25,
295
+ "no_op": false,
296
+ "replace_with_linear": false,
297
+ "sparsify": null
298
+ }
299
+ },
300
+ {
301
+ "attention": {
302
+ "n_heads_in_group": 8,
303
+ "no_op": false,
304
+ "num_sink_tokens": null,
305
+ "replace_with_linear": false,
306
+ "sparsify": null,
307
+ "unshifted_sink": false,
308
+ "use_prefill_window_in_sink_attention": false,
309
+ "window_length": null
310
+ },
311
+ "ffn": {
312
+ "ffn_mult": 5.25,
313
+ "no_op": false,
314
+ "replace_with_linear": false,
315
+ "sparsify": null
316
+ }
317
+ },
318
+ {
319
+ "attention": {
320
+ "n_heads_in_group": 8,
321
+ "no_op": false,
322
+ "num_sink_tokens": null,
323
+ "replace_with_linear": false,
324
+ "sparsify": null,
325
+ "unshifted_sink": false,
326
+ "use_prefill_window_in_sink_attention": false,
327
+ "window_length": null
328
+ },
329
+ "ffn": {
330
+ "ffn_mult": 5.25,
331
+ "no_op": false,
332
+ "replace_with_linear": false,
333
+ "sparsify": null
334
+ }
335
+ },
336
+ {
337
+ "attention": {
338
+ "n_heads_in_group": 8,
339
+ "no_op": false,
340
+ "num_sink_tokens": null,
341
+ "replace_with_linear": false,
342
+ "sparsify": null,
343
+ "unshifted_sink": false,
344
+ "use_prefill_window_in_sink_attention": false,
345
+ "window_length": null
346
+ },
347
+ "ffn": {
348
+ "ffn_mult": 5.25,
349
+ "no_op": false,
350
+ "replace_with_linear": false,
351
+ "sparsify": null
352
+ }
353
+ },
354
+ {
355
+ "attention": {
356
+ "n_heads_in_group": 8,
357
+ "no_op": false,
358
+ "num_sink_tokens": null,
359
+ "replace_with_linear": false,
360
+ "sparsify": null,
361
+ "unshifted_sink": false,
362
+ "use_prefill_window_in_sink_attention": false,
363
+ "window_length": null
364
+ },
365
+ "ffn": {
366
+ "ffn_mult": 5.25,
367
+ "no_op": false,
368
+ "replace_with_linear": false,
369
+ "sparsify": null
370
+ }
371
+ },
372
+ {
373
+ "attention": {
374
+ "n_heads_in_group": 8,
375
+ "no_op": false,
376
+ "num_sink_tokens": null,
377
+ "replace_with_linear": false,
378
+ "sparsify": null,
379
+ "unshifted_sink": false,
380
+ "use_prefill_window_in_sink_attention": false,
381
+ "window_length": null
382
+ },
383
+ "ffn": {
384
+ "ffn_mult": 5.25,
385
+ "no_op": false,
386
+ "replace_with_linear": false,
387
+ "sparsify": null
388
+ }
389
+ },
390
+ {
391
+ "attention": {
392
+ "n_heads_in_group": 8,
393
+ "no_op": false,
394
+ "num_sink_tokens": null,
395
+ "replace_with_linear": false,
396
+ "sparsify": null,
397
+ "unshifted_sink": false,
398
+ "use_prefill_window_in_sink_attention": false,
399
+ "window_length": null
400
+ },
401
+ "ffn": {
402
+ "ffn_mult": 5.25,
403
+ "no_op": false,
404
+ "replace_with_linear": false,
405
+ "sparsify": null
406
+ }
407
+ },
408
+ {
409
+ "attention": {
410
+ "n_heads_in_group": 8,
411
+ "no_op": false,
412
+ "num_sink_tokens": null,
413
+ "replace_with_linear": false,
414
+ "sparsify": null,
415
+ "unshifted_sink": false,
416
+ "use_prefill_window_in_sink_attention": false,
417
+ "window_length": null
418
+ },
419
+ "ffn": {
420
+ "ffn_mult": 5.25,
421
+ "no_op": false,
422
+ "replace_with_linear": false,
423
+ "sparsify": null
424
+ }
425
+ },
426
+ {
427
+ "attention": {
428
+ "n_heads_in_group": 8,
429
+ "no_op": false,
430
+ "num_sink_tokens": null,
431
+ "replace_with_linear": false,
432
+ "sparsify": null,
433
+ "unshifted_sink": false,
434
+ "use_prefill_window_in_sink_attention": false,
435
+ "window_length": null
436
+ },
437
+ "ffn": {
438
+ "ffn_mult": 5.25,
439
+ "no_op": false,
440
+ "replace_with_linear": false,
441
+ "sparsify": null
442
+ }
443
+ },
444
+ {
445
+ "attention": {
446
+ "n_heads_in_group": 8,
447
+ "no_op": false,
448
+ "num_sink_tokens": null,
449
+ "replace_with_linear": false,
450
+ "sparsify": null,
451
+ "unshifted_sink": false,
452
+ "use_prefill_window_in_sink_attention": false,
453
+ "window_length": null
454
+ },
455
+ "ffn": {
456
+ "ffn_mult": 5.25,
457
+ "no_op": false,
458
+ "replace_with_linear": false,
459
+ "sparsify": null
460
+ }
461
+ },
462
+ {
463
+ "attention": {
464
+ "n_heads_in_group": 8,
465
+ "no_op": false,
466
+ "num_sink_tokens": null,
467
+ "replace_with_linear": false,
468
+ "sparsify": null,
469
+ "unshifted_sink": false,
470
+ "use_prefill_window_in_sink_attention": false,
471
+ "window_length": null
472
+ },
473
+ "ffn": {
474
+ "ffn_mult": 5.25,
475
+ "no_op": false,
476
+ "replace_with_linear": false,
477
+ "sparsify": null
478
+ }
479
+ },
480
+ {
481
+ "attention": {
482
+ "n_heads_in_group": 8,
483
+ "no_op": false,
484
+ "num_sink_tokens": null,
485
+ "replace_with_linear": false,
486
+ "sparsify": null,
487
+ "unshifted_sink": false,
488
+ "use_prefill_window_in_sink_attention": false,
489
+ "window_length": null
490
+ },
491
+ "ffn": {
492
+ "ffn_mult": 5.25,
493
+ "no_op": false,
494
+ "replace_with_linear": false,
495
+ "sparsify": null
496
+ }
497
+ },
498
+ {
499
+ "attention": {
500
+ "n_heads_in_group": 8,
501
+ "no_op": false,
502
+ "num_sink_tokens": null,
503
+ "replace_with_linear": false,
504
+ "sparsify": null,
505
+ "unshifted_sink": false,
506
+ "use_prefill_window_in_sink_attention": false,
507
+ "window_length": null
508
+ },
509
+ "ffn": {
510
+ "ffn_mult": 5.25,
511
+ "no_op": false,
512
+ "replace_with_linear": false,
513
+ "sparsify": null
514
+ }
515
+ },
516
+ {
517
+ "attention": {
518
+ "n_heads_in_group": 8,
519
+ "no_op": false,
520
+ "num_sink_tokens": null,
521
+ "replace_with_linear": false,
522
+ "sparsify": null,
523
+ "unshifted_sink": false,
524
+ "use_prefill_window_in_sink_attention": false,
525
+ "window_length": null
526
+ },
527
+ "ffn": {
528
+ "ffn_mult": 5.25,
529
+ "no_op": false,
530
+ "replace_with_linear": false,
531
+ "sparsify": null
532
+ }
533
+ },
534
+ {
535
+ "attention": {
536
+ "n_heads_in_group": 8,
537
+ "no_op": false,
538
+ "num_sink_tokens": null,
539
+ "replace_with_linear": false,
540
+ "sparsify": null,
541
+ "unshifted_sink": false,
542
+ "use_prefill_window_in_sink_attention": false,
543
+ "window_length": null
544
+ },
545
+ "ffn": {
546
+ "ffn_mult": 5.25,
547
+ "no_op": false,
548
+ "replace_with_linear": false,
549
+ "sparsify": null
550
+ }
551
+ },
552
+ {
553
+ "attention": {
554
+ "n_heads_in_group": 8,
555
+ "no_op": false,
556
+ "num_sink_tokens": null,
557
+ "replace_with_linear": false,
558
+ "sparsify": null,
559
+ "unshifted_sink": false,
560
+ "use_prefill_window_in_sink_attention": false,
561
+ "window_length": null
562
+ },
563
+ "ffn": {
564
+ "ffn_mult": 5.25,
565
+ "no_op": false,
566
+ "replace_with_linear": false,
567
+ "sparsify": null
568
+ }
569
+ },
570
+ {
571
+ "attention": {
572
+ "n_heads_in_group": 8,
573
+ "no_op": false,
574
+ "num_sink_tokens": null,
575
+ "replace_with_linear": false,
576
+ "sparsify": null,
577
+ "unshifted_sink": false,
578
+ "use_prefill_window_in_sink_attention": false,
579
+ "window_length": null
580
+ },
581
+ "ffn": {
582
+ "ffn_mult": 5.25,
583
+ "no_op": false,
584
+ "replace_with_linear": false,
585
+ "sparsify": null
586
+ }
587
+ },
588
+ {
589
+ "attention": {
590
+ "n_heads_in_group": 8,
591
+ "no_op": false,
592
+ "num_sink_tokens": null,
593
+ "replace_with_linear": false,
594
+ "sparsify": null,
595
+ "unshifted_sink": false,
596
+ "use_prefill_window_in_sink_attention": false,
597
+ "window_length": null
598
+ },
599
+ "ffn": {
600
+ "ffn_mult": 5.25,
601
+ "no_op": false,
602
+ "replace_with_linear": false,
603
+ "sparsify": null
604
+ }
605
+ },
606
+ {
607
+ "attention": {
608
+ "n_heads_in_group": 8,
609
+ "no_op": false,
610
+ "num_sink_tokens": null,
611
+ "replace_with_linear": false,
612
+ "sparsify": null,
613
+ "unshifted_sink": false,
614
+ "use_prefill_window_in_sink_attention": false,
615
+ "window_length": null
616
+ },
617
+ "ffn": {
618
+ "ffn_mult": 5.25,
619
+ "no_op": false,
620
+ "replace_with_linear": false,
621
+ "sparsify": null
622
+ }
623
+ },
624
+ {
625
+ "attention": {
626
+ "n_heads_in_group": 8,
627
+ "no_op": false,
628
+ "num_sink_tokens": null,
629
+ "replace_with_linear": false,
630
+ "sparsify": null,
631
+ "unshifted_sink": false,
632
+ "use_prefill_window_in_sink_attention": false,
633
+ "window_length": null
634
+ },
635
+ "ffn": {
636
+ "ffn_mult": 5.25,
637
+ "no_op": false,
638
+ "replace_with_linear": false,
639
+ "sparsify": null
640
+ }
641
+ },
642
+ {
643
+ "attention": {
644
+ "n_heads_in_group": 8,
645
+ "no_op": false,
646
+ "num_sink_tokens": null,
647
+ "replace_with_linear": false,
648
+ "sparsify": null,
649
+ "unshifted_sink": false,
650
+ "use_prefill_window_in_sink_attention": false,
651
+ "window_length": null
652
+ },
653
+ "ffn": {
654
+ "ffn_mult": 5.25,
655
+ "no_op": false,
656
+ "replace_with_linear": false,
657
+ "sparsify": null
658
+ }
659
+ },
660
+ {
661
+ "attention": {
662
+ "n_heads_in_group": 8,
663
+ "no_op": false,
664
+ "num_sink_tokens": null,
665
+ "replace_with_linear": false,
666
+ "sparsify": null,
667
+ "unshifted_sink": false,
668
+ "use_prefill_window_in_sink_attention": false,
669
+ "window_length": null
670
+ },
671
+ "ffn": {
672
+ "ffn_mult": 5.25,
673
+ "no_op": false,
674
+ "replace_with_linear": false,
675
+ "sparsify": null
676
+ }
677
+ },
678
+ {
679
+ "attention": {
680
+ "n_heads_in_group": 8,
681
+ "no_op": false,
682
+ "num_sink_tokens": null,
683
+ "replace_with_linear": false,
684
+ "sparsify": null,
685
+ "unshifted_sink": false,
686
+ "use_prefill_window_in_sink_attention": false,
687
+ "window_length": null
688
+ },
689
+ "ffn": {
690
+ "ffn_mult": 5.25,
691
+ "no_op": false,
692
+ "replace_with_linear": false,
693
+ "sparsify": null
694
+ }
695
+ },
696
+ {
697
+ "attention": {
698
+ "n_heads_in_group": 8,
699
+ "no_op": false,
700
+ "num_sink_tokens": null,
701
+ "replace_with_linear": false,
702
+ "sparsify": null,
703
+ "unshifted_sink": false,
704
+ "use_prefill_window_in_sink_attention": false,
705
+ "window_length": null
706
+ },
707
+ "ffn": {
708
+ "ffn_mult": 5.25,
709
+ "no_op": false,
710
+ "replace_with_linear": false,
711
+ "sparsify": null
712
+ }
713
+ },
714
+ {
715
+ "attention": {
716
+ "n_heads_in_group": 8,
717
+ "no_op": false,
718
+ "num_sink_tokens": null,
719
+ "replace_with_linear": false,
720
+ "sparsify": null,
721
+ "unshifted_sink": false,
722
+ "use_prefill_window_in_sink_attention": false,
723
+ "window_length": null
724
+ },
725
+ "ffn": {
726
+ "ffn_mult": 5.25,
727
+ "no_op": false,
728
+ "replace_with_linear": false,
729
+ "sparsify": null
730
+ }
731
+ },
732
+ {
733
+ "attention": {
734
+ "n_heads_in_group": 8,
735
+ "no_op": false,
736
+ "num_sink_tokens": null,
737
+ "replace_with_linear": false,
738
+ "sparsify": null,
739
+ "unshifted_sink": false,
740
+ "use_prefill_window_in_sink_attention": false,
741
+ "window_length": null
742
+ },
743
+ "ffn": {
744
+ "ffn_mult": 5.25,
745
+ "no_op": false,
746
+ "replace_with_linear": false,
747
+ "sparsify": null
748
+ }
749
+ },
750
+ {
751
+ "attention": {
752
+ "n_heads_in_group": 8,
753
+ "no_op": false,
754
+ "num_sink_tokens": null,
755
+ "replace_with_linear": false,
756
+ "sparsify": null,
757
+ "unshifted_sink": false,
758
+ "use_prefill_window_in_sink_attention": false,
759
+ "window_length": null
760
+ },
761
+ "ffn": {
762
+ "ffn_mult": 5.25,
763
+ "no_op": false,
764
+ "replace_with_linear": false,
765
+ "sparsify": null
766
+ }
767
+ },
768
+ {
769
+ "attention": {
770
+ "n_heads_in_group": null,
771
+ "no_op": true,
772
+ "num_sink_tokens": null,
773
+ "replace_with_linear": false,
774
+ "sparsify": null,
775
+ "unshifted_sink": false,
776
+ "use_prefill_window_in_sink_attention": false,
777
+ "window_length": null
778
+ },
779
+ "ffn": {
780
+ "ffn_mult": 1.3125,
781
+ "no_op": false,
782
+ "replace_with_linear": false,
783
+ "sparsify": null
784
+ }
785
+ },
786
+ {
787
+ "attention": {
788
+ "n_heads_in_group": null,
789
+ "no_op": true,
790
+ "num_sink_tokens": null,
791
+ "replace_with_linear": false,
792
+ "sparsify": null,
793
+ "unshifted_sink": false,
794
+ "use_prefill_window_in_sink_attention": false,
795
+ "window_length": null
796
+ },
797
+ "ffn": {
798
+ "ffn_mult": 2.625,
799
+ "no_op": false,
800
+ "replace_with_linear": false,
801
+ "sparsify": null
802
+ }
803
+ },
804
+ {
805
+ "attention": {
806
+ "n_heads_in_group": null,
807
+ "no_op": true,
808
+ "num_sink_tokens": null,
809
+ "replace_with_linear": false,
810
+ "sparsify": null,
811
+ "unshifted_sink": false,
812
+ "use_prefill_window_in_sink_attention": false,
813
+ "window_length": null
814
+ },
815
+ "ffn": {
816
+ "ffn_mult": 2.625,
817
+ "no_op": false,
818
+ "replace_with_linear": false,
819
+ "sparsify": null
820
+ }
821
+ },
822
+ {
823
+ "attention": {
824
+ "n_heads_in_group": null,
825
+ "no_op": true,
826
+ "num_sink_tokens": null,
827
+ "replace_with_linear": false,
828
+ "sparsify": null,
829
+ "unshifted_sink": false,
830
+ "use_prefill_window_in_sink_attention": false,
831
+ "window_length": null
832
+ },
833
+ "ffn": {
834
+ "ffn_mult": 1.3125,
835
+ "no_op": false,
836
+ "replace_with_linear": false,
837
+ "sparsify": null
838
+ }
839
+ },
840
+ {
841
+ "attention": {
842
+ "n_heads_in_group": null,
843
+ "no_op": true,
844
+ "num_sink_tokens": null,
845
+ "replace_with_linear": false,
846
+ "sparsify": null,
847
+ "unshifted_sink": false,
848
+ "use_prefill_window_in_sink_attention": false,
849
+ "window_length": null
850
+ },
851
+ "ffn": {
852
+ "ffn_mult": 5.25,
853
+ "no_op": false,
854
+ "replace_with_linear": false,
855
+ "sparsify": null
856
+ }
857
+ },
858
+ {
859
+ "attention": {
860
+ "n_heads_in_group": null,
861
+ "no_op": true,
862
+ "num_sink_tokens": null,
863
+ "replace_with_linear": false,
864
+ "sparsify": null,
865
+ "unshifted_sink": false,
866
+ "use_prefill_window_in_sink_attention": false,
867
+ "window_length": null
868
+ },
869
+ "ffn": {
870
+ "ffn_mult": 1.3125,
871
+ "no_op": false,
872
+ "replace_with_linear": false,
873
+ "sparsify": null
874
+ }
875
+ },
876
+ {
877
+ "attention": {
878
+ "n_heads_in_group": null,
879
+ "no_op": true,
880
+ "num_sink_tokens": null,
881
+ "replace_with_linear": false,
882
+ "sparsify": null,
883
+ "unshifted_sink": false,
884
+ "use_prefill_window_in_sink_attention": false,
885
+ "window_length": null
886
+ },
887
+ "ffn": {
888
+ "ffn_mult": 2.625,
889
+ "no_op": false,
890
+ "replace_with_linear": false,
891
+ "sparsify": null
892
+ }
893
+ },
894
+ {
895
+ "attention": {
896
+ "n_heads_in_group": null,
897
+ "no_op": true,
898
+ "num_sink_tokens": null,
899
+ "replace_with_linear": false,
900
+ "sparsify": null,
901
+ "unshifted_sink": false,
902
+ "use_prefill_window_in_sink_attention": false,
903
+ "window_length": null
904
+ },
905
+ "ffn": {
906
+ "ffn_mult": 1.3125,
907
+ "no_op": false,
908
+ "replace_with_linear": false,
909
+ "sparsify": null
910
+ }
911
+ },
912
+ {
913
+ "attention": {
914
+ "n_heads_in_group": null,
915
+ "no_op": true,
916
+ "num_sink_tokens": null,
917
+ "replace_with_linear": false,
918
+ "sparsify": null,
919
+ "unshifted_sink": false,
920
+ "use_prefill_window_in_sink_attention": false,
921
+ "window_length": null
922
+ },
923
+ "ffn": {
924
+ "ffn_mult": 1.3125,
925
+ "no_op": false,
926
+ "replace_with_linear": false,
927
+ "sparsify": null
928
+ }
929
+ },
930
+ {
931
+ "attention": {
932
+ "n_heads_in_group": null,
933
+ "no_op": true,
934
+ "num_sink_tokens": null,
935
+ "replace_with_linear": false,
936
+ "sparsify": null,
937
+ "unshifted_sink": false,
938
+ "use_prefill_window_in_sink_attention": false,
939
+ "window_length": null
940
+ },
941
+ "ffn": {
942
+ "ffn_mult": 1.3125,
943
+ "no_op": false,
944
+ "replace_with_linear": false,
945
+ "sparsify": null
946
+ }
947
+ },
948
+ {
949
+ "attention": {
950
+ "n_heads_in_group": 8,
951
+ "no_op": false,
952
+ "num_sink_tokens": null,
953
+ "replace_with_linear": false,
954
+ "sparsify": null,
955
+ "unshifted_sink": false,
956
+ "use_prefill_window_in_sink_attention": false,
957
+ "window_length": null
958
+ },
959
+ "ffn": {
960
+ "ffn_mult": 5.25,
961
+ "no_op": false,
962
+ "replace_with_linear": false,
963
+ "sparsify": null
964
+ }
965
+ },
966
+ {
967
+ "attention": {
968
+ "n_heads_in_group": null,
969
+ "no_op": true,
970
+ "num_sink_tokens": null,
971
+ "replace_with_linear": false,
972
+ "sparsify": null,
973
+ "unshifted_sink": false,
974
+ "use_prefill_window_in_sink_attention": false,
975
+ "window_length": null
976
+ },
977
+ "ffn": {
978
+ "ffn_mult": 1.3125,
979
+ "no_op": false,
980
+ "replace_with_linear": false,
981
+ "sparsify": null
982
+ }
983
+ },
984
+ {
985
+ "attention": {
986
+ "n_heads_in_group": null,
987
+ "no_op": true,
988
+ "num_sink_tokens": null,
989
+ "replace_with_linear": false,
990
+ "sparsify": null,
991
+ "unshifted_sink": false,
992
+ "use_prefill_window_in_sink_attention": false,
993
+ "window_length": null
994
+ },
995
+ "ffn": {
996
+ "ffn_mult": 1.0,
997
+ "no_op": false,
998
+ "replace_with_linear": false,
999
+ "sparsify": null
1000
+ }
1001
+ },
1002
+ {
1003
+ "attention": {
1004
+ "n_heads_in_group": null,
1005
+ "no_op": true,
1006
+ "num_sink_tokens": null,
1007
+ "replace_with_linear": false,
1008
+ "sparsify": null,
1009
+ "unshifted_sink": false,
1010
+ "use_prefill_window_in_sink_attention": false,
1011
+ "window_length": null
1012
+ },
1013
+ "ffn": {
1014
+ "ffn_mult": 1.0,
1015
+ "no_op": false,
1016
+ "replace_with_linear": false,
1017
+ "sparsify": null
1018
+ }
1019
+ },
1020
+ {
1021
+ "attention": {
1022
+ "n_heads_in_group": null,
1023
+ "no_op": true,
1024
+ "num_sink_tokens": null,
1025
+ "replace_with_linear": false,
1026
+ "sparsify": null,
1027
+ "unshifted_sink": false,
1028
+ "use_prefill_window_in_sink_attention": false,
1029
+ "window_length": null
1030
+ },
1031
+ "ffn": {
1032
+ "ffn_mult": 1.3125,
1033
+ "no_op": false,
1034
+ "replace_with_linear": false,
1035
+ "sparsify": null
1036
+ }
1037
+ },
1038
+ {
1039
+ "attention": {
1040
+ "n_heads_in_group": null,
1041
+ "no_op": true,
1042
+ "num_sink_tokens": null,
1043
+ "replace_with_linear": false,
1044
+ "sparsify": null,
1045
+ "unshifted_sink": false,
1046
+ "use_prefill_window_in_sink_attention": false,
1047
+ "window_length": null
1048
+ },
1049
+ "ffn": {
1050
+ "ffn_mult": 1.0,
1051
+ "no_op": false,
1052
+ "replace_with_linear": false,
1053
+ "sparsify": null
1054
+ }
1055
+ },
1056
+ {
1057
+ "attention": {
1058
+ "n_heads_in_group": null,
1059
+ "no_op": true,
1060
+ "num_sink_tokens": null,
1061
+ "replace_with_linear": false,
1062
+ "sparsify": null,
1063
+ "unshifted_sink": false,
1064
+ "use_prefill_window_in_sink_attention": false,
1065
+ "window_length": null
1066
+ },
1067
+ "ffn": {
1068
+ "ffn_mult": 1.0,
1069
+ "no_op": false,
1070
+ "replace_with_linear": false,
1071
+ "sparsify": null
1072
+ }
1073
+ },
1074
+ {
1075
+ "attention": {
1076
+ "n_heads_in_group": null,
1077
+ "no_op": true,
1078
+ "num_sink_tokens": null,
1079
+ "replace_with_linear": false,
1080
+ "sparsify": null,
1081
+ "unshifted_sink": false,
1082
+ "use_prefill_window_in_sink_attention": false,
1083
+ "window_length": null
1084
+ },
1085
+ "ffn": {
1086
+ "ffn_mult": 1.0,
1087
+ "no_op": false,
1088
+ "replace_with_linear": false,
1089
+ "sparsify": null
1090
+ }
1091
+ },
1092
+ {
1093
+ "attention": {
1094
+ "n_heads_in_group": null,
1095
+ "no_op": true,
1096
+ "num_sink_tokens": null,
1097
+ "replace_with_linear": false,
1098
+ "sparsify": null,
1099
+ "unshifted_sink": false,
1100
+ "use_prefill_window_in_sink_attention": false,
1101
+ "window_length": null
1102
+ },
1103
+ "ffn": {
1104
+ "ffn_mult": 1.3125,
1105
+ "no_op": false,
1106
+ "replace_with_linear": false,
1107
+ "sparsify": null
1108
+ }
1109
+ },
1110
+ {
1111
+ "attention": {
1112
+ "n_heads_in_group": null,
1113
+ "no_op": true,
1114
+ "num_sink_tokens": null,
1115
+ "replace_with_linear": false,
1116
+ "sparsify": null,
1117
+ "unshifted_sink": false,
1118
+ "use_prefill_window_in_sink_attention": false,
1119
+ "window_length": null
1120
+ },
1121
+ "ffn": {
1122
+ "ffn_mult": 1.3125,
1123
+ "no_op": false,
1124
+ "replace_with_linear": false,
1125
+ "sparsify": null
1126
+ }
1127
+ },
1128
+ {
1129
+ "attention": {
1130
+ "n_heads_in_group": null,
1131
+ "no_op": true,
1132
+ "num_sink_tokens": null,
1133
+ "replace_with_linear": false,
1134
+ "sparsify": null,
1135
+ "unshifted_sink": false,
1136
+ "use_prefill_window_in_sink_attention": false,
1137
+ "window_length": null
1138
+ },
1139
+ "ffn": {
1140
+ "ffn_mult": 0.5,
1141
+ "no_op": false,
1142
+ "replace_with_linear": false,
1143
+ "sparsify": null
1144
+ }
1145
+ },
1146
+ {
1147
+ "attention": {
1148
+ "n_heads_in_group": null,
1149
+ "no_op": true,
1150
+ "num_sink_tokens": null,
1151
+ "replace_with_linear": false,
1152
+ "sparsify": null,
1153
+ "unshifted_sink": false,
1154
+ "use_prefill_window_in_sink_attention": false,
1155
+ "window_length": null
1156
+ },
1157
+ "ffn": {
1158
+ "ffn_mult": 0.5,
1159
+ "no_op": false,
1160
+ "replace_with_linear": false,
1161
+ "sparsify": null
1162
+ }
1163
+ },
1164
+ {
1165
+ "attention": {
1166
+ "n_heads_in_group": null,
1167
+ "no_op": true,
1168
+ "num_sink_tokens": null,
1169
+ "replace_with_linear": false,
1170
+ "sparsify": null,
1171
+ "unshifted_sink": false,
1172
+ "use_prefill_window_in_sink_attention": false,
1173
+ "window_length": null
1174
+ },
1175
+ "ffn": {
1176
+ "ffn_mult": 1.0,
1177
+ "no_op": false,
1178
+ "replace_with_linear": false,
1179
+ "sparsify": null
1180
+ }
1181
+ },
1182
+ {
1183
+ "attention": {
1184
+ "n_heads_in_group": null,
1185
+ "no_op": true,
1186
+ "num_sink_tokens": null,
1187
+ "replace_with_linear": false,
1188
+ "sparsify": null,
1189
+ "unshifted_sink": false,
1190
+ "use_prefill_window_in_sink_attention": false,
1191
+ "window_length": null
1192
+ },
1193
+ "ffn": {
1194
+ "ffn_mult": 1.0,
1195
+ "no_op": false,
1196
+ "replace_with_linear": false,
1197
+ "sparsify": null
1198
+ }
1199
+ },
1200
+ {
1201
+ "attention": {
1202
+ "n_heads_in_group": null,
1203
+ "no_op": true,
1204
+ "num_sink_tokens": null,
1205
+ "replace_with_linear": false,
1206
+ "sparsify": null,
1207
+ "unshifted_sink": false,
1208
+ "use_prefill_window_in_sink_attention": false,
1209
+ "window_length": null
1210
+ },
1211
+ "ffn": {
1212
+ "ffn_mult": 0.5,
1213
+ "no_op": false,
1214
+ "replace_with_linear": false,
1215
+ "sparsify": null
1216
+ }
1217
+ },
1218
+ {
1219
+ "attention": {
1220
+ "n_heads_in_group": null,
1221
+ "no_op": true,
1222
+ "num_sink_tokens": null,
1223
+ "replace_with_linear": false,
1224
+ "sparsify": null,
1225
+ "unshifted_sink": false,
1226
+ "use_prefill_window_in_sink_attention": false,
1227
+ "window_length": null
1228
+ },
1229
+ "ffn": {
1230
+ "ffn_mult": 0.5,
1231
+ "no_op": false,
1232
+ "replace_with_linear": false,
1233
+ "sparsify": null
1234
+ }
1235
+ },
1236
+ {
1237
+ "attention": {
1238
+ "n_heads_in_group": null,
1239
+ "no_op": true,
1240
+ "num_sink_tokens": null,
1241
+ "replace_with_linear": false,
1242
+ "sparsify": null,
1243
+ "unshifted_sink": false,
1244
+ "use_prefill_window_in_sink_attention": false,
1245
+ "window_length": null
1246
+ },
1247
+ "ffn": {
1248
+ "ffn_mult": 1.0,
1249
+ "no_op": false,
1250
+ "replace_with_linear": false,
1251
+ "sparsify": null
1252
+ }
1253
+ },
1254
+ {
1255
+ "attention": {
1256
+ "n_heads_in_group": null,
1257
+ "no_op": true,
1258
+ "num_sink_tokens": null,
1259
+ "replace_with_linear": false,
1260
+ "sparsify": null,
1261
+ "unshifted_sink": false,
1262
+ "use_prefill_window_in_sink_attention": false,
1263
+ "window_length": null
1264
+ },
1265
+ "ffn": {
1266
+ "ffn_mult": 0.5,
1267
+ "no_op": false,
1268
+ "replace_with_linear": false,
1269
+ "sparsify": null
1270
+ }
1271
+ },
1272
+ {
1273
+ "attention": {
1274
+ "n_heads_in_group": null,
1275
+ "no_op": true,
1276
+ "num_sink_tokens": null,
1277
+ "replace_with_linear": false,
1278
+ "sparsify": null,
1279
+ "unshifted_sink": false,
1280
+ "use_prefill_window_in_sink_attention": false,
1281
+ "window_length": null
1282
+ },
1283
+ "ffn": {
1284
+ "ffn_mult": 0.5,
1285
+ "no_op": false,
1286
+ "replace_with_linear": false,
1287
+ "sparsify": null
1288
+ }
1289
+ },
1290
+ {
1291
+ "attention": {
1292
+ "n_heads_in_group": 8,
1293
+ "no_op": false,
1294
+ "num_sink_tokens": null,
1295
+ "replace_with_linear": false,
1296
+ "sparsify": null,
1297
+ "unshifted_sink": false,
1298
+ "use_prefill_window_in_sink_attention": false,
1299
+ "window_length": null
1300
+ },
1301
+ "ffn": {
1302
+ "ffn_mult": 5.25,
1303
+ "no_op": false,
1304
+ "replace_with_linear": false,
1305
+ "sparsify": null
1306
+ }
1307
+ },
1308
+ {
1309
+ "attention": {
1310
+ "n_heads_in_group": 8,
1311
+ "no_op": false,
1312
+ "num_sink_tokens": null,
1313
+ "replace_with_linear": false,
1314
+ "sparsify": null,
1315
+ "unshifted_sink": false,
1316
+ "use_prefill_window_in_sink_attention": false,
1317
+ "window_length": null
1318
+ },
1319
+ "ffn": {
1320
+ "ffn_mult": 5.25,
1321
+ "no_op": false,
1322
+ "replace_with_linear": false,
1323
+ "sparsify": null
1324
+ }
1325
+ },
1326
+ {
1327
+ "attention": {
1328
+ "n_heads_in_group": 8,
1329
+ "no_op": false,
1330
+ "num_sink_tokens": null,
1331
+ "replace_with_linear": false,
1332
+ "sparsify": null,
1333
+ "unshifted_sink": false,
1334
+ "use_prefill_window_in_sink_attention": false,
1335
+ "window_length": null
1336
+ },
1337
+ "ffn": {
1338
+ "ffn_mult": 5.25,
1339
+ "no_op": false,
1340
+ "replace_with_linear": false,
1341
+ "sparsify": null
1342
+ }
1343
+ },
1344
+ {
1345
+ "attention": {
1346
+ "n_heads_in_group": 8,
1347
+ "no_op": false,
1348
+ "num_sink_tokens": null,
1349
+ "replace_with_linear": false,
1350
+ "sparsify": null,
1351
+ "unshifted_sink": false,
1352
+ "use_prefill_window_in_sink_attention": false,
1353
+ "window_length": null
1354
+ },
1355
+ "ffn": {
1356
+ "ffn_mult": 5.25,
1357
+ "no_op": false,
1358
+ "replace_with_linear": false,
1359
+ "sparsify": null
1360
+ }
1361
+ },
1362
+ {
1363
+ "attention": {
1364
+ "n_heads_in_group": 8,
1365
+ "no_op": false,
1366
+ "num_sink_tokens": null,
1367
+ "replace_with_linear": false,
1368
+ "sparsify": null,
1369
+ "unshifted_sink": false,
1370
+ "use_prefill_window_in_sink_attention": false,
1371
+ "window_length": null
1372
+ },
1373
+ "ffn": {
1374
+ "ffn_mult": 5.25,
1375
+ "no_op": false,
1376
+ "replace_with_linear": false,
1377
+ "sparsify": null
1378
+ }
1379
+ },
1380
+ {
1381
+ "attention": {
1382
+ "n_heads_in_group": 8,
1383
+ "no_op": false,
1384
+ "num_sink_tokens": null,
1385
+ "replace_with_linear": false,
1386
+ "sparsify": null,
1387
+ "unshifted_sink": false,
1388
+ "use_prefill_window_in_sink_attention": false,
1389
+ "window_length": null
1390
+ },
1391
+ "ffn": {
1392
+ "ffn_mult": 5.25,
1393
+ "no_op": false,
1394
+ "replace_with_linear": false,
1395
+ "sparsify": null
1396
+ }
1397
+ },
1398
+ {
1399
+ "attention": {
1400
+ "n_heads_in_group": 8,
1401
+ "no_op": false,
1402
+ "num_sink_tokens": null,
1403
+ "replace_with_linear": false,
1404
+ "sparsify": null,
1405
+ "unshifted_sink": false,
1406
+ "use_prefill_window_in_sink_attention": false,
1407
+ "window_length": null
1408
+ },
1409
+ "ffn": {
1410
+ "ffn_mult": 5.25,
1411
+ "no_op": false,
1412
+ "replace_with_linear": false,
1413
+ "sparsify": null
1414
+ }
1415
+ },
1416
+ {
1417
+ "attention": {
1418
+ "n_heads_in_group": 8,
1419
+ "no_op": false,
1420
+ "num_sink_tokens": null,
1421
+ "replace_with_linear": false,
1422
+ "sparsify": null,
1423
+ "unshifted_sink": false,
1424
+ "use_prefill_window_in_sink_attention": false,
1425
+ "window_length": null
1426
+ },
1427
+ "ffn": {
1428
+ "ffn_mult": 5.25,
1429
+ "no_op": false,
1430
+ "replace_with_linear": false,
1431
+ "sparsify": null
1432
+ }
1433
+ },
1434
+ {
1435
+ "attention": {
1436
+ "n_heads_in_group": 8,
1437
+ "no_op": false,
1438
+ "num_sink_tokens": null,
1439
+ "replace_with_linear": false,
1440
+ "sparsify": null,
1441
+ "unshifted_sink": false,
1442
+ "use_prefill_window_in_sink_attention": false,
1443
+ "window_length": null
1444
+ },
1445
+ "ffn": {
1446
+ "ffn_mult": 5.25,
1447
+ "no_op": false,
1448
+ "replace_with_linear": false,
1449
+ "sparsify": null
1450
+ }
1451
+ }
1452
+ ],
1453
+ "bos_token_id": 128000,
1454
+ "eos_token_id": [
1455
+ 128001,
1456
+ 128008,
1457
+ 128009
1458
+ ],
1459
+ "hidden_act": "silu",
1460
+ "hidden_size": 8192,
1461
+ "initializer_range": 0.02,
1462
+ "intermediate_size": null,
1463
+ "max_position_embeddings": 131072,
1464
+ "mlp_bias": false,
1465
+ "model_type": "nemotron-nas",
1466
+ "num_attention_heads": 64,
1467
+ "num_hidden_layers": 80,
1468
+ "num_key_value_heads": null,
1469
+ "pad_token_id": 128004,
1470
+ "pretraining_tp": 1,
1471
+ "rms_norm_eps": 1e-05,
1472
+ "rope_scaling": {
1473
+ "factor": 8.0,
1474
+ "high_freq_factor": 4.0,
1475
+ "low_freq_factor": 1.0,
1476
+ "original_max_position_embeddings": 8192,
1477
+ "rope_type": "llama3"
1478
+ },
1479
+ "rope_theta": 500000.0,
1480
+ "tie_word_embeddings": false,
1481
+ "torch_dtype": "bfloat16",
1482
+ "transformers_version": "4.52.2",
1483
+ "unsloth_fixed": true,
1484
+ "use_cache": true,
1485
+ "vocab_size": 128256
1486
+ }
configuration_decilm.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Nvidia Corporation. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import dataclasses
17
+ import warnings
18
+ from typing import Dict, Any
19
+
20
+ from transformers.utils import is_flash_attn_2_available
21
+
22
+ from .block_config import BlockConfig
23
+ from .transformers_4_44_2__configuration_llama import LlamaConfig
24
+ from .transformers_4_44_2__modeling_rope_utils import \
25
+ rope_config_validation # fake import to make AutoConfig infer the dependency
26
+
27
+ rope_config_validation # this line is here to make sure that auto-formatting doesn't remove the import
28
+
29
+
30
+ class DeciLMConfig(LlamaConfig):
31
+ model_type = "nemotron-nas"
32
+
33
+ def __init__(
34
+ self,
35
+ block_configs: list[dict] | list[BlockConfig] = None,
36
+ **kwargs,
37
+ ):
38
+ attn_implementation = kwargs.pop("attn_implementation", None)
39
+ if attn_implementation is None and is_flash_attn_2_available():
40
+ attn_implementation = "flash_attention_2"
41
+
42
+ if block_configs is not None:
43
+ if isinstance(block_configs[0], dict):
44
+ block_configs = [BlockConfig(**conf) for conf in block_configs]
45
+
46
+ using_unshifted_sink = any([block_config.attention.unshifted_sink for block_config in block_configs])
47
+ if using_unshifted_sink and attn_implementation != "eager":
48
+ warnings.warn("Forcing attn_implementation='eager' since some attention layers use unshifted sink")
49
+ attn_implementation = "eager"
50
+
51
+ super().__init__(attn_implementation=attn_implementation, **kwargs)
52
+
53
+ self.intermediate_size = None
54
+ self.num_key_value_heads = None
55
+
56
+ if block_configs is not None:
57
+ assert len(block_configs) == self.num_hidden_layers
58
+
59
+ self.block_configs: list[BlockConfig] = block_configs
60
+
61
+ def to_dict(self) -> Dict[str, Any]:
62
+ self_dict = super().to_dict()
63
+ if self.block_configs is not None:
64
+ self_dict["block_configs"] = [dataclasses.asdict(conf) for conf in self.block_configs]
65
+ return self_dict
explainability.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ | Field: | Response: |
2
+ | :---- | :---- |
3
+ | Intended Application(s) & Domain(s): | Text generation, reasoning, summarization, and question answering. |
4
+ | Model Type: | Text-to-text transformer |
5
+ | Intended Users: | This model is intended for developers, researchers, and customers building/utilizing LLMs, while balancing accuracy and efficiency. |
6
+ | Output: | Text String(s) |
7
+ | Describe how the model works: | Generates text by predicting the next word or token based on the context provided in the input sequence using multiple self-attention layers. |
8
+ | Technical Limitations: | The model was trained on data that contains toxic language, unsafe content, and societal biases originally crawled from the internet. Therefore, the model may amplify those biases and return toxic responses especially when prompted with toxic prompts. The model may generate answers that may be inaccurate, omit key information, or include irrelevant or redundant text producing socially unacceptable or undesirable text, even if the prompt itself does not include anything explicitly offensive. The model demonstrates weakness to alignment-breaking attacks. Users are advised to deploy language model guardrails alongside this model to prevent potentially harmful outputs. The Model may generate answers that are inaccurate, omit key information, or include irrelevant or redundant text. |
9
+ | Verified to have met prescribed quality standards? | Yes |
10
+ | Performance Metrics: | Accuracy, Throughput, and user-side throughput |
11
+ | Potential Known Risks: | The model was optimized explicitly for instruction following and as such is more susceptible to prompt injection and jailbreaking in various forms as a result of its instruction tuning. This means that the model should be paired with additional rails or system filtering to limit exposure to instructions from malicious sources \-- either directly or indirectly by retrieval (e.g. via visiting a website) \-- as they may yield outputs that can lead to harmful, system-level outcomes up to and including remote code execution in agentic systems when effective security controls including guardrails are not in place. The model was trained on data that contains toxic language and societal biases originally crawled from the internet. Therefore, the model may amplify those biases and return toxic responses especially when prompted with toxic prompts. The model may generate answers that may be inaccurate, omit key information, or include irrelevant or redundant text producing socially unacceptable or undesirable text, even if the prompt itself does not include anything explicitly offensive. |
12
+ | End User License Agreement: | Your use of this model is governed by the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/). Additional Information: [Llama 3.3 Community License Agreement](https://www.llama.com/llama3_3/license/). Built with Llama. |
flow.png ADDED

Git LFS Details

  • SHA256: a6d9325407851dabdab25700b536e94bb22b15635995d4525cd238285c266598
  • Pointer size: 131 Bytes
  • Size of remote file: 255 kB
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 128000,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 128001,
7
+ 128008,
8
+ 128009
9
+ ],
10
+ "transformers_version": "4.48.3"
11
+ }
model-00001-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c783cb78f9b73fbc08c58fcdd9cdbaab3cc18b17affd87ccfbcfd5a2189b666
3
+ size 4987112064
model-00002-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:143911f6fce7cd8cc408e8aaf921ae1ad04347dd8f2f2a04e093c1df5a667467
3
+ size 4966157048
model-00003-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6ef7c455e68110552172b88d7bf0e43e71349a3f01c9ec33bc5d930113e2747
3
+ size 4999712064
model-00004-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d638446a5873772604c03629e0d9d1d711373d93a39213931701335d80bd5671
3
+ size 4907436920
model-00005-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f08de9d7aa97857017fe206d09945a8fec200f3b9de261dbe6796c9ed7d0f33
3
+ size 4664167416
model-00006-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f663499b22c16867fb889860a61e042f0ebf3b8357f782a2a7b4d79d63ce968
3
+ size 4999695232
model-00007-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4201ccb2ccd023139d597db81aac7d722b9cefefbbdd3562beac80855e26c3fb
3
+ size 4966157072
model-00008-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ed8f95f5e6700a13c03ed041ea60757e88f4aa2eec0017e2d95a36681153c91
3
+ size 4664150920
model-00009-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f343ef42b5b014e063e1ba8edd825401a16bef45a8170f053dfb5f8109009e1
3
+ size 4664167416
model-00010-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a734e91c0f7e32fc178f7279c625cfed322b4db284c308200d62f417b58a823
3
+ size 4664167416
model-00011-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d86d69f314697e3fdd3ec642f55ec1967fac1d338161bc51deaef975ab7d40f
3
+ size 4999695232
model-00012-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:325bed8f23a5e5049263ad557f140bee0af9062543ff42e7578c381402c3379f
3
+ size 4966157072
model-00013-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:669830c7f7520244fdbadf3a0cc363f733fe05fd996ccbc2a1725359afc81a02
3
+ size 4664150920
model-00014-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:854985b5c7703bf468e389c83bf59287587c239224c1ef8d785b6937408a5030
3
+ size 4664167416
model-00015-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca12a0f7457f6c7e2f751a9179ca0a839dd564cd48a9350ff499dfa0815152a6
3
+ size 4764847328
model-00016-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7743450f190f588700ab60d77392bc1b7a3791542f181b8680d3fdf082987da7
3
+ size 4764847664
model-00017-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d05c90d8fbf8635139018b7f272d9f552c4617a0d1d65bcb3b1b0af3501d64e0
3
+ size 4924432792
model-00018-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77430254bcd97f064105260fd826faec547976e8c24582f64c29c2b778a5e0d2
3
+ size 4664150920
model-00019-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1da1833a6bee3e5239a878aff41dd0892d2ba353f2d0178c5a342b1d5657b5ba
3
+ size 4664167416
model-00020-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0654718f1a0e04010298cd01be24357b009f418d6767519c6a899e452ae69fe9
3
+ size 4664167416
model-00021-of-00021.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fb3509b81a348aee61ea67998e9ced46f2a7b8255ad17157f716d79d319dc5e
3
+ size 3510649400
model.safetensors.index.json ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 99734290432
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00021-of-00021.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00021.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00021.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00021.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00021.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00021.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00021.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00021.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00021.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00021.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00021.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00021.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00021.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00021.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00021.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00021.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00021.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00021.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00021.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00021.safetensors",
26
+ "model.layers.10.input_layernorm.weight": "model-00004-of-00021.safetensors",
27
+ "model.layers.10.mlp.down_proj.weight": "model-00004-of-00021.safetensors",
28
+ "model.layers.10.mlp.gate_proj.weight": "model-00004-of-00021.safetensors",
29
+ "model.layers.10.mlp.up_proj.weight": "model-00004-of-00021.safetensors",
30
+ "model.layers.10.post_attention_layernorm.weight": "model-00004-of-00021.safetensors",
31
+ "model.layers.10.self_attn.k_proj.weight": "model-00004-of-00021.safetensors",
32
+ "model.layers.10.self_attn.o_proj.weight": "model-00004-of-00021.safetensors",
33
+ "model.layers.10.self_attn.q_proj.weight": "model-00004-of-00021.safetensors",
34
+ "model.layers.10.self_attn.v_proj.weight": "model-00004-of-00021.safetensors",
35
+ "model.layers.11.mlp.down_proj.weight": "model-00004-of-00021.safetensors",
36
+ "model.layers.11.mlp.gate_proj.weight": "model-00004-of-00021.safetensors",
37
+ "model.layers.11.mlp.up_proj.weight": "model-00004-of-00021.safetensors",
38
+ "model.layers.11.post_attention_layernorm.weight": "model-00004-of-00021.safetensors",
39
+ "model.layers.12.input_layernorm.weight": "model-00004-of-00021.safetensors",
40
+ "model.layers.12.mlp.down_proj.weight": "model-00005-of-00021.safetensors",
41
+ "model.layers.12.mlp.gate_proj.weight": "model-00004-of-00021.safetensors",
42
+ "model.layers.12.mlp.up_proj.weight": "model-00005-of-00021.safetensors",
43
+ "model.layers.12.post_attention_layernorm.weight": "model-00004-of-00021.safetensors",
44
+ "model.layers.12.self_attn.k_proj.weight": "model-00004-of-00021.safetensors",
45
+ "model.layers.12.self_attn.o_proj.weight": "model-00004-of-00021.safetensors",
46
+ "model.layers.12.self_attn.q_proj.weight": "model-00004-of-00021.safetensors",
47
+ "model.layers.12.self_attn.v_proj.weight": "model-00004-of-00021.safetensors",
48
+ "model.layers.13.input_layernorm.weight": "model-00005-of-00021.safetensors",
49
+ "model.layers.13.mlp.down_proj.weight": "model-00005-of-00021.safetensors",
50
+ "model.layers.13.mlp.gate_proj.weight": "model-00005-of-00021.safetensors",
51
+ "model.layers.13.mlp.up_proj.weight": "model-00005-of-00021.safetensors",
52
+ "model.layers.13.post_attention_layernorm.weight": "model-00005-of-00021.safetensors",
53
+ "model.layers.13.self_attn.k_proj.weight": "model-00005-of-00021.safetensors",
54
+ "model.layers.13.self_attn.o_proj.weight": "model-00005-of-00021.safetensors",
55
+ "model.layers.13.self_attn.q_proj.weight": "model-00005-of-00021.safetensors",
56
+ "model.layers.13.self_attn.v_proj.weight": "model-00005-of-00021.safetensors",
57
+ "model.layers.14.input_layernorm.weight": "model-00005-of-00021.safetensors",
58
+ "model.layers.14.mlp.down_proj.weight": "model-00005-of-00021.safetensors",
59
+ "model.layers.14.mlp.gate_proj.weight": "model-00005-of-00021.safetensors",
60
+ "model.layers.14.mlp.up_proj.weight": "model-00005-of-00021.safetensors",
61
+ "model.layers.14.post_attention_layernorm.weight": "model-00005-of-00021.safetensors",
62
+ "model.layers.14.self_attn.k_proj.weight": "model-00005-of-00021.safetensors",
63
+ "model.layers.14.self_attn.o_proj.weight": "model-00005-of-00021.safetensors",
64
+ "model.layers.14.self_attn.q_proj.weight": "model-00005-of-00021.safetensors",
65
+ "model.layers.14.self_attn.v_proj.weight": "model-00005-of-00021.safetensors",
66
+ "model.layers.15.input_layernorm.weight": "model-00005-of-00021.safetensors",
67
+ "model.layers.15.mlp.down_proj.weight": "model-00006-of-00021.safetensors",
68
+ "model.layers.15.mlp.gate_proj.weight": "model-00006-of-00021.safetensors",
69
+ "model.layers.15.mlp.up_proj.weight": "model-00006-of-00021.safetensors",
70
+ "model.layers.15.post_attention_layernorm.weight": "model-00005-of-00021.safetensors",
71
+ "model.layers.15.self_attn.k_proj.weight": "model-00005-of-00021.safetensors",
72
+ "model.layers.15.self_attn.o_proj.weight": "model-00005-of-00021.safetensors",
73
+ "model.layers.15.self_attn.q_proj.weight": "model-00005-of-00021.safetensors",
74
+ "model.layers.15.self_attn.v_proj.weight": "model-00005-of-00021.safetensors",
75
+ "model.layers.16.input_layernorm.weight": "model-00006-of-00021.safetensors",
76
+ "model.layers.16.mlp.down_proj.weight": "model-00006-of-00021.safetensors",
77
+ "model.layers.16.mlp.gate_proj.weight": "model-00006-of-00021.safetensors",
78
+ "model.layers.16.mlp.up_proj.weight": "model-00006-of-00021.safetensors",
79
+ "model.layers.16.post_attention_layernorm.weight": "model-00006-of-00021.safetensors",
80
+ "model.layers.16.self_attn.k_proj.weight": "model-00006-of-00021.safetensors",
81
+ "model.layers.16.self_attn.o_proj.weight": "model-00006-of-00021.safetensors",
82
+ "model.layers.16.self_attn.q_proj.weight": "model-00006-of-00021.safetensors",
83
+ "model.layers.16.self_attn.v_proj.weight": "model-00006-of-00021.safetensors",
84
+ "model.layers.17.input_layernorm.weight": "model-00006-of-00021.safetensors",
85
+ "model.layers.17.mlp.down_proj.weight": "model-00006-of-00021.safetensors",
86
+ "model.layers.17.mlp.gate_proj.weight": "model-00006-of-00021.safetensors",
87
+ "model.layers.17.mlp.up_proj.weight": "model-00006-of-00021.safetensors",
88
+ "model.layers.17.post_attention_layernorm.weight": "model-00006-of-00021.safetensors",
89
+ "model.layers.17.self_attn.k_proj.weight": "model-00006-of-00021.safetensors",
90
+ "model.layers.17.self_attn.o_proj.weight": "model-00006-of-00021.safetensors",
91
+ "model.layers.17.self_attn.q_proj.weight": "model-00006-of-00021.safetensors",
92
+ "model.layers.17.self_attn.v_proj.weight": "model-00006-of-00021.safetensors",
93
+ "model.layers.18.input_layernorm.weight": "model-00006-of-00021.safetensors",
94
+ "model.layers.18.mlp.down_proj.weight": "model-00007-of-00021.safetensors",
95
+ "model.layers.18.mlp.gate_proj.weight": "model-00007-of-00021.safetensors",
96
+ "model.layers.18.mlp.up_proj.weight": "model-00007-of-00021.safetensors",
97
+ "model.layers.18.post_attention_layernorm.weight": "model-00007-of-00021.safetensors",
98
+ "model.layers.18.self_attn.k_proj.weight": "model-00006-of-00021.safetensors",
99
+ "model.layers.18.self_attn.o_proj.weight": "model-00007-of-00021.safetensors",
100
+ "model.layers.18.self_attn.q_proj.weight": "model-00006-of-00021.safetensors",
101
+ "model.layers.18.self_attn.v_proj.weight": "model-00006-of-00021.safetensors",
102
+ "model.layers.19.input_layernorm.weight": "model-00007-of-00021.safetensors",
103
+ "model.layers.19.mlp.down_proj.weight": "model-00007-of-00021.safetensors",
104
+ "model.layers.19.mlp.gate_proj.weight": "model-00007-of-00021.safetensors",
105
+ "model.layers.19.mlp.up_proj.weight": "model-00007-of-00021.safetensors",
106
+ "model.layers.19.post_attention_layernorm.weight": "model-00007-of-00021.safetensors",
107
+ "model.layers.19.self_attn.k_proj.weight": "model-00007-of-00021.safetensors",
108
+ "model.layers.19.self_attn.o_proj.weight": "model-00007-of-00021.safetensors",
109
+ "model.layers.19.self_attn.q_proj.weight": "model-00007-of-00021.safetensors",
110
+ "model.layers.19.self_attn.v_proj.weight": "model-00007-of-00021.safetensors",
111
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00021.safetensors",
112
+ "model.layers.2.mlp.down_proj.weight": "model-00002-of-00021.safetensors",
113
+ "model.layers.2.mlp.gate_proj.weight": "model-00002-of-00021.safetensors",
114
+ "model.layers.2.mlp.up_proj.weight": "model-00002-of-00021.safetensors",
115
+ "model.layers.2.post_attention_layernorm.weight": "model-00002-of-00021.safetensors",
116
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00021.safetensors",
117
+ "model.layers.2.self_attn.o_proj.weight": "model-00002-of-00021.safetensors",
118
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00021.safetensors",
119
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00021.safetensors",
120
+ "model.layers.20.input_layernorm.weight": "model-00007-of-00021.safetensors",
121
+ "model.layers.20.mlp.down_proj.weight": "model-00007-of-00021.safetensors",
122
+ "model.layers.20.mlp.gate_proj.weight": "model-00007-of-00021.safetensors",
123
+ "model.layers.20.mlp.up_proj.weight": "model-00007-of-00021.safetensors",
124
+ "model.layers.20.post_attention_layernorm.weight": "model-00007-of-00021.safetensors",
125
+ "model.layers.20.self_attn.k_proj.weight": "model-00007-of-00021.safetensors",
126
+ "model.layers.20.self_attn.o_proj.weight": "model-00007-of-00021.safetensors",
127
+ "model.layers.20.self_attn.q_proj.weight": "model-00007-of-00021.safetensors",
128
+ "model.layers.20.self_attn.v_proj.weight": "model-00007-of-00021.safetensors",
129
+ "model.layers.21.input_layernorm.weight": "model-00007-of-00021.safetensors",
130
+ "model.layers.21.mlp.down_proj.weight": "model-00008-of-00021.safetensors",
131
+ "model.layers.21.mlp.gate_proj.weight": "model-00008-of-00021.safetensors",
132
+ "model.layers.21.mlp.up_proj.weight": "model-00008-of-00021.safetensors",
133
+ "model.layers.21.post_attention_layernorm.weight": "model-00008-of-00021.safetensors",
134
+ "model.layers.21.self_attn.k_proj.weight": "model-00008-of-00021.safetensors",
135
+ "model.layers.21.self_attn.o_proj.weight": "model-00008-of-00021.safetensors",
136
+ "model.layers.21.self_attn.q_proj.weight": "model-00008-of-00021.safetensors",
137
+ "model.layers.21.self_attn.v_proj.weight": "model-00008-of-00021.safetensors",
138
+ "model.layers.22.input_layernorm.weight": "model-00008-of-00021.safetensors",
139
+ "model.layers.22.mlp.down_proj.weight": "model-00008-of-00021.safetensors",
140
+ "model.layers.22.mlp.gate_proj.weight": "model-00008-of-00021.safetensors",
141
+ "model.layers.22.mlp.up_proj.weight": "model-00008-of-00021.safetensors",
142
+ "model.layers.22.post_attention_layernorm.weight": "model-00008-of-00021.safetensors",
143
+ "model.layers.22.self_attn.k_proj.weight": "model-00008-of-00021.safetensors",
144
+ "model.layers.22.self_attn.o_proj.weight": "model-00008-of-00021.safetensors",
145
+ "model.layers.22.self_attn.q_proj.weight": "model-00008-of-00021.safetensors",
146
+ "model.layers.22.self_attn.v_proj.weight": "model-00008-of-00021.safetensors",
147
+ "model.layers.23.input_layernorm.weight": "model-00008-of-00021.safetensors",
148
+ "model.layers.23.mlp.down_proj.weight": "model-00009-of-00021.safetensors",
149
+ "model.layers.23.mlp.gate_proj.weight": "model-00008-of-00021.safetensors",
150
+ "model.layers.23.mlp.up_proj.weight": "model-00008-of-00021.safetensors",
151
+ "model.layers.23.post_attention_layernorm.weight": "model-00008-of-00021.safetensors",
152
+ "model.layers.23.self_attn.k_proj.weight": "model-00008-of-00021.safetensors",
153
+ "model.layers.23.self_attn.o_proj.weight": "model-00008-of-00021.safetensors",
154
+ "model.layers.23.self_attn.q_proj.weight": "model-00008-of-00021.safetensors",
155
+ "model.layers.23.self_attn.v_proj.weight": "model-00008-of-00021.safetensors",
156
+ "model.layers.24.input_layernorm.weight": "model-00009-of-00021.safetensors",
157
+ "model.layers.24.mlp.down_proj.weight": "model-00009-of-00021.safetensors",
158
+ "model.layers.24.mlp.gate_proj.weight": "model-00009-of-00021.safetensors",
159
+ "model.layers.24.mlp.up_proj.weight": "model-00009-of-00021.safetensors",
160
+ "model.layers.24.post_attention_layernorm.weight": "model-00009-of-00021.safetensors",
161
+ "model.layers.24.self_attn.k_proj.weight": "model-00009-of-00021.safetensors",
162
+ "model.layers.24.self_attn.o_proj.weight": "model-00009-of-00021.safetensors",
163
+ "model.layers.24.self_attn.q_proj.weight": "model-00009-of-00021.safetensors",
164
+ "model.layers.24.self_attn.v_proj.weight": "model-00009-of-00021.safetensors",
165
+ "model.layers.25.input_layernorm.weight": "model-00009-of-00021.safetensors",
166
+ "model.layers.25.mlp.down_proj.weight": "model-00009-of-00021.safetensors",
167
+ "model.layers.25.mlp.gate_proj.weight": "model-00009-of-00021.safetensors",
168
+ "model.layers.25.mlp.up_proj.weight": "model-00009-of-00021.safetensors",
169
+ "model.layers.25.post_attention_layernorm.weight": "model-00009-of-00021.safetensors",
170
+ "model.layers.25.self_attn.k_proj.weight": "model-00009-of-00021.safetensors",
171
+ "model.layers.25.self_attn.o_proj.weight": "model-00009-of-00021.safetensors",
172
+ "model.layers.25.self_attn.q_proj.weight": "model-00009-of-00021.safetensors",
173
+ "model.layers.25.self_attn.v_proj.weight": "model-00009-of-00021.safetensors",
174
+ "model.layers.26.input_layernorm.weight": "model-00009-of-00021.safetensors",
175
+ "model.layers.26.mlp.down_proj.weight": "model-00010-of-00021.safetensors",
176
+ "model.layers.26.mlp.gate_proj.weight": "model-00009-of-00021.safetensors",
177
+ "model.layers.26.mlp.up_proj.weight": "model-00010-of-00021.safetensors",
178
+ "model.layers.26.post_attention_layernorm.weight": "model-00009-of-00021.safetensors",
179
+ "model.layers.26.self_attn.k_proj.weight": "model-00009-of-00021.safetensors",
180
+ "model.layers.26.self_attn.o_proj.weight": "model-00009-of-00021.safetensors",
181
+ "model.layers.26.self_attn.q_proj.weight": "model-00009-of-00021.safetensors",
182
+ "model.layers.26.self_attn.v_proj.weight": "model-00009-of-00021.safetensors",
183
+ "model.layers.27.input_layernorm.weight": "model-00010-of-00021.safetensors",
184
+ "model.layers.27.mlp.down_proj.weight": "model-00010-of-00021.safetensors",
185
+ "model.layers.27.mlp.gate_proj.weight": "model-00010-of-00021.safetensors",
186
+ "model.layers.27.mlp.up_proj.weight": "model-00010-of-00021.safetensors",
187
+ "model.layers.27.post_attention_layernorm.weight": "model-00010-of-00021.safetensors",
188
+ "model.layers.27.self_attn.k_proj.weight": "model-00010-of-00021.safetensors",
189
+ "model.layers.27.self_attn.o_proj.weight": "model-00010-of-00021.safetensors",
190
+ "model.layers.27.self_attn.q_proj.weight": "model-00010-of-00021.safetensors",
191
+ "model.layers.27.self_attn.v_proj.weight": "model-00010-of-00021.safetensors",
192
+ "model.layers.28.input_layernorm.weight": "model-00010-of-00021.safetensors",
193
+ "model.layers.28.mlp.down_proj.weight": "model-00010-of-00021.safetensors",
194
+ "model.layers.28.mlp.gate_proj.weight": "model-00010-of-00021.safetensors",
195
+ "model.layers.28.mlp.up_proj.weight": "model-00010-of-00021.safetensors",
196
+ "model.layers.28.post_attention_layernorm.weight": "model-00010-of-00021.safetensors",
197
+ "model.layers.28.self_attn.k_proj.weight": "model-00010-of-00021.safetensors",
198
+ "model.layers.28.self_attn.o_proj.weight": "model-00010-of-00021.safetensors",
199
+ "model.layers.28.self_attn.q_proj.weight": "model-00010-of-00021.safetensors",
200
+ "model.layers.28.self_attn.v_proj.weight": "model-00010-of-00021.safetensors",
201
+ "model.layers.29.input_layernorm.weight": "model-00010-of-00021.safetensors",
202
+ "model.layers.29.mlp.down_proj.weight": "model-00011-of-00021.safetensors",
203
+ "model.layers.29.mlp.gate_proj.weight": "model-00011-of-00021.safetensors",
204
+ "model.layers.29.mlp.up_proj.weight": "model-00011-of-00021.safetensors",
205
+ "model.layers.29.post_attention_layernorm.weight": "model-00010-of-00021.safetensors",
206
+ "model.layers.29.self_attn.k_proj.weight": "model-00010-of-00021.safetensors",
207
+ "model.layers.29.self_attn.o_proj.weight": "model-00010-of-00021.safetensors",
208
+ "model.layers.29.self_attn.q_proj.weight": "model-00010-of-00021.safetensors",
209
+ "model.layers.29.self_attn.v_proj.weight": "model-00010-of-00021.safetensors",
210
+ "model.layers.3.input_layernorm.weight": "model-00002-of-00021.safetensors",
211
+ "model.layers.3.mlp.down_proj.weight": "model-00002-of-00021.safetensors",
212
+ "model.layers.3.mlp.gate_proj.weight": "model-00002-of-00021.safetensors",
213
+ "model.layers.3.mlp.up_proj.weight": "model-00002-of-00021.safetensors",
214
+ "model.layers.3.post_attention_layernorm.weight": "model-00002-of-00021.safetensors",
215
+ "model.layers.3.self_attn.k_proj.weight": "model-00002-of-00021.safetensors",
216
+ "model.layers.3.self_attn.o_proj.weight": "model-00002-of-00021.safetensors",
217
+ "model.layers.3.self_attn.q_proj.weight": "model-00002-of-00021.safetensors",
218
+ "model.layers.3.self_attn.v_proj.weight": "model-00002-of-00021.safetensors",
219
+ "model.layers.30.input_layernorm.weight": "model-00011-of-00021.safetensors",
220
+ "model.layers.30.mlp.down_proj.weight": "model-00011-of-00021.safetensors",
221
+ "model.layers.30.mlp.gate_proj.weight": "model-00011-of-00021.safetensors",
222
+ "model.layers.30.mlp.up_proj.weight": "model-00011-of-00021.safetensors",
223
+ "model.layers.30.post_attention_layernorm.weight": "model-00011-of-00021.safetensors",
224
+ "model.layers.30.self_attn.k_proj.weight": "model-00011-of-00021.safetensors",
225
+ "model.layers.30.self_attn.o_proj.weight": "model-00011-of-00021.safetensors",
226
+ "model.layers.30.self_attn.q_proj.weight": "model-00011-of-00021.safetensors",
227
+ "model.layers.30.self_attn.v_proj.weight": "model-00011-of-00021.safetensors",
228
+ "model.layers.31.input_layernorm.weight": "model-00011-of-00021.safetensors",
229
+ "model.layers.31.mlp.down_proj.weight": "model-00011-of-00021.safetensors",
230
+ "model.layers.31.mlp.gate_proj.weight": "model-00011-of-00021.safetensors",
231
+ "model.layers.31.mlp.up_proj.weight": "model-00011-of-00021.safetensors",
232
+ "model.layers.31.post_attention_layernorm.weight": "model-00011-of-00021.safetensors",
233
+ "model.layers.31.self_attn.k_proj.weight": "model-00011-of-00021.safetensors",
234
+ "model.layers.31.self_attn.o_proj.weight": "model-00011-of-00021.safetensors",
235
+ "model.layers.31.self_attn.q_proj.weight": "model-00011-of-00021.safetensors",
236
+ "model.layers.31.self_attn.v_proj.weight": "model-00011-of-00021.safetensors",
237
+ "model.layers.32.input_layernorm.weight": "model-00011-of-00021.safetensors",
238
+ "model.layers.32.mlp.down_proj.weight": "model-00012-of-00021.safetensors",
239
+ "model.layers.32.mlp.gate_proj.weight": "model-00012-of-00021.safetensors",
240
+ "model.layers.32.mlp.up_proj.weight": "model-00012-of-00021.safetensors",
241
+ "model.layers.32.post_attention_layernorm.weight": "model-00012-of-00021.safetensors",
242
+ "model.layers.32.self_attn.k_proj.weight": "model-00011-of-00021.safetensors",
243
+ "model.layers.32.self_attn.o_proj.weight": "model-00012-of-00021.safetensors",
244
+ "model.layers.32.self_attn.q_proj.weight": "model-00011-of-00021.safetensors",
245
+ "model.layers.32.self_attn.v_proj.weight": "model-00011-of-00021.safetensors",
246
+ "model.layers.33.input_layernorm.weight": "model-00012-of-00021.safetensors",
247
+ "model.layers.33.mlp.down_proj.weight": "model-00012-of-00021.safetensors",
248
+ "model.layers.33.mlp.gate_proj.weight": "model-00012-of-00021.safetensors",
249
+ "model.layers.33.mlp.up_proj.weight": "model-00012-of-00021.safetensors",
250
+ "model.layers.33.post_attention_layernorm.weight": "model-00012-of-00021.safetensors",
251
+ "model.layers.33.self_attn.k_proj.weight": "model-00012-of-00021.safetensors",
252
+ "model.layers.33.self_attn.o_proj.weight": "model-00012-of-00021.safetensors",
253
+ "model.layers.33.self_attn.q_proj.weight": "model-00012-of-00021.safetensors",
254
+ "model.layers.33.self_attn.v_proj.weight": "model-00012-of-00021.safetensors",
255
+ "model.layers.34.input_layernorm.weight": "model-00012-of-00021.safetensors",
256
+ "model.layers.34.mlp.down_proj.weight": "model-00012-of-00021.safetensors",
257
+ "model.layers.34.mlp.gate_proj.weight": "model-00012-of-00021.safetensors",
258
+ "model.layers.34.mlp.up_proj.weight": "model-00012-of-00021.safetensors",
259
+ "model.layers.34.post_attention_layernorm.weight": "model-00012-of-00021.safetensors",
260
+ "model.layers.34.self_attn.k_proj.weight": "model-00012-of-00021.safetensors",
261
+ "model.layers.34.self_attn.o_proj.weight": "model-00012-of-00021.safetensors",
262
+ "model.layers.34.self_attn.q_proj.weight": "model-00012-of-00021.safetensors",
263
+ "model.layers.34.self_attn.v_proj.weight": "model-00012-of-00021.safetensors",
264
+ "model.layers.35.input_layernorm.weight": "model-00012-of-00021.safetensors",
265
+ "model.layers.35.mlp.down_proj.weight": "model-00013-of-00021.safetensors",
266
+ "model.layers.35.mlp.gate_proj.weight": "model-00013-of-00021.safetensors",
267
+ "model.layers.35.mlp.up_proj.weight": "model-00013-of-00021.safetensors",
268
+ "model.layers.35.post_attention_layernorm.weight": "model-00013-of-00021.safetensors",
269
+ "model.layers.35.self_attn.k_proj.weight": "model-00013-of-00021.safetensors",
270
+ "model.layers.35.self_attn.o_proj.weight": "model-00013-of-00021.safetensors",
271
+ "model.layers.35.self_attn.q_proj.weight": "model-00013-of-00021.safetensors",
272
+ "model.layers.35.self_attn.v_proj.weight": "model-00013-of-00021.safetensors",
273
+ "model.layers.36.input_layernorm.weight": "model-00013-of-00021.safetensors",
274
+ "model.layers.36.mlp.down_proj.weight": "model-00013-of-00021.safetensors",
275
+ "model.layers.36.mlp.gate_proj.weight": "model-00013-of-00021.safetensors",
276
+ "model.layers.36.mlp.up_proj.weight": "model-00013-of-00021.safetensors",
277
+ "model.layers.36.post_attention_layernorm.weight": "model-00013-of-00021.safetensors",
278
+ "model.layers.36.self_attn.k_proj.weight": "model-00013-of-00021.safetensors",
279
+ "model.layers.36.self_attn.o_proj.weight": "model-00013-of-00021.safetensors",
280
+ "model.layers.36.self_attn.q_proj.weight": "model-00013-of-00021.safetensors",
281
+ "model.layers.36.self_attn.v_proj.weight": "model-00013-of-00021.safetensors",
282
+ "model.layers.37.input_layernorm.weight": "model-00013-of-00021.safetensors",
283
+ "model.layers.37.mlp.down_proj.weight": "model-00014-of-00021.safetensors",
284
+ "model.layers.37.mlp.gate_proj.weight": "model-00013-of-00021.safetensors",
285
+ "model.layers.37.mlp.up_proj.weight": "model-00013-of-00021.safetensors",
286
+ "model.layers.37.post_attention_layernorm.weight": "model-00013-of-00021.safetensors",
287
+ "model.layers.37.self_attn.k_proj.weight": "model-00013-of-00021.safetensors",
288
+ "model.layers.37.self_attn.o_proj.weight": "model-00013-of-00021.safetensors",
289
+ "model.layers.37.self_attn.q_proj.weight": "model-00013-of-00021.safetensors",
290
+ "model.layers.37.self_attn.v_proj.weight": "model-00013-of-00021.safetensors",
291
+ "model.layers.38.input_layernorm.weight": "model-00014-of-00021.safetensors",
292
+ "model.layers.38.mlp.down_proj.weight": "model-00014-of-00021.safetensors",
293
+ "model.layers.38.mlp.gate_proj.weight": "model-00014-of-00021.safetensors",
294
+ "model.layers.38.mlp.up_proj.weight": "model-00014-of-00021.safetensors",
295
+ "model.layers.38.post_attention_layernorm.weight": "model-00014-of-00021.safetensors",
296
+ "model.layers.38.self_attn.k_proj.weight": "model-00014-of-00021.safetensors",
297
+ "model.layers.38.self_attn.o_proj.weight": "model-00014-of-00021.safetensors",
298
+ "model.layers.38.self_attn.q_proj.weight": "model-00014-of-00021.safetensors",
299
+ "model.layers.38.self_attn.v_proj.weight": "model-00014-of-00021.safetensors",
300
+ "model.layers.39.input_layernorm.weight": "model-00014-of-00021.safetensors",
301
+ "model.layers.39.mlp.down_proj.weight": "model-00014-of-00021.safetensors",
302
+ "model.layers.39.mlp.gate_proj.weight": "model-00014-of-00021.safetensors",
303
+ "model.layers.39.mlp.up_proj.weight": "model-00014-of-00021.safetensors",
304
+ "model.layers.39.post_attention_layernorm.weight": "model-00014-of-00021.safetensors",
305
+ "model.layers.39.self_attn.k_proj.weight": "model-00014-of-00021.safetensors",
306
+ "model.layers.39.self_attn.o_proj.weight": "model-00014-of-00021.safetensors",
307
+ "model.layers.39.self_attn.q_proj.weight": "model-00014-of-00021.safetensors",
308
+ "model.layers.39.self_attn.v_proj.weight": "model-00014-of-00021.safetensors",
309
+ "model.layers.4.input_layernorm.weight": "model-00002-of-00021.safetensors",
310
+ "model.layers.4.mlp.down_proj.weight": "model-00002-of-00021.safetensors",
311
+ "model.layers.4.mlp.gate_proj.weight": "model-00002-of-00021.safetensors",
312
+ "model.layers.4.mlp.up_proj.weight": "model-00002-of-00021.safetensors",
313
+ "model.layers.4.post_attention_layernorm.weight": "model-00002-of-00021.safetensors",
314
+ "model.layers.4.self_attn.k_proj.weight": "model-00002-of-00021.safetensors",
315
+ "model.layers.4.self_attn.o_proj.weight": "model-00002-of-00021.safetensors",
316
+ "model.layers.4.self_attn.q_proj.weight": "model-00002-of-00021.safetensors",
317
+ "model.layers.4.self_attn.v_proj.weight": "model-00002-of-00021.safetensors",
318
+ "model.layers.40.input_layernorm.weight": "model-00014-of-00021.safetensors",
319
+ "model.layers.40.mlp.down_proj.weight": "model-00015-of-00021.safetensors",
320
+ "model.layers.40.mlp.gate_proj.weight": "model-00014-of-00021.safetensors",
321
+ "model.layers.40.mlp.up_proj.weight": "model-00015-of-00021.safetensors",
322
+ "model.layers.40.post_attention_layernorm.weight": "model-00014-of-00021.safetensors",
323
+ "model.layers.40.self_attn.k_proj.weight": "model-00014-of-00021.safetensors",
324
+ "model.layers.40.self_attn.o_proj.weight": "model-00014-of-00021.safetensors",
325
+ "model.layers.40.self_attn.q_proj.weight": "model-00014-of-00021.safetensors",
326
+ "model.layers.40.self_attn.v_proj.weight": "model-00014-of-00021.safetensors",
327
+ "model.layers.41.input_layernorm.weight": "model-00015-of-00021.safetensors",
328
+ "model.layers.41.mlp.down_proj.weight": "model-00015-of-00021.safetensors",
329
+ "model.layers.41.mlp.gate_proj.weight": "model-00015-of-00021.safetensors",
330
+ "model.layers.41.mlp.up_proj.weight": "model-00015-of-00021.safetensors",
331
+ "model.layers.41.post_attention_layernorm.weight": "model-00015-of-00021.safetensors",
332
+ "model.layers.41.self_attn.k_proj.weight": "model-00015-of-00021.safetensors",
333
+ "model.layers.41.self_attn.o_proj.weight": "model-00015-of-00021.safetensors",
334
+ "model.layers.41.self_attn.q_proj.weight": "model-00015-of-00021.safetensors",
335
+ "model.layers.41.self_attn.v_proj.weight": "model-00015-of-00021.safetensors",
336
+ "model.layers.42.mlp.down_proj.weight": "model-00015-of-00021.safetensors",
337
+ "model.layers.42.mlp.gate_proj.weight": "model-00015-of-00021.safetensors",
338
+ "model.layers.42.mlp.up_proj.weight": "model-00015-of-00021.safetensors",
339
+ "model.layers.42.post_attention_layernorm.weight": "model-00015-of-00021.safetensors",
340
+ "model.layers.43.mlp.down_proj.weight": "model-00015-of-00021.safetensors",
341
+ "model.layers.43.mlp.gate_proj.weight": "model-00015-of-00021.safetensors",
342
+ "model.layers.43.mlp.up_proj.weight": "model-00015-of-00021.safetensors",
343
+ "model.layers.43.post_attention_layernorm.weight": "model-00015-of-00021.safetensors",
344
+ "model.layers.44.mlp.down_proj.weight": "model-00015-of-00021.safetensors",
345
+ "model.layers.44.mlp.gate_proj.weight": "model-00015-of-00021.safetensors",
346
+ "model.layers.44.mlp.up_proj.weight": "model-00015-of-00021.safetensors",
347
+ "model.layers.44.post_attention_layernorm.weight": "model-00015-of-00021.safetensors",
348
+ "model.layers.45.mlp.down_proj.weight": "model-00015-of-00021.safetensors",
349
+ "model.layers.45.mlp.gate_proj.weight": "model-00015-of-00021.safetensors",
350
+ "model.layers.45.mlp.up_proj.weight": "model-00015-of-00021.safetensors",
351
+ "model.layers.45.post_attention_layernorm.weight": "model-00015-of-00021.safetensors",
352
+ "model.layers.46.mlp.down_proj.weight": "model-00016-of-00021.safetensors",
353
+ "model.layers.46.mlp.gate_proj.weight": "model-00016-of-00021.safetensors",
354
+ "model.layers.46.mlp.up_proj.weight": "model-00016-of-00021.safetensors",
355
+ "model.layers.46.post_attention_layernorm.weight": "model-00015-of-00021.safetensors",
356
+ "model.layers.47.mlp.down_proj.weight": "model-00016-of-00021.safetensors",
357
+ "model.layers.47.mlp.gate_proj.weight": "model-00016-of-00021.safetensors",
358
+ "model.layers.47.mlp.up_proj.weight": "model-00016-of-00021.safetensors",
359
+ "model.layers.47.post_attention_layernorm.weight": "model-00016-of-00021.safetensors",
360
+ "model.layers.48.mlp.down_proj.weight": "model-00016-of-00021.safetensors",
361
+ "model.layers.48.mlp.gate_proj.weight": "model-00016-of-00021.safetensors",
362
+ "model.layers.48.mlp.up_proj.weight": "model-00016-of-00021.safetensors",
363
+ "model.layers.48.post_attention_layernorm.weight": "model-00016-of-00021.safetensors",
364
+ "model.layers.49.mlp.down_proj.weight": "model-00016-of-00021.safetensors",
365
+ "model.layers.49.mlp.gate_proj.weight": "model-00016-of-00021.safetensors",
366
+ "model.layers.49.mlp.up_proj.weight": "model-00016-of-00021.safetensors",
367
+ "model.layers.49.post_attention_layernorm.weight": "model-00016-of-00021.safetensors",
368
+ "model.layers.5.input_layernorm.weight": "model-00002-of-00021.safetensors",
369
+ "model.layers.5.mlp.down_proj.weight": "model-00003-of-00021.safetensors",
370
+ "model.layers.5.mlp.gate_proj.weight": "model-00003-of-00021.safetensors",
371
+ "model.layers.5.mlp.up_proj.weight": "model-00003-of-00021.safetensors",
372
+ "model.layers.5.post_attention_layernorm.weight": "model-00003-of-00021.safetensors",
373
+ "model.layers.5.self_attn.k_proj.weight": "model-00003-of-00021.safetensors",
374
+ "model.layers.5.self_attn.o_proj.weight": "model-00003-of-00021.safetensors",
375
+ "model.layers.5.self_attn.q_proj.weight": "model-00003-of-00021.safetensors",
376
+ "model.layers.5.self_attn.v_proj.weight": "model-00003-of-00021.safetensors",
377
+ "model.layers.50.mlp.down_proj.weight": "model-00016-of-00021.safetensors",
378
+ "model.layers.50.mlp.gate_proj.weight": "model-00016-of-00021.safetensors",
379
+ "model.layers.50.mlp.up_proj.weight": "model-00016-of-00021.safetensors",
380
+ "model.layers.50.post_attention_layernorm.weight": "model-00016-of-00021.safetensors",
381
+ "model.layers.51.mlp.down_proj.weight": "model-00016-of-00021.safetensors",
382
+ "model.layers.51.mlp.gate_proj.weight": "model-00016-of-00021.safetensors",
383
+ "model.layers.51.mlp.up_proj.weight": "model-00016-of-00021.safetensors",
384
+ "model.layers.51.post_attention_layernorm.weight": "model-00016-of-00021.safetensors",
385
+ "model.layers.52.input_layernorm.weight": "model-00016-of-00021.safetensors",
386
+ "model.layers.52.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
387
+ "model.layers.52.mlp.gate_proj.weight": "model-00016-of-00021.safetensors",
388
+ "model.layers.52.mlp.up_proj.weight": "model-00016-of-00021.safetensors",
389
+ "model.layers.52.post_attention_layernorm.weight": "model-00016-of-00021.safetensors",
390
+ "model.layers.52.self_attn.k_proj.weight": "model-00016-of-00021.safetensors",
391
+ "model.layers.52.self_attn.o_proj.weight": "model-00016-of-00021.safetensors",
392
+ "model.layers.52.self_attn.q_proj.weight": "model-00016-of-00021.safetensors",
393
+ "model.layers.52.self_attn.v_proj.weight": "model-00016-of-00021.safetensors",
394
+ "model.layers.53.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
395
+ "model.layers.53.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
396
+ "model.layers.53.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
397
+ "model.layers.53.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
398
+ "model.layers.54.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
399
+ "model.layers.54.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
400
+ "model.layers.54.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
401
+ "model.layers.54.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
402
+ "model.layers.55.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
403
+ "model.layers.55.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
404
+ "model.layers.55.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
405
+ "model.layers.55.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
406
+ "model.layers.56.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
407
+ "model.layers.56.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
408
+ "model.layers.56.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
409
+ "model.layers.56.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
410
+ "model.layers.57.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
411
+ "model.layers.57.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
412
+ "model.layers.57.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
413
+ "model.layers.57.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
414
+ "model.layers.58.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
415
+ "model.layers.58.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
416
+ "model.layers.58.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
417
+ "model.layers.58.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
418
+ "model.layers.59.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
419
+ "model.layers.59.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
420
+ "model.layers.59.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
421
+ "model.layers.59.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
422
+ "model.layers.6.mlp.down_proj.weight": "model-00003-of-00021.safetensors",
423
+ "model.layers.6.mlp.gate_proj.weight": "model-00003-of-00021.safetensors",
424
+ "model.layers.6.mlp.up_proj.weight": "model-00003-of-00021.safetensors",
425
+ "model.layers.6.post_attention_layernorm.weight": "model-00003-of-00021.safetensors",
426
+ "model.layers.60.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
427
+ "model.layers.60.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
428
+ "model.layers.60.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
429
+ "model.layers.60.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
430
+ "model.layers.61.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
431
+ "model.layers.61.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
432
+ "model.layers.61.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
433
+ "model.layers.61.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
434
+ "model.layers.62.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
435
+ "model.layers.62.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
436
+ "model.layers.62.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
437
+ "model.layers.62.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
438
+ "model.layers.63.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
439
+ "model.layers.63.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
440
+ "model.layers.63.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
441
+ "model.layers.63.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
442
+ "model.layers.64.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
443
+ "model.layers.64.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
444
+ "model.layers.64.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
445
+ "model.layers.64.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
446
+ "model.layers.65.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
447
+ "model.layers.65.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
448
+ "model.layers.65.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
449
+ "model.layers.65.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
450
+ "model.layers.66.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
451
+ "model.layers.66.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
452
+ "model.layers.66.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
453
+ "model.layers.66.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
454
+ "model.layers.67.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
455
+ "model.layers.67.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
456
+ "model.layers.67.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
457
+ "model.layers.67.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
458
+ "model.layers.68.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
459
+ "model.layers.68.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
460
+ "model.layers.68.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
461
+ "model.layers.68.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
462
+ "model.layers.69.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
463
+ "model.layers.69.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
464
+ "model.layers.69.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
465
+ "model.layers.69.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
466
+ "model.layers.7.mlp.down_proj.weight": "model-00003-of-00021.safetensors",
467
+ "model.layers.7.mlp.gate_proj.weight": "model-00003-of-00021.safetensors",
468
+ "model.layers.7.mlp.up_proj.weight": "model-00003-of-00021.safetensors",
469
+ "model.layers.7.post_attention_layernorm.weight": "model-00003-of-00021.safetensors",
470
+ "model.layers.70.mlp.down_proj.weight": "model-00017-of-00021.safetensors",
471
+ "model.layers.70.mlp.gate_proj.weight": "model-00017-of-00021.safetensors",
472
+ "model.layers.70.mlp.up_proj.weight": "model-00017-of-00021.safetensors",
473
+ "model.layers.70.post_attention_layernorm.weight": "model-00017-of-00021.safetensors",
474
+ "model.layers.71.input_layernorm.weight": "model-00017-of-00021.safetensors",
475
+ "model.layers.71.mlp.down_proj.weight": "model-00018-of-00021.safetensors",
476
+ "model.layers.71.mlp.gate_proj.weight": "model-00018-of-00021.safetensors",
477
+ "model.layers.71.mlp.up_proj.weight": "model-00018-of-00021.safetensors",
478
+ "model.layers.71.post_attention_layernorm.weight": "model-00018-of-00021.safetensors",
479
+ "model.layers.71.self_attn.k_proj.weight": "model-00018-of-00021.safetensors",
480
+ "model.layers.71.self_attn.o_proj.weight": "model-00018-of-00021.safetensors",
481
+ "model.layers.71.self_attn.q_proj.weight": "model-00018-of-00021.safetensors",
482
+ "model.layers.71.self_attn.v_proj.weight": "model-00018-of-00021.safetensors",
483
+ "model.layers.72.input_layernorm.weight": "model-00018-of-00021.safetensors",
484
+ "model.layers.72.mlp.down_proj.weight": "model-00018-of-00021.safetensors",
485
+ "model.layers.72.mlp.gate_proj.weight": "model-00018-of-00021.safetensors",
486
+ "model.layers.72.mlp.up_proj.weight": "model-00018-of-00021.safetensors",
487
+ "model.layers.72.post_attention_layernorm.weight": "model-00018-of-00021.safetensors",
488
+ "model.layers.72.self_attn.k_proj.weight": "model-00018-of-00021.safetensors",
489
+ "model.layers.72.self_attn.o_proj.weight": "model-00018-of-00021.safetensors",
490
+ "model.layers.72.self_attn.q_proj.weight": "model-00018-of-00021.safetensors",
491
+ "model.layers.72.self_attn.v_proj.weight": "model-00018-of-00021.safetensors",
492
+ "model.layers.73.input_layernorm.weight": "model-00018-of-00021.safetensors",
493
+ "model.layers.73.mlp.down_proj.weight": "model-00019-of-00021.safetensors",
494
+ "model.layers.73.mlp.gate_proj.weight": "model-00018-of-00021.safetensors",
495
+ "model.layers.73.mlp.up_proj.weight": "model-00018-of-00021.safetensors",
496
+ "model.layers.73.post_attention_layernorm.weight": "model-00018-of-00021.safetensors",
497
+ "model.layers.73.self_attn.k_proj.weight": "model-00018-of-00021.safetensors",
498
+ "model.layers.73.self_attn.o_proj.weight": "model-00018-of-00021.safetensors",
499
+ "model.layers.73.self_attn.q_proj.weight": "model-00018-of-00021.safetensors",
500
+ "model.layers.73.self_attn.v_proj.weight": "model-00018-of-00021.safetensors",
501
+ "model.layers.74.input_layernorm.weight": "model-00019-of-00021.safetensors",
502
+ "model.layers.74.mlp.down_proj.weight": "model-00019-of-00021.safetensors",
503
+ "model.layers.74.mlp.gate_proj.weight": "model-00019-of-00021.safetensors",
504
+ "model.layers.74.mlp.up_proj.weight": "model-00019-of-00021.safetensors",
505
+ "model.layers.74.post_attention_layernorm.weight": "model-00019-of-00021.safetensors",
506
+ "model.layers.74.self_attn.k_proj.weight": "model-00019-of-00021.safetensors",
507
+ "model.layers.74.self_attn.o_proj.weight": "model-00019-of-00021.safetensors",
508
+ "model.layers.74.self_attn.q_proj.weight": "model-00019-of-00021.safetensors",
509
+ "model.layers.74.self_attn.v_proj.weight": "model-00019-of-00021.safetensors",
510
+ "model.layers.75.input_layernorm.weight": "model-00019-of-00021.safetensors",
511
+ "model.layers.75.mlp.down_proj.weight": "model-00019-of-00021.safetensors",
512
+ "model.layers.75.mlp.gate_proj.weight": "model-00019-of-00021.safetensors",
513
+ "model.layers.75.mlp.up_proj.weight": "model-00019-of-00021.safetensors",
514
+ "model.layers.75.post_attention_layernorm.weight": "model-00019-of-00021.safetensors",
515
+ "model.layers.75.self_attn.k_proj.weight": "model-00019-of-00021.safetensors",
516
+ "model.layers.75.self_attn.o_proj.weight": "model-00019-of-00021.safetensors",
517
+ "model.layers.75.self_attn.q_proj.weight": "model-00019-of-00021.safetensors",
518
+ "model.layers.75.self_attn.v_proj.weight": "model-00019-of-00021.safetensors",
519
+ "model.layers.76.input_layernorm.weight": "model-00019-of-00021.safetensors",
520
+ "model.layers.76.mlp.down_proj.weight": "model-00020-of-00021.safetensors",
521
+ "model.layers.76.mlp.gate_proj.weight": "model-00019-of-00021.safetensors",
522
+ "model.layers.76.mlp.up_proj.weight": "model-00020-of-00021.safetensors",
523
+ "model.layers.76.post_attention_layernorm.weight": "model-00019-of-00021.safetensors",
524
+ "model.layers.76.self_attn.k_proj.weight": "model-00019-of-00021.safetensors",
525
+ "model.layers.76.self_attn.o_proj.weight": "model-00019-of-00021.safetensors",
526
+ "model.layers.76.self_attn.q_proj.weight": "model-00019-of-00021.safetensors",
527
+ "model.layers.76.self_attn.v_proj.weight": "model-00019-of-00021.safetensors",
528
+ "model.layers.77.input_layernorm.weight": "model-00020-of-00021.safetensors",
529
+ "model.layers.77.mlp.down_proj.weight": "model-00020-of-00021.safetensors",
530
+ "model.layers.77.mlp.gate_proj.weight": "model-00020-of-00021.safetensors",
531
+ "model.layers.77.mlp.up_proj.weight": "model-00020-of-00021.safetensors",
532
+ "model.layers.77.post_attention_layernorm.weight": "model-00020-of-00021.safetensors",
533
+ "model.layers.77.self_attn.k_proj.weight": "model-00020-of-00021.safetensors",
534
+ "model.layers.77.self_attn.o_proj.weight": "model-00020-of-00021.safetensors",
535
+ "model.layers.77.self_attn.q_proj.weight": "model-00020-of-00021.safetensors",
536
+ "model.layers.77.self_attn.v_proj.weight": "model-00020-of-00021.safetensors",
537
+ "model.layers.78.input_layernorm.weight": "model-00020-of-00021.safetensors",
538
+ "model.layers.78.mlp.down_proj.weight": "model-00020-of-00021.safetensors",
539
+ "model.layers.78.mlp.gate_proj.weight": "model-00020-of-00021.safetensors",
540
+ "model.layers.78.mlp.up_proj.weight": "model-00020-of-00021.safetensors",
541
+ "model.layers.78.post_attention_layernorm.weight": "model-00020-of-00021.safetensors",
542
+ "model.layers.78.self_attn.k_proj.weight": "model-00020-of-00021.safetensors",
543
+ "model.layers.78.self_attn.o_proj.weight": "model-00020-of-00021.safetensors",
544
+ "model.layers.78.self_attn.q_proj.weight": "model-00020-of-00021.safetensors",
545
+ "model.layers.78.self_attn.v_proj.weight": "model-00020-of-00021.safetensors",
546
+ "model.layers.79.input_layernorm.weight": "model-00020-of-00021.safetensors",
547
+ "model.layers.79.mlp.down_proj.weight": "model-00021-of-00021.safetensors",
548
+ "model.layers.79.mlp.gate_proj.weight": "model-00021-of-00021.safetensors",
549
+ "model.layers.79.mlp.up_proj.weight": "model-00021-of-00021.safetensors",
550
+ "model.layers.79.post_attention_layernorm.weight": "model-00020-of-00021.safetensors",
551
+ "model.layers.79.self_attn.k_proj.weight": "model-00020-of-00021.safetensors",
552
+ "model.layers.79.self_attn.o_proj.weight": "model-00020-of-00021.safetensors",
553
+ "model.layers.79.self_attn.q_proj.weight": "model-00020-of-00021.safetensors",
554
+ "model.layers.79.self_attn.v_proj.weight": "model-00020-of-00021.safetensors",
555
+ "model.layers.8.input_layernorm.weight": "model-00003-of-00021.safetensors",
556
+ "model.layers.8.mlp.down_proj.weight": "model-00003-of-00021.safetensors",
557
+ "model.layers.8.mlp.gate_proj.weight": "model-00003-of-00021.safetensors",
558
+ "model.layers.8.mlp.up_proj.weight": "model-00003-of-00021.safetensors",
559
+ "model.layers.8.post_attention_layernorm.weight": "model-00003-of-00021.safetensors",
560
+ "model.layers.8.self_attn.k_proj.weight": "model-00003-of-00021.safetensors",
561
+ "model.layers.8.self_attn.o_proj.weight": "model-00003-of-00021.safetensors",
562
+ "model.layers.8.self_attn.q_proj.weight": "model-00003-of-00021.safetensors",
563
+ "model.layers.8.self_attn.v_proj.weight": "model-00003-of-00021.safetensors",
564
+ "model.layers.9.input_layernorm.weight": "model-00003-of-00021.safetensors",
565
+ "model.layers.9.mlp.down_proj.weight": "model-00004-of-00021.safetensors",
566
+ "model.layers.9.mlp.gate_proj.weight": "model-00004-of-00021.safetensors",
567
+ "model.layers.9.mlp.up_proj.weight": "model-00004-of-00021.safetensors",
568
+ "model.layers.9.post_attention_layernorm.weight": "model-00004-of-00021.safetensors",
569
+ "model.layers.9.self_attn.k_proj.weight": "model-00003-of-00021.safetensors",
570
+ "model.layers.9.self_attn.o_proj.weight": "model-00004-of-00021.safetensors",
571
+ "model.layers.9.self_attn.q_proj.weight": "model-00003-of-00021.safetensors",
572
+ "model.layers.9.self_attn.v_proj.weight": "model-00003-of-00021.safetensors",
573
+ "model.norm.weight": "model-00021-of-00021.safetensors"
574
+ }
575
+ }
modeling_decilm.py ADDED
@@ -0,0 +1,1681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Nvidia Corporation, Google Inc, HuggingFace Inc, EleutherAI. All rights reserved.
3
+ #
4
+ # This code for Nvidia's model is based on the Llama modeling code by HuggingFace,
5
+ # which is in turn based on EleutherAI's GPT-NeoX library and the GPT-NeoX and
6
+ # OPT implementations in this library.
7
+ # Sliding window code based on Gemma2 by Google.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers import GenerationConfig
30
+ from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING, GenerationMixin, GenerateOutput
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
33
+ from transformers.utils import (
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ is_flash_attn_greater_or_equal_2_10,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+
41
+ from .block_config import AttentionConfig, FFNConfig
42
+ from .configuration_decilm import DeciLMConfig
43
+ from .transformers_4_44_2__activations import ACT2FN
44
+ from .transformers_4_44_2__cache_utils import Cache, StaticCache
45
+ from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter
46
+ from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import _flash_attention_forward
47
+ from .transformers_4_44_2__modeling_outputs import (
48
+ BaseModelOutputWithPast,
49
+ CausalLMOutputWithPast,
50
+ QuestionAnsweringModelOutput,
51
+ SequenceClassifierOutputWithPast,
52
+ TokenClassifierOutput,
53
+ )
54
+ from .transformers_4_44_2__modeling_rope_utils import ROPE_INIT_FUNCTIONS
55
+ from .transformers_4_44_2__pytorch_utils import ALL_LAYERNORM_LAYERS
56
+ from .variable_cache import VariableCache
57
+
58
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[DeciLMConfig.model_type] = "DeciLMForCausalLM"
59
+ logger = logging.get_logger(__name__)
60
+
61
+ _CONFIG_FOR_DOC = "DeciLMConfig"
62
+
63
+
64
+ def _prepare_4d_causal_attention_mask_with_cache_position(
65
+ attention_mask: torch.Tensor,
66
+ sequence_length: int,
67
+ target_length: int,
68
+ dtype: torch.dtype,
69
+ device: torch.device,
70
+ min_dtype: float,
71
+ cache_position: torch.Tensor,
72
+ batch_size: int,
73
+ ):
74
+ """
75
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
76
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
77
+
78
+ Args:
79
+ attention_mask (`torch.Tensor`):
80
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
81
+ sequence_length (`int`):
82
+ The sequence length being processed.
83
+ target_length (`int`):
84
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
85
+ dtype (`torch.dtype`):
86
+ The dtype to use for the 4D attention mask.
87
+ device (`torch.device`):
88
+ The device to place the 4D attention mask on.
89
+ min_dtype (`float`):
90
+ The minimum value representable with the dtype `dtype`.
91
+ cache_position (`torch.Tensor`):
92
+ Indices depicting the position of the input sequence tokens in the sequence.
93
+ batch_size (`torch.Tensor`):
94
+ Batch size.
95
+ """
96
+ if attention_mask is not None and attention_mask.dim() == 4:
97
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
98
+ causal_mask = attention_mask
99
+ else:
100
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
101
+ if sequence_length != 1:
102
+ causal_mask = torch.triu(causal_mask, diagonal=1)
103
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
104
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
105
+ if attention_mask is not None:
106
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
107
+ mask_length = attention_mask.shape[-1]
108
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
109
+ padding_mask = padding_mask == 0
110
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
111
+ padding_mask, min_dtype
112
+ )
113
+
114
+ return causal_mask
115
+
116
+
117
+ class DeciLMRMSNorm(nn.Module):
118
+ def __init__(self, hidden_size, eps=1e-6):
119
+ """
120
+ DeciLMRMSNorm is equivalent to T5LayerNorm
121
+ """
122
+ super().__init__()
123
+ self.weight = nn.Parameter(torch.ones(hidden_size))
124
+ self.variance_epsilon = eps
125
+
126
+ def forward(self, hidden_states):
127
+ input_dtype = hidden_states.dtype
128
+ hidden_states = hidden_states.to(torch.float32)
129
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
130
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
131
+ return self.weight * hidden_states.to(input_dtype)
132
+
133
+ def extra_repr(self):
134
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
135
+
136
+
137
+ ALL_LAYERNORM_LAYERS.append(DeciLMRMSNorm)
138
+
139
+
140
+ class DeciLMRotaryEmbedding(nn.Module):
141
+ def __init__(
142
+ self,
143
+ dim=None,
144
+ max_position_embeddings=2048,
145
+ base=10000,
146
+ device=None,
147
+ scaling_factor=1.0,
148
+ rope_type="default",
149
+ config: Optional[DeciLMConfig] = None,
150
+ ):
151
+ super().__init__()
152
+ # TODO (joao): remove the `if` below, only used for BC
153
+ self.rope_kwargs = {}
154
+ if config is None:
155
+ logger.warning_once(
156
+ "`DeciLMRotaryEmbedding` can now be fully parameterized by passing the model config through the "
157
+ "`config` argument. All other arguments will be removed in v4.45"
158
+ )
159
+ self.rope_kwargs = {
160
+ "rope_type": rope_type,
161
+ "factor": scaling_factor,
162
+ "dim": dim,
163
+ "base": base,
164
+ "max_position_embeddings": max_position_embeddings,
165
+ }
166
+ self.rope_type = rope_type
167
+ self.max_seq_len_cached = max_position_embeddings
168
+ self.original_max_seq_len = max_position_embeddings
169
+ else:
170
+ # BC: "rope_type" was originally "type"
171
+ if config.rope_scaling is not None:
172
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
173
+ else:
174
+ self.rope_type = "default"
175
+ self.max_seq_len_cached = config.max_position_embeddings
176
+ self.original_max_seq_len = config.max_position_embeddings
177
+
178
+ self.config = config
179
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
180
+
181
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
182
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
183
+ self.original_inv_freq = self.inv_freq
184
+
185
+ def _dynamic_frequency_update(self, position_ids, device):
186
+ """
187
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
188
+ 1 - growing beyond the cached sequence length (allow scaling)
189
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
190
+ """
191
+ seq_len = torch.max(position_ids) + 1
192
+ if seq_len > self.max_seq_len_cached: # growth
193
+ inv_freq, self.attention_scaling = self.rope_init_fn(
194
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
195
+ )
196
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
197
+ self.max_seq_len_cached = seq_len
198
+
199
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
200
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
201
+ self.max_seq_len_cached = self.original_max_seq_len
202
+
203
+ @torch.no_grad()
204
+ def forward(self, x, position_ids):
205
+ if "dynamic" in self.rope_type:
206
+ self._dynamic_frequency_update(position_ids, device=x.device)
207
+
208
+ # Core RoPE block
209
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
210
+ position_ids_expanded = position_ids[:, None, :].float()
211
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
212
+ device_type = x.device.type
213
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
214
+ with torch.autocast(device_type=device_type, enabled=False):
215
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
216
+ emb = torch.cat((freqs, freqs), dim=-1)
217
+ cos = emb.cos()
218
+ sin = emb.sin()
219
+
220
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
221
+ cos = cos * self.attention_scaling
222
+ sin = sin * self.attention_scaling
223
+
224
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
225
+
226
+
227
+ class DeciLMLinearScalingRotaryEmbedding(DeciLMRotaryEmbedding):
228
+ """DeciLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
229
+
230
+ def __init__(self, *args, **kwargs):
231
+ logger.warning_once(
232
+ "`DeciLMLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
233
+ "`DeciLMRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
234
+ )
235
+ kwargs["rope_type"] = "linear"
236
+ super().__init__(*args, **kwargs)
237
+
238
+
239
+ class DeciLMDynamicNTKScalingRotaryEmbedding(DeciLMRotaryEmbedding):
240
+ """DeciLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
241
+
242
+ def __init__(self, *args, **kwargs):
243
+ logger.warning_once(
244
+ "`DeciLMDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
245
+ "`DeciLMRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
246
+ "__init__)."
247
+ )
248
+ kwargs["rope_type"] = "dynamic"
249
+ super().__init__(*args, **kwargs)
250
+
251
+
252
+ def rotate_half(x):
253
+ """Rotates half the hidden dims of the input."""
254
+ x1 = x[..., : x.shape[-1] // 2]
255
+ x2 = x[..., x.shape[-1] // 2:]
256
+ return torch.cat((-x2, x1), dim=-1)
257
+
258
+
259
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
260
+ """Applies Rotary Position Embedding to the query and key tensors.
261
+
262
+ Args:
263
+ q (`torch.Tensor`): The query tensor.
264
+ k (`torch.Tensor`): The key tensor.
265
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
266
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
267
+ position_ids (`torch.Tensor`, *optional*):
268
+ Deprecated and unused.
269
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
270
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
271
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
272
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
273
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
274
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
275
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
276
+ Returns:
277
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
278
+ """
279
+ cos = cos.unsqueeze(unsqueeze_dim)
280
+ sin = sin.unsqueeze(unsqueeze_dim)
281
+ q_embed = (q * cos) + (rotate_half(q) * sin)
282
+ k_embed = (k * cos) + (rotate_half(k) * sin)
283
+ return q_embed, k_embed
284
+
285
+
286
+ class DeciLMMLP(nn.Module):
287
+ def __init__(self,
288
+ config: DeciLMConfig,
289
+ ffn_config: FFNConfig,
290
+ ):
291
+ super().__init__()
292
+ self.config = config
293
+ self.ffn_config = ffn_config
294
+ self.hidden_size = config.hidden_size
295
+ self.intermediate_size = _ffn_mult_to_intermediate_size(
296
+ ffn_config.ffn_mult, config.hidden_size) # DeciLM-specific code
297
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
298
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
299
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
300
+ self.act_fn = ACT2FN[config.hidden_act]
301
+
302
+ if ffn_config.sparsify is not None:
303
+ self.register_full_backward_hook(sparsity_backward_hook)
304
+
305
+ def forward(self, x):
306
+ if self.config.pretraining_tp > 1:
307
+ slice = self.intermediate_size // self.config.pretraining_tp
308
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
309
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
310
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
311
+
312
+ gate_proj = torch.cat(
313
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
314
+ )
315
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
316
+
317
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
318
+ down_proj = [
319
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
320
+ ]
321
+ down_proj = sum(down_proj)
322
+ else:
323
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
324
+
325
+ return down_proj
326
+
327
+
328
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
329
+ """
330
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
331
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
332
+ """
333
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
334
+ if n_rep == 1:
335
+ return hidden_states
336
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
337
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
338
+
339
+
340
+ class DeciLMAttention(nn.Module):
341
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
342
+
343
+ def __init__(self,
344
+ config: DeciLMConfig,
345
+ attention_config: AttentionConfig,
346
+ layer_idx: Optional[int] = None,
347
+ ):
348
+ super().__init__()
349
+ self.config = config
350
+ self.attention_config = attention_config
351
+ self.layer_idx = layer_idx
352
+ if layer_idx is None:
353
+ logger.warning_once(
354
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
355
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
356
+ "when creating this class."
357
+ )
358
+
359
+ self.attention_dropout = config.attention_dropout
360
+ self.hidden_size = config.hidden_size
361
+ self.num_heads = config.num_attention_heads
362
+ self.head_dim = self.hidden_size // self.num_heads
363
+ self.num_key_value_groups = attention_config.n_heads_in_group # DeciLM-specific code
364
+ self.num_key_value_heads = self.num_heads // self.num_key_value_groups # DeciLM-specific code
365
+ self.max_position_embeddings = config.max_position_embeddings
366
+ self.rope_theta = config.rope_theta
367
+ self.is_causal = True
368
+
369
+ if (self.head_dim * self.num_heads) != self.hidden_size:
370
+ raise ValueError(
371
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
372
+ f" and `num_heads`: {self.num_heads})."
373
+ )
374
+
375
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
376
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
377
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
378
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
379
+
380
+ # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
381
+ self.rotary_emb = DeciLMRotaryEmbedding(config=self.config)
382
+
383
+ if attention_config.sparsify is not None:
384
+ self.register_full_backward_hook(sparsity_backward_hook)
385
+
386
+ def forward(
387
+ self,
388
+ hidden_states: torch.Tensor,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ position_ids: Optional[torch.LongTensor] = None,
391
+ past_key_value: Optional[Cache] = None,
392
+ output_attentions: bool = False,
393
+ use_cache: bool = False,
394
+ cache_position: Optional[torch.LongTensor] = None,
395
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
396
+ **kwargs,
397
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
398
+ bsz, q_len, _ = hidden_states.size()
399
+ if self.config.pretraining_tp > 1:
400
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
401
+ query_slices = self.q_proj.weight.split(
402
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
403
+ )
404
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
405
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
406
+
407
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
408
+ query_states = torch.cat(query_states, dim=-1)
409
+
410
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
411
+ key_states = torch.cat(key_states, dim=-1)
412
+
413
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
414
+ value_states = torch.cat(value_states, dim=-1)
415
+
416
+ else:
417
+ query_states = self.q_proj(hidden_states)
418
+ key_states = self.k_proj(hidden_states)
419
+ value_states = self.v_proj(hidden_states)
420
+
421
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
422
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
423
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
424
+
425
+ if position_embeddings is None:
426
+ logger.warning_once(
427
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
428
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
429
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
430
+ "removed and `position_embeddings` will be mandatory."
431
+ )
432
+ cos, sin = self.rotary_emb(value_states, position_ids)
433
+ else:
434
+ cos, sin = position_embeddings
435
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
436
+
437
+ if past_key_value is not None:
438
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
439
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
440
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
441
+
442
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
443
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
444
+
445
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
446
+
447
+ if attention_mask is not None: # no matter the length, we just slice it
448
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
449
+ attn_weights = attn_weights + causal_mask
450
+
451
+ # upcast attention to fp32
452
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
453
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
454
+ attn_output = torch.matmul(attn_weights, value_states)
455
+
456
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
457
+ raise ValueError(
458
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
459
+ f" {attn_output.size()}"
460
+ )
461
+
462
+ attn_output = attn_output.transpose(1, 2).contiguous()
463
+
464
+ attn_output = attn_output.reshape(bsz, q_len, -1)
465
+
466
+ if self.config.pretraining_tp > 1:
467
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
468
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
469
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
470
+ else:
471
+ attn_output = self.o_proj(attn_output)
472
+
473
+ if not output_attentions:
474
+ attn_weights = None
475
+
476
+ return attn_output, attn_weights, past_key_value
477
+
478
+
479
+ class DeciLMFlashAttention2(DeciLMAttention):
480
+ """
481
+ DeciLM flash attention module. This module inherits from `DeciLMAttention` as the weights of the module stays
482
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
483
+ flash attention and deal with padding tokens in case the input contains any of them.
484
+ """
485
+
486
+ def __init__(self, *args, **kwargs):
487
+ super().__init__(*args, **kwargs)
488
+
489
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
490
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
491
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
492
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
493
+
494
+ self.sliding_window = self.attention_config.prefill_sliding_window
495
+
496
+ def forward(
497
+ self,
498
+ hidden_states: torch.Tensor,
499
+ attention_mask: Optional[torch.LongTensor] = None,
500
+ position_ids: Optional[torch.LongTensor] = None,
501
+ past_key_value: Optional[Cache] = None,
502
+ output_attentions: bool = False,
503
+ use_cache: bool = False,
504
+ cache_position: Optional[torch.LongTensor] = None,
505
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
506
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
507
+ output_attentions = False
508
+
509
+ bsz, q_len, _ = hidden_states.size()
510
+
511
+ query_states = self.q_proj(hidden_states)
512
+ key_states = self.k_proj(hidden_states)
513
+ value_states = self.v_proj(hidden_states)
514
+
515
+ # Flash attention requires the input to have the shape
516
+ # batch_size x seq_length x head_dim x hidden_dim
517
+ # therefore we just need to keep the original shape
518
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
519
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
520
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
521
+
522
+ if position_embeddings is None:
523
+ logger.warning_once(
524
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
525
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
526
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
527
+ "removed and `position_embeddings` will be mandatory."
528
+ )
529
+ cos, sin = self.rotary_emb(value_states, position_ids)
530
+ else:
531
+ cos, sin = position_embeddings
532
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
533
+
534
+ if past_key_value is not None:
535
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
536
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
537
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
538
+
539
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
540
+ # to be able to avoid many of these transpose/reshape/view.
541
+ query_states = query_states.transpose(1, 2)
542
+ key_states = key_states.transpose(1, 2)
543
+ value_states = value_states.transpose(1, 2)
544
+
545
+ dropout_rate = self.attention_dropout if self.training else 0.0
546
+
547
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
548
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
549
+ # cast them back in the correct dtype just to be sure everything works as expected.
550
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
551
+ # in fp32. (DeciLMRMSNorm handles it correctly)
552
+
553
+ input_dtype = query_states.dtype
554
+ if input_dtype == torch.float32:
555
+ if torch.is_autocast_enabled():
556
+ target_dtype = torch.get_autocast_gpu_dtype()
557
+ # Handle the case where the model is quantized
558
+ elif hasattr(self.config, "_pre_quantization_dtype"):
559
+ target_dtype = self.config._pre_quantization_dtype
560
+ else:
561
+ target_dtype = self.q_proj.weight.dtype
562
+
563
+ logger.warning_once(
564
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
565
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
566
+ f" {target_dtype}."
567
+ )
568
+
569
+ query_states = query_states.to(target_dtype)
570
+ key_states = key_states.to(target_dtype)
571
+ value_states = value_states.to(target_dtype)
572
+
573
+ attn_output = _flash_attention_forward(
574
+ query_states,
575
+ key_states,
576
+ value_states,
577
+ attention_mask,
578
+ q_len,
579
+ position_ids=position_ids,
580
+ dropout=dropout_rate,
581
+ sliding_window=self.sliding_window,
582
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
583
+ is_causal=self.is_causal,
584
+ )
585
+
586
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
587
+ attn_output = self.o_proj(attn_output)
588
+
589
+ if not output_attentions:
590
+ attn_weights = None
591
+
592
+ return attn_output, attn_weights, past_key_value
593
+
594
+
595
+ DECILM_ATTENTION_CLASSES = {
596
+ "eager": DeciLMAttention,
597
+ "flash_attention_2": DeciLMFlashAttention2,
598
+ }
599
+
600
+
601
+ class DeciLMDecoderLayer(nn.Module):
602
+ # DeciLM-specific code
603
+ def __init__(self, config: DeciLMConfig, layer_idx: int):
604
+ super().__init__()
605
+ self.config = config
606
+ self.hidden_size = config.hidden_size
607
+ self.block_config = config.block_configs[layer_idx]
608
+ self.attention_config = self.block_config.attention
609
+ self.ffn_config = self.block_config.ffn
610
+
611
+ if not self.attention_config.no_op:
612
+ self.input_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
613
+ if not self.attention_config.replace_with_linear:
614
+ self.self_attn = DECILM_ATTENTION_CLASSES[config._attn_implementation](
615
+ config=config, attention_config=self.attention_config, layer_idx=layer_idx)
616
+ else:
617
+ self.self_attn = DeciLMLinearAttention(config)
618
+
619
+ if not self.ffn_config.no_op:
620
+ self.post_attention_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
621
+ if not self.ffn_config.replace_with_linear:
622
+ self.mlp = DeciLMMLP(config, self.ffn_config)
623
+ else:
624
+ self.mlp = DeciLMLinearMLP(config)
625
+
626
+ self.is_sliding = self.attention_config.is_sliding
627
+ self.sliding_window = self.attention_config.prefill_sliding_window
628
+
629
+ def forward(
630
+ self,
631
+ hidden_states: torch.Tensor,
632
+ attention_mask: Optional[torch.Tensor] = None,
633
+ position_ids: Optional[torch.LongTensor] = None,
634
+ past_key_value: Optional[Cache] = None,
635
+ output_attentions: Optional[bool] = False,
636
+ use_cache: Optional[bool] = False,
637
+ cache_position: Optional[torch.LongTensor] = None,
638
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
639
+ **kwargs,
640
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
641
+ """
642
+ Args:
643
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
644
+ attention_mask (`torch.FloatTensor`, *optional*):
645
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
646
+ query_sequence_length, key_sequence_length)` if default attention is used.
647
+ output_attentions (`bool`, *optional*):
648
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
649
+ returned tensors for more detail.
650
+ use_cache (`bool`, *optional*):
651
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
652
+ (see `past_key_values`).
653
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
654
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
655
+ Indices depicting the position of the input sequence tokens in the sequence
656
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
657
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
658
+ with `head_dim` being the embedding dimension of each attention head.
659
+ kwargs (`dict`, *optional*):
660
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
661
+ into the model
662
+ """
663
+ if self.attention_config.unshifted_sink and self.attention_config.is_sink:
664
+ attention_mask = self._unshifted_sink_mask(
665
+ attention_mask, hidden_states,
666
+ self.attention_config.window_length, self.attention_config.num_sink_tokens)
667
+ else:
668
+ attention_mask = self._gemma2_window_mask(attention_mask, hidden_states, past_key_value)
669
+
670
+ self_attn_weights = None
671
+ present_key_value = past_key_value
672
+ if self.attention_config.no_op:
673
+ pass
674
+ elif self.attention_config.replace_with_linear:
675
+ residual = hidden_states
676
+ hidden_states = self.input_layernorm(hidden_states)
677
+ hidden_states = self.self_attn(hidden_states)
678
+ hidden_states = residual + hidden_states
679
+ else:
680
+ residual = hidden_states
681
+ hidden_states = self.input_layernorm(hidden_states)
682
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
683
+ hidden_states=hidden_states,
684
+ attention_mask=attention_mask,
685
+ position_ids=position_ids,
686
+ past_key_value=past_key_value,
687
+ output_attentions=output_attentions,
688
+ use_cache=use_cache,
689
+ cache_position=cache_position,
690
+ position_embeddings=position_embeddings,
691
+ **kwargs,
692
+ )
693
+ hidden_states = residual + hidden_states
694
+
695
+ if not self.ffn_config.no_op:
696
+ residual = hidden_states
697
+ hidden_states = self.post_attention_layernorm(hidden_states)
698
+ hidden_states = self.mlp(hidden_states)
699
+ hidden_states = residual + hidden_states
700
+
701
+ outputs = (hidden_states,)
702
+
703
+ if output_attentions:
704
+ outputs += (self_attn_weights,)
705
+
706
+ if use_cache:
707
+ outputs += (present_key_value,)
708
+
709
+ return outputs
710
+
711
+ def _gemma2_window_mask(self,
712
+ attention_mask: Optional[torch.Tensor],
713
+ hidden_states: torch.Tensor,
714
+ past_key_value: Optional[VariableCache],
715
+ ) -> Optional[torch.Tensor]:
716
+ if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
717
+ # Flash-attn is a 2D tensor
718
+ if self.config._attn_implementation == "flash_attention_2":
719
+ if past_key_value is not None: # when decoding
720
+ attention_mask = attention_mask[:, -self.sliding_window:]
721
+ else:
722
+ min_dtype = torch.finfo(hidden_states.dtype).min
723
+ sliding_window_mask = torch.tril(
724
+ torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
725
+ )
726
+ attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
727
+ if attention_mask.shape[-1] <= 1: # when decoding
728
+ attention_mask = attention_mask[:, :, :, -self.sliding_window:]
729
+ return attention_mask
730
+
731
+ def _unshifted_sink_mask(self,
732
+ attention_mask: torch.Tensor,
733
+ hidden_states: torch.Tensor,
734
+ window_length: int,
735
+ num_sink_tokens: Optional[int],
736
+ ) -> torch.Tensor:
737
+ assert self.config._attn_implementation == "eager", "Unshifted sink is only supported in 'eager' mode."
738
+ assert attention_mask is not None, "The attention mask seems to not be prepared"
739
+
740
+ attention_mask = attention_mask.clone()
741
+ min_dtype = torch.finfo(hidden_states.dtype).min
742
+
743
+ if window_length == 0:
744
+ attention_mask = torch.full_like(attention_mask, fill_value=min_dtype)
745
+ else:
746
+ query_length = attention_mask.shape[-2]
747
+ is_decode = (query_length == 1)
748
+ if is_decode:
749
+ attention_mask[:, :, :, :-window_length] = min_dtype
750
+ else:
751
+ sliding_window_mask = torch.tril(
752
+ torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-window_length
753
+ )
754
+ attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
755
+
756
+ attention_mask[:, :, :, :num_sink_tokens] = 0
757
+ return attention_mask
758
+
759
+
760
+ DECILM_START_DOCSTRING = r"""
761
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
762
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
763
+ etc.)
764
+
765
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
766
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
767
+ and behavior.
768
+
769
+ Parameters:
770
+ config ([`DeciLMConfig`]):
771
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
772
+ load the weights associated with the model, only the configuration. Check out the
773
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
774
+ """
775
+
776
+
777
+ @add_start_docstrings(
778
+ "The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
779
+ DECILM_START_DOCSTRING,
780
+ )
781
+ class DeciLMPreTrainedModel(PreTrainedModel):
782
+ config_class = DeciLMConfig
783
+ base_model_prefix = "model"
784
+ supports_gradient_checkpointing = True
785
+ _no_split_modules = ["DeciLMDecoderLayer"]
786
+ _skip_keys_device_placement = ["past_key_values"]
787
+ _supports_flash_attn_2 = True
788
+ _supports_sdpa = False
789
+ _supports_cache_class = True
790
+ _supports_quantized_cache = False
791
+ _supports_static_cache = True
792
+
793
+ def _init_weights(self, module):
794
+ std = self.config.initializer_range
795
+ if isinstance(module, nn.Linear):
796
+ module.weight.data.normal_(mean=0.0, std=std)
797
+ if module.bias is not None:
798
+ module.bias.data.zero_()
799
+ elif isinstance(module, nn.Embedding):
800
+ module.weight.data.normal_(mean=0.0, std=std)
801
+ if module.padding_idx is not None:
802
+ module.weight.data[module.padding_idx].zero_()
803
+
804
+ def _prepare_generation_config(
805
+ self, generation_config: Optional[GenerationConfig], **kwargs: dict
806
+ ) -> tuple[GenerationConfig, dict]:
807
+ # DeciLM-specific code
808
+ generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
809
+ generation_config.cache_implementation = "variable"
810
+ NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = VariableCache
811
+ return generation_config, model_kwargs
812
+
813
+
814
+ DECILM_INPUTS_DOCSTRING = r"""
815
+ Args:
816
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
817
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
818
+ it.
819
+
820
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
821
+ [`PreTrainedTokenizer.__call__`] for details.
822
+
823
+ [What are input IDs?](../glossary#input-ids)
824
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
825
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
826
+
827
+ - 1 for tokens that are **not masked**,
828
+ - 0 for tokens that are **masked**.
829
+
830
+ [What are attention masks?](../glossary#attention-mask)
831
+
832
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
833
+ [`PreTrainedTokenizer.__call__`] for details.
834
+
835
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
836
+ `past_key_values`).
837
+
838
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
839
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
840
+ information on the default strategy.
841
+
842
+ - 1 indicates the head is **not masked**,
843
+ - 0 indicates the head is **masked**.
844
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
845
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
846
+ config.n_positions - 1]`.
847
+
848
+ [What are position IDs?](../glossary#position-ids)
849
+ past_key_values (`VariableCache`, *optional*):
850
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
851
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
852
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
853
+
854
+ If passed to the forward function, past_key_values must be a VariableCache object (see imports).
855
+ For generation purposes, this is already handled inside model.generate().
856
+
857
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
858
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
859
+ of shape `(batch_size, sequence_length)`.
860
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
861
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
862
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
863
+ model's internal embedding lookup matrix.
864
+ use_cache (`bool`, *optional*):
865
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
866
+ `past_key_values`).
867
+ output_attentions (`bool`, *optional*):
868
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
869
+ tensors for more detail.
870
+ output_hidden_states (`bool`, *optional*):
871
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
872
+ more detail.
873
+ return_dict (`bool`, *optional*):
874
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
875
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
876
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
877
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
878
+ the complete sequence length.
879
+ """
880
+
881
+
882
+ @add_start_docstrings(
883
+ "The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
884
+ DECILM_START_DOCSTRING,
885
+ )
886
+ class DeciLMModel(DeciLMPreTrainedModel):
887
+ """
888
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`]
889
+
890
+ Args:
891
+ config: DeciLMConfig
892
+ """
893
+
894
+ def __init__(self, config: DeciLMConfig):
895
+ super().__init__(config)
896
+ self.padding_idx = config.pad_token_id
897
+ self.vocab_size = config.vocab_size
898
+
899
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
900
+ self.layers = nn.ModuleList(
901
+ [DeciLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
902
+ )
903
+ self.norm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
904
+ self.rotary_emb = DeciLMRotaryEmbedding(config=config)
905
+ self.gradient_checkpointing = False
906
+
907
+ # Initialize weights and apply final processing
908
+ self.post_init()
909
+
910
+ def get_input_embeddings(self):
911
+ return self.embed_tokens
912
+
913
+ def set_input_embeddings(self, value):
914
+ self.embed_tokens = value
915
+
916
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
917
+ def forward(
918
+ self,
919
+ input_ids: torch.LongTensor = None,
920
+ attention_mask: Optional[torch.Tensor] = None,
921
+ position_ids: Optional[torch.LongTensor] = None,
922
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
923
+ inputs_embeds: Optional[torch.FloatTensor] = None,
924
+ use_cache: Optional[bool] = None,
925
+ output_attentions: Optional[bool] = None,
926
+ output_hidden_states: Optional[bool] = None,
927
+ return_dict: Optional[bool] = None,
928
+ cache_position: Optional[torch.LongTensor] = None,
929
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
930
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
931
+ output_hidden_states = (
932
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
933
+ )
934
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
935
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
936
+
937
+ if (input_ids is None) ^ (inputs_embeds is not None):
938
+ raise ValueError(
939
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
940
+ )
941
+
942
+ if self.gradient_checkpointing and self.training and use_cache:
943
+ logger.warning_once(
944
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
945
+ )
946
+ use_cache = False
947
+
948
+ if inputs_embeds is None:
949
+ inputs_embeds = self.embed_tokens(input_ids)
950
+
951
+ is_legacy_cache_format = (past_key_values is not None) and not isinstance(past_key_values, Cache)
952
+ if is_legacy_cache_format:
953
+ raise NotImplementedError("DeciLMModel does not support legacy cache format, please use a newer "
954
+ "transformers version or use VariableCache explicitly (see import in this file).")
955
+
956
+ if cache_position is None:
957
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
958
+ cache_position = torch.arange(
959
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
960
+ )
961
+ if position_ids is None:
962
+ position_ids = cache_position.unsqueeze(0)
963
+
964
+ causal_mask = self._update_causal_mask(
965
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
966
+ )
967
+ hidden_states = inputs_embeds
968
+
969
+ # create position embeddings to be shared across the decoder layers
970
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
971
+
972
+ # decoder layers
973
+ all_hidden_states = () if output_hidden_states else None
974
+ all_self_attns = () if output_attentions else None
975
+ next_decoder_cache = None
976
+
977
+ for decoder_layer in self.layers:
978
+ if output_hidden_states:
979
+ all_hidden_states += (hidden_states,)
980
+
981
+ if self.gradient_checkpointing and self.training:
982
+ layer_outputs = self._gradient_checkpointing_func(
983
+ decoder_layer.__call__,
984
+ hidden_states,
985
+ causal_mask,
986
+ position_ids,
987
+ past_key_values,
988
+ output_attentions,
989
+ use_cache,
990
+ cache_position,
991
+ position_embeddings,
992
+ )
993
+ else:
994
+ layer_outputs = decoder_layer(
995
+ hidden_states,
996
+ attention_mask=causal_mask,
997
+ position_ids=position_ids,
998
+ past_key_value=past_key_values,
999
+ output_attentions=output_attentions,
1000
+ use_cache=use_cache,
1001
+ cache_position=cache_position,
1002
+ position_embeddings=position_embeddings,
1003
+ )
1004
+
1005
+ hidden_states = layer_outputs[0]
1006
+
1007
+ if use_cache:
1008
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1009
+
1010
+ if output_attentions:
1011
+ all_self_attns += (layer_outputs[1],)
1012
+
1013
+ hidden_states = self.norm(hidden_states)
1014
+
1015
+ # add hidden states from the last decoder layer
1016
+ if output_hidden_states:
1017
+ all_hidden_states += (hidden_states,)
1018
+
1019
+ next_cache = next_decoder_cache if use_cache else None
1020
+
1021
+ if not return_dict:
1022
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1023
+ return BaseModelOutputWithPast(
1024
+ last_hidden_state=hidden_states,
1025
+ past_key_values=next_cache,
1026
+ hidden_states=all_hidden_states,
1027
+ attentions=all_self_attns,
1028
+ )
1029
+
1030
+ def _update_causal_mask(
1031
+ self,
1032
+ attention_mask: torch.Tensor,
1033
+ input_tensor: torch.Tensor,
1034
+ cache_position: torch.Tensor,
1035
+ past_key_values: Cache,
1036
+ output_attentions: bool,
1037
+ ):
1038
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1039
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1040
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1041
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1042
+
1043
+ if self.config._attn_implementation == "flash_attention_2":
1044
+ if attention_mask is not None and 0.0 in attention_mask:
1045
+ return attention_mask
1046
+ return None
1047
+
1048
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1049
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1050
+ # to infer the attention mask.
1051
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1052
+ assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache"
1053
+ using_static_cache = isinstance(past_key_values, StaticCache)
1054
+
1055
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1056
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1057
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1058
+ attention_mask,
1059
+ inputs_embeds=input_tensor,
1060
+ past_key_values_length=past_seen_tokens,
1061
+ is_training=self.training,
1062
+ ) and all([not layer.is_sliding for layer in self.layers]):
1063
+ return None
1064
+
1065
+ dtype, device = input_tensor.dtype, input_tensor.device
1066
+ min_dtype = torch.finfo(dtype).min
1067
+ sequence_length = input_tensor.shape[1]
1068
+ if using_static_cache:
1069
+ target_length = past_key_values.get_max_length()
1070
+ else:
1071
+ target_length = (
1072
+ attention_mask.shape[-1]
1073
+ if isinstance(attention_mask, torch.Tensor)
1074
+ else past_seen_tokens + sequence_length + 1
1075
+ )
1076
+
1077
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1078
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1079
+ attention_mask,
1080
+ sequence_length=sequence_length,
1081
+ target_length=target_length,
1082
+ dtype=dtype,
1083
+ device=device,
1084
+ min_dtype=min_dtype,
1085
+ cache_position=cache_position,
1086
+ batch_size=input_tensor.shape[0],
1087
+ )
1088
+
1089
+ if (
1090
+ self.config._attn_implementation == "sdpa"
1091
+ and attention_mask is not None
1092
+ and attention_mask.device.type == "cuda"
1093
+ and not output_attentions
1094
+ ):
1095
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1096
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1097
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1098
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1099
+
1100
+ return causal_mask
1101
+
1102
+
1103
+ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1104
+ _tied_weights_keys = ["lm_head.weight"]
1105
+
1106
+ def __init__(self, config):
1107
+ super().__init__(config)
1108
+ self.model = DeciLMModel(config)
1109
+ self.vocab_size = config.vocab_size
1110
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1111
+
1112
+ # Initialize weights and apply final processing
1113
+ self.post_init()
1114
+
1115
+ def get_input_embeddings(self):
1116
+ return self.model.embed_tokens
1117
+
1118
+ def set_input_embeddings(self, value):
1119
+ self.model.embed_tokens = value
1120
+
1121
+ def get_output_embeddings(self):
1122
+ return self.lm_head
1123
+
1124
+ def set_output_embeddings(self, new_embeddings):
1125
+ self.lm_head = new_embeddings
1126
+
1127
+ def set_decoder(self, decoder):
1128
+ self.model = decoder
1129
+
1130
+ def get_decoder(self):
1131
+ return self.model
1132
+
1133
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
1134
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1135
+ def forward(
1136
+ self,
1137
+ input_ids: torch.LongTensor = None,
1138
+ attention_mask: Optional[torch.Tensor] = None,
1139
+ position_ids: Optional[torch.LongTensor] = None,
1140
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1141
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1142
+ labels: Optional[torch.LongTensor] = None,
1143
+ use_cache: Optional[bool] = None,
1144
+ output_attentions: Optional[bool] = None,
1145
+ output_hidden_states: Optional[bool] = None,
1146
+ return_dict: Optional[bool] = None,
1147
+ cache_position: Optional[torch.LongTensor] = None,
1148
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1149
+ r"""
1150
+ Args:
1151
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1152
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1153
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1154
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1155
+
1156
+ Return:
1157
+ """
1158
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1159
+ output_hidden_states = (
1160
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1161
+ )
1162
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1163
+
1164
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1165
+ outputs = self.model(
1166
+ input_ids=input_ids,
1167
+ attention_mask=attention_mask,
1168
+ position_ids=position_ids,
1169
+ past_key_values=past_key_values,
1170
+ inputs_embeds=inputs_embeds,
1171
+ use_cache=use_cache,
1172
+ output_attentions=output_attentions,
1173
+ output_hidden_states=output_hidden_states,
1174
+ return_dict=return_dict,
1175
+ cache_position=cache_position,
1176
+ )
1177
+
1178
+ hidden_states = outputs[0]
1179
+ if self.config.pretraining_tp > 1:
1180
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1181
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1182
+ logits = torch.cat(logits, dim=-1)
1183
+ else:
1184
+ logits = self.lm_head(hidden_states)
1185
+ logits = logits.float()
1186
+
1187
+ loss = None
1188
+ if labels is not None:
1189
+ # Shift so that tokens < n predict n
1190
+ shift_logits = logits[..., :-1, :].contiguous()
1191
+ shift_labels = labels[..., 1:].contiguous()
1192
+ # Flatten the tokens
1193
+ loss_fct = CrossEntropyLoss()
1194
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1195
+ shift_labels = shift_labels.view(-1)
1196
+ # Enable model parallelism
1197
+ shift_labels = shift_labels.to(shift_logits.device)
1198
+ loss = loss_fct(shift_logits, shift_labels)
1199
+
1200
+ if not return_dict:
1201
+ output = (logits,) + outputs[1:]
1202
+ return (loss,) + output if loss is not None else output
1203
+
1204
+ return CausalLMOutputWithPast(
1205
+ loss=loss,
1206
+ logits=logits,
1207
+ past_key_values=outputs.past_key_values,
1208
+ hidden_states=outputs.hidden_states,
1209
+ attentions=outputs.attentions,
1210
+ )
1211
+
1212
+ def prepare_inputs_for_generation(
1213
+ self,
1214
+ input_ids,
1215
+ past_key_values=None,
1216
+ attention_mask=None,
1217
+ inputs_embeds=None,
1218
+ cache_position=None,
1219
+ position_ids=None,
1220
+ use_cache=True,
1221
+ **kwargs,
1222
+ ):
1223
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1224
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1225
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1226
+ if past_key_values is not None:
1227
+ if inputs_embeds is not None: # Exception 1
1228
+ input_ids = input_ids[:, -cache_position.shape[0]:]
1229
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1230
+ input_ids = input_ids[:, cache_position]
1231
+
1232
+ if attention_mask is not None and position_ids is None:
1233
+ # create position_ids on the fly for batch generation
1234
+ position_ids = attention_mask.long().cumsum(-1) - 1
1235
+ position_ids.masked_fill_(attention_mask == 0, 1)
1236
+ if past_key_values:
1237
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1238
+
1239
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1240
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1241
+
1242
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1243
+ if inputs_embeds is not None and cache_position[0] == 0:
1244
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1245
+ else:
1246
+ # The clone here is for the same reason as for `position_ids`.
1247
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
1248
+
1249
+ assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache"
1250
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1251
+ if model_inputs["inputs_embeds"] is not None:
1252
+ batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1253
+ device = model_inputs["inputs_embeds"].device
1254
+ else:
1255
+ batch_size, sequence_length = model_inputs["input_ids"].shape
1256
+ device = model_inputs["input_ids"].device
1257
+
1258
+ dtype = self.lm_head.weight.dtype
1259
+ min_dtype = torch.finfo(dtype).min
1260
+
1261
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1262
+ attention_mask,
1263
+ sequence_length=sequence_length,
1264
+ target_length=past_key_values.get_max_length(),
1265
+ dtype=dtype,
1266
+ device=device,
1267
+ min_dtype=min_dtype,
1268
+ cache_position=cache_position,
1269
+ batch_size=batch_size,
1270
+ )
1271
+
1272
+ model_inputs.update(
1273
+ {
1274
+ "position_ids": position_ids,
1275
+ "cache_position": cache_position,
1276
+ "past_key_values": past_key_values,
1277
+ "use_cache": use_cache,
1278
+ "attention_mask": attention_mask,
1279
+ }
1280
+ )
1281
+ return model_inputs
1282
+
1283
+ def _maybe_initialize_input_ids_for_generation(
1284
+ self,
1285
+ inputs: Optional[torch.Tensor] = None,
1286
+ bos_token_id: Optional[torch.Tensor] = None,
1287
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
1288
+ ) -> torch.LongTensor:
1289
+ """
1290
+ Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
1291
+ """
1292
+ input_ids = super()._maybe_initialize_input_ids_for_generation(
1293
+ inputs=inputs, bos_token_id=bos_token_id, model_kwargs=model_kwargs)
1294
+ if (
1295
+ "inputs_embeds" in model_kwargs
1296
+ and input_ids is not None
1297
+ and input_ids.shape[1] == 0
1298
+ ):
1299
+ batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2]
1300
+ input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device)
1301
+ return input_ids
1302
+
1303
+ def generate(
1304
+ self,
1305
+ inputs: Optional[torch.Tensor] = None,
1306
+ *args,
1307
+ **kwargs,
1308
+ ) -> Union[GenerateOutput, torch.LongTensor]:
1309
+ """
1310
+ Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
1311
+ """
1312
+ only_passed_inputs_embeds = (
1313
+ "inputs_embeds" in kwargs and
1314
+ "input_ids" not in kwargs and
1315
+ inputs is None
1316
+ )
1317
+ if only_passed_inputs_embeds:
1318
+ input_sequence_length = kwargs["inputs_embeds"].shape[1]
1319
+
1320
+ generation_output = super().generate(inputs=inputs, *args, **kwargs)
1321
+
1322
+ if only_passed_inputs_embeds and isinstance(generation_output, torch.Tensor):
1323
+ generation_output = generation_output[:, input_sequence_length:]
1324
+
1325
+ return generation_output
1326
+
1327
+
1328
+ @add_start_docstrings(
1329
+ """
1330
+ The DeciLM Model transformer with a sequence classification head on top (linear layer).
1331
+
1332
+ [`DeciLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1333
+ (e.g. GPT-2) do.
1334
+
1335
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1336
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1337
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1338
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1339
+ each row of the batch).
1340
+ """,
1341
+ DECILM_START_DOCSTRING,
1342
+ )
1343
+ class DeciLMForSequenceClassification(DeciLMPreTrainedModel):
1344
+ def __init__(self, config):
1345
+ super().__init__(config)
1346
+ self.num_labels = config.num_labels
1347
+ self.model = DeciLMModel(config)
1348
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1349
+
1350
+ # Initialize weights and apply final processing
1351
+ self.post_init()
1352
+
1353
+ def get_input_embeddings(self):
1354
+ return self.model.embed_tokens
1355
+
1356
+ def set_input_embeddings(self, value):
1357
+ self.model.embed_tokens = value
1358
+
1359
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
1360
+ def forward(
1361
+ self,
1362
+ input_ids: Optional[torch.LongTensor] = None,
1363
+ attention_mask: Optional[torch.Tensor] = None,
1364
+ position_ids: Optional[torch.LongTensor] = None,
1365
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1366
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1367
+ labels: Optional[torch.LongTensor] = None,
1368
+ use_cache: Optional[bool] = None,
1369
+ output_attentions: Optional[bool] = None,
1370
+ output_hidden_states: Optional[bool] = None,
1371
+ return_dict: Optional[bool] = None,
1372
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1373
+ r"""
1374
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1375
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1376
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1377
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1378
+ """
1379
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1380
+
1381
+ transformer_outputs = self.model(
1382
+ input_ids,
1383
+ attention_mask=attention_mask,
1384
+ position_ids=position_ids,
1385
+ past_key_values=past_key_values,
1386
+ inputs_embeds=inputs_embeds,
1387
+ use_cache=use_cache,
1388
+ output_attentions=output_attentions,
1389
+ output_hidden_states=output_hidden_states,
1390
+ return_dict=return_dict,
1391
+ )
1392
+ hidden_states = transformer_outputs[0]
1393
+ logits = self.score(hidden_states)
1394
+
1395
+ if input_ids is not None:
1396
+ batch_size = input_ids.shape[0]
1397
+ else:
1398
+ batch_size = inputs_embeds.shape[0]
1399
+
1400
+ if self.config.pad_token_id is None and batch_size != 1:
1401
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1402
+ if self.config.pad_token_id is None:
1403
+ sequence_lengths = -1
1404
+ else:
1405
+ if input_ids is not None:
1406
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1407
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1408
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1409
+ sequence_lengths = sequence_lengths.to(logits.device)
1410
+ else:
1411
+ sequence_lengths = -1
1412
+
1413
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1414
+
1415
+ loss = None
1416
+ if labels is not None:
1417
+ labels = labels.to(logits.device)
1418
+ if self.config.problem_type is None:
1419
+ if self.num_labels == 1:
1420
+ self.config.problem_type = "regression"
1421
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1422
+ self.config.problem_type = "single_label_classification"
1423
+ else:
1424
+ self.config.problem_type = "multi_label_classification"
1425
+
1426
+ if self.config.problem_type == "regression":
1427
+ loss_fct = MSELoss()
1428
+ if self.num_labels == 1:
1429
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1430
+ else:
1431
+ loss = loss_fct(pooled_logits, labels)
1432
+ elif self.config.problem_type == "single_label_classification":
1433
+ loss_fct = CrossEntropyLoss()
1434
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1435
+ elif self.config.problem_type == "multi_label_classification":
1436
+ loss_fct = BCEWithLogitsLoss()
1437
+ loss = loss_fct(pooled_logits, labels)
1438
+ if not return_dict:
1439
+ output = (pooled_logits,) + transformer_outputs[1:]
1440
+ return ((loss,) + output) if loss is not None else output
1441
+
1442
+ return SequenceClassifierOutputWithPast(
1443
+ loss=loss,
1444
+ logits=pooled_logits,
1445
+ past_key_values=transformer_outputs.past_key_values,
1446
+ hidden_states=transformer_outputs.hidden_states,
1447
+ attentions=transformer_outputs.attentions,
1448
+ )
1449
+
1450
+
1451
+ @add_start_docstrings(
1452
+ """
1453
+ The DeciLM Model transformer with a span classification head on top for extractive question-answering tasks like
1454
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1455
+ """,
1456
+ DECILM_START_DOCSTRING,
1457
+ )
1458
+ class DeciLMForQuestionAnswering(DeciLMPreTrainedModel):
1459
+ base_model_prefix = "transformer"
1460
+
1461
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->DeciLM
1462
+ def __init__(self, config):
1463
+ super().__init__(config)
1464
+ self.transformer = DeciLMModel(config)
1465
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1466
+
1467
+ # Initialize weights and apply final processing
1468
+ self.post_init()
1469
+
1470
+ def get_input_embeddings(self):
1471
+ return self.transformer.embed_tokens
1472
+
1473
+ def set_input_embeddings(self, value):
1474
+ self.transformer.embed_tokens = value
1475
+
1476
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
1477
+ def forward(
1478
+ self,
1479
+ input_ids: Optional[torch.LongTensor] = None,
1480
+ attention_mask: Optional[torch.FloatTensor] = None,
1481
+ position_ids: Optional[torch.LongTensor] = None,
1482
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1483
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1484
+ start_positions: Optional[torch.LongTensor] = None,
1485
+ end_positions: Optional[torch.LongTensor] = None,
1486
+ output_attentions: Optional[bool] = None,
1487
+ output_hidden_states: Optional[bool] = None,
1488
+ return_dict: Optional[bool] = None,
1489
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1490
+ r"""
1491
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1492
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1493
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1494
+ are not taken into account for computing the loss.
1495
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1496
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1497
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1498
+ are not taken into account for computing the loss.
1499
+ """
1500
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1501
+
1502
+ outputs = self.transformer(
1503
+ input_ids,
1504
+ attention_mask=attention_mask,
1505
+ position_ids=position_ids,
1506
+ past_key_values=past_key_values,
1507
+ inputs_embeds=inputs_embeds,
1508
+ output_attentions=output_attentions,
1509
+ output_hidden_states=output_hidden_states,
1510
+ return_dict=return_dict,
1511
+ )
1512
+
1513
+ sequence_output = outputs[0]
1514
+
1515
+ logits = self.qa_outputs(sequence_output)
1516
+ start_logits, end_logits = logits.split(1, dim=-1)
1517
+ start_logits = start_logits.squeeze(-1).contiguous()
1518
+ end_logits = end_logits.squeeze(-1).contiguous()
1519
+
1520
+ total_loss = None
1521
+ if start_positions is not None and end_positions is not None:
1522
+ # If we are on multi-GPU, split add a dimension
1523
+ if len(start_positions.size()) > 1:
1524
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1525
+ if len(end_positions.size()) > 1:
1526
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1527
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1528
+ ignored_index = start_logits.size(1)
1529
+ start_positions = start_positions.clamp(0, ignored_index)
1530
+ end_positions = end_positions.clamp(0, ignored_index)
1531
+
1532
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1533
+ start_loss = loss_fct(start_logits, start_positions)
1534
+ end_loss = loss_fct(end_logits, end_positions)
1535
+ total_loss = (start_loss + end_loss) / 2
1536
+
1537
+ if not return_dict:
1538
+ output = (start_logits, end_logits) + outputs[2:]
1539
+ return ((total_loss,) + output) if total_loss is not None else output
1540
+
1541
+ return QuestionAnsweringModelOutput(
1542
+ loss=total_loss,
1543
+ start_logits=start_logits,
1544
+ end_logits=end_logits,
1545
+ hidden_states=outputs.hidden_states,
1546
+ attentions=outputs.attentions,
1547
+ )
1548
+
1549
+
1550
+ @add_start_docstrings(
1551
+ """
1552
+ The DeciLM Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1553
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1554
+ """,
1555
+ DECILM_START_DOCSTRING,
1556
+ )
1557
+ class DeciLMForTokenClassification(DeciLMPreTrainedModel):
1558
+ def __init__(self, config):
1559
+ super().__init__(config)
1560
+ self.num_labels = config.num_labels
1561
+ self.model = DeciLMModel(config)
1562
+ if getattr(config, "classifier_dropout", None) is not None:
1563
+ classifier_dropout = config.classifier_dropout
1564
+ elif getattr(config, "hidden_dropout", None) is not None:
1565
+ classifier_dropout = config.hidden_dropout
1566
+ else:
1567
+ classifier_dropout = 0.1
1568
+ self.dropout = nn.Dropout(classifier_dropout)
1569
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1570
+
1571
+ # Initialize weights and apply final processing
1572
+ self.post_init()
1573
+
1574
+ def get_input_embeddings(self):
1575
+ return self.model.embed_tokens
1576
+
1577
+ def set_input_embeddings(self, value):
1578
+ self.model.embed_tokens = value
1579
+
1580
+ @add_start_docstrings_to_model_forward(DECILM_INPUTS_DOCSTRING)
1581
+ def forward(
1582
+ self,
1583
+ input_ids: Optional[torch.LongTensor] = None,
1584
+ attention_mask: Optional[torch.Tensor] = None,
1585
+ position_ids: Optional[torch.LongTensor] = None,
1586
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1587
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1588
+ labels: Optional[torch.LongTensor] = None,
1589
+ use_cache: Optional[bool] = None,
1590
+ output_attentions: Optional[bool] = None,
1591
+ output_hidden_states: Optional[bool] = None,
1592
+ return_dict: Optional[bool] = None,
1593
+ ) -> Union[Tuple, TokenClassifierOutput]:
1594
+ r"""
1595
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1596
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1597
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1598
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1599
+ """
1600
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1601
+
1602
+ outputs = self.model(
1603
+ input_ids,
1604
+ attention_mask=attention_mask,
1605
+ position_ids=position_ids,
1606
+ past_key_values=past_key_values,
1607
+ inputs_embeds=inputs_embeds,
1608
+ use_cache=use_cache,
1609
+ output_attentions=output_attentions,
1610
+ output_hidden_states=output_hidden_states,
1611
+ return_dict=return_dict,
1612
+ )
1613
+ sequence_output = outputs[0]
1614
+ sequence_output = self.dropout(sequence_output)
1615
+ logits = self.score(sequence_output)
1616
+
1617
+ loss = None
1618
+ if labels is not None:
1619
+ loss_fct = CrossEntropyLoss()
1620
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1621
+
1622
+ if not return_dict:
1623
+ output = (logits,) + outputs[2:]
1624
+ return ((loss,) + output) if loss is not None else output
1625
+
1626
+ return TokenClassifierOutput(
1627
+ loss=loss,
1628
+ logits=logits,
1629
+ hidden_states=outputs.hidden_states,
1630
+ attentions=outputs.attentions,
1631
+ )
1632
+
1633
+
1634
+ ########################################################################
1635
+ # DeciLM-specific code
1636
+ ########################################################################
1637
+
1638
+
1639
+ def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
1640
+ # DeciLM-specific code
1641
+ intermediate_size = int(2 * ffn_mult * n_embd / 3)
1642
+ return _find_multiple(intermediate_size, 256)
1643
+
1644
+
1645
+ def _find_multiple(n: int, k: int) -> int:
1646
+ # DeciLM-specific code
1647
+ if n % k == 0:
1648
+ return n
1649
+ return n + k - (n % k)
1650
+
1651
+
1652
+ class DeciLMLinearMLP(nn.Module):
1653
+ # DeciLM-specific code
1654
+ def __init__(self,
1655
+ config: DeciLMConfig,
1656
+ ):
1657
+ super().__init__()
1658
+ self.linear_mlp = nn.Linear(in_features=config.hidden_size,
1659
+ out_features=config.hidden_size,
1660
+ bias=False)
1661
+
1662
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1663
+ return self.linear_mlp.forward(x)
1664
+
1665
+
1666
+ class DeciLMLinearAttention(nn.Module):
1667
+ # DeciLM-specific code
1668
+ def __init__(self,
1669
+ config: DeciLMConfig,
1670
+ ):
1671
+ super().__init__()
1672
+ self.linear_attn = nn.Linear(in_features=config.hidden_size,
1673
+ out_features=config.hidden_size,
1674
+ bias=False)
1675
+
1676
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1677
+ return self.linear_attn.forward(x)
1678
+
1679
+
1680
+ def sparsity_backward_hook(*args, **kwargs):
1681
+ raise NotImplementedError("No support for sparsity when training HF DeciLM (inference is ok though)")
privacy.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ | Field: | Response: |
2
+ | :---- | :---- |
3
+ | Generatable or Reverse engineerable personal data? | None |
4
+ | Was consent obtained for any personal data used? | None Known |
5
+ | Personal data used to create this model? | None Known |
6
+ | How often is dataset reviewed? | Before Release |
7
+ | Is there provenance for all datasets used in training? | Yes |
8
+ | Does data labeling (annotation, metadata) comply with privacy laws? | Yes |
9
+ | Applicable NVIDIA Privacy Policy | [https://www.nvidia.com/en-us/about-nvidia/privacy-policy/](https://www.nvidia.com/en-us/about-nvidia/privacy-policy/) |
safety.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ | Field: | Response: |
2
+ | :---- | :---- |
3
+ | Model Application(s): | Chat, Instruction Following, Chatbot Development, Code Generation, Reasoning |
4
+ | Describe life critical application (if present): | None Known (please see referenced Known Risks in the Explainability subcard). |
5
+ | Use Case Restrictions: | Your use of this model is governed by the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/). Additional Information: [Llama 3.3 Community License Agreement](https://www.llama.com/llama3_3/license/). Built with Llama. |
6
+ | Model and Dataset Restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation. Restrictions enforce dataset access during training, and dataset license constraints adhered to. Model checkpoints are made available on Hugging Face and NGC, and may become available on cloud providers' model catalog. |
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|eot_id|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|finetune_right_pad_id|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b9e4e7fb171f92fd137b777cc2714bf87d11576700a1dcd7a399e7bbe39537b
3
+ size 17209920
tokenizer_config.json ADDED
@@ -0,0 +1,2067 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "added_tokens_decoder": {
4
+ "128000": {
5
+ "content": "<|begin_of_text|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "128001": {
13
+ "content": "<|end_of_text|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "128002": {
21
+ "content": "<|reserved_special_token_0|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "128003": {
29
+ "content": "<|reserved_special_token_1|>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "128004": {
37
+ "content": "<|finetune_right_pad_id|>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "128005": {
45
+ "content": "<|reserved_special_token_2|>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "128006": {
53
+ "content": "<|start_header_id|>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "128007": {
61
+ "content": "<|end_header_id|>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "128008": {
69
+ "content": "<|eom_id|>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "128009": {
77
+ "content": "<|eot_id|>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "128010": {
85
+ "content": "<|python_tag|>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "128011": {
93
+ "content": "<|reserved_special_token_3|>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "128012": {
101
+ "content": "<|reserved_special_token_4|>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "128013": {
109
+ "content": "<|reserved_special_token_5|>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "128014": {
117
+ "content": "<|reserved_special_token_6|>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "128015": {
125
+ "content": "<|reserved_special_token_7|>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "128016": {
133
+ "content": "<|reserved_special_token_8|>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "128017": {
141
+ "content": "<|reserved_special_token_9|>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "128018": {
149
+ "content": "<|reserved_special_token_10|>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "128019": {
157
+ "content": "<|reserved_special_token_11|>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "128020": {
165
+ "content": "<|reserved_special_token_12|>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "128021": {
173
+ "content": "<|reserved_special_token_13|>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "128022": {
181
+ "content": "<|reserved_special_token_14|>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "128023": {
189
+ "content": "<|reserved_special_token_15|>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "128024": {
197
+ "content": "<|reserved_special_token_16|>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "128025": {
205
+ "content": "<|reserved_special_token_17|>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "128026": {
213
+ "content": "<|reserved_special_token_18|>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "128027": {
221
+ "content": "<|reserved_special_token_19|>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "128028": {
229
+ "content": "<|reserved_special_token_20|>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "128029": {
237
+ "content": "<|reserved_special_token_21|>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "128030": {
245
+ "content": "<|reserved_special_token_22|>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "128031": {
253
+ "content": "<|reserved_special_token_23|>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "128032": {
261
+ "content": "<|reserved_special_token_24|>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "128033": {
269
+ "content": "<|reserved_special_token_25|>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "128034": {
277
+ "content": "<|reserved_special_token_26|>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "128035": {
285
+ "content": "<|reserved_special_token_27|>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "128036": {
293
+ "content": "<|reserved_special_token_28|>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "128037": {
301
+ "content": "<|reserved_special_token_29|>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "128038": {
309
+ "content": "<|reserved_special_token_30|>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "128039": {
317
+ "content": "<|reserved_special_token_31|>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "128040": {
325
+ "content": "<|reserved_special_token_32|>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "128041": {
333
+ "content": "<|reserved_special_token_33|>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "128042": {
341
+ "content": "<|reserved_special_token_34|>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "128043": {
349
+ "content": "<|reserved_special_token_35|>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "128044": {
357
+ "content": "<|reserved_special_token_36|>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "128045": {
365
+ "content": "<|reserved_special_token_37|>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "128046": {
373
+ "content": "<|reserved_special_token_38|>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "128047": {
381
+ "content": "<|reserved_special_token_39|>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "128048": {
389
+ "content": "<|reserved_special_token_40|>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "128049": {
397
+ "content": "<|reserved_special_token_41|>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "128050": {
405
+ "content": "<|reserved_special_token_42|>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "128051": {
413
+ "content": "<|reserved_special_token_43|>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "128052": {
421
+ "content": "<|reserved_special_token_44|>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "128053": {
429
+ "content": "<|reserved_special_token_45|>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "128054": {
437
+ "content": "<|reserved_special_token_46|>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "128055": {
445
+ "content": "<|reserved_special_token_47|>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "128056": {
453
+ "content": "<|reserved_special_token_48|>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "128057": {
461
+ "content": "<|reserved_special_token_49|>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "128058": {
469
+ "content": "<|reserved_special_token_50|>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "128059": {
477
+ "content": "<|reserved_special_token_51|>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "128060": {
485
+ "content": "<|reserved_special_token_52|>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "128061": {
493
+ "content": "<|reserved_special_token_53|>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "128062": {
501
+ "content": "<|reserved_special_token_54|>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "128063": {
509
+ "content": "<|reserved_special_token_55|>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "128064": {
517
+ "content": "<|reserved_special_token_56|>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "128065": {
525
+ "content": "<|reserved_special_token_57|>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "128066": {
533
+ "content": "<|reserved_special_token_58|>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "128067": {
541
+ "content": "<|reserved_special_token_59|>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "128068": {
549
+ "content": "<|reserved_special_token_60|>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "128069": {
557
+ "content": "<|reserved_special_token_61|>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "128070": {
565
+ "content": "<|reserved_special_token_62|>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "128071": {
573
+ "content": "<|reserved_special_token_63|>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "128072": {
581
+ "content": "<|reserved_special_token_64|>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "128073": {
589
+ "content": "<|reserved_special_token_65|>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "128074": {
597
+ "content": "<|reserved_special_token_66|>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "128075": {
605
+ "content": "<|reserved_special_token_67|>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "128076": {
613
+ "content": "<|reserved_special_token_68|>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "128077": {
621
+ "content": "<|reserved_special_token_69|>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "128078": {
629
+ "content": "<|reserved_special_token_70|>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "128079": {
637
+ "content": "<|reserved_special_token_71|>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "128080": {
645
+ "content": "<|reserved_special_token_72|>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "128081": {
653
+ "content": "<|reserved_special_token_73|>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "128082": {
661
+ "content": "<|reserved_special_token_74|>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "128083": {
669
+ "content": "<|reserved_special_token_75|>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "128084": {
677
+ "content": "<|reserved_special_token_76|>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "128085": {
685
+ "content": "<|reserved_special_token_77|>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "128086": {
693
+ "content": "<|reserved_special_token_78|>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "128087": {
701
+ "content": "<|reserved_special_token_79|>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "128088": {
709
+ "content": "<|reserved_special_token_80|>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "128089": {
717
+ "content": "<|reserved_special_token_81|>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "128090": {
725
+ "content": "<|reserved_special_token_82|>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "128091": {
733
+ "content": "<|reserved_special_token_83|>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "128092": {
741
+ "content": "<|reserved_special_token_84|>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "128093": {
749
+ "content": "<|reserved_special_token_85|>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "128094": {
757
+ "content": "<|reserved_special_token_86|>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "128095": {
765
+ "content": "<|reserved_special_token_87|>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "128096": {
773
+ "content": "<|reserved_special_token_88|>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "128097": {
781
+ "content": "<|reserved_special_token_89|>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "128098": {
789
+ "content": "<|reserved_special_token_90|>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "128099": {
797
+ "content": "<|reserved_special_token_91|>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "128100": {
805
+ "content": "<|reserved_special_token_92|>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "128101": {
813
+ "content": "<|reserved_special_token_93|>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "128102": {
821
+ "content": "<|reserved_special_token_94|>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ },
828
+ "128103": {
829
+ "content": "<|reserved_special_token_95|>",
830
+ "lstrip": false,
831
+ "normalized": false,
832
+ "rstrip": false,
833
+ "single_word": false,
834
+ "special": true
835
+ },
836
+ "128104": {
837
+ "content": "<|reserved_special_token_96|>",
838
+ "lstrip": false,
839
+ "normalized": false,
840
+ "rstrip": false,
841
+ "single_word": false,
842
+ "special": true
843
+ },
844
+ "128105": {
845
+ "content": "<|reserved_special_token_97|>",
846
+ "lstrip": false,
847
+ "normalized": false,
848
+ "rstrip": false,
849
+ "single_word": false,
850
+ "special": true
851
+ },
852
+ "128106": {
853
+ "content": "<|reserved_special_token_98|>",
854
+ "lstrip": false,
855
+ "normalized": false,
856
+ "rstrip": false,
857
+ "single_word": false,
858
+ "special": true
859
+ },
860
+ "128107": {
861
+ "content": "<|reserved_special_token_99|>",
862
+ "lstrip": false,
863
+ "normalized": false,
864
+ "rstrip": false,
865
+ "single_word": false,
866
+ "special": true
867
+ },
868
+ "128108": {
869
+ "content": "<|reserved_special_token_100|>",
870
+ "lstrip": false,
871
+ "normalized": false,
872
+ "rstrip": false,
873
+ "single_word": false,
874
+ "special": true
875
+ },
876
+ "128109": {
877
+ "content": "<|reserved_special_token_101|>",
878
+ "lstrip": false,
879
+ "normalized": false,
880
+ "rstrip": false,
881
+ "single_word": false,
882
+ "special": true
883
+ },
884
+ "128110": {
885
+ "content": "<|reserved_special_token_102|>",
886
+ "lstrip": false,
887
+ "normalized": false,
888
+ "rstrip": false,
889
+ "single_word": false,
890
+ "special": true
891
+ },
892
+ "128111": {
893
+ "content": "<|reserved_special_token_103|>",
894
+ "lstrip": false,
895
+ "normalized": false,
896
+ "rstrip": false,
897
+ "single_word": false,
898
+ "special": true
899
+ },
900
+ "128112": {
901
+ "content": "<|reserved_special_token_104|>",
902
+ "lstrip": false,
903
+ "normalized": false,
904
+ "rstrip": false,
905
+ "single_word": false,
906
+ "special": true
907
+ },
908
+ "128113": {
909
+ "content": "<|reserved_special_token_105|>",
910
+ "lstrip": false,
911
+ "normalized": false,
912
+ "rstrip": false,
913
+ "single_word": false,
914
+ "special": true
915
+ },
916
+ "128114": {
917
+ "content": "<|reserved_special_token_106|>",
918
+ "lstrip": false,
919
+ "normalized": false,
920
+ "rstrip": false,
921
+ "single_word": false,
922
+ "special": true
923
+ },
924
+ "128115": {
925
+ "content": "<|reserved_special_token_107|>",
926
+ "lstrip": false,
927
+ "normalized": false,
928
+ "rstrip": false,
929
+ "single_word": false,
930
+ "special": true
931
+ },
932
+ "128116": {
933
+ "content": "<|reserved_special_token_108|>",
934
+ "lstrip": false,
935
+ "normalized": false,
936
+ "rstrip": false,
937
+ "single_word": false,
938
+ "special": true
939
+ },
940
+ "128117": {
941
+ "content": "<|reserved_special_token_109|>",
942
+ "lstrip": false,
943
+ "normalized": false,
944
+ "rstrip": false,
945
+ "single_word": false,
946
+ "special": true
947
+ },
948
+ "128118": {
949
+ "content": "<|reserved_special_token_110|>",
950
+ "lstrip": false,
951
+ "normalized": false,
952
+ "rstrip": false,
953
+ "single_word": false,
954
+ "special": true
955
+ },
956
+ "128119": {
957
+ "content": "<|reserved_special_token_111|>",
958
+ "lstrip": false,
959
+ "normalized": false,
960
+ "rstrip": false,
961
+ "single_word": false,
962
+ "special": true
963
+ },
964
+ "128120": {
965
+ "content": "<|reserved_special_token_112|>",
966
+ "lstrip": false,
967
+ "normalized": false,
968
+ "rstrip": false,
969
+ "single_word": false,
970
+ "special": true
971
+ },
972
+ "128121": {
973
+ "content": "<|reserved_special_token_113|>",
974
+ "lstrip": false,
975
+ "normalized": false,
976
+ "rstrip": false,
977
+ "single_word": false,
978
+ "special": true
979
+ },
980
+ "128122": {
981
+ "content": "<|reserved_special_token_114|>",
982
+ "lstrip": false,
983
+ "normalized": false,
984
+ "rstrip": false,
985
+ "single_word": false,
986
+ "special": true
987
+ },
988
+ "128123": {
989
+ "content": "<|reserved_special_token_115|>",
990
+ "lstrip": false,
991
+ "normalized": false,
992
+ "rstrip": false,
993
+ "single_word": false,
994
+ "special": true
995
+ },
996
+ "128124": {
997
+ "content": "<|reserved_special_token_116|>",
998
+ "lstrip": false,
999
+ "normalized": false,
1000
+ "rstrip": false,
1001
+ "single_word": false,
1002
+ "special": true
1003
+ },
1004
+ "128125": {
1005
+ "content": "<|reserved_special_token_117|>",
1006
+ "lstrip": false,
1007
+ "normalized": false,
1008
+ "rstrip": false,
1009
+ "single_word": false,
1010
+ "special": true
1011
+ },
1012
+ "128126": {
1013
+ "content": "<|reserved_special_token_118|>",
1014
+ "lstrip": false,
1015
+ "normalized": false,
1016
+ "rstrip": false,
1017
+ "single_word": false,
1018
+ "special": true
1019
+ },
1020
+ "128127": {
1021
+ "content": "<|reserved_special_token_119|>",
1022
+ "lstrip": false,
1023
+ "normalized": false,
1024
+ "rstrip": false,
1025
+ "single_word": false,
1026
+ "special": true
1027
+ },
1028
+ "128128": {
1029
+ "content": "<|reserved_special_token_120|>",
1030
+ "lstrip": false,
1031
+ "normalized": false,
1032
+ "rstrip": false,
1033
+ "single_word": false,
1034
+ "special": true
1035
+ },
1036
+ "128129": {
1037
+ "content": "<|reserved_special_token_121|>",
1038
+ "lstrip": false,
1039
+ "normalized": false,
1040
+ "rstrip": false,
1041
+ "single_word": false,
1042
+ "special": true
1043
+ },
1044
+ "128130": {
1045
+ "content": "<|reserved_special_token_122|>",
1046
+ "lstrip": false,
1047
+ "normalized": false,
1048
+ "rstrip": false,
1049
+ "single_word": false,
1050
+ "special": true
1051
+ },
1052
+ "128131": {
1053
+ "content": "<|reserved_special_token_123|>",
1054
+ "lstrip": false,
1055
+ "normalized": false,
1056
+ "rstrip": false,
1057
+ "single_word": false,
1058
+ "special": true
1059
+ },
1060
+ "128132": {
1061
+ "content": "<|reserved_special_token_124|>",
1062
+ "lstrip": false,
1063
+ "normalized": false,
1064
+ "rstrip": false,
1065
+ "single_word": false,
1066
+ "special": true
1067
+ },
1068
+ "128133": {
1069
+ "content": "<|reserved_special_token_125|>",
1070
+ "lstrip": false,
1071
+ "normalized": false,
1072
+ "rstrip": false,
1073
+ "single_word": false,
1074
+ "special": true
1075
+ },
1076
+ "128134": {
1077
+ "content": "<|reserved_special_token_126|>",
1078
+ "lstrip": false,
1079
+ "normalized": false,
1080
+ "rstrip": false,
1081
+ "single_word": false,
1082
+ "special": true
1083
+ },
1084
+ "128135": {
1085
+ "content": "<|reserved_special_token_127|>",
1086
+ "lstrip": false,
1087
+ "normalized": false,
1088
+ "rstrip": false,
1089
+ "single_word": false,
1090
+ "special": true
1091
+ },
1092
+ "128136": {
1093
+ "content": "<|reserved_special_token_128|>",
1094
+ "lstrip": false,
1095
+ "normalized": false,
1096
+ "rstrip": false,
1097
+ "single_word": false,
1098
+ "special": true
1099
+ },
1100
+ "128137": {
1101
+ "content": "<|reserved_special_token_129|>",
1102
+ "lstrip": false,
1103
+ "normalized": false,
1104
+ "rstrip": false,
1105
+ "single_word": false,
1106
+ "special": true
1107
+ },
1108
+ "128138": {
1109
+ "content": "<|reserved_special_token_130|>",
1110
+ "lstrip": false,
1111
+ "normalized": false,
1112
+ "rstrip": false,
1113
+ "single_word": false,
1114
+ "special": true
1115
+ },
1116
+ "128139": {
1117
+ "content": "<|reserved_special_token_131|>",
1118
+ "lstrip": false,
1119
+ "normalized": false,
1120
+ "rstrip": false,
1121
+ "single_word": false,
1122
+ "special": true
1123
+ },
1124
+ "128140": {
1125
+ "content": "<|reserved_special_token_132|>",
1126
+ "lstrip": false,
1127
+ "normalized": false,
1128
+ "rstrip": false,
1129
+ "single_word": false,
1130
+ "special": true
1131
+ },
1132
+ "128141": {
1133
+ "content": "<|reserved_special_token_133|>",
1134
+ "lstrip": false,
1135
+ "normalized": false,
1136
+ "rstrip": false,
1137
+ "single_word": false,
1138
+ "special": true
1139
+ },
1140
+ "128142": {
1141
+ "content": "<|reserved_special_token_134|>",
1142
+ "lstrip": false,
1143
+ "normalized": false,
1144
+ "rstrip": false,
1145
+ "single_word": false,
1146
+ "special": true
1147
+ },
1148
+ "128143": {
1149
+ "content": "<|reserved_special_token_135|>",
1150
+ "lstrip": false,
1151
+ "normalized": false,
1152
+ "rstrip": false,
1153
+ "single_word": false,
1154
+ "special": true
1155
+ },
1156
+ "128144": {
1157
+ "content": "<|reserved_special_token_136|>",
1158
+ "lstrip": false,
1159
+ "normalized": false,
1160
+ "rstrip": false,
1161
+ "single_word": false,
1162
+ "special": true
1163
+ },
1164
+ "128145": {
1165
+ "content": "<|reserved_special_token_137|>",
1166
+ "lstrip": false,
1167
+ "normalized": false,
1168
+ "rstrip": false,
1169
+ "single_word": false,
1170
+ "special": true
1171
+ },
1172
+ "128146": {
1173
+ "content": "<|reserved_special_token_138|>",
1174
+ "lstrip": false,
1175
+ "normalized": false,
1176
+ "rstrip": false,
1177
+ "single_word": false,
1178
+ "special": true
1179
+ },
1180
+ "128147": {
1181
+ "content": "<|reserved_special_token_139|>",
1182
+ "lstrip": false,
1183
+ "normalized": false,
1184
+ "rstrip": false,
1185
+ "single_word": false,
1186
+ "special": true
1187
+ },
1188
+ "128148": {
1189
+ "content": "<|reserved_special_token_140|>",
1190
+ "lstrip": false,
1191
+ "normalized": false,
1192
+ "rstrip": false,
1193
+ "single_word": false,
1194
+ "special": true
1195
+ },
1196
+ "128149": {
1197
+ "content": "<|reserved_special_token_141|>",
1198
+ "lstrip": false,
1199
+ "normalized": false,
1200
+ "rstrip": false,
1201
+ "single_word": false,
1202
+ "special": true
1203
+ },
1204
+ "128150": {
1205
+ "content": "<|reserved_special_token_142|>",
1206
+ "lstrip": false,
1207
+ "normalized": false,
1208
+ "rstrip": false,
1209
+ "single_word": false,
1210
+ "special": true
1211
+ },
1212
+ "128151": {
1213
+ "content": "<|reserved_special_token_143|>",
1214
+ "lstrip": false,
1215
+ "normalized": false,
1216
+ "rstrip": false,
1217
+ "single_word": false,
1218
+ "special": true
1219
+ },
1220
+ "128152": {
1221
+ "content": "<|reserved_special_token_144|>",
1222
+ "lstrip": false,
1223
+ "normalized": false,
1224
+ "rstrip": false,
1225
+ "single_word": false,
1226
+ "special": true
1227
+ },
1228
+ "128153": {
1229
+ "content": "<|reserved_special_token_145|>",
1230
+ "lstrip": false,
1231
+ "normalized": false,
1232
+ "rstrip": false,
1233
+ "single_word": false,
1234
+ "special": true
1235
+ },
1236
+ "128154": {
1237
+ "content": "<|reserved_special_token_146|>",
1238
+ "lstrip": false,
1239
+ "normalized": false,
1240
+ "rstrip": false,
1241
+ "single_word": false,
1242
+ "special": true
1243
+ },
1244
+ "128155": {
1245
+ "content": "<|reserved_special_token_147|>",
1246
+ "lstrip": false,
1247
+ "normalized": false,
1248
+ "rstrip": false,
1249
+ "single_word": false,
1250
+ "special": true
1251
+ },
1252
+ "128156": {
1253
+ "content": "<|reserved_special_token_148|>",
1254
+ "lstrip": false,
1255
+ "normalized": false,
1256
+ "rstrip": false,
1257
+ "single_word": false,
1258
+ "special": true
1259
+ },
1260
+ "128157": {
1261
+ "content": "<|reserved_special_token_149|>",
1262
+ "lstrip": false,
1263
+ "normalized": false,
1264
+ "rstrip": false,
1265
+ "single_word": false,
1266
+ "special": true
1267
+ },
1268
+ "128158": {
1269
+ "content": "<|reserved_special_token_150|>",
1270
+ "lstrip": false,
1271
+ "normalized": false,
1272
+ "rstrip": false,
1273
+ "single_word": false,
1274
+ "special": true
1275
+ },
1276
+ "128159": {
1277
+ "content": "<|reserved_special_token_151|>",
1278
+ "lstrip": false,
1279
+ "normalized": false,
1280
+ "rstrip": false,
1281
+ "single_word": false,
1282
+ "special": true
1283
+ },
1284
+ "128160": {
1285
+ "content": "<|reserved_special_token_152|>",
1286
+ "lstrip": false,
1287
+ "normalized": false,
1288
+ "rstrip": false,
1289
+ "single_word": false,
1290
+ "special": true
1291
+ },
1292
+ "128161": {
1293
+ "content": "<|reserved_special_token_153|>",
1294
+ "lstrip": false,
1295
+ "normalized": false,
1296
+ "rstrip": false,
1297
+ "single_word": false,
1298
+ "special": true
1299
+ },
1300
+ "128162": {
1301
+ "content": "<|reserved_special_token_154|>",
1302
+ "lstrip": false,
1303
+ "normalized": false,
1304
+ "rstrip": false,
1305
+ "single_word": false,
1306
+ "special": true
1307
+ },
1308
+ "128163": {
1309
+ "content": "<|reserved_special_token_155|>",
1310
+ "lstrip": false,
1311
+ "normalized": false,
1312
+ "rstrip": false,
1313
+ "single_word": false,
1314
+ "special": true
1315
+ },
1316
+ "128164": {
1317
+ "content": "<|reserved_special_token_156|>",
1318
+ "lstrip": false,
1319
+ "normalized": false,
1320
+ "rstrip": false,
1321
+ "single_word": false,
1322
+ "special": true
1323
+ },
1324
+ "128165": {
1325
+ "content": "<|reserved_special_token_157|>",
1326
+ "lstrip": false,
1327
+ "normalized": false,
1328
+ "rstrip": false,
1329
+ "single_word": false,
1330
+ "special": true
1331
+ },
1332
+ "128166": {
1333
+ "content": "<|reserved_special_token_158|>",
1334
+ "lstrip": false,
1335
+ "normalized": false,
1336
+ "rstrip": false,
1337
+ "single_word": false,
1338
+ "special": true
1339
+ },
1340
+ "128167": {
1341
+ "content": "<|reserved_special_token_159|>",
1342
+ "lstrip": false,
1343
+ "normalized": false,
1344
+ "rstrip": false,
1345
+ "single_word": false,
1346
+ "special": true
1347
+ },
1348
+ "128168": {
1349
+ "content": "<|reserved_special_token_160|>",
1350
+ "lstrip": false,
1351
+ "normalized": false,
1352
+ "rstrip": false,
1353
+ "single_word": false,
1354
+ "special": true
1355
+ },
1356
+ "128169": {
1357
+ "content": "<|reserved_special_token_161|>",
1358
+ "lstrip": false,
1359
+ "normalized": false,
1360
+ "rstrip": false,
1361
+ "single_word": false,
1362
+ "special": true
1363
+ },
1364
+ "128170": {
1365
+ "content": "<|reserved_special_token_162|>",
1366
+ "lstrip": false,
1367
+ "normalized": false,
1368
+ "rstrip": false,
1369
+ "single_word": false,
1370
+ "special": true
1371
+ },
1372
+ "128171": {
1373
+ "content": "<|reserved_special_token_163|>",
1374
+ "lstrip": false,
1375
+ "normalized": false,
1376
+ "rstrip": false,
1377
+ "single_word": false,
1378
+ "special": true
1379
+ },
1380
+ "128172": {
1381
+ "content": "<|reserved_special_token_164|>",
1382
+ "lstrip": false,
1383
+ "normalized": false,
1384
+ "rstrip": false,
1385
+ "single_word": false,
1386
+ "special": true
1387
+ },
1388
+ "128173": {
1389
+ "content": "<|reserved_special_token_165|>",
1390
+ "lstrip": false,
1391
+ "normalized": false,
1392
+ "rstrip": false,
1393
+ "single_word": false,
1394
+ "special": true
1395
+ },
1396
+ "128174": {
1397
+ "content": "<|reserved_special_token_166|>",
1398
+ "lstrip": false,
1399
+ "normalized": false,
1400
+ "rstrip": false,
1401
+ "single_word": false,
1402
+ "special": true
1403
+ },
1404
+ "128175": {
1405
+ "content": "<|reserved_special_token_167|>",
1406
+ "lstrip": false,
1407
+ "normalized": false,
1408
+ "rstrip": false,
1409
+ "single_word": false,
1410
+ "special": true
1411
+ },
1412
+ "128176": {
1413
+ "content": "<|reserved_special_token_168|>",
1414
+ "lstrip": false,
1415
+ "normalized": false,
1416
+ "rstrip": false,
1417
+ "single_word": false,
1418
+ "special": true
1419
+ },
1420
+ "128177": {
1421
+ "content": "<|reserved_special_token_169|>",
1422
+ "lstrip": false,
1423
+ "normalized": false,
1424
+ "rstrip": false,
1425
+ "single_word": false,
1426
+ "special": true
1427
+ },
1428
+ "128178": {
1429
+ "content": "<|reserved_special_token_170|>",
1430
+ "lstrip": false,
1431
+ "normalized": false,
1432
+ "rstrip": false,
1433
+ "single_word": false,
1434
+ "special": true
1435
+ },
1436
+ "128179": {
1437
+ "content": "<|reserved_special_token_171|>",
1438
+ "lstrip": false,
1439
+ "normalized": false,
1440
+ "rstrip": false,
1441
+ "single_word": false,
1442
+ "special": true
1443
+ },
1444
+ "128180": {
1445
+ "content": "<|reserved_special_token_172|>",
1446
+ "lstrip": false,
1447
+ "normalized": false,
1448
+ "rstrip": false,
1449
+ "single_word": false,
1450
+ "special": true
1451
+ },
1452
+ "128181": {
1453
+ "content": "<|reserved_special_token_173|>",
1454
+ "lstrip": false,
1455
+ "normalized": false,
1456
+ "rstrip": false,
1457
+ "single_word": false,
1458
+ "special": true
1459
+ },
1460
+ "128182": {
1461
+ "content": "<|reserved_special_token_174|>",
1462
+ "lstrip": false,
1463
+ "normalized": false,
1464
+ "rstrip": false,
1465
+ "single_word": false,
1466
+ "special": true
1467
+ },
1468
+ "128183": {
1469
+ "content": "<|reserved_special_token_175|>",
1470
+ "lstrip": false,
1471
+ "normalized": false,
1472
+ "rstrip": false,
1473
+ "single_word": false,
1474
+ "special": true
1475
+ },
1476
+ "128184": {
1477
+ "content": "<|reserved_special_token_176|>",
1478
+ "lstrip": false,
1479
+ "normalized": false,
1480
+ "rstrip": false,
1481
+ "single_word": false,
1482
+ "special": true
1483
+ },
1484
+ "128185": {
1485
+ "content": "<|reserved_special_token_177|>",
1486
+ "lstrip": false,
1487
+ "normalized": false,
1488
+ "rstrip": false,
1489
+ "single_word": false,
1490
+ "special": true
1491
+ },
1492
+ "128186": {
1493
+ "content": "<|reserved_special_token_178|>",
1494
+ "lstrip": false,
1495
+ "normalized": false,
1496
+ "rstrip": false,
1497
+ "single_word": false,
1498
+ "special": true
1499
+ },
1500
+ "128187": {
1501
+ "content": "<|reserved_special_token_179|>",
1502
+ "lstrip": false,
1503
+ "normalized": false,
1504
+ "rstrip": false,
1505
+ "single_word": false,
1506
+ "special": true
1507
+ },
1508
+ "128188": {
1509
+ "content": "<|reserved_special_token_180|>",
1510
+ "lstrip": false,
1511
+ "normalized": false,
1512
+ "rstrip": false,
1513
+ "single_word": false,
1514
+ "special": true
1515
+ },
1516
+ "128189": {
1517
+ "content": "<|reserved_special_token_181|>",
1518
+ "lstrip": false,
1519
+ "normalized": false,
1520
+ "rstrip": false,
1521
+ "single_word": false,
1522
+ "special": true
1523
+ },
1524
+ "128190": {
1525
+ "content": "<|reserved_special_token_182|>",
1526
+ "lstrip": false,
1527
+ "normalized": false,
1528
+ "rstrip": false,
1529
+ "single_word": false,
1530
+ "special": true
1531
+ },
1532
+ "128191": {
1533
+ "content": "<|reserved_special_token_183|>",
1534
+ "lstrip": false,
1535
+ "normalized": false,
1536
+ "rstrip": false,
1537
+ "single_word": false,
1538
+ "special": true
1539
+ },
1540
+ "128192": {
1541
+ "content": "<|reserved_special_token_184|>",
1542
+ "lstrip": false,
1543
+ "normalized": false,
1544
+ "rstrip": false,
1545
+ "single_word": false,
1546
+ "special": true
1547
+ },
1548
+ "128193": {
1549
+ "content": "<|reserved_special_token_185|>",
1550
+ "lstrip": false,
1551
+ "normalized": false,
1552
+ "rstrip": false,
1553
+ "single_word": false,
1554
+ "special": true
1555
+ },
1556
+ "128194": {
1557
+ "content": "<|reserved_special_token_186|>",
1558
+ "lstrip": false,
1559
+ "normalized": false,
1560
+ "rstrip": false,
1561
+ "single_word": false,
1562
+ "special": true
1563
+ },
1564
+ "128195": {
1565
+ "content": "<|reserved_special_token_187|>",
1566
+ "lstrip": false,
1567
+ "normalized": false,
1568
+ "rstrip": false,
1569
+ "single_word": false,
1570
+ "special": true
1571
+ },
1572
+ "128196": {
1573
+ "content": "<|reserved_special_token_188|>",
1574
+ "lstrip": false,
1575
+ "normalized": false,
1576
+ "rstrip": false,
1577
+ "single_word": false,
1578
+ "special": true
1579
+ },
1580
+ "128197": {
1581
+ "content": "<|reserved_special_token_189|>",
1582
+ "lstrip": false,
1583
+ "normalized": false,
1584
+ "rstrip": false,
1585
+ "single_word": false,
1586
+ "special": true
1587
+ },
1588
+ "128198": {
1589
+ "content": "<|reserved_special_token_190|>",
1590
+ "lstrip": false,
1591
+ "normalized": false,
1592
+ "rstrip": false,
1593
+ "single_word": false,
1594
+ "special": true
1595
+ },
1596
+ "128199": {
1597
+ "content": "<|reserved_special_token_191|>",
1598
+ "lstrip": false,
1599
+ "normalized": false,
1600
+ "rstrip": false,
1601
+ "single_word": false,
1602
+ "special": true
1603
+ },
1604
+ "128200": {
1605
+ "content": "<|reserved_special_token_192|>",
1606
+ "lstrip": false,
1607
+ "normalized": false,
1608
+ "rstrip": false,
1609
+ "single_word": false,
1610
+ "special": true
1611
+ },
1612
+ "128201": {
1613
+ "content": "<|reserved_special_token_193|>",
1614
+ "lstrip": false,
1615
+ "normalized": false,
1616
+ "rstrip": false,
1617
+ "single_word": false,
1618
+ "special": true
1619
+ },
1620
+ "128202": {
1621
+ "content": "<|reserved_special_token_194|>",
1622
+ "lstrip": false,
1623
+ "normalized": false,
1624
+ "rstrip": false,
1625
+ "single_word": false,
1626
+ "special": true
1627
+ },
1628
+ "128203": {
1629
+ "content": "<|reserved_special_token_195|>",
1630
+ "lstrip": false,
1631
+ "normalized": false,
1632
+ "rstrip": false,
1633
+ "single_word": false,
1634
+ "special": true
1635
+ },
1636
+ "128204": {
1637
+ "content": "<|reserved_special_token_196|>",
1638
+ "lstrip": false,
1639
+ "normalized": false,
1640
+ "rstrip": false,
1641
+ "single_word": false,
1642
+ "special": true
1643
+ },
1644
+ "128205": {
1645
+ "content": "<|reserved_special_token_197|>",
1646
+ "lstrip": false,
1647
+ "normalized": false,
1648
+ "rstrip": false,
1649
+ "single_word": false,
1650
+ "special": true
1651
+ },
1652
+ "128206": {
1653
+ "content": "<|reserved_special_token_198|>",
1654
+ "lstrip": false,
1655
+ "normalized": false,
1656
+ "rstrip": false,
1657
+ "single_word": false,
1658
+ "special": true
1659
+ },
1660
+ "128207": {
1661
+ "content": "<|reserved_special_token_199|>",
1662
+ "lstrip": false,
1663
+ "normalized": false,
1664
+ "rstrip": false,
1665
+ "single_word": false,
1666
+ "special": true
1667
+ },
1668
+ "128208": {
1669
+ "content": "<|reserved_special_token_200|>",
1670
+ "lstrip": false,
1671
+ "normalized": false,
1672
+ "rstrip": false,
1673
+ "single_word": false,
1674
+ "special": true
1675
+ },
1676
+ "128209": {
1677
+ "content": "<|reserved_special_token_201|>",
1678
+ "lstrip": false,
1679
+ "normalized": false,
1680
+ "rstrip": false,
1681
+ "single_word": false,
1682
+ "special": true
1683
+ },
1684
+ "128210": {
1685
+ "content": "<|reserved_special_token_202|>",
1686
+ "lstrip": false,
1687
+ "normalized": false,
1688
+ "rstrip": false,
1689
+ "single_word": false,
1690
+ "special": true
1691
+ },
1692
+ "128211": {
1693
+ "content": "<|reserved_special_token_203|>",
1694
+ "lstrip": false,
1695
+ "normalized": false,
1696
+ "rstrip": false,
1697
+ "single_word": false,
1698
+ "special": true
1699
+ },
1700
+ "128212": {
1701
+ "content": "<|reserved_special_token_204|>",
1702
+ "lstrip": false,
1703
+ "normalized": false,
1704
+ "rstrip": false,
1705
+ "single_word": false,
1706
+ "special": true
1707
+ },
1708
+ "128213": {
1709
+ "content": "<|reserved_special_token_205|>",
1710
+ "lstrip": false,
1711
+ "normalized": false,
1712
+ "rstrip": false,
1713
+ "single_word": false,
1714
+ "special": true
1715
+ },
1716
+ "128214": {
1717
+ "content": "<|reserved_special_token_206|>",
1718
+ "lstrip": false,
1719
+ "normalized": false,
1720
+ "rstrip": false,
1721
+ "single_word": false,
1722
+ "special": true
1723
+ },
1724
+ "128215": {
1725
+ "content": "<|reserved_special_token_207|>",
1726
+ "lstrip": false,
1727
+ "normalized": false,
1728
+ "rstrip": false,
1729
+ "single_word": false,
1730
+ "special": true
1731
+ },
1732
+ "128216": {
1733
+ "content": "<|reserved_special_token_208|>",
1734
+ "lstrip": false,
1735
+ "normalized": false,
1736
+ "rstrip": false,
1737
+ "single_word": false,
1738
+ "special": true
1739
+ },
1740
+ "128217": {
1741
+ "content": "<|reserved_special_token_209|>",
1742
+ "lstrip": false,
1743
+ "normalized": false,
1744
+ "rstrip": false,
1745
+ "single_word": false,
1746
+ "special": true
1747
+ },
1748
+ "128218": {
1749
+ "content": "<|reserved_special_token_210|>",
1750
+ "lstrip": false,
1751
+ "normalized": false,
1752
+ "rstrip": false,
1753
+ "single_word": false,
1754
+ "special": true
1755
+ },
1756
+ "128219": {
1757
+ "content": "<|reserved_special_token_211|>",
1758
+ "lstrip": false,
1759
+ "normalized": false,
1760
+ "rstrip": false,
1761
+ "single_word": false,
1762
+ "special": true
1763
+ },
1764
+ "128220": {
1765
+ "content": "<|reserved_special_token_212|>",
1766
+ "lstrip": false,
1767
+ "normalized": false,
1768
+ "rstrip": false,
1769
+ "single_word": false,
1770
+ "special": true
1771
+ },
1772
+ "128221": {
1773
+ "content": "<|reserved_special_token_213|>",
1774
+ "lstrip": false,
1775
+ "normalized": false,
1776
+ "rstrip": false,
1777
+ "single_word": false,
1778
+ "special": true
1779
+ },
1780
+ "128222": {
1781
+ "content": "<|reserved_special_token_214|>",
1782
+ "lstrip": false,
1783
+ "normalized": false,
1784
+ "rstrip": false,
1785
+ "single_word": false,
1786
+ "special": true
1787
+ },
1788
+ "128223": {
1789
+ "content": "<|reserved_special_token_215|>",
1790
+ "lstrip": false,
1791
+ "normalized": false,
1792
+ "rstrip": false,
1793
+ "single_word": false,
1794
+ "special": true
1795
+ },
1796
+ "128224": {
1797
+ "content": "<|reserved_special_token_216|>",
1798
+ "lstrip": false,
1799
+ "normalized": false,
1800
+ "rstrip": false,
1801
+ "single_word": false,
1802
+ "special": true
1803
+ },
1804
+ "128225": {
1805
+ "content": "<|reserved_special_token_217|>",
1806
+ "lstrip": false,
1807
+ "normalized": false,
1808
+ "rstrip": false,
1809
+ "single_word": false,
1810
+ "special": true
1811
+ },
1812
+ "128226": {
1813
+ "content": "<|reserved_special_token_218|>",
1814
+ "lstrip": false,
1815
+ "normalized": false,
1816
+ "rstrip": false,
1817
+ "single_word": false,
1818
+ "special": true
1819
+ },
1820
+ "128227": {
1821
+ "content": "<|reserved_special_token_219|>",
1822
+ "lstrip": false,
1823
+ "normalized": false,
1824
+ "rstrip": false,
1825
+ "single_word": false,
1826
+ "special": true
1827
+ },
1828
+ "128228": {
1829
+ "content": "<|reserved_special_token_220|>",
1830
+ "lstrip": false,
1831
+ "normalized": false,
1832
+ "rstrip": false,
1833
+ "single_word": false,
1834
+ "special": true
1835
+ },
1836
+ "128229": {
1837
+ "content": "<|reserved_special_token_221|>",
1838
+ "lstrip": false,
1839
+ "normalized": false,
1840
+ "rstrip": false,
1841
+ "single_word": false,
1842
+ "special": true
1843
+ },
1844
+ "128230": {
1845
+ "content": "<|reserved_special_token_222|>",
1846
+ "lstrip": false,
1847
+ "normalized": false,
1848
+ "rstrip": false,
1849
+ "single_word": false,
1850
+ "special": true
1851
+ },
1852
+ "128231": {
1853
+ "content": "<|reserved_special_token_223|>",
1854
+ "lstrip": false,
1855
+ "normalized": false,
1856
+ "rstrip": false,
1857
+ "single_word": false,
1858
+ "special": true
1859
+ },
1860
+ "128232": {
1861
+ "content": "<|reserved_special_token_224|>",
1862
+ "lstrip": false,
1863
+ "normalized": false,
1864
+ "rstrip": false,
1865
+ "single_word": false,
1866
+ "special": true
1867
+ },
1868
+ "128233": {
1869
+ "content": "<|reserved_special_token_225|>",
1870
+ "lstrip": false,
1871
+ "normalized": false,
1872
+ "rstrip": false,
1873
+ "single_word": false,
1874
+ "special": true
1875
+ },
1876
+ "128234": {
1877
+ "content": "<|reserved_special_token_226|>",
1878
+ "lstrip": false,
1879
+ "normalized": false,
1880
+ "rstrip": false,
1881
+ "single_word": false,
1882
+ "special": true
1883
+ },
1884
+ "128235": {
1885
+ "content": "<|reserved_special_token_227|>",
1886
+ "lstrip": false,
1887
+ "normalized": false,
1888
+ "rstrip": false,
1889
+ "single_word": false,
1890
+ "special": true
1891
+ },
1892
+ "128236": {
1893
+ "content": "<|reserved_special_token_228|>",
1894
+ "lstrip": false,
1895
+ "normalized": false,
1896
+ "rstrip": false,
1897
+ "single_word": false,
1898
+ "special": true
1899
+ },
1900
+ "128237": {
1901
+ "content": "<|reserved_special_token_229|>",
1902
+ "lstrip": false,
1903
+ "normalized": false,
1904
+ "rstrip": false,
1905
+ "single_word": false,
1906
+ "special": true
1907
+ },
1908
+ "128238": {
1909
+ "content": "<|reserved_special_token_230|>",
1910
+ "lstrip": false,
1911
+ "normalized": false,
1912
+ "rstrip": false,
1913
+ "single_word": false,
1914
+ "special": true
1915
+ },
1916
+ "128239": {
1917
+ "content": "<|reserved_special_token_231|>",
1918
+ "lstrip": false,
1919
+ "normalized": false,
1920
+ "rstrip": false,
1921
+ "single_word": false,
1922
+ "special": true
1923
+ },
1924
+ "128240": {
1925
+ "content": "<|reserved_special_token_232|>",
1926
+ "lstrip": false,
1927
+ "normalized": false,
1928
+ "rstrip": false,
1929
+ "single_word": false,
1930
+ "special": true
1931
+ },
1932
+ "128241": {
1933
+ "content": "<|reserved_special_token_233|>",
1934
+ "lstrip": false,
1935
+ "normalized": false,
1936
+ "rstrip": false,
1937
+ "single_word": false,
1938
+ "special": true
1939
+ },
1940
+ "128242": {
1941
+ "content": "<|reserved_special_token_234|>",
1942
+ "lstrip": false,
1943
+ "normalized": false,
1944
+ "rstrip": false,
1945
+ "single_word": false,
1946
+ "special": true
1947
+ },
1948
+ "128243": {
1949
+ "content": "<|reserved_special_token_235|>",
1950
+ "lstrip": false,
1951
+ "normalized": false,
1952
+ "rstrip": false,
1953
+ "single_word": false,
1954
+ "special": true
1955
+ },
1956
+ "128244": {
1957
+ "content": "<|reserved_special_token_236|>",
1958
+ "lstrip": false,
1959
+ "normalized": false,
1960
+ "rstrip": false,
1961
+ "single_word": false,
1962
+ "special": true
1963
+ },
1964
+ "128245": {
1965
+ "content": "<|reserved_special_token_237|>",
1966
+ "lstrip": false,
1967
+ "normalized": false,
1968
+ "rstrip": false,
1969
+ "single_word": false,
1970
+ "special": true
1971
+ },
1972
+ "128246": {
1973
+ "content": "<|reserved_special_token_238|>",
1974
+ "lstrip": false,
1975
+ "normalized": false,
1976
+ "rstrip": false,
1977
+ "single_word": false,
1978
+ "special": true
1979
+ },
1980
+ "128247": {
1981
+ "content": "<|reserved_special_token_239|>",
1982
+ "lstrip": false,
1983
+ "normalized": false,
1984
+ "rstrip": false,
1985
+ "single_word": false,
1986
+ "special": true
1987
+ },
1988
+ "128248": {
1989
+ "content": "<|reserved_special_token_240|>",
1990
+ "lstrip": false,
1991
+ "normalized": false,
1992
+ "rstrip": false,
1993
+ "single_word": false,
1994
+ "special": true
1995
+ },
1996
+ "128249": {
1997
+ "content": "<|reserved_special_token_241|>",
1998
+ "lstrip": false,
1999
+ "normalized": false,
2000
+ "rstrip": false,
2001
+ "single_word": false,
2002
+ "special": true
2003
+ },
2004
+ "128250": {
2005
+ "content": "<|reserved_special_token_242|>",
2006
+ "lstrip": false,
2007
+ "normalized": false,
2008
+ "rstrip": false,
2009
+ "single_word": false,
2010
+ "special": true
2011
+ },
2012
+ "128251": {
2013
+ "content": "<|reserved_special_token_243|>",
2014
+ "lstrip": false,
2015
+ "normalized": false,
2016
+ "rstrip": false,
2017
+ "single_word": false,
2018
+ "special": true
2019
+ },
2020
+ "128252": {
2021
+ "content": "<|reserved_special_token_244|>",
2022
+ "lstrip": false,
2023
+ "normalized": false,
2024
+ "rstrip": false,
2025
+ "single_word": false,
2026
+ "special": true
2027
+ },
2028
+ "128253": {
2029
+ "content": "<|reserved_special_token_245|>",
2030
+ "lstrip": false,
2031
+ "normalized": false,
2032
+ "rstrip": false,
2033
+ "single_word": false,
2034
+ "special": true
2035
+ },
2036
+ "128254": {
2037
+ "content": "<|reserved_special_token_246|>",
2038
+ "lstrip": false,
2039
+ "normalized": false,
2040
+ "rstrip": false,
2041
+ "single_word": false,
2042
+ "special": true
2043
+ },
2044
+ "128255": {
2045
+ "content": "<|reserved_special_token_247|>",
2046
+ "lstrip": false,
2047
+ "normalized": false,
2048
+ "rstrip": false,
2049
+ "single_word": false,
2050
+ "special": true
2051
+ }
2052
+ },
2053
+ "bos_token": "<|begin_of_text|>",
2054
+ "clean_up_tokenization_spaces": true,
2055
+ "eos_token": "<|eot_id|>",
2056
+ "extra_special_tokens": {},
2057
+ "model_input_names": [
2058
+ "input_ids",
2059
+ "attention_mask"
2060
+ ],
2061
+ "model_max_length": 131072,
2062
+ "pad_token": "<|finetune_right_pad_id|>",
2063
+ "padding_side": "left",
2064
+ "tokenizer_class": "PreTrainedTokenizer",
2065
+ "unk_token": null,
2066
+ "chat_template": "{{- bos_token }}{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content']|trim %}{%- set messages = messages[1:] %}{%- else %}{%- set system_message = \"\" %}{%- endif %}{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}{{- system_message }}{{- \"<|eot_id|>\" }}{%- for message in messages %}{%- if message['role'] == 'assistant' and '</think>' in message['content'] %}{%- set content = message['content'].split('</think>')[-1].lstrip() %}{%- else %}{%- set content = message['content'] %}{%- endif %}{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' + content | trim + '<|eot_id|>' }}{%- endfor %}{%- if add_generation_prompt %}{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{%- endif %}"
2067
+ }
transformers_4_44_2__activations.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from collections import OrderedDict
17
+
18
+ import torch
19
+ from packaging import version
20
+ from torch import Tensor, nn
21
+
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class PytorchGELUTanh(nn.Module):
29
+ """
30
+ A fast C implementation of the tanh approximation of the GeLU activation function. See
31
+ https://arxiv.org/abs/1606.08415.
32
+
33
+ This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
34
+ match due to rounding errors.
35
+ """
36
+
37
+ def __init__(self):
38
+ super().__init__()
39
+ if version.parse(torch.__version__) < version.parse("1.12.0"):
40
+ raise ImportError(
41
+ f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
42
+ "PytorchGELUTanh. Please upgrade torch."
43
+ )
44
+
45
+ def forward(self, input: Tensor) -> Tensor:
46
+ return nn.functional.gelu(input, approximate="tanh")
47
+
48
+
49
+ class NewGELUActivation(nn.Module):
50
+ """
51
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
52
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
53
+ """
54
+
55
+ def forward(self, input: Tensor) -> Tensor:
56
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
57
+
58
+
59
+ class GELUActivation(nn.Module):
60
+ """
61
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
62
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
63
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
64
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
65
+ """
66
+
67
+ def __init__(self, use_gelu_python: bool = False):
68
+ super().__init__()
69
+ if use_gelu_python:
70
+ self.act = self._gelu_python
71
+ else:
72
+ self.act = nn.functional.gelu
73
+
74
+ def _gelu_python(self, input: Tensor) -> Tensor:
75
+ return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
76
+
77
+ def forward(self, input: Tensor) -> Tensor:
78
+ return self.act(input)
79
+
80
+
81
+ class FastGELUActivation(nn.Module):
82
+ """
83
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
84
+ """
85
+
86
+ def forward(self, input: Tensor) -> Tensor:
87
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
88
+
89
+
90
+ class QuickGELUActivation(nn.Module):
91
+ """
92
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
93
+ """
94
+
95
+ def forward(self, input: Tensor) -> Tensor:
96
+ return input * torch.sigmoid(1.702 * input)
97
+
98
+
99
+ class ClippedGELUActivation(nn.Module):
100
+ """
101
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
102
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
103
+ https://arxiv.org/abs/2004.09602.
104
+
105
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
106
+ initially created.
107
+
108
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
109
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
110
+ """
111
+
112
+ def __init__(self, min: float, max: float):
113
+ if min > max:
114
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
115
+
116
+ super().__init__()
117
+ self.min = min
118
+ self.max = max
119
+
120
+ def forward(self, x: Tensor) -> Tensor:
121
+ return torch.clip(gelu(x), self.min, self.max)
122
+
123
+
124
+ class AccurateGELUActivation(nn.Module):
125
+ """
126
+ Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
127
+ https://github.com/hendrycks/GELUs
128
+
129
+ Implemented along with MEGA (Moving Average Equipped Gated Attention)
130
+ """
131
+
132
+ def __init__(self):
133
+ super().__init__()
134
+ self.precomputed_constant = math.sqrt(2 / math.pi)
135
+
136
+ def forward(self, input: Tensor) -> Tensor:
137
+ return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
138
+
139
+
140
+ class MishActivation(nn.Module):
141
+ """
142
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
143
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
144
+ """
145
+
146
+ def __init__(self):
147
+ super().__init__()
148
+ if version.parse(torch.__version__) < version.parse("1.9.0"):
149
+ self.act = self._mish_python
150
+ else:
151
+ self.act = nn.functional.mish
152
+
153
+ def _mish_python(self, input: Tensor) -> Tensor:
154
+ return input * torch.tanh(nn.functional.softplus(input))
155
+
156
+ def forward(self, input: Tensor) -> Tensor:
157
+ return self.act(input)
158
+
159
+
160
+ class LinearActivation(nn.Module):
161
+ """
162
+ Applies the linear activation function, i.e. forwarding input directly to output.
163
+ """
164
+
165
+ def forward(self, input: Tensor) -> Tensor:
166
+ return input
167
+
168
+
169
+ class LaplaceActivation(nn.Module):
170
+ """
171
+ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
172
+ https://arxiv.org/abs/2209.10655
173
+
174
+ Inspired by squared relu, but with bounded range and gradient for better stability
175
+ """
176
+
177
+ def forward(self, input, mu=0.707107, sigma=0.282095):
178
+ input = (input - mu).div(sigma * math.sqrt(2.0))
179
+ return 0.5 * (1.0 + torch.erf(input))
180
+
181
+
182
+ class ReLUSquaredActivation(nn.Module):
183
+ """
184
+ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
185
+ """
186
+
187
+ def forward(self, input):
188
+ relu_applied = nn.functional.relu(input)
189
+ squared = torch.square(relu_applied)
190
+ return squared
191
+
192
+
193
+ class ClassInstantier(OrderedDict):
194
+ def __getitem__(self, key):
195
+ content = super().__getitem__(key)
196
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
197
+ return cls(**kwargs)
198
+
199
+
200
+ ACT2CLS = {
201
+ "gelu": GELUActivation,
202
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
203
+ "gelu_fast": FastGELUActivation,
204
+ "gelu_new": NewGELUActivation,
205
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
206
+ "gelu_pytorch_tanh": PytorchGELUTanh,
207
+ "gelu_accurate": AccurateGELUActivation,
208
+ "laplace": LaplaceActivation,
209
+ "leaky_relu": nn.LeakyReLU,
210
+ "linear": LinearActivation,
211
+ "mish": MishActivation,
212
+ "quick_gelu": QuickGELUActivation,
213
+ "relu": nn.ReLU,
214
+ "relu2": ReLUSquaredActivation,
215
+ "relu6": nn.ReLU6,
216
+ "sigmoid": nn.Sigmoid,
217
+ "silu": nn.SiLU,
218
+ "swish": nn.SiLU,
219
+ "tanh": nn.Tanh,
220
+ }
221
+ ACT2FN = ClassInstantier(ACT2CLS)
222
+
223
+
224
+ def get_activation(activation_string):
225
+ if activation_string in ACT2FN:
226
+ return ACT2FN[activation_string]
227
+ else:
228
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
229
+
230
+
231
+ # For backwards compatibility with: from activations import gelu_python
232
+ gelu_python = get_activation("gelu_python")
233
+ gelu_new = get_activation("gelu_new")
234
+ gelu = get_activation("gelu")
235
+ gelu_fast = get_activation("gelu_fast")
236
+ quick_gelu = get_activation("quick_gelu")
237
+ silu = get_activation("silu")
238
+ mish = get_activation("mish")
239
+ linear_act = get_activation("linear")
transformers_4_44_2__cache_utils.py ADDED
@@ -0,0 +1,1347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import importlib.metadata
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from packaging import version
10
+
11
+ from transformers.configuration_utils import PretrainedConfig
12
+ from transformers.utils import is_torchdynamo_compiling, logging
13
+
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class Cache(torch.nn.Module):
19
+ """
20
+ Base, abstract class for all caches. The actual data structure is specific to each subclass.
21
+ """
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ def update(
27
+ self,
28
+ key_states: torch.Tensor,
29
+ value_states: torch.Tensor,
30
+ layer_idx: int,
31
+ cache_kwargs: Optional[Dict[str, Any]] = None,
32
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
33
+ """
34
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
35
+
36
+ Parameters:
37
+ key_states (`torch.Tensor`):
38
+ The new key states to cache.
39
+ value_states (`torch.Tensor`):
40
+ The new value states to cache.
41
+ layer_idx (`int`):
42
+ The index of the layer to cache the states for.
43
+ cache_kwargs (`Dict[str, Any]`, `optional`):
44
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
45
+ cache to be created.
46
+
47
+ Return:
48
+ A tuple containing the updated key and value states.
49
+ """
50
+ raise NotImplementedError("Make sure to implement `update` in a subclass.")
51
+
52
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
53
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
54
+ # TODO: deprecate this function in favor of `cache_position`
55
+ raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
56
+
57
+ def get_max_length(self) -> Optional[int]:
58
+ """Returns the maximum sequence length of the cached states, if there is any."""
59
+ raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
60
+
61
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
62
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
63
+ # Cache without size limit -> all cache is usable
64
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
65
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
66
+ max_length = self.get_max_length()
67
+ previous_seq_length = self.get_seq_length(layer_idx)
68
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
69
+ return max_length - new_seq_length
70
+ return previous_seq_length
71
+
72
+ def reorder_cache(self, beam_idx: torch.LongTensor):
73
+ """Reorders the cache for beam search, given the selected beam indices."""
74
+ for layer_idx in range(len(self.key_cache)):
75
+ device = self.key_cache[layer_idx].device
76
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
77
+ device = self.value_cache[layer_idx].device
78
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
79
+
80
+ @property
81
+ def seen_tokens(self):
82
+ logger.warning_once(
83
+ "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
84
+ "model input instead."
85
+ )
86
+ if hasattr(self, "_seen_tokens"):
87
+ return self._seen_tokens
88
+ else:
89
+ return None
90
+
91
+
92
+ @dataclass
93
+ class CacheConfig:
94
+ """
95
+ Base class for cache configs
96
+ """
97
+
98
+ cache_implementation: None
99
+
100
+ @classmethod
101
+ def from_dict(cls, config_dict, **kwargs):
102
+ """
103
+ Constructs a CacheConfig instance from a dictionary of parameters.
104
+ Args:
105
+ config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
106
+ **kwargs: Additional keyword arguments to override dictionary values.
107
+
108
+ Returns:
109
+ CacheConfig: Instance of CacheConfig constructed from the dictionary.
110
+ """
111
+ config = cls(**config_dict)
112
+ to_remove = []
113
+ for key, value in kwargs.items():
114
+ if hasattr(config, key):
115
+ setattr(config, key, value)
116
+ to_remove.append(key)
117
+ for key in to_remove:
118
+ kwargs.pop(key, None)
119
+ return config
120
+
121
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
122
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
123
+ """
124
+ Save this instance to a JSON file.
125
+
126
+ Args:
127
+ json_file_path (`str` or `os.PathLike`):
128
+ Path to the JSON file in which this configuration instance's parameters will be saved.
129
+ use_diff (`bool`, *optional*, defaults to `True`):
130
+ If set to `True`, only the difference between the config instance and the default
131
+ `QuantizationConfig()` is serialized to JSON file.
132
+ """
133
+ with open(json_file_path, "w", encoding="utf-8") as writer:
134
+ config_dict = self.to_dict()
135
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
136
+
137
+ writer.write(json_string)
138
+
139
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
140
+ def to_dict(self) -> Dict[str, Any]:
141
+ """
142
+ Serializes this instance to a Python dictionary. Returns:
143
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
144
+ """
145
+ return copy.deepcopy(self.__dict__)
146
+
147
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
148
+ def __iter__(self):
149
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
150
+ for attr, value in copy.deepcopy(self.__dict__).items():
151
+ yield attr, value
152
+
153
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
154
+ def __repr__(self):
155
+ return f"{self.__class__.__name__} {self.to_json_string()}"
156
+
157
+ def to_json_string(self):
158
+ """
159
+ Serializes this instance to a JSON formatted string.
160
+ Returns:
161
+ str: JSON formatted string representing the configuration instance.
162
+ """
163
+ return json.dumps(self.__dict__, indent=2) + "\n"
164
+
165
+ # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
166
+ def update(self, **kwargs):
167
+ """
168
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
169
+ returning all the unused kwargs.
170
+
171
+ Args:
172
+ kwargs (`Dict[str, Any]`):
173
+ Dictionary of attributes to tentatively update this class.
174
+
175
+ Returns:
176
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
177
+ """
178
+ to_remove = []
179
+ for key, value in kwargs.items():
180
+ if hasattr(self, key):
181
+ setattr(self, key, value)
182
+ to_remove.append(key)
183
+
184
+ # Remove all the attributes that were updated, without modifying the input dict
185
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
186
+ return unused_kwargs
187
+
188
+
189
+ class DynamicCache(Cache):
190
+ """
191
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
192
+
193
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
194
+ `[batch_size, num_heads, seq_len, head_dim]`.
195
+
196
+ Example:
197
+
198
+ ```python
199
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
200
+
201
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
202
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
203
+
204
+ >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
205
+
206
+ >>> # Prepare a cache class and pass it to model's forward
207
+ >>> past_key_values = DynamicCache()
208
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
209
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
210
+ ```
211
+ """
212
+
213
+ def __init__(self) -> None:
214
+ super().__init__()
215
+ self.key_cache: List[torch.Tensor] = []
216
+ self.value_cache: List[torch.Tensor] = []
217
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
218
+
219
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
220
+ """
221
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
222
+ sequence length.
223
+ """
224
+ if layer_idx < len(self):
225
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
226
+ else:
227
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
228
+
229
+ def __iter__(self):
230
+ """
231
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
232
+ keys and values
233
+ """
234
+ for layer_idx in range(len(self)):
235
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
236
+
237
+ def __len__(self):
238
+ """
239
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
240
+ to the number of layers in the model.
241
+ """
242
+ return len(self.key_cache)
243
+
244
+ def update(
245
+ self,
246
+ key_states: torch.Tensor,
247
+ value_states: torch.Tensor,
248
+ layer_idx: int,
249
+ cache_kwargs: Optional[Dict[str, Any]] = None,
250
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
251
+ """
252
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
253
+
254
+ Parameters:
255
+ key_states (`torch.Tensor`):
256
+ The new key states to cache.
257
+ value_states (`torch.Tensor`):
258
+ The new value states to cache.
259
+ layer_idx (`int`):
260
+ The index of the layer to cache the states for.
261
+ cache_kwargs (`Dict[str, Any]`, `optional`):
262
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
263
+
264
+ Return:
265
+ A tuple containing the updated key and value states.
266
+ """
267
+ # Update the number of seen tokens
268
+ if layer_idx == 0:
269
+ self._seen_tokens += key_states.shape[-2]
270
+
271
+ # Update the cache
272
+ if len(self.key_cache) <= layer_idx:
273
+ self.key_cache.append(key_states)
274
+ self.value_cache.append(value_states)
275
+ else:
276
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
277
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
278
+
279
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
280
+
281
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
282
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
283
+ # TODO: deprecate this function in favor of `cache_position`
284
+ if len(self.key_cache) <= layer_idx:
285
+ return 0
286
+ return self.key_cache[layer_idx].shape[-2]
287
+
288
+ def get_max_length(self) -> Optional[int]:
289
+ """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
290
+ return None
291
+
292
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
293
+ """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
294
+ backward compatibility."""
295
+ legacy_cache = ()
296
+ for layer_idx in range(len(self)):
297
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
298
+ return legacy_cache
299
+
300
+ @classmethod
301
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
302
+ """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
303
+ backward compatibility."""
304
+ cache = cls()
305
+ if past_key_values is not None:
306
+ for layer_idx in range(len(past_key_values)):
307
+ key_states, value_states = past_key_values[layer_idx]
308
+ cache.update(key_states, value_states, layer_idx)
309
+ return cache
310
+
311
+ def crop(self, max_length: int):
312
+ """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
313
+ negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
314
+ # In case it is negative
315
+ if max_length < 0:
316
+ max_length = self.get_seq_length() - abs(max_length)
317
+
318
+ if self.get_seq_length() <= max_length:
319
+ return
320
+
321
+ self._seen_tokens = max_length
322
+ for idx in range(len(self.key_cache)):
323
+ self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
324
+ self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
325
+
326
+ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
327
+ """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
328
+ `_split_model_inputs()` in `generation.utils`"""
329
+ out = []
330
+ for i in range(0, full_batch_size, split_size):
331
+ current_split = DynamicCache()
332
+ current_split._seen_tokens = self._seen_tokens
333
+ current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
334
+ current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
335
+ out.append(current_split)
336
+ return out
337
+
338
+ @classmethod
339
+ def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
340
+ """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
341
+ `generation.utils`"""
342
+ cache = cls()
343
+ for idx in range(len(splits[0])):
344
+ layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
345
+ layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
346
+ cache.update(layer_keys, layer_values, idx)
347
+ return cache
348
+
349
+ def batch_repeat_interleave(self, repeats: int):
350
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
351
+ for layer_idx in range(len(self)):
352
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
353
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
354
+
355
+ def batch_select_indices(self, indices: torch.Tensor):
356
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
357
+ for layer_idx in range(len(self)):
358
+ self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
359
+ self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
360
+
361
+
362
+ class OffloadedCache(DynamicCache):
363
+ """
364
+ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
365
+ Useful for generating from models with very long context.
366
+
367
+ In addition to the default CUDA stream, where all forward() computations happen,
368
+ this class uses another stream, the prefetch stream, which it creates itself.
369
+ Since scheduling of operations on separate streams happens independently, this class uses
370
+ the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
371
+ The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
372
+ ensure the eviction is scheduled after all computations on that cache are finished.
373
+ """
374
+
375
+ def __init__(self) -> None:
376
+ if not torch.cuda.is_available():
377
+ raise RuntimeError("OffloadedCache can only be used with a GPU")
378
+ super().__init__()
379
+ self.original_device = []
380
+ self.prefetch_stream = torch.cuda.Stream()
381
+ self.beam_idx = None # used to delay beam search operations
382
+
383
+ def prefetch_layer(self, layer_idx: int):
384
+ "Starts prefetching the next layer cache"
385
+ if layer_idx < len(self):
386
+ with torch.cuda.stream(self.prefetch_stream):
387
+ # Prefetch next layer tensors to GPU
388
+ device = self.original_device[layer_idx]
389
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
390
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
391
+
392
+ def evict_previous_layer(self, layer_idx: int):
393
+ "Moves the previous layer cache to the CPU"
394
+ if len(self) > 2:
395
+ # We do it on the default stream so it occurs after all earlier computations on these tensors are done
396
+ prev_layer_idx = (layer_idx - 1) % len(self)
397
+ self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
398
+ self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
399
+
400
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
401
+ "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
402
+ if layer_idx < len(self):
403
+ # Evict the previous layer if necessary
404
+ torch.cuda.current_stream().synchronize()
405
+ self.evict_previous_layer(layer_idx)
406
+ # Load current layer cache to its original device if not already there
407
+ original_device = self.original_device[layer_idx]
408
+ self.prefetch_stream.synchronize()
409
+ key_tensor = self.key_cache[layer_idx]
410
+ value_tensor = self.value_cache[layer_idx]
411
+ # Now deal with beam search ops which were delayed
412
+ if self.beam_idx is not None:
413
+ self.beam_idx = self.beam_idx.to(original_device)
414
+ key_tensor = key_tensor.index_select(0, self.beam_idx)
415
+ value_tensor = value_tensor.index_select(0, self.beam_idx)
416
+ # Prefetch the next layer
417
+ self.prefetch_layer((layer_idx + 1) % len(self))
418
+ return (key_tensor, value_tensor)
419
+ else:
420
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
421
+
422
+ def reorder_cache(self, beam_idx: torch.LongTensor):
423
+ """Saves the beam indices and reorders the cache when the tensor is back to its device."""
424
+ # We delay this operation until the tensors are back to their original
425
+ # device because performing torch.index_select on the CPU is very slow
426
+ del self.beam_idx
427
+ self.beam_idx = beam_idx.clone()
428
+
429
+ def update(
430
+ self,
431
+ key_states: torch.Tensor,
432
+ value_states: torch.Tensor,
433
+ layer_idx: int,
434
+ cache_kwargs: Optional[Dict[str, Any]] = None,
435
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
436
+ """
437
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
438
+ Parameters:
439
+ key_states (`torch.Tensor`):
440
+ The new key states to cache.
441
+ value_states (`torch.Tensor`):
442
+ The new value states to cache.
443
+ layer_idx (`int`):
444
+ The index of the layer to cache the states for.
445
+ cache_kwargs (`Dict[str, Any]`, `optional`):
446
+ Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
447
+ Return:
448
+ A tuple containing the updated key and value states.
449
+ """
450
+ # Update the number of seen tokens
451
+ if layer_idx == 0:
452
+ self._seen_tokens += key_states.shape[-2]
453
+
454
+ # Update the cache
455
+ if len(self.key_cache) <= layer_idx:
456
+ self.key_cache.append(key_states)
457
+ self.value_cache.append(value_states)
458
+ self.original_device.append(key_states.device)
459
+ self.evict_previous_layer(layer_idx)
460
+ else:
461
+ key_tensor, value_tensor = self[layer_idx]
462
+ self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
463
+ self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
464
+
465
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
466
+
467
+ # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
468
+ # if a method is not supposed to be supported in a subclass we should set it to None
469
+ from_legacy_cache = None
470
+
471
+ to_legacy_cache = None
472
+
473
+
474
+ class SinkCache(Cache):
475
+ """
476
+ A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
477
+ generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
478
+ tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
479
+
480
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
481
+ `[batch_size, num_heads, seq_len, head_dim]`.
482
+
483
+ Parameters:
484
+ window_length (`int`):
485
+ The length of the context window.
486
+ num_sink_tokens (`int`):
487
+ The number of sink tokens. See the original paper for more information.
488
+
489
+ Example:
490
+
491
+ ```python
492
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
493
+
494
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
495
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
496
+
497
+ >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
498
+
499
+ >>> # Prepare a cache class and pass it to model's forward
500
+ >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
501
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
502
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
503
+ ```
504
+ """
505
+
506
+ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
507
+ super().__init__()
508
+ self.key_cache: List[torch.Tensor] = []
509
+ self.value_cache: List[torch.Tensor] = []
510
+ self.window_length = window_length
511
+ self.num_sink_tokens = num_sink_tokens
512
+ self.cos_sin_rerotation_cache = {}
513
+ self._cos_cache = None
514
+ self._sin_cache = None
515
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
516
+
517
+ @staticmethod
518
+ def _rotate_half(x):
519
+ x1 = x[..., : x.shape[-1] // 2]
520
+ x2 = x[..., x.shape[-1] // 2 :]
521
+ return torch.cat((-x2, x1), dim=-1)
522
+
523
+ def _apply_key_rotary_pos_emb(
524
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
525
+ ) -> torch.Tensor:
526
+ rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
527
+ return rotated_key_states
528
+
529
+ def _get_rerotation_cos_sin(
530
+ self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
531
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
532
+ if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
533
+ # Upcast to float32 temporarily for better accuracy
534
+ cos = cos.to(torch.float32)
535
+ sin = sin.to(torch.float32)
536
+
537
+ # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
538
+ original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
539
+ shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
540
+ original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
541
+ shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
542
+ rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
543
+ rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
544
+
545
+ self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
546
+ rerotation_cos.to(key_states.dtype).unsqueeze(0),
547
+ rerotation_sin.to(key_states.dtype).unsqueeze(0),
548
+ )
549
+ return self.cos_sin_rerotation_cache[key_states.shape[-2]]
550
+
551
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
552
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
553
+ # TODO: deprecate this function in favor of `cache_position`
554
+ # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
555
+ if len(self.key_cache) <= layer_idx:
556
+ return 0
557
+ return self.key_cache[layer_idx].shape[-2]
558
+
559
+ def get_max_length(self) -> Optional[int]:
560
+ """Returns the maximum sequence length of the cached states."""
561
+ return self.window_length
562
+
563
+ def update(
564
+ self,
565
+ key_states: torch.Tensor,
566
+ value_states: torch.Tensor,
567
+ layer_idx: int,
568
+ cache_kwargs: Optional[Dict[str, Any]] = None,
569
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
570
+ """
571
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
572
+
573
+ Parameters:
574
+ key_states (`torch.Tensor`):
575
+ The new key states to cache.
576
+ value_states (`torch.Tensor`):
577
+ The new value states to cache.
578
+ layer_idx (`int`):
579
+ The index of the layer to cache the states for.
580
+ cache_kwargs (`Dict[str, Any]`, `optional`):
581
+ Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
582
+ `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
583
+ rotation as the tokens are shifted.
584
+
585
+ Return:
586
+ A tuple containing the updated key and value states.
587
+ """
588
+ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
589
+ # with partially rotated position embeddings, like Phi or Persimmon.
590
+ sin = cache_kwargs.get("sin")
591
+ cos = cache_kwargs.get("cos")
592
+ partial_rotation_size = cache_kwargs.get("partial_rotation_size")
593
+ using_rope = cos is not None and sin is not None
594
+
595
+ # Update the number of seen tokens
596
+ if layer_idx == 0:
597
+ self._seen_tokens += key_states.shape[-2]
598
+
599
+ # Update the sin/cos cache, which holds sin/cos values for all possible positions
600
+ if using_rope and layer_idx == 0:
601
+ # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
602
+ # after all RoPE models have a llama-like cache utilization.
603
+ if cos.dim() == 2:
604
+ self._cos_cache = cos
605
+ self._sin_cache = sin
606
+ else:
607
+ if self._cos_cache is None:
608
+ self._cos_cache = cos[0, ...]
609
+ self._sin_cache = sin[0, ...]
610
+ elif self._cos_cache.shape[0] < self.window_length:
611
+ self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
612
+ self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
613
+
614
+ # [bsz, num_heads, seq_len, head_dim]
615
+ if len(self.key_cache) <= layer_idx:
616
+ # Empty cache
617
+ self.key_cache.append(key_states)
618
+ self.value_cache.append(value_states)
619
+
620
+ elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
621
+ # Growing cache
622
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
623
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
624
+
625
+ else:
626
+ # Shifting cache
627
+ keys_to_keep = self.key_cache[layer_idx][
628
+ :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
629
+ ]
630
+
631
+ # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
632
+ if using_rope:
633
+ rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
634
+ key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
635
+ )
636
+ if partial_rotation_size is not None:
637
+ keys_to_keep, keys_pass = (
638
+ keys_to_keep[..., :partial_rotation_size],
639
+ keys_to_keep[..., partial_rotation_size:],
640
+ )
641
+ keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
642
+ if partial_rotation_size is not None:
643
+ keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
644
+
645
+ # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
646
+ sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
647
+ self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
648
+
649
+ sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
650
+ values_to_keep = self.value_cache[layer_idx][
651
+ :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
652
+ ]
653
+ self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
654
+
655
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
656
+
657
+
658
+ class StaticCache(Cache):
659
+ """
660
+ Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
661
+
662
+ Parameters:
663
+ config (`PretrainedConfig`):
664
+ The configuration file defining the shape-related attributes required to initialize the static cache.
665
+ max_batch_size (`int`):
666
+ The maximum batch size with which the model will be used.
667
+ max_cache_len (`int`):
668
+ The maximum sequence length with which the model will be used.
669
+ device (`torch.device`):
670
+ The device on which the cache should be initialized. Should be the same as the layer.
671
+ dtype (*optional*, defaults to `torch.float32`):
672
+ The default `dtype` to use when initializing the layer.
673
+
674
+ Example:
675
+
676
+ ```python
677
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
678
+
679
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
680
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
681
+
682
+ >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
683
+
684
+ >>> # Prepare a cache class and pass it to model's forward
685
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
686
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
687
+ >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
688
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
689
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
690
+ ```
691
+ """
692
+
693
+ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
694
+ super().__init__()
695
+ self.max_batch_size = max_batch_size
696
+ self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
697
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
698
+ self.head_dim = (
699
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
700
+ )
701
+
702
+ self.dtype = dtype if dtype is not None else torch.float32
703
+ self.num_key_value_heads = (
704
+ config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
705
+ )
706
+
707
+ self.key_cache: List[torch.Tensor] = []
708
+ self.value_cache: List[torch.Tensor] = []
709
+ # Note: There will be significant perf decrease if switching to use 5D tensors instead.
710
+ cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
711
+ for idx in range(config.num_hidden_layers):
712
+ new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
713
+ new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
714
+ # Notes:
715
+ # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
716
+ # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
717
+ # it is not needed anyway)
718
+ # 2. `torch.export()` requires mutations to be registered as buffers.
719
+ if not is_torchdynamo_compiling():
720
+ self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
721
+ self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
722
+ new_layer_key_cache = getattr(self, f"key_cache_{idx}")
723
+ new_layer_value_cache = getattr(self, f"value_cache_{idx}")
724
+ torch._dynamo.mark_static_address(new_layer_key_cache)
725
+ torch._dynamo.mark_static_address(new_layer_value_cache)
726
+ self.key_cache.append(new_layer_key_cache)
727
+ self.value_cache.append(new_layer_value_cache)
728
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
729
+
730
+ def update(
731
+ self,
732
+ key_states: torch.Tensor,
733
+ value_states: torch.Tensor,
734
+ layer_idx: int,
735
+ cache_kwargs: Optional[Dict[str, Any]] = None,
736
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
737
+ """
738
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
739
+ It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
740
+
741
+ Parameters:
742
+ key_states (`torch.Tensor`):
743
+ The new key states to cache.
744
+ value_states (`torch.Tensor`):
745
+ The new value states to cache.
746
+ layer_idx (`int`):
747
+ The index of the layer to cache the states for.
748
+ cache_kwargs (`Dict[str, Any]`, `optional`):
749
+ Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
750
+ to know how where to write in the cache.
751
+
752
+ Return:
753
+ A tuple containing the updated key and value states.
754
+ """
755
+ # Update the number of seen tokens
756
+ if layer_idx == 0:
757
+ self._seen_tokens += key_states.shape[-2]
758
+
759
+ cache_position = cache_kwargs.get("cache_position")
760
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
761
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
762
+ k_out = self.key_cache[layer_idx]
763
+ v_out = self.value_cache[layer_idx]
764
+
765
+ if cache_position is None:
766
+ k_out.copy_(key_states)
767
+ v_out.copy_(value_states)
768
+ else:
769
+ # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
770
+ # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
771
+ # operation, that avoids copies and uses less memory.
772
+ try:
773
+ k_out.index_copy_(2, cache_position, key_states)
774
+ v_out.index_copy_(2, cache_position, value_states)
775
+ except NotImplementedError:
776
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
777
+ k_out[:, :, cache_position] = key_states
778
+ v_out[:, :, cache_position] = value_states
779
+
780
+ return k_out, v_out
781
+
782
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
783
+ """Returns the sequence length of the cached states that were seen by the model."""
784
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
785
+ # limit the check to the first batch member and head dimension.
786
+ # TODO: deprecate this function in favor of `cache_position`
787
+ # return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
788
+ return self._seen_tokens
789
+
790
+ def get_max_length(self) -> Optional[int]:
791
+ """Returns the maximum sequence length of the cached states."""
792
+ return self.max_cache_len
793
+
794
+ def reset(self):
795
+ self._seen_tokens = 0
796
+ """Resets the cache values while preserving the objects"""
797
+ for layer_idx in range(len(self.key_cache)):
798
+ # In-place ops prevent breaking the static address
799
+ self.key_cache[layer_idx].zero_()
800
+ self.value_cache[layer_idx].zero_()
801
+
802
+
803
+ class SlidingWindowCache(StaticCache):
804
+ """
805
+ Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
806
+ Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
807
+ if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
808
+ we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
809
+
810
+ The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
811
+
812
+ indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
813
+ tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
814
+ 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
815
+ 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
816
+ 55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
817
+
818
+ We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
819
+
820
+ Parameters:
821
+ config (`PretrainedConfig`):
822
+ The configuration file defining the shape-related attributes required to initialize the static cache.
823
+ max_batch_size (`int`):
824
+ The maximum batch size with which the model will be used.
825
+ max_cache_len (`int`):
826
+ The maximum sequence length with which the model will be used.
827
+ device (`torch.device`):
828
+ The device on which the cache should be initialized. Should be the same as the layer.
829
+ dtype (*optional*, defaults to `torch.float32`):
830
+ The default `dtype` to use when initializing the layer.
831
+
832
+ Example:
833
+
834
+ ```python
835
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache
836
+
837
+ >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
838
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
839
+
840
+ >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
841
+
842
+ >>> # Prepare a cache class and pass it to model's forward
843
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
844
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
845
+ >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
846
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
847
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
848
+ ```
849
+ """
850
+
851
+ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
852
+ super().__init__(config, max_batch_size, max_cache_len, device, dtype)
853
+ if not hasattr(config, "sliding_window") or config.sliding_window is None:
854
+ raise ValueError(
855
+ "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
856
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
857
+ "config and it's not set to None."
858
+ )
859
+ max_cache_len = min(config.sliding_window, max_cache_len)
860
+ super().__init__(
861
+ config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
862
+ )
863
+
864
+ def update(
865
+ self,
866
+ key_states: torch.Tensor,
867
+ value_states: torch.Tensor,
868
+ layer_idx: int,
869
+ cache_kwargs: Optional[Dict[str, Any]] = None,
870
+ ) -> Tuple[torch.Tensor]:
871
+ cache_position = cache_kwargs.get("cache_position")
872
+ k_out = self.key_cache[layer_idx]
873
+ v_out = self.value_cache[layer_idx]
874
+
875
+ # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
876
+ if cache_position.shape[0] > self.max_cache_len:
877
+ k_out = key_states[:, :, -self.max_cache_len :, :]
878
+ v_out = value_states[:, :, -self.max_cache_len :, :]
879
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
880
+ self.key_cache[layer_idx] += k_out
881
+ self.value_cache[layer_idx] += v_out
882
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
883
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
884
+ return key_states, value_states
885
+
886
+ slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
887
+ cache_position = cache_position.clamp(0, self.max_cache_len - 1)
888
+ to_shift = cache_position >= self.max_cache_len - 1
889
+ indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
890
+
891
+ k_out = k_out[:, :, indices]
892
+ v_out = v_out[:, :, indices]
893
+
894
+ try:
895
+ cache_position.to(device=k_out.device)
896
+ k_out.index_copy_(2, cache_position, key_states)
897
+ v_out.index_copy_(2, cache_position, value_states)
898
+ except NotImplementedError:
899
+ # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
900
+ k_out[:, :, cache_position] = key_states
901
+ v_out[:, :, cache_position] = value_states
902
+
903
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
904
+ self.key_cache[layer_idx].zero_()
905
+ self.value_cache[layer_idx].zero_()
906
+
907
+ self.key_cache[layer_idx] += k_out
908
+ self.value_cache[layer_idx] += v_out
909
+
910
+ return k_out, v_out
911
+
912
+ def get_max_length(self) -> Optional[int]:
913
+ # in theory there is no limit because the sliding window size is fixed no matter how long the sentence is
914
+ return None
915
+
916
+ def reset(self):
917
+ for layer_idx in range(len(self.key_cache)):
918
+ # In-place ops prevent breaking the static address
919
+ self.key_cache[layer_idx].zero_()
920
+ self.value_cache[layer_idx].zero_()
921
+
922
+
923
+ class EncoderDecoderCache(Cache):
924
+ """
925
+ Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
926
+ cross-attention caches.
927
+
928
+ Example:
929
+
930
+ ```python
931
+ >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
932
+
933
+ >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
934
+ >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
935
+
936
+ >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
937
+
938
+ >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
939
+ >>> self_attention_cache = DynamicCache()
940
+ >>> cross_attention_cache = DynamicCache()
941
+ >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
942
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
943
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
944
+ ```
945
+
946
+ """
947
+
948
+ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
949
+ super().__init__()
950
+ self.self_attention_cache = self_attention_cache
951
+ self.cross_attention_cache = cross_attention_cache
952
+
953
+ self.is_updated = {}
954
+ for layer_idx in range(len(cross_attention_cache.key_cache)):
955
+ self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
956
+
957
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
958
+ """
959
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
960
+ sequence length.
961
+ """
962
+ if layer_idx < len(self):
963
+ return (
964
+ self.self_attention_cache.key_cache[layer_idx],
965
+ self.self_attention_cache.value_cache[layer_idx],
966
+ self.cross_attention_cache.key_cache[layer_idx],
967
+ self.cross_attention_cache.value_cache[layer_idx],
968
+ )
969
+ else:
970
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
971
+
972
+ def __len__(self):
973
+ """
974
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
975
+ to the number of layers in the model.
976
+ """
977
+ return len(self.self_attention_cache)
978
+
979
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
980
+ """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
981
+ legacy_cache = ()
982
+ if len(self.cross_attention_cache) > 0:
983
+ for self_attn, cross_attn in zip(
984
+ self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
985
+ ):
986
+ legacy_cache += (self_attn + cross_attn,)
987
+ else:
988
+ legacy_cache = self.self_attention_cache.to_legacy_cache()
989
+ return legacy_cache
990
+
991
+ @classmethod
992
+ def from_legacy_cache(
993
+ cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
994
+ ) -> "EncoderDecoderCache":
995
+ """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
996
+ cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
997
+ if past_key_values is not None:
998
+ for layer_idx in range(len(past_key_values)):
999
+ key_states, value_states = past_key_values[layer_idx][:2]
1000
+ cache.self_attention_cache.update(key_states, value_states, layer_idx)
1001
+ if len(past_key_values[layer_idx]) > 2:
1002
+ key_states, value_states = past_key_values[layer_idx][2:]
1003
+ cache.cross_attention_cache.update(key_states, value_states, layer_idx)
1004
+ cache.is_updated[layer_idx] = True
1005
+ return cache
1006
+
1007
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
1008
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
1009
+ if len(self.self_attention_cache.key_cache) <= layer_idx:
1010
+ return 0
1011
+ return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
1012
+
1013
+ def reset(self):
1014
+ if hasattr(self.self_attention_cache, "reset"):
1015
+ self.self_attention_cache.reset()
1016
+ if hasattr(self.cross_attention_cache, "reset"):
1017
+ self.cross_attention_cache.reset()
1018
+ elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
1019
+ raise ValueError(
1020
+ "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
1021
+ "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
1022
+ f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
1023
+ f"{self.cross_attention_cache.__str__()} for the cross attention cache."
1024
+ )
1025
+ for layer_idx in self.is_updated:
1026
+ self.is_updated[layer_idx] = False
1027
+
1028
+ def reorder_cache(self, beam_idx: torch.LongTensor):
1029
+ """Reorders the cache for beam search, given the selected beam indices."""
1030
+ self.self_attention_cache.reorder_cache(beam_idx)
1031
+ self.cross_attention_cache.reorder_cache(beam_idx)
1032
+
1033
+ def check_dynamic_cache(self, method: str):
1034
+ if not (
1035
+ isinstance(self.self_attention_cache, DynamicCache)
1036
+ and isinstance(self.cross_attention_cache, DynamicCache)
1037
+ ):
1038
+ raise ValueError(
1039
+ f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
1040
+ f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
1041
+ )
1042
+
1043
+ # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
1044
+ def crop(self, maximum_length: int):
1045
+ """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
1046
+ negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
1047
+ self.check_dynamic_cache(self.crop.__name__)
1048
+ self.self_attention_cache.crop(maximum_length)
1049
+
1050
+ def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
1051
+ """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
1052
+ `_split_model_inputs()` in `generation.utils`"""
1053
+ self.check_dynamic_cache(self.batch_split.__name__)
1054
+ self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
1055
+ cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
1056
+
1057
+ out = []
1058
+ for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
1059
+ out.append(EncoderDecoderCache(self_attn, cross_attn))
1060
+ return out
1061
+
1062
+ @classmethod
1063
+ def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
1064
+ """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
1065
+ `generation.utils`"""
1066
+ self_attention_cache = DynamicCache()
1067
+ cross_attention_cache = DynamicCache()
1068
+ for idx in range(len(splits[0])):
1069
+ layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
1070
+ layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
1071
+ self_attention_cache.update(layer_keys, layer_values, idx)
1072
+
1073
+ layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
1074
+ layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
1075
+ cross_attention_cache.update(layer_keys, layer_values, idx)
1076
+ return cls(self_attention_cache, cross_attention_cache)
1077
+
1078
+ def batch_repeat_interleave(self, repeats: int):
1079
+ """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
1080
+ self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
1081
+ self.self_attention_cache.batch_repeat_interleave(repeats)
1082
+ self.cross_attention_cache.batch_repeat_interleave(repeats)
1083
+
1084
+ def batch_select_indices(self, indices: torch.Tensor):
1085
+ """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
1086
+ self.check_dynamic_cache(self.batch_select_indices.__name__)
1087
+ self.self_attention_cache.batch_select_indices(indices)
1088
+ self.cross_attention_cache.batch_select_indices(indices)
1089
+
1090
+
1091
+ class HybridCache(Cache):
1092
+ """
1093
+ Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
1094
+ and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
1095
+ and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
1096
+
1097
+ Parameters:
1098
+ config (`PretrainedConfig):
1099
+ The configuration file defining the shape-related attributes required to initialize the static cache.
1100
+ max_batch_size (`int`):
1101
+ The maximum batch size with which the model will be used.
1102
+ max_cache_len (`int`):
1103
+ The maximum sequence length with which the model will be used.
1104
+ device (`torch.device`, *optional*, defaults to `"cpu"`):
1105
+ The device on which the cache should be initialized. Should be the same as the layer.
1106
+ dtype (*optional*, defaults to `torch.float32`):
1107
+ The default `dtype` to use when initializing the layer.
1108
+
1109
+ Example:
1110
+
1111
+ ```python
1112
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
1113
+
1114
+ >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b")
1115
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
1116
+
1117
+ >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
1118
+
1119
+ >>> # Prepare a cache class and pass it to model's forward
1120
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
1121
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
1122
+ >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1123
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1124
+ >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
1125
+ ```
1126
+ """
1127
+
1128
+ def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
1129
+ super().__init__()
1130
+ if not hasattr(config, "sliding_window") or config.sliding_window is None:
1131
+ raise ValueError(
1132
+ "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
1133
+ "sliding window attention, please check if there is a `sliding_window` field in the model "
1134
+ "config and it's not set to None."
1135
+ )
1136
+ self.max_cache_len = max_cache_len
1137
+ self.max_batch_size = max_batch_size
1138
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
1139
+ self.head_dim = (
1140
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
1141
+ )
1142
+
1143
+ self.dtype = dtype if dtype is not None else torch.float32
1144
+ self.num_key_value_heads = (
1145
+ config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
1146
+ )
1147
+ self.is_sliding = torch.tensor(
1148
+ [not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
1149
+ )
1150
+ self.key_cache: List[torch.Tensor] = []
1151
+ self.value_cache: List[torch.Tensor] = []
1152
+ global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
1153
+ sliding_cache_shape = (
1154
+ max_batch_size,
1155
+ self.num_key_value_heads,
1156
+ min(config.sliding_window, max_cache_len),
1157
+ self.head_dim,
1158
+ )
1159
+ for i in range(config.num_hidden_layers):
1160
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1161
+ # breaks when updating the cache.
1162
+ cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
1163
+ new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
1164
+ new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
1165
+ torch._dynamo.mark_static_address(new_layer_key_cache)
1166
+ torch._dynamo.mark_static_address(new_layer_value_cache)
1167
+ self.key_cache.append(new_layer_key_cache)
1168
+ self.value_cache.append(new_layer_value_cache)
1169
+
1170
+ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
1171
+ if cache_position.shape[0] > max_cache_len:
1172
+ k_out = key_states[:, :, -max_cache_len:, :]
1173
+ v_out = value_states[:, :, -max_cache_len:, :]
1174
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
1175
+ self.key_cache[layer_idx] += k_out
1176
+ self.value_cache[layer_idx] += v_out
1177
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
1178
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
1179
+ return key_states, value_states
1180
+
1181
+ slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
1182
+ cache_position = cache_position.clamp(0, max_cache_len - 1)
1183
+ to_shift = cache_position >= max_cache_len - 1
1184
+ indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
1185
+ k_out = k_out[:, :, indices]
1186
+ v_out = v_out[:, :, indices]
1187
+
1188
+ k_out[:, :, cache_position] = key_states
1189
+ v_out[:, :, cache_position] = value_states
1190
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
1191
+ self.key_cache[layer_idx].zero_()
1192
+ self.value_cache[layer_idx].zero_()
1193
+
1194
+ self.key_cache[layer_idx] += k_out
1195
+ self.value_cache[layer_idx] += v_out
1196
+ return k_out, v_out
1197
+
1198
+ def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
1199
+ k_out[:, :, cache_position] = key_states
1200
+ v_out[:, :, cache_position] = value_states
1201
+
1202
+ self.key_cache[layer_idx] = k_out
1203
+ self.value_cache[layer_idx] = v_out
1204
+ return k_out, v_out
1205
+
1206
+ def update(
1207
+ self,
1208
+ key_states: torch.Tensor,
1209
+ value_states: torch.Tensor,
1210
+ layer_idx: int,
1211
+ cache_kwargs: Optional[Dict[str, Any]] = None,
1212
+ ) -> Tuple[torch.Tensor]:
1213
+ cache_position = cache_kwargs.get("cache_position")
1214
+ sliding_window = cache_kwargs.get("sliding_window")
1215
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
1216
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
1217
+ k_out = self.key_cache[layer_idx]
1218
+ v_out = self.value_cache[layer_idx]
1219
+ if sliding_window:
1220
+ update_fn = self._sliding_update
1221
+ else:
1222
+ update_fn = self._static_update
1223
+
1224
+ return update_fn(
1225
+ cache_position,
1226
+ layer_idx,
1227
+ key_states,
1228
+ value_states,
1229
+ k_out,
1230
+ v_out,
1231
+ k_out.shape[2],
1232
+ )
1233
+
1234
+ def get_max_length(self) -> Optional[int]:
1235
+ # in theory there is no limit because the sliding window size is fixed
1236
+ # no matter how long the sentence is
1237
+ return self.max_cache_len
1238
+
1239
+ def get_seq_length(self, layer_idx: Optional[int] = 0):
1240
+ return None
1241
+
1242
+ def reset(self):
1243
+ """Resets the cache values while preserving the objects"""
1244
+ for layer_idx in range(len(self.key_cache)):
1245
+ # In-place ops prevent breaking the static address
1246
+ self.key_cache[layer_idx].zero_()
1247
+ self.value_cache[layer_idx].zero_()
1248
+
1249
+
1250
+ class MambaCache:
1251
+ """
1252
+ Cache for mamba model which does not have attention mechanism and key value states.
1253
+
1254
+ Arguments:
1255
+ config (`PretrainedConfig):
1256
+ The configuration file defining the shape-related attributes required to initialize the static cache.
1257
+ max_batch_size (`int`):
1258
+ The maximum batch size with which the model will be used.
1259
+ dtype (*optional*, defaults to `torch.float16`):
1260
+ The default `dtype` to use when initializing the layer.
1261
+ device (`torch.device`, *optional*):
1262
+ The device on which the cache should be initialized. Should be the same as the layer.
1263
+
1264
+ Attributes:
1265
+ dtype: (`torch.dtype`):
1266
+ The default `dtype` used to initializing the cache.
1267
+ intermediate_size: (`int`):
1268
+ Model's intermediate_size taken from config.
1269
+ ssm_state_size: (`int`):
1270
+ Model's state_size taken from config.
1271
+ conv_kernel_size: (`int`):
1272
+ Model's convolution kernel size taken from config
1273
+ conv_states: (`torch.Tensor`):
1274
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
1275
+ ssm_states: (`torch.Tensor`):
1276
+ A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
1277
+
1278
+ Example:
1279
+
1280
+ ```python
1281
+ >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
1282
+
1283
+ >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
1284
+ >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
1285
+
1286
+ >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
1287
+
1288
+ >>> # Prepare a cache class and pass it to model's forward
1289
+ >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
1290
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1291
+ >>> past_kv = outputs.past_key_values
1292
+ ```
1293
+ """
1294
+
1295
+ def __init__(
1296
+ self,
1297
+ config: PretrainedConfig,
1298
+ max_batch_size: int,
1299
+ dtype: torch.dtype = torch.float16,
1300
+ device: Optional[str] = None,
1301
+ **kwargs,
1302
+ ):
1303
+ self.dtype = dtype
1304
+ self.max_batch_size = max_batch_size
1305
+ self.intermediate_size = config.intermediate_size
1306
+ self.ssm_state_size = config.state_size
1307
+ self.conv_kernel_size = config.conv_kernel
1308
+
1309
+ self.conv_states: torch.Tensor = torch.zeros(
1310
+ config.num_hidden_layers,
1311
+ self.max_batch_size,
1312
+ self.intermediate_size,
1313
+ self.conv_kernel_size,
1314
+ device=device,
1315
+ dtype=dtype,
1316
+ )
1317
+ self.ssm_states: torch.Tensor = torch.zeros(
1318
+ config.num_hidden_layers,
1319
+ self.max_batch_size,
1320
+ self.intermediate_size,
1321
+ self.ssm_state_size,
1322
+ device=device,
1323
+ dtype=dtype,
1324
+ )
1325
+
1326
+ torch._dynamo.mark_static_address(self.conv_states)
1327
+ torch._dynamo.mark_static_address(self.ssm_states)
1328
+
1329
+ def update_conv_state(
1330
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
1331
+ ) -> torch.Tensor:
1332
+ conv_state = self.conv_states[layer_idx]
1333
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
1334
+
1335
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
1336
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
1337
+ self.conv_states[layer_idx].zero_()
1338
+ self.conv_states[layer_idx] += conv_state
1339
+ return self.conv_states[layer_idx]
1340
+
1341
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
1342
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
1343
+ return self.ssm_states[layer_idx]
1344
+
1345
+ def reset(self):
1346
+ self.conv_states.zero_()
1347
+ self.ssm_states.zero_()
transformers_4_44_2__configuration_llama.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """LLaMA model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from .transformers_4_44_2__modeling_rope_utils import rope_config_validation
24
+
25
+
26
+ class LlamaConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the LLaMA-7B.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 32000):
38
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`LlamaModel`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 11008):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ num_key_value_heads (`int`, *optional*):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
55
+ `num_attention_heads`.
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
57
+ The non-linear activation function (function or string) in the decoder.
58
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
59
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
60
+ Llama 2 up to 4096, CodeLlama up to 16384.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ pad_token_id (`int`, *optional*):
69
+ Padding token id.
70
+ bos_token_id (`int`, *optional*, defaults to 1):
71
+ Beginning of stream token id.
72
+ eos_token_id (`int`, *optional*, defaults to 2):
73
+ End of stream token id.
74
+ pretraining_tp (`int`, *optional*, defaults to 1):
75
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
76
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
77
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
78
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
79
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80
+ Whether to tie weight embeddings
81
+ rope_theta (`float`, *optional*, defaults to 10000.0):
82
+ The base period of the RoPE embeddings.
83
+ rope_scaling (`Dict`, *optional*):
84
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
85
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
86
+ accordingly.
87
+ Expected contents:
88
+ `rope_type` (`str`):
89
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
90
+ 'llama3'], with 'default' being the original RoPE implementation.
91
+ `factor` (`float`, *optional*):
92
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
93
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
94
+ original maximum pre-trained length.
95
+ `original_max_position_embeddings` (`int`, *optional*):
96
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
97
+ pretraining.
98
+ `attention_factor` (`float`, *optional*):
99
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
100
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
101
+ `factor` field to infer the suggested value.
102
+ `beta_fast` (`float`, *optional*):
103
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
104
+ ramp function. If unspecified, it defaults to 32.
105
+ `beta_slow` (`float`, *optional*):
106
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
107
+ ramp function. If unspecified, it defaults to 1.
108
+ `short_factor` (`List[float]`, *optional*):
109
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
110
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
111
+ size divided by the number of attention heads divided by 2
112
+ `long_factor` (`List[float]`, *optional*):
113
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
114
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
115
+ size divided by the number of attention heads divided by 2
116
+ `low_freq_factor` (`float`, *optional*):
117
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
118
+ `high_freq_factor` (`float`, *optional*):
119
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
120
+ attention_bias (`bool`, *optional*, defaults to `False`):
121
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
122
+ attention_dropout (`float`, *optional*, defaults to 0.0):
123
+ The dropout ratio for the attention probabilities.
124
+ mlp_bias (`bool`, *optional*, defaults to `False`):
125
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
126
+
127
+ ```python
128
+ >>> from transformers import LlamaModel, LlamaConfig
129
+
130
+ >>> # Initializing a LLaMA llama-7b style configuration
131
+ >>> configuration = LlamaConfig()
132
+
133
+ >>> # Initializing a model from the llama-7b style configuration
134
+ >>> model = LlamaModel(configuration)
135
+
136
+ >>> # Accessing the model configuration
137
+ >>> configuration = model.config
138
+ ```"""
139
+
140
+ model_type = "llama"
141
+ keys_to_ignore_at_inference = ["past_key_values"]
142
+
143
+ def __init__(
144
+ self,
145
+ vocab_size=32000,
146
+ hidden_size=4096,
147
+ intermediate_size=11008,
148
+ num_hidden_layers=32,
149
+ num_attention_heads=32,
150
+ num_key_value_heads=None,
151
+ hidden_act="silu",
152
+ max_position_embeddings=2048,
153
+ initializer_range=0.02,
154
+ rms_norm_eps=1e-6,
155
+ use_cache=True,
156
+ pad_token_id=None,
157
+ bos_token_id=1,
158
+ eos_token_id=2,
159
+ pretraining_tp=1,
160
+ tie_word_embeddings=False,
161
+ rope_theta=10000.0,
162
+ rope_scaling=None,
163
+ attention_bias=False,
164
+ attention_dropout=0.0,
165
+ mlp_bias=False,
166
+ **kwargs,
167
+ ):
168
+ self.vocab_size = vocab_size
169
+ self.max_position_embeddings = max_position_embeddings
170
+ self.hidden_size = hidden_size
171
+ self.intermediate_size = intermediate_size
172
+ self.num_hidden_layers = num_hidden_layers
173
+ self.num_attention_heads = num_attention_heads
174
+
175
+ # for backward compatibility
176
+ if num_key_value_heads is None:
177
+ num_key_value_heads = num_attention_heads
178
+
179
+ self.num_key_value_heads = num_key_value_heads
180
+ self.hidden_act = hidden_act
181
+ self.initializer_range = initializer_range
182
+ self.rms_norm_eps = rms_norm_eps
183
+ self.pretraining_tp = pretraining_tp
184
+ self.use_cache = use_cache
185
+ self.rope_theta = rope_theta
186
+ self.rope_scaling = rope_scaling
187
+ self.attention_bias = attention_bias
188
+ self.attention_dropout = attention_dropout
189
+ self.mlp_bias = mlp_bias
190
+
191
+ # Validate the correctness of rotary position embeddings parameters
192
+ # BC: if there is a 'type' field, move it to 'rope_type'.
193
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
194
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
195
+ rope_config_validation(self)
196
+
197
+ super().__init__(
198
+ pad_token_id=pad_token_id,
199
+ bos_token_id=bos_token_id,
200
+ eos_token_id=eos_token_id,
201
+ tie_word_embeddings=tie_word_embeddings,
202
+ **kwargs,
203
+ )
transformers_4_44_2__modeling_attn_mask_utils.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+
20
+ @dataclass
21
+ class AttentionMaskConverter:
22
+ """
23
+ A utility attention mask class that allows one to:
24
+ - Create a causal 4d mask
25
+ - Create a causal 4d mask with slided window
26
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
27
+ key_value_length) that can be multiplied with attention scores
28
+
29
+ Examples:
30
+
31
+ ```python
32
+ >>> import torch
33
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
+
35
+ >>> converter = AttentionMaskConverter(True)
36
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
37
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
38
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
39
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
40
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
41
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
42
+ ```
43
+
44
+ Parameters:
45
+ is_causal (`bool`):
46
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
47
+
48
+ sliding_window (`int`, *optional*):
49
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
50
+ """
51
+
52
+ is_causal: bool
53
+ sliding_window: int
54
+
55
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
56
+ self.is_causal = is_causal
57
+ self.sliding_window = sliding_window
58
+
59
+ if self.sliding_window is not None and self.sliding_window <= 0:
60
+ raise ValueError(
61
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
62
+ )
63
+
64
+ def to_causal_4d(
65
+ self,
66
+ batch_size: int,
67
+ query_length: int,
68
+ key_value_length: int,
69
+ dtype: torch.dtype,
70
+ device: Union[torch.device, "str"] = "cpu",
71
+ ) -> Optional[torch.Tensor]:
72
+ """
73
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
74
+ bias to upper right hand triangular matrix (causal mask).
75
+ """
76
+ if not self.is_causal:
77
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
78
+
79
+ # If shape is not cached, create a new causal mask and cache it
80
+ input_shape = (batch_size, query_length)
81
+ past_key_values_length = key_value_length - query_length
82
+
83
+ # create causal mask
84
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
85
+ causal_4d_mask = None
86
+ if input_shape[-1] > 1 or self.sliding_window is not None:
87
+ causal_4d_mask = self._make_causal_mask(
88
+ input_shape,
89
+ dtype,
90
+ device=device,
91
+ past_key_values_length=past_key_values_length,
92
+ sliding_window=self.sliding_window,
93
+ )
94
+
95
+ return causal_4d_mask
96
+
97
+ def to_4d(
98
+ self,
99
+ attention_mask_2d: torch.Tensor,
100
+ query_length: int,
101
+ dtype: torch.dtype,
102
+ key_value_length: Optional[int] = None,
103
+ ) -> torch.Tensor:
104
+ """
105
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
106
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
107
+ causal, a causal mask will be added.
108
+ """
109
+ input_shape = (attention_mask_2d.shape[0], query_length)
110
+
111
+ # create causal mask
112
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
113
+ causal_4d_mask = None
114
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
115
+ if key_value_length is None:
116
+ raise ValueError(
117
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
118
+ )
119
+
120
+ past_key_values_length = key_value_length - query_length
121
+ causal_4d_mask = self._make_causal_mask(
122
+ input_shape,
123
+ dtype,
124
+ device=attention_mask_2d.device,
125
+ past_key_values_length=past_key_values_length,
126
+ sliding_window=self.sliding_window,
127
+ )
128
+ elif self.sliding_window is not None:
129
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
130
+
131
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
132
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
133
+ attention_mask_2d.device
134
+ )
135
+
136
+ if causal_4d_mask is not None:
137
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
138
+
139
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
140
+ expanded_4d_mask = expanded_attn_mask
141
+
142
+ return expanded_4d_mask
143
+
144
+ @staticmethod
145
+ def _make_causal_mask(
146
+ input_ids_shape: torch.Size,
147
+ dtype: torch.dtype,
148
+ device: torch.device,
149
+ past_key_values_length: int = 0,
150
+ sliding_window: Optional[int] = None,
151
+ ):
152
+ """
153
+ Make causal mask used for bi-directional self-attention.
154
+ """
155
+ bsz, tgt_len = input_ids_shape
156
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
157
+ mask_cond = torch.arange(mask.size(-1), device=device)
158
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
159
+
160
+ mask = mask.to(dtype)
161
+
162
+ if past_key_values_length > 0:
163
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
164
+
165
+ # add lower triangular sliding window mask if necessary
166
+ if sliding_window is not None:
167
+ diagonal = past_key_values_length - sliding_window - 1
168
+
169
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
170
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
171
+
172
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
173
+
174
+ @staticmethod
175
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
176
+ """
177
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
178
+ """
179
+ bsz, src_len = mask.size()
180
+ tgt_len = tgt_len if tgt_len is not None else src_len
181
+
182
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
183
+
184
+ inverted_mask = 1.0 - expanded_mask
185
+
186
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
187
+
188
+ @staticmethod
189
+ def _unmask_unattended(
190
+ expanded_mask: torch.FloatTensor,
191
+ min_dtype: float,
192
+ ):
193
+ # fmt: off
194
+ """
195
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
196
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
197
+ Details: https://github.com/pytorch/pytorch/issues/110213
198
+
199
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
200
+ `attention_mask` is [bsz, src_seq_len].
201
+
202
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
203
+
204
+ For example, if `expanded_mask` is (e.g. here left-padding case)
205
+ ```
206
+ [[[[0, 0, 0],
207
+ [0, 0, 0],
208
+ [0, 0, 1]]],
209
+ [[[1, 0, 0],
210
+ [1, 1, 0],
211
+ [1, 1, 1]]],
212
+ [[[0, 0, 0],
213
+ [0, 1, 0],
214
+ [0, 1, 1]]]]
215
+ ```
216
+ then the modified `expanded_mask` will be
217
+ ```
218
+ [[[[1, 1, 1], <-- modified
219
+ [1, 1, 1], <-- modified
220
+ [0, 0, 1]]],
221
+ [[[1, 0, 0],
222
+ [1, 1, 0],
223
+ [1, 1, 1]]],
224
+ [[[1, 1, 1], <-- modified
225
+ [0, 1, 0],
226
+ [0, 1, 1]]]]
227
+ ```
228
+ """
229
+ # fmt: on
230
+ if expanded_mask.dtype == torch.bool:
231
+ raise ValueError(
232
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
233
+ )
234
+
235
+ return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
236
+
237
+ @staticmethod
238
+ def _ignore_causal_mask_sdpa(
239
+ attention_mask: Optional[torch.Tensor],
240
+ inputs_embeds: torch.Tensor,
241
+ past_key_values_length: int,
242
+ sliding_window: Optional[int] = None,
243
+ is_training: bool = False,
244
+ ) -> bool:
245
+ """
246
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
247
+
248
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
249
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
250
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
251
+ """
252
+
253
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
254
+ key_value_length = query_length + past_key_values_length
255
+
256
+ is_tracing = (
257
+ torch.jit.is_tracing()
258
+ or isinstance(inputs_embeds, torch.fx.Proxy)
259
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
260
+ )
261
+
262
+ ignore_causal_mask = False
263
+
264
+ if attention_mask is None:
265
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or
266
+ # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
267
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
268
+ #
269
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
270
+ if (
271
+ (is_training or not is_tracing)
272
+ and (query_length == 1 or key_value_length == query_length)
273
+ and (sliding_window is None or key_value_length < sliding_window)
274
+ ):
275
+ ignore_causal_mask = True
276
+ elif sliding_window is None or key_value_length < sliding_window:
277
+ if len(attention_mask.shape) == 4:
278
+ return False
279
+ elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
280
+ if query_length == 1 or key_value_length == query_length:
281
+ # For query_length == 1, causal attention and bi-directional attention are the same.
282
+ ignore_causal_mask = True
283
+
284
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
285
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
286
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
287
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
288
+
289
+ return ignore_causal_mask
290
+
291
+
292
+ def _prepare_4d_causal_attention_mask(
293
+ attention_mask: Optional[torch.Tensor],
294
+ input_shape: Union[torch.Size, Tuple, List],
295
+ inputs_embeds: torch.Tensor,
296
+ past_key_values_length: int,
297
+ sliding_window: Optional[int] = None,
298
+ ):
299
+ """
300
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
301
+ `(batch_size, key_value_length)`
302
+
303
+ Args:
304
+ attention_mask (`torch.Tensor` or `None`):
305
+ A 2D attention mask of shape `(batch_size, key_value_length)`
306
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
307
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
308
+ inputs_embeds (`torch.Tensor`):
309
+ The embedded inputs as a torch Tensor.
310
+ past_key_values_length (`int`):
311
+ The length of the key value cache.
312
+ sliding_window (`int`, *optional*):
313
+ If the model uses windowed attention, a sliding window should be passed.
314
+ """
315
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
316
+
317
+ key_value_length = input_shape[-1] + past_key_values_length
318
+
319
+ # 4d mask is passed through the layers
320
+ if attention_mask is not None and len(attention_mask.shape) == 2:
321
+ attention_mask = attn_mask_converter.to_4d(
322
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
323
+ )
324
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
325
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
326
+ if tuple(attention_mask.shape) != expected_shape:
327
+ raise ValueError(
328
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
329
+ )
330
+ else:
331
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
332
+ inverted_mask = 1.0 - attention_mask
333
+ attention_mask = inverted_mask.masked_fill(
334
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
335
+ )
336
+ else:
337
+ attention_mask = attn_mask_converter.to_causal_4d(
338
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
339
+ )
340
+
341
+ return attention_mask
342
+
343
+
344
+ # Adapted from _prepare_4d_causal_attention_mask
345
+ def _prepare_4d_causal_attention_mask_for_sdpa(
346
+ attention_mask: Optional[torch.Tensor],
347
+ input_shape: Union[torch.Size, Tuple, List],
348
+ inputs_embeds: torch.Tensor,
349
+ past_key_values_length: int,
350
+ sliding_window: Optional[int] = None,
351
+ ):
352
+ """
353
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
354
+
355
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
356
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
357
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
358
+ """
359
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
360
+
361
+ key_value_length = input_shape[-1] + past_key_values_length
362
+
363
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
364
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
365
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
366
+ is_tracing = (
367
+ torch.jit.is_tracing()
368
+ or isinstance(inputs_embeds, torch.fx.Proxy)
369
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
370
+ )
371
+
372
+ ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
373
+ attention_mask=attention_mask,
374
+ inputs_embeds=inputs_embeds,
375
+ past_key_values_length=past_key_values_length,
376
+ sliding_window=sliding_window,
377
+ )
378
+
379
+ if ignore_causal_mask:
380
+ expanded_4d_mask = None
381
+ elif attention_mask is None:
382
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
383
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
384
+ )
385
+ else:
386
+ if attention_mask.dim() == 4:
387
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
388
+ if attention_mask.max() != 0:
389
+ raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
390
+ expanded_4d_mask = attention_mask
391
+ else:
392
+ expanded_4d_mask = attn_mask_converter.to_4d(
393
+ attention_mask,
394
+ input_shape[-1],
395
+ dtype=inputs_embeds.dtype,
396
+ key_value_length=key_value_length,
397
+ )
398
+
399
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
400
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
401
+ # Details: https://github.com/pytorch/pytorch/issues/110213
402
+ if not is_tracing and expanded_4d_mask.device.type == "cuda":
403
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
404
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
405
+ )
406
+
407
+ return expanded_4d_mask
408
+
409
+
410
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
411
+ """
412
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
413
+ `(batch_size, key_value_length)`
414
+
415
+ Args:
416
+ mask (`torch.Tensor`):
417
+ A 2D attention mask of shape `(batch_size, key_value_length)`
418
+ dtype (`torch.dtype`):
419
+ The torch dtype the created mask shall have.
420
+ tgt_len (`int`):
421
+ The target length or query length the created mask shall have.
422
+ """
423
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
424
+
425
+
426
+ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
427
+ """
428
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
429
+ `(batch_size, key_value_length)`
430
+
431
+ Args:
432
+ mask (`torch.Tensor`):
433
+ A 2D attention mask of shape `(batch_size, key_value_length)`
434
+ dtype (`torch.dtype`):
435
+ The torch dtype the created mask shall have.
436
+ tgt_len (`int`):
437
+ The target length or query length the created mask shall have.
438
+ """
439
+ _, key_value_length = mask.shape
440
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
441
+
442
+ is_tracing = (
443
+ torch.jit.is_tracing()
444
+ or isinstance(mask, torch.fx.Proxy)
445
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
446
+ )
447
+
448
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
449
+ if not is_tracing and torch.all(mask == 1):
450
+ return None
451
+ else:
452
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
453
+
454
+
455
+ def _create_4d_causal_attention_mask(
456
+ input_shape: Union[torch.Size, Tuple, List],
457
+ dtype: torch.dtype,
458
+ device: torch.device,
459
+ past_key_values_length: int = 0,
460
+ sliding_window: Optional[int] = None,
461
+ ) -> Optional[torch.Tensor]:
462
+ """
463
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
464
+
465
+ Args:
466
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
467
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
468
+ dtype (`torch.dtype`):
469
+ The torch dtype the created mask shall have.
470
+ device (`int`):
471
+ The torch device the created mask shall have.
472
+ sliding_window (`int`, *optional*):
473
+ If the model uses windowed attention, a sliding window should be passed.
474
+ """
475
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
476
+
477
+ key_value_length = past_key_values_length + input_shape[-1]
478
+ attention_mask = attn_mask_converter.to_causal_4d(
479
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
480
+ )
481
+
482
+ return attention_mask
transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import os
18
+ from typing import Optional, Tuple, Union
19
+
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+ from functools import lru_cache
25
+ import importlib.metadata
26
+ import importlib.util
27
+ from packaging import version
28
+
29
+ from transformers.utils import is_flash_attn_2_available
30
+
31
+
32
+ if is_flash_attn_2_available():
33
+ try:
34
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
36
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
37
+ except ImportError:
38
+ raise "Unable to import flash_attn"
39
+
40
+
41
+ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
42
+ # Check if the package spec exists and grab its version to avoid importing a local directory
43
+ package_exists = importlib.util.find_spec(pkg_name) is not None
44
+ package_version = "N/A"
45
+ if package_exists:
46
+ try:
47
+ # Primary method to get the package version
48
+ package_version = importlib.metadata.version(pkg_name)
49
+ except importlib.metadata.PackageNotFoundError:
50
+ # Fallback method: Only for "torch" and versions containing "dev"
51
+ if pkg_name == "torch":
52
+ try:
53
+ package = importlib.import_module(pkg_name)
54
+ temp_version = getattr(package, "__version__", "N/A")
55
+ # Check if the version contains "dev"
56
+ if "dev" in temp_version:
57
+ package_version = temp_version
58
+ package_exists = True
59
+ else:
60
+ package_exists = False
61
+ except ImportError:
62
+ # If the package can't be imported, it's not available
63
+ package_exists = False
64
+ else:
65
+ # For packages other than "torch", don't attempt the fallback and set as not available
66
+ package_exists = False
67
+ if return_version:
68
+ return package_exists, package_version
69
+ else:
70
+ return package_exists
71
+
72
+
73
+ @lru_cache()
74
+ def is_flash_attn_greater_or_equal(library_version: str):
75
+ if not _is_package_available("flash_attn"):
76
+ return False
77
+
78
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
79
+
80
+
81
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
82
+ """
83
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
84
+
85
+ Arguments:
86
+ attention_mask (`torch.Tensor`):
87
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
88
+
89
+ Return:
90
+ indices (`torch.Tensor`):
91
+ The indices of non-masked tokens from the flattened input sequence.
92
+ cu_seqlens (`torch.Tensor`):
93
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
94
+ max_seqlen_in_batch (`int`):
95
+ Maximum sequence length in batch.
96
+ """
97
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
98
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
99
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
100
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
101
+ return (
102
+ indices,
103
+ cu_seqlens,
104
+ max_seqlen_in_batch,
105
+ )
106
+
107
+
108
+ def _upad_input(
109
+ query_layer: torch.Tensor,
110
+ key_layer: torch.Tensor,
111
+ value_layer: torch.Tensor,
112
+ attention_mask: torch.Tensor,
113
+ query_length: int,
114
+ ):
115
+ """
116
+ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
117
+
118
+ This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
119
+ tensors for query, key, value tensors.
120
+
121
+ Arguments:
122
+ query_layer (`torch.Tensor`):
123
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
124
+ key_layer (`torch.Tensor`):
125
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
126
+ value_layer (`torch.Tensor`):
127
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
128
+ attention_mask (`torch.Tensor`):
129
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
130
+ query_length (`int`):
131
+ Target length.
132
+
133
+ Return:
134
+ query_layer (`torch.Tensor`):
135
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
136
+ key_layer (`torch.Tensor`):
137
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
138
+ value_layer (`torch.Tensor`):
139
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
140
+ indices_q (`torch.Tensor`):
141
+ The indices of non-masked tokens from the flattened input target sequence.
142
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
143
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
144
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
145
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
146
+ """
147
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
148
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
149
+
150
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
151
+ value_layer = index_first_axis(
152
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
153
+ )
154
+ if query_length == kv_seq_len:
155
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
156
+ cu_seqlens_q = cu_seqlens_k
157
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
158
+ indices_q = indices_k
159
+ elif query_length == 1:
160
+ max_seqlen_in_batch_q = 1
161
+ cu_seqlens_q = torch.arange(
162
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
163
+ ) # There is a memcpy here, that is very bad.
164
+ indices_q = cu_seqlens_q[:-1]
165
+ query_layer = query_layer.squeeze(1)
166
+ else:
167
+ # The -q_len: slice assumes left padding.
168
+ attention_mask = attention_mask[:, -query_length:]
169
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
170
+
171
+ return (
172
+ query_layer,
173
+ key_layer,
174
+ value_layer,
175
+ indices_q,
176
+ (cu_seqlens_q, cu_seqlens_k),
177
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
178
+ )
179
+
180
+
181
+ def prepare_fa2_from_position_ids(query, key, value, position_ids):
182
+ """
183
+ This function returns necessary arguments to call `flash_attn_varlen_func`.
184
+ All three query, key, value states will be flattened.
185
+ Cummulative lengths of each examples in the batch will be extracted from position_ids.
186
+
187
+ NOTE: ideally cummulative lengths should be prepared at the data collator stage
188
+
189
+ Arguments:
190
+ query (`torch.Tensor`):
191
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
192
+ key (`torch.Tensor`):
193
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
194
+ value (`torch.Tensor`):
195
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
196
+ position_ids (`torch.Tensor`):
197
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
198
+
199
+ Return:
200
+ query (`torch.Tensor`):
201
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
202
+ key (`torch.Tensor`):
203
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
204
+ value (`torch.Tensor`):
205
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
206
+ indices_q (`torch.Tensor`):
207
+ The indices of non-masked tokens from the flattened input target sequence.
208
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
209
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
210
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
211
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
212
+ """
213
+ query = query.view(-1, query.size(-2), query.size(-1))
214
+ key = key.view(-1, key.size(-2), key.size(-1))
215
+ value = value.view(-1, value.size(-2), value.size(-1))
216
+ position_ids = position_ids.flatten()
217
+ indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
218
+
219
+ cu_seq_lens = torch.cat(
220
+ (
221
+ indices_q[position_ids == 0],
222
+ torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
223
+ )
224
+ )
225
+
226
+ max_length = position_ids.max() + 1
227
+
228
+ return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
229
+
230
+
231
+ def _flash_attention_forward(
232
+ query_states: torch.Tensor,
233
+ key_states: torch.Tensor,
234
+ value_states: torch.Tensor,
235
+ attention_mask: torch.Tensor,
236
+ query_length: int,
237
+ is_causal: bool,
238
+ dropout: float = 0.0,
239
+ position_ids: Optional[torch.Tensor] = None,
240
+ softmax_scale: Optional[float] = None,
241
+ sliding_window: Optional[int] = None,
242
+ use_top_left_mask: bool = False,
243
+ softcap: Optional[float] = None,
244
+ deterministic: bool = None,
245
+ ):
246
+ """
247
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
248
+ first unpad the input, then computes the attention scores and pad the final attention scores.
249
+
250
+ Args:
251
+ query_states (`torch.Tensor`):
252
+ Input query states to be passed to Flash Attention API
253
+ key_states (`torch.Tensor`):
254
+ Input key states to be passed to Flash Attention API
255
+ value_states (`torch.Tensor`):
256
+ Input value states to be passed to Flash Attention API
257
+ attention_mask (`torch.Tensor`):
258
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
259
+ position of padding tokens and 1 for the position of non-padding tokens.
260
+ dropout (`float`):
261
+ Attention dropout
262
+ softmax_scale (`float`, *optional*):
263
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
264
+ use_top_left_mask (`bool`, defaults to `False`):
265
+ flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
266
+ softcap (`float`, *optional*):
267
+ Softcap for the attention logits, used e.g. in gemma2.
268
+ deterministic (`bool`, *optional*):
269
+ Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
270
+ """
271
+ if not use_top_left_mask:
272
+ causal = is_causal
273
+ else:
274
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
275
+ causal = is_causal and query_length != 1
276
+
277
+ # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
278
+ use_sliding_windows = (
279
+ _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
280
+ )
281
+ flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
282
+
283
+ if is_flash_attn_greater_or_equal("2.4.1"):
284
+ if deterministic is None:
285
+ deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
286
+ flash_kwargs["deterministic"] = deterministic
287
+
288
+ if softcap is not None:
289
+ flash_kwargs["softcap"] = softcap
290
+
291
+ # Contains at least one padding token in the sequence
292
+ if attention_mask is not None:
293
+ batch_size = query_states.shape[0]
294
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
295
+ query_states, key_states, value_states, attention_mask, query_length
296
+ )
297
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
298
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
299
+
300
+ attn_output_unpad = flash_attn_varlen_func(
301
+ query_states,
302
+ key_states,
303
+ value_states,
304
+ cu_seqlens_q=cu_seqlens_q,
305
+ cu_seqlens_k=cu_seqlens_k,
306
+ max_seqlen_q=max_seqlen_in_batch_q,
307
+ max_seqlen_k=max_seqlen_in_batch_k,
308
+ dropout_p=dropout,
309
+ softmax_scale=softmax_scale,
310
+ causal=causal,
311
+ **flash_kwargs,
312
+ )
313
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
314
+
315
+ # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
316
+ # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
317
+ # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
318
+ elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
319
+ batch_size = query_states.size(0)
320
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
321
+ query_states, key_states, value_states, position_ids
322
+ )
323
+
324
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
325
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
326
+
327
+ attn_output = flash_attn_varlen_func(
328
+ query_states,
329
+ key_states,
330
+ value_states,
331
+ cu_seqlens_q=cu_seqlens_q,
332
+ cu_seqlens_k=cu_seqlens_k,
333
+ max_seqlen_q=max_seqlen_in_batch_q,
334
+ max_seqlen_k=max_seqlen_in_batch_k,
335
+ dropout_p=dropout,
336
+ softmax_scale=softmax_scale,
337
+ causal=causal,
338
+ **flash_kwargs,
339
+ )
340
+
341
+ attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
342
+
343
+ else:
344
+ attn_output = flash_attn_func(
345
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
346
+ )
347
+
348
+ return attn_output
transformers_4_44_2__modeling_outputs.py ADDED
The diff for this file is too large to render. See raw diff
 
transformers_4_44_2__modeling_rope_utils.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Optional, Tuple
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import is_torch_available, logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ if is_torch_available():
26
+ import torch
27
+
28
+
29
+ def _compute_default_rope_parameters(
30
+ config: Optional[PretrainedConfig] = None,
31
+ device: Optional["torch.device"] = None,
32
+ seq_len: Optional[int] = None,
33
+ **rope_kwargs,
34
+ ) -> Tuple["torch.Tensor", float]:
35
+ """
36
+ Computes the inverse frequencies according to the original RoPE implementation
37
+ Args:
38
+ config ([`~transformers.PretrainedConfig`]):
39
+ The model configuration.
40
+ device (`torch.device`):
41
+ The device to use for initialization of the inverse frequencies.
42
+ seq_len (`int`, *optional*):
43
+ The current sequence length. Unused for this type of RoPE.
44
+ rope_kwargs (`Dict`, *optional*):
45
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
46
+ Returns:
47
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
48
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
49
+ """
50
+ if config is not None and len(rope_kwargs) > 0:
51
+ raise ValueError(
52
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
53
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
54
+ )
55
+ if len(rope_kwargs) > 0:
56
+ base = rope_kwargs["base"]
57
+ dim = rope_kwargs["dim"]
58
+ elif config is not None:
59
+ base = config.rope_theta
60
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
61
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
62
+ dim = int(head_dim * partial_rotary_factor)
63
+
64
+ attention_factor = 1.0 # Unused in this type of RoPE
65
+
66
+ # Compute the inverse frequencies
67
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
68
+ return inv_freq, attention_factor
69
+
70
+
71
+ def _compute_linear_scaling_rope_parameters(
72
+ config: Optional[PretrainedConfig] = None,
73
+ device: Optional["torch.device"] = None,
74
+ seq_len: Optional[int] = None,
75
+ **rope_kwargs,
76
+ ) -> Tuple["torch.Tensor", float]:
77
+ """
78
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
79
+ Args:
80
+ config ([`~transformers.PretrainedConfig`]):
81
+ The model configuration.
82
+ device (`torch.device`):
83
+ The device to use for initialization of the inverse frequencies.
84
+ seq_len (`int`, *optional*):
85
+ The current sequence length. Unused for this type of RoPE.
86
+ rope_kwargs (`Dict`, *optional*):
87
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
88
+ Returns:
89
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
90
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
91
+ """
92
+ if config is not None and len(rope_kwargs) > 0:
93
+ raise ValueError(
94
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
95
+ f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
96
+ )
97
+ if len(rope_kwargs) > 0:
98
+ factor = rope_kwargs["factor"]
99
+ elif config is not None:
100
+ factor = config.rope_scaling["factor"]
101
+
102
+ # Gets the default RoPE parameters
103
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
104
+
105
+ # Then applies linear scaling to the frequencies.
106
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
107
+ # applying scaling to the inverse frequencies is equivalent.
108
+ inv_freq /= factor
109
+ return inv_freq, attention_factor
110
+
111
+
112
+ def _compute_dynamic_ntk_parameters(
113
+ config: Optional[PretrainedConfig] = None,
114
+ device: Optional["torch.device"] = None,
115
+ seq_len: Optional[int] = None,
116
+ **rope_kwargs,
117
+ ) -> Tuple["torch.Tensor", float]:
118
+ """
119
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
120
+ Args:
121
+ config ([`~transformers.PretrainedConfig`]):
122
+ The model configuration.
123
+ device (`torch.device`):
124
+ The device to use for initialization of the inverse frequencies.
125
+ seq_len (`int`, *optional*):
126
+ The current sequence length, used to update the dynamic RoPE at inference time.
127
+ rope_kwargs (`Dict`, *optional*):
128
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
129
+ Returns:
130
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
131
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
132
+ """
133
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
134
+ if config is not None and len(rope_kwargs) > 0:
135
+ raise ValueError(
136
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
137
+ f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
138
+ )
139
+ if len(rope_kwargs) > 0:
140
+ base = rope_kwargs["base"]
141
+ dim = rope_kwargs["dim"]
142
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
143
+ factor = rope_kwargs["factor"]
144
+ elif config is not None:
145
+ base = config.rope_theta
146
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
147
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
148
+ dim = int(head_dim * partial_rotary_factor)
149
+ max_position_embeddings = config.max_position_embeddings
150
+ factor = config.rope_scaling["factor"]
151
+
152
+ attention_factor = 1.0 # Unused in this type of RoPE
153
+
154
+ # seq_len: default to max_position_embeddings, e.g. at init time
155
+ seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
156
+
157
+ # Compute the inverse frequencies
158
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
159
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
160
+ return inv_freq, attention_factor
161
+
162
+
163
+ def _compute_yarn_parameters(
164
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
165
+ ) -> Tuple["torch.Tensor", float]:
166
+ """
167
+ Computes the inverse frequencies with NTK scaling. Please refer to the
168
+ [original paper](https://arxiv.org/abs/2309.00071)
169
+ Args:
170
+ config ([`~transformers.PretrainedConfig`]):
171
+ The model configuration.
172
+ device (`torch.device`):
173
+ The device to use for initialization of the inverse frequencies.
174
+ seq_len (`int`, *optional*):
175
+ The current sequence length. Unused for this type of RoPE.
176
+ rope_kwargs (`Dict`, *optional*):
177
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
178
+ Returns:
179
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
180
+ post-processing scaling factor applied to the computed cos/sin.
181
+ """
182
+ # No need to keep BC with yarn, unreleased when this new pattern was created.
183
+ if len(rope_kwargs) > 0:
184
+ raise ValueError(
185
+ f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
186
+ )
187
+
188
+ base = config.rope_theta
189
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
190
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
191
+ dim = int(head_dim * partial_rotary_factor)
192
+ max_position_embeddings = config.max_position_embeddings
193
+ factor = config.rope_scaling["factor"]
194
+
195
+ # Sets the attention factor as suggested in the paper
196
+ attention_factor = config.rope_scaling.get("attention_factor")
197
+ if attention_factor is None:
198
+ attention_factor = 0.1 * math.log(factor) + 1.0
199
+
200
+ # Optional config options
201
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
202
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
203
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
204
+
205
+ # Compute the inverse frequencies
206
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
207
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
208
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
209
+
210
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
211
+ """Find dimension range bounds based on rotations"""
212
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
213
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
214
+ return max(low, 0), min(high, dim - 1)
215
+
216
+ def linear_ramp_factor(min, max, dim):
217
+ if min == max:
218
+ max += 0.001 # Prevent singularity
219
+
220
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
221
+ ramp_func = torch.clamp(linear_func, 0, 1)
222
+ return ramp_func
223
+
224
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
225
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
226
+ pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
227
+ inv_freq_extrapolation = 1.0 / pos_freqs
228
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
229
+
230
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
231
+
232
+ # Get n-dimensional rotational scaling corrected for extrapolation
233
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
234
+ inv_freq = (
235
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
236
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
237
+ )
238
+
239
+ return inv_freq, attention_factor
240
+
241
+
242
+ def _compute_longrope_parameters(
243
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
244
+ ) -> Tuple["torch.Tensor", float]:
245
+ """
246
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
247
+ [original implementation](https://github.com/microsoft/LongRoPE)
248
+ Args:
249
+ config ([`~transformers.PretrainedConfig`]):
250
+ The model configuration.
251
+ device (`torch.device`):
252
+ The device to use for initialization of the inverse frequencies.
253
+ seq_len (`int`, *optional*):
254
+ The current sequence length. Unused for this type of RoPE.
255
+ rope_kwargs (`Dict`, *optional*):
256
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
257
+ Returns:
258
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
259
+ post-processing scaling factor applied to the computed cos/sin.
260
+ """
261
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
262
+ # No need to keep BC with longrope, unreleased when this new pattern was created.
263
+ if len(rope_kwargs) > 0:
264
+ raise ValueError(
265
+ "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
266
+ f"{rope_kwargs}"
267
+ )
268
+
269
+ base = config.rope_theta
270
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
271
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
272
+ dim = int(head_dim * partial_rotary_factor)
273
+ long_factor = config.rope_scaling["long_factor"]
274
+ short_factor = config.rope_scaling["short_factor"]
275
+ factor = config.rope_scaling.get("factor")
276
+ attention_factor = config.rope_scaling.get("attention_factor")
277
+
278
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
279
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
280
+ # values to compute the default attention scaling factor, instead of using `factor`.
281
+ if hasattr(config, "original_max_position_embeddings"):
282
+ max_position_embeddings = config.original_max_position_embeddings
283
+ expanded_max_position_embeddings = config.max_position_embeddings
284
+ factor = expanded_max_position_embeddings / max_position_embeddings
285
+ else:
286
+ max_position_embeddings = config.max_position_embeddings
287
+ expanded_max_position_embeddings = max_position_embeddings * factor
288
+
289
+ # Sets the attention factor as suggested in the paper
290
+ if attention_factor is None:
291
+ if factor <= 1.0:
292
+ attention_factor = 1.0
293
+ else:
294
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
295
+
296
+ # Compute the inverse frequencies -- scaled based on the target sequence length
297
+ if expanded_max_position_embeddings > max_position_embeddings:
298
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
299
+ else:
300
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
301
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
302
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
303
+
304
+ return inv_freq, attention_factor
305
+
306
+
307
+ def _compute_llama3_parameters(
308
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
309
+ ) -> Tuple["torch.Tensor", float]:
310
+ """
311
+ Computes the inverse frequencies for llama 3.1.
312
+
313
+ Args:
314
+ config ([`~transformers.PretrainedConfig`]):
315
+ The model configuration.
316
+ device (`torch.device`):
317
+ The device to use for initialization of the inverse frequencies.
318
+ seq_len (`int`, *optional*):
319
+ The current sequence length. Unused for this type of RoPE.
320
+ rope_kwargs (`Dict`, *optional*):
321
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
322
+ Returns:
323
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
324
+ post-processing scaling factor applied to the computed cos/sin.
325
+ """
326
+ # Gets the default RoPE parameters
327
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
328
+
329
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
330
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
331
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
332
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
333
+
334
+ low_freq_wavelen = old_context_len / low_freq_factor
335
+ high_freq_wavelen = old_context_len / high_freq_factor
336
+
337
+ wavelen = 2 * math.pi / inv_freq
338
+ # wavelen < high_freq_wavelen: do nothing
339
+ # wavelen > low_freq_wavelen: divide by factor
340
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
341
+ # otherwise: interpolate between the two, using a smooth factor
342
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
343
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
344
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
345
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
346
+
347
+ return inv_freq_llama, attention_factor
348
+
349
+
350
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
351
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
352
+ # parameterizations, as long as the callable has the same signature.
353
+ ROPE_INIT_FUNCTIONS = {
354
+ "default": _compute_default_rope_parameters,
355
+ "linear": _compute_linear_scaling_rope_parameters,
356
+ "dynamic": _compute_dynamic_ntk_parameters,
357
+ "yarn": _compute_yarn_parameters,
358
+ "longrope": _compute_longrope_parameters,
359
+ "llama3": _compute_llama3_parameters,
360
+ }
361
+
362
+
363
+ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
364
+ """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
365
+ # BC: "rope_type" was originally "type" -- let's gracefully handle it
366
+ if "rope_type" not in received_keys and "type" in received_keys:
367
+ received_keys -= {"type"}
368
+ received_keys.add("rope_type")
369
+
370
+ missing_keys = required_keys - received_keys
371
+ if missing_keys:
372
+ raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
373
+
374
+ if optional_keys is not None:
375
+ unused_keys = received_keys - required_keys - optional_keys
376
+ else:
377
+ unused_keys = received_keys - required_keys
378
+ if unused_keys:
379
+ logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
380
+
381
+
382
+ def _validate_default_rope_parameters(config: PretrainedConfig):
383
+ rope_scaling = config.rope_scaling
384
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
385
+ required_keys = {"rope_type"}
386
+ received_keys = set(rope_scaling.keys())
387
+ _check_received_keys(rope_type, received_keys, required_keys)
388
+
389
+
390
+ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
391
+ rope_scaling = config.rope_scaling
392
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
393
+ required_keys = {"rope_type", "factor"}
394
+ received_keys = set(rope_scaling.keys())
395
+ _check_received_keys(rope_type, received_keys, required_keys)
396
+
397
+ factor = rope_scaling["factor"]
398
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
399
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
400
+
401
+
402
+ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
403
+ rope_scaling = config.rope_scaling
404
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
405
+ required_keys = {"rope_type", "factor"}
406
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
407
+ optional_keys = {"original_max_position_embeddings"}
408
+ received_keys = set(rope_scaling.keys())
409
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
410
+
411
+ factor = rope_scaling["factor"]
412
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
413
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
414
+
415
+
416
+ def _validate_yarn_parameters(config: PretrainedConfig):
417
+ rope_scaling = config.rope_scaling
418
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
419
+ required_keys = {"rope_type", "factor"}
420
+ optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
421
+ received_keys = set(rope_scaling.keys())
422
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
423
+
424
+ factor = rope_scaling["factor"]
425
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
426
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
427
+
428
+ attention_factor = rope_scaling.get("attention_factor")
429
+ if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
430
+ logger.warning(
431
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
432
+ )
433
+ beta_fast = rope_scaling.get("beta_fast")
434
+ if beta_fast is not None and not isinstance(beta_fast, float):
435
+ logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
436
+ beta_slow = rope_scaling.get("beta_slow")
437
+ if beta_slow is not None and not isinstance(beta_slow, float):
438
+ logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
439
+
440
+ if (beta_fast or 32) < (beta_slow or 1):
441
+ logger.warning(
442
+ f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
443
+ f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
444
+ )
445
+
446
+
447
+ def _validate_longrope_parameters(config: PretrainedConfig):
448
+ rope_scaling = config.rope_scaling
449
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
450
+ required_keys = {"rope_type", "short_factor", "long_factor"}
451
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
452
+ optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
453
+ received_keys = set(rope_scaling.keys())
454
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys)
455
+
456
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
457
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
458
+ dim = int(head_dim * partial_rotary_factor)
459
+
460
+ short_factor = rope_scaling.get("short_factor")
461
+ if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
462
+ logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
463
+ if not len(short_factor) == dim // 2:
464
+ logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
465
+
466
+ long_factor = rope_scaling.get("long_factor")
467
+ if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
468
+ logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
469
+ if not len(long_factor) == dim // 2:
470
+ logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
471
+
472
+ # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
473
+ # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
474
+ # unique to longrope (= undesirable)
475
+ if hasattr(config, "original_max_position_embeddings"):
476
+ logger.warning_once(
477
+ "This model has set a `original_max_position_embeddings` field, to be used together with "
478
+ "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
479
+ "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
480
+ "as it is compatible with most model architectures."
481
+ )
482
+ else:
483
+ factor = rope_scaling.get("factor")
484
+ if factor is None:
485
+ logger.warning("Missing required keys in `rope_scaling`: 'factor'")
486
+ elif not isinstance(factor, float) or factor < 1.0:
487
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
488
+
489
+ attention_factor = rope_scaling.get("attention_factor")
490
+ if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
491
+ logger.warning(
492
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
493
+ )
494
+
495
+
496
+ def _validate_llama3_parameters(config: PretrainedConfig):
497
+ rope_scaling = config.rope_scaling
498
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
499
+ required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
500
+ received_keys = set(rope_scaling.keys())
501
+ _check_received_keys(rope_type, received_keys, required_keys)
502
+
503
+ factor = rope_scaling["factor"]
504
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
505
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
506
+
507
+ low_freq_factor = rope_scaling["low_freq_factor"]
508
+ high_freq_factor = rope_scaling["high_freq_factor"]
509
+ if low_freq_factor is None or not isinstance(low_freq_factor, float):
510
+ logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
511
+ if high_freq_factor is None or not isinstance(high_freq_factor, float):
512
+ logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
513
+ if high_freq_factor <= low_freq_factor:
514
+ logger.warning(
515
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
516
+ f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
517
+ )
518
+
519
+ original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
520
+ if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
521
+ logger.warning(
522
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
523
+ f"{original_max_position_embeddings}"
524
+ )
525
+ if original_max_position_embeddings >= config.max_position_embeddings:
526
+ logger.warning(
527
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
528
+ f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
529
+ )
530
+
531
+
532
+ # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
533
+ ROPE_VALIDATION_FUNCTIONS = {
534
+ "default": _validate_default_rope_parameters,
535
+ "linear": _validate_linear_scaling_rope_parameters,
536
+ "dynamic": _validate_dynamic_scaling_rope_parameters,
537
+ "yarn": _validate_yarn_parameters,
538
+ "longrope": _validate_longrope_parameters,
539
+ "llama3": _validate_llama3_parameters,
540
+ }
541
+
542
+
543
+ def rope_config_validation(config: PretrainedConfig):
544
+ """
545
+ Validate the RoPE config arguments, given a `PretrainedConfig` object
546
+ """
547
+ rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
548
+ if rope_scaling is None:
549
+ return
550
+
551
+ # BC: "rope_type" was originally "type"
552
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
553
+ validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
554
+ if validation_fn is not None:
555
+ validation_fn(config)
556
+ else:
557
+ logger.warning(
558
+ f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
559
+ )
transformers_4_44_2__pytorch_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torch import nn
16
+
17
+ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
variable_cache.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Nvidia Corporation. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from copy import deepcopy
17
+ from typing import Optional, Dict, Any, Tuple
18
+
19
+ import torch
20
+ from transformers.cache_utils import Cache # used to let GenerationMixin know that we use a Cache object
21
+
22
+ from .configuration_decilm import DeciLMConfig
23
+ from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2, SinkCache, StaticCache, SlidingWindowCache
24
+
25
+
26
+ class VariableCache(Cache_4_44_2, Cache):
27
+ """
28
+ A Cache object that supports a different Cache implementation for every layer,
29
+ including layers without any kv-cache.
30
+ Implemented using a list of Cache objects, each represents a "model" with 1 layer.
31
+ The default implementation for the layer caches is StaticCache.
32
+ The cache of each layer is allocated to the same gpu as the layer itself.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ *, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions
38
+ config: DeciLMConfig,
39
+ batch_size: int = None,
40
+ max_cache_len: int = None,
41
+ dtype: torch.dtype = torch.float32,
42
+ max_batch_size: Optional[int] = None,
43
+ **kwargs,
44
+ ) -> None:
45
+ Cache_4_44_2.__init__(self)
46
+
47
+ self.config = deepcopy(config)
48
+ self.max_batch_size = batch_size or max_batch_size
49
+ self.batch_size = self.max_batch_size
50
+ self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
51
+ self.dtype = dtype
52
+
53
+ self.layer_caches: list[Cache_4_44_2 | None] = [None] * config.num_hidden_layers
54
+ self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers
55
+
56
+ def update(
57
+ self,
58
+ key_states: torch.Tensor,
59
+ value_states: torch.Tensor,
60
+ layer_idx: int,
61
+ cache_kwargs: Optional[Dict[str, Any]] = None,
62
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ if self.layer_caches[layer_idx] is None:
64
+ self.layer_devices[layer_idx] = key_states.device
65
+ self._init_layer_cache(layer_idx)
66
+
67
+ layer_cache = self.layer_caches[layer_idx]
68
+ assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}"
69
+
70
+ k_out, v_out = layer_cache.update(key_states=key_states,
71
+ value_states=value_states,
72
+ layer_idx=0,
73
+ cache_kwargs=cache_kwargs)
74
+ seq_len = self.get_seq_length(layer_idx)
75
+ k_out = k_out[:, :, :seq_len, :]
76
+ v_out = v_out[:, :, :seq_len, :]
77
+ return k_out, v_out
78
+
79
+ def _init_layer_cache(self, layer_idx: int) -> None:
80
+ block_config = self.config.block_configs[layer_idx]
81
+ attention_config = block_config.attention
82
+
83
+ if attention_config.no_op or attention_config.replace_with_linear:
84
+ return None
85
+
86
+ device = self.layer_devices[layer_idx]
87
+ assert device is not None, f"Trying to init layer cache for {layer_idx=} without device"
88
+
89
+ config = deepcopy(self.config)
90
+ config.num_hidden_layers = 1
91
+ config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group
92
+
93
+ if attention_config.window_length is not None:
94
+ if not attention_config.is_sink:
95
+ config.sliding_window = attention_config.window_length
96
+ self.layer_caches[layer_idx] = SlidingWindowCache(config=config,
97
+ max_batch_size=self.max_batch_size,
98
+ max_cache_len=self.max_cache_len,
99
+ device=device,
100
+ dtype=self.dtype)
101
+ return
102
+ elif not attention_config.unshifted_sink:
103
+ self.layer_caches[layer_idx] = SinkCache(window_length=attention_config.window_length,
104
+ num_sink_tokens=attention_config.num_sink_tokens)
105
+ return
106
+
107
+ self.layer_caches[layer_idx] = StaticCache(config=config,
108
+ max_batch_size=self.max_batch_size,
109
+ max_cache_len=self.max_cache_len,
110
+ device=device,
111
+ dtype=self.dtype)
112
+
113
+ def _get_first_real_cache(self) -> Cache:
114
+ for layer_cache in self.layer_caches:
115
+ if layer_cache is not None:
116
+ return layer_cache
117
+ raise ValueError(f"No real cache found, all layer caches are None.")
118
+
119
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
120
+ if layer_idx == 0 and self.layer_caches[0] is None:
121
+ try:
122
+ layer_cache = self._get_first_real_cache()
123
+ except ValueError:
124
+ return 0
125
+ else:
126
+ layer_cache = self.layer_caches[layer_idx]
127
+ return layer_cache.get_seq_length()
128
+
129
+ def get_max_length(self) -> Optional[int]:
130
+ """Returns the maximum sequence length of the cached states."""
131
+ return self.max_cache_len
132
+
133
+ def reset(self):
134
+ for layer_idx in range(len(self.layer_caches)):
135
+ layer_cache = self.layer_caches[layer_idx]
136
+ if hasattr(layer_cache, "reset"):
137
+ layer_cache.reset()
138
+ else:
139
+ self._init_layer_cache(layer_idx)