SophieA17 commited on
Commit
edd20a2
·
verified ·
1 Parent(s): b0406a5

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +34 -186
  2. config.json +6 -0
  3. configuration_sophie0.py +60 -0
  4. modeling_sophie0.py +767 -0
README.md CHANGED
@@ -1,199 +1,47 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
5
 
6
- # Model Card for Model ID
7
 
8
- <!-- Provide a quick summary of what the model is/does. -->
9
 
 
10
 
 
11
 
12
- ## Model Details
13
 
14
- ### Model Description
15
 
16
- <!-- Provide a longer summary of what this model is. -->
 
 
17
 
18
- This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
 
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
- - **Language(s) (NLP):** [More Information Needed]
25
- - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
 
28
- ### Model Sources [optional]
 
 
 
 
29
 
30
- <!-- Provide the basic links for the model. -->
31
 
32
- - **Repository:** [More Information Needed]
33
- - **Paper [optional]:** [More Information Needed]
34
- - **Demo [optional]:** [More Information Needed]
 
 
 
 
 
 
 
 
 
35
 
36
- ## Uses
37
-
38
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
-
40
- ### Direct Use
41
-
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
-
44
- [More Information Needed]
45
-
46
- ### Downstream Use [optional]
47
-
48
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
-
50
- [More Information Needed]
51
-
52
- ### Out-of-Scope Use
53
-
54
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
-
56
- [More Information Needed]
57
-
58
- ## Bias, Risks, and Limitations
59
-
60
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
-
62
- [More Information Needed]
63
-
64
- ### Recommendations
65
-
66
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
-
68
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
-
70
- ## How to Get Started with the Model
71
-
72
- Use the code below to get started with the model.
73
-
74
- [More Information Needed]
75
-
76
- ## Training Details
77
-
78
- ### Training Data
79
-
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
-
82
- [More Information Needed]
83
-
84
- ### Training Procedure
85
-
86
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
-
88
- #### Preprocessing [optional]
89
-
90
- [More Information Needed]
91
-
92
-
93
- #### Training Hyperparameters
94
-
95
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
-
97
- #### Speeds, Sizes, Times [optional]
98
-
99
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
-
101
- [More Information Needed]
102
-
103
- ## Evaluation
104
-
105
- <!-- This section describes the evaluation protocols and provides the results. -->
106
-
107
- ### Testing Data, Factors & Metrics
108
-
109
- #### Testing Data
110
-
111
- <!-- This should link to a Dataset Card if possible. -->
112
-
113
- [More Information Needed]
114
-
115
- #### Factors
116
-
117
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
-
119
- [More Information Needed]
120
-
121
- #### Metrics
122
-
123
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
-
125
- [More Information Needed]
126
-
127
- ### Results
128
-
129
- [More Information Needed]
130
-
131
- #### Summary
132
-
133
-
134
-
135
- ## Model Examination [optional]
136
-
137
- <!-- Relevant interpretability work for the model goes here -->
138
-
139
- [More Information Needed]
140
-
141
- ## Environmental Impact
142
-
143
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
-
145
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
-
147
- - **Hardware Type:** [More Information Needed]
148
- - **Hours used:** [More Information Needed]
149
- - **Cloud Provider:** [More Information Needed]
150
- - **Compute Region:** [More Information Needed]
151
- - **Carbon Emitted:** [More Information Needed]
152
-
153
- ## Technical Specifications [optional]
154
-
155
- ### Model Architecture and Objective
156
-
157
- [More Information Needed]
158
-
159
- ### Compute Infrastructure
160
-
161
- [More Information Needed]
162
-
163
- #### Hardware
164
-
165
- [More Information Needed]
166
-
167
- #### Software
168
-
169
- [More Information Needed]
170
-
171
- ## Citation [optional]
172
-
173
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
-
175
- **BibTeX:**
176
-
177
- [More Information Needed]
178
-
179
- **APA:**
180
-
181
- [More Information Needed]
182
-
183
- ## Glossary [optional]
184
-
185
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
-
187
- [More Information Needed]
188
-
189
- ## More Information [optional]
190
-
191
- [More Information Needed]
192
-
193
- ## Model Card Authors [optional]
194
-
195
- [More Information Needed]
196
-
197
- ## Model Card Contact
198
-
199
- [More Information Needed]
 
1
+ Sophie0-SFT
 
 
 
2
 
3
+ ### Introduction
4
 
5
+ Sophie0是一个从头实现的单人0.5B大语言模型项目,主要核心在于完整跑通**预训练(Pretrain)**、**监督微调(Supervised Fine-tune, SFT)**、**直接偏好优化(Direct Preference Optimization, DPO)**、以及基于 **组内相对策略优化(Group Relative Policy Optimization, GRPO)** 的显示**思维链推理**等主要流程。其中预训练阶段使用BAAI开源的多领域数据集,总数据量约11B Tokens,消耗52x4 GPU hours;微调阶段使用BAAI以及数学CoT数据在内总计9.74M行对话数据,消耗24x8 GPU hours;DPO阶段使用BAAI的偏好数据以及从LLama 3提取的英语对话数据在内总计159.3k对数据,消耗1x4 GPU hours;GRPO阶段使用Knights & Knaves 3ppl数据集以及从DeepSeek-R1提取的思维链对模型进行SFT和GRPO,SFT阶段总计有1.5k条数据,GRPO阶段总计有500条prompt,前者消耗10min x 1 GPU Times,后者消耗51x2 GPU hours.
6
 
7
+ 此外,本项目进一步探讨了在下游SFT和DPO阶段完全使用变长(varlen)序列训练的可行性以及实现方式,充分利用了flash attention 2自带的`varlen attention` 和 `varlen RoPE`算子,同时也探讨了批量推理时引入的填充token对输出的影响,以及如何通过设计兼容varlen的KV Cache类直接基于Huggingface GenerationMixin接口无缝切块填充推理和无填充变长序列推理
8
 
9
+ 更多内容详见[此处](https://github.com/Sophie10001b/sophie0)
10
 
11
+ ### QuickStart
12
 
13
+ 可通过如下方法直接调用本模型
14
 
15
+ ```python
16
+ import os
17
+ import torch
18
 
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
20
 
21
+ model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained("SophieA17/Sophie0-SFT", trust_remote_code=True)
22
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("SophieA17/Sophie0-SFT", trust_remote_code=True)
 
 
 
 
 
23
 
24
+ model = model.to(device="cuda:0", dtype=torch.bfloat16)
25
+ inputs = [
26
+ "<s><user>Could you please introduce youself?</s>\n",
27
+ "<s><user>Where is the best place for traveling in summer?</s>\n"
28
+ ]
29
 
30
+ input_ids = tokenizer(inputs, return_tensors="pt", padding=True, padding_side="left", return_token_type_ids=False).to(model.device)
31
 
32
+ generation_config = GenerationConfig(
33
+ bos_token_id=tokenizer.bos_token_id,
34
+ eos_token_id=tokenizer.eos_token_id,
35
+ pad_token_id=tokenizer.pad_token_id,
36
+ max_new_tokens=1024,
37
+ do_sample=True,
38
+ top_k=20,
39
+ top_p=0.8,
40
+ temperature=0.8,
41
+ repeat_penalty=1.1,
42
+ use_cache=True
43
+ )
44
 
45
+ outputs = model.generate(**input_ids, use_varlen_inference=True, generation_config=generation_config)
46
+ outputs = tokenizer.batch_decode(outputs, skip_special_tokens=False)
47
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -1,7 +1,13 @@
1
  {
 
2
  "architectures": [
3
  "Sophie0ForCausalLM"
4
  ],
 
 
 
 
 
5
  "bos_token_id": 0,
6
  "bot_token_id": 10,
7
  "dropout": 0.0,
 
1
  {
2
+ "_name_or_path": "Sophie0-SFT",
3
  "architectures": [
4
  "Sophie0ForCausalLM"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_sophie0.Sophie0Config",
8
+ "AutoModel": "modeling_sophie0.Sophie0Model",
9
+ "AutoModelForCausalLM": "modeling_sophie0.Sophie0ForCausalLM"
10
+ },
11
  "bos_token_id": 0,
12
  "bot_token_id": 10,
13
  "dropout": 0.0,
configuration_sophie0.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Union, Optional, Any
2
+ from transformers.configuration_utils import PretrainedConfig
3
+
4
+ class Sophie0Config(PretrainedConfig):
5
+
6
+ model_type = 'transformer'
7
+ keys_to_ignore_at_inference = ['past_key_values']
8
+
9
+ def __init__(
10
+ self,
11
+ hidden_size: int = 1024,
12
+ num_hidden_layers: int = 28,
13
+ num_heads: int = 16,
14
+ num_kv_heads: int = 8,
15
+ window_size: Optional[int] = None,
16
+ rope_base: Optional[int] = int(1e6),
17
+ intermediate_size: Optional[int] = 4096,
18
+ hidden_act: str = "swish",
19
+ eps: float = 1e-5,
20
+ use_cache: bool = True,
21
+ pad_token_id: int = 3,
22
+ bos_token_id: int = 0,
23
+ eos_token_id: int = 1,
24
+ prompt_token_id: int = 8,
25
+ user_token_id: int = 9,
26
+ bot_token_id: int = 10,
27
+ tie_word_embeddings: bool = True,
28
+ vocab_size: int = 65536,
29
+ dropout: float = 0.0,
30
+ right_shift: bool = False,
31
+ **kwargs,
32
+ ):
33
+ self.hidden_size = hidden_size
34
+ self.num_hidden_layers = num_hidden_layers
35
+ self.num_heads = num_heads
36
+ self.num_kv_heads = num_kv_heads
37
+ self.window_size = window_size
38
+ self.rope_base = rope_base
39
+
40
+ self.intermediate_size = intermediate_size
41
+ self.hidden_act = hidden_act
42
+
43
+ self.eps = eps
44
+ self.use_cache = use_cache
45
+
46
+ self.vocab_size = vocab_size
47
+ self.dropout = dropout
48
+ self.right_shift = right_shift
49
+
50
+ self.prompt_token_id = prompt_token_id
51
+ self.user_token_id = user_token_id
52
+ self.bot_token_id = bot_token_id
53
+
54
+ super().__init__(
55
+ pad_token_id=pad_token_id,
56
+ bos_token_id=bos_token_id,
57
+ eos_token_id=eos_token_id,
58
+ tie_word_embeddings=tie_word_embeddings,
59
+ **kwargs,
60
+ )
modeling_sophie0.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint as checkpoint
7
+ import transformers
8
+
9
+ from typing import Optional, Dict, Tuple, List, Union, Unpack, Sequence, Any
10
+ from flash_attn import (
11
+ flash_attn_kvpacked_func,
12
+ flash_attn_varlen_func
13
+ )
14
+ from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb
15
+ from flash_attn.ops.triton.layer_norm import RMSNorm
16
+ from flash_attn.modules.mlp import GatedMlp
17
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
18
+ from einops import rearrange
19
+ from itertools import chain
20
+ from flash_attn.bert_padding import unpad_input
21
+
22
+ from .configuration_sophie0 import Sophie0Config
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.generation import GenerationMixin
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
26
+
27
+ #########################################################
28
+ # --- basic functions ---
29
+ #########################################################
30
+ class Cache(transformers.cache_utils.Cache):
31
+ """
32
+ A cache used for storing hidden states produced by flash linear attention models.
33
+
34
+ **Input:**
35
+ - attn_state: Cache for standard attention, tuple(size(bsz, k_len/v_len, dmodel) * 2)
36
+ """
37
+
38
+ is_compileable = True
39
+
40
+ def __init__(self, cache_position: int = 0):
41
+ super().__init__()
42
+
43
+ self.states: List[Dict[str, Any]] = []
44
+ self._cache_position = [cache_position] # Used in `generate` to keep tally of how many tokens the cache has seen
45
+
46
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
47
+ if layer_idx < len(self):
48
+ return self.states[layer_idx]
49
+ else:
50
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
51
+
52
+ def __iter__(self):
53
+ for state in self.states: yield state
54
+
55
+ def __len__(self):
56
+ return len(self.states)
57
+
58
+ def update(
59
+ self,
60
+ attn_state: Tuple[torch.Tensor, torch.Tensor] = None,
61
+ layer_idx: int = 0,
62
+ offset: Optional[int] = 1,
63
+ cache_kwargs: Optional[Dict[str, Any]] = {},
64
+ ) -> Dict[str, Any]:
65
+ """
66
+ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
67
+
68
+ Args:
69
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
70
+ The new attention key/value states to cache.
71
+ layer_idx (`int`, defaults to 0):
72
+ The index of the layer to cache the states for.
73
+ offset (`int`, `optional`, defaults to 1):
74
+ The number of new tokens being processed.
75
+ cache_kwargs (`Dict[str, Any]`, `optional`):
76
+ Additional arguments for the cache subclass.
77
+
78
+ Return:
79
+ Dictionary of the updated state.
80
+ """
81
+
82
+ # Update the number of seen tokens
83
+ if len(self._cache_position) <= layer_idx:
84
+ self._cache_position.append(0)
85
+
86
+ self._cache_position[layer_idx] += offset
87
+
88
+ if attn_state is not None:
89
+ input_size = attn_state[0].shape[-2]
90
+ window_size = cache_kwargs.get('window_size', None)
91
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
92
+ raise ValueError("`attn_state` must be a tuple of two tensors for key/value states")
93
+ if len(self.states) <= layer_idx:
94
+ if attn_state is not None:
95
+ if window_size is not None and input_size > window_size:
96
+ attn_state = (attn_state[0][..., -window_size:, :].contiguous(),
97
+ attn_state[1][..., -window_size:, :].contiguous())
98
+ state = dict(
99
+ attn_state=attn_state,
100
+ )
101
+ self.states.append(state)
102
+ else:
103
+ state = self.states[layer_idx]
104
+ if attn_state is not None:
105
+ if state['attn_state'] is None:
106
+ if window_size is not None and input_size > window_size:
107
+ attn_state = (attn_state[0][..., -window_size:, :].contiguous(),
108
+ attn_state[1][..., -window_size:, :].contiguous())
109
+ else:
110
+ key_state, value_state = state['attn_state']
111
+ if window_size is not None and key_state.shape[-2] == window_size:
112
+ # DO NOT allocate new memory if the cache is full
113
+ # roll the key/value states to the left by `input_size`
114
+ key_state = key_state.roll(-input_size, -2)
115
+ value_state = value_state.roll(-input_size, -2)
116
+ # replace the last `input_size` tokens with the new key/value states
117
+ key_state[..., -input_size:, :] = attn_state[0]
118
+ value_state[..., -input_size:, :] = attn_state[1]
119
+ attn_state = (key_state, value_state)
120
+ else:
121
+ attn_state = (torch.cat([key_state, attn_state[0]], -2),
122
+ torch.cat([value_state, attn_state[1]], -2),)
123
+ state['attn_state'] = attn_state
124
+
125
+ return state
126
+
127
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
128
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
129
+ if len(self.states) <= layer_idx:
130
+ return 0
131
+ return self._cache_position[layer_idx]
132
+
133
+ def get_max_length(self) -> Optional[int]:
134
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
135
+ return None
136
+
137
+ def to_legacy_cache(self) -> Tuple:
138
+ return tuple(self.states)
139
+
140
+ def reorder_cache(self, beam_idx: torch.LongTensor):
141
+ """Reorders the cache for beam search, given the selected beam indices."""
142
+ for layer_idx in range(len(self.states)):
143
+ for k in self.states[layer_idx].keys():
144
+ if isinstance(self.states[layer_idx][k], torch.Tensor):
145
+ device = self.states[layer_idx][k].device
146
+ self.states[layer_idx][k] = self.states[layer_idx][k].index_select(0, beam_idx.to(device))
147
+ elif isinstance(self.states[layer_idx][k], Tuple):
148
+ _temp = []
149
+ for i in range(len(self.states[layer_idx][k])):
150
+ device = self.states[layer_idx][k][i].device
151
+ _temp.append(self.states[layer_idx][k][i].index_select(0, beam_idx.to(device)))
152
+ self.states[layer_idx][k] = tuple(_temp)
153
+
154
+ @classmethod
155
+ @torch.compiler.disable
156
+ def from_legacy_cache(
157
+ cls,
158
+ past_key_values: Optional[Tuple] = None,
159
+ cache_position: int = 0
160
+ ):
161
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
162
+
163
+ cache = cls(cache_position)
164
+ if isinstance(past_key_values, list):
165
+ for layer_idx in range(len(past_key_values)):
166
+ cache.states.append(past_key_values[layer_idx])
167
+ return cache
168
+
169
+ class VarlenCache(transformers.cache_utils.Cache):
170
+ """
171
+ A varlen cache used for storing hidden states produced by varlen batch inference.
172
+
173
+ **Input:**
174
+ - attn_state: Cache for standard attention, tuple(size(total_nnz, dmodel) * 2)
175
+ """
176
+
177
+ is_compileable = True
178
+
179
+ def __init__(self, cache_position: int = 0, batch_size: int = 1, device: str | torch.device = None):
180
+ super().__init__()
181
+
182
+ self.states: List[Dict[str, Any]] = []
183
+ self._cache_position = [torch.full((batch_size,), cache_position, dtype=torch.int64, device=device)] # Used in `generate` to keep tally of how many tokens the cache has seen
184
+ self.batch_size = batch_size
185
+ self.device = device
186
+
187
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
188
+ if layer_idx < len(self):
189
+ return self.states[layer_idx]
190
+ else:
191
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
192
+
193
+ def __iter__(self):
194
+ for state in self.states: yield state
195
+
196
+ def __len__(self):
197
+ return len(self.states)
198
+
199
+ def update(
200
+ self,
201
+ attn_state: Tuple[torch.Tensor, torch.Tensor] = None,
202
+ cu_seqlens: torch.LongTensor = None,
203
+ layer_idx: int = 0,
204
+ cache_kwargs: Optional[Dict[str, Any]] = {},
205
+ ) -> Dict[str, Any]:
206
+ """
207
+ Updates the cache with the new `attn_state` for the layer `layer_idx`.
208
+
209
+ Args:
210
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
211
+ The new attention key/value states to cache, sizes (total_nnz, hidden_size)
212
+ cu_seqlens (`torch.LongTensor`):
213
+ the accumulated sequence length for current states, sizes (bsz + 1,)
214
+ layer_idx (`int`, defaults to 0):
215
+ The index of the layer to cache the states for.
216
+ cache_kwargs (`Dict[str, Any]`, `optional`):
217
+ Additional arguments for the cache subclass.
218
+
219
+ Return:
220
+ Dictionary of the updated state.
221
+ """
222
+
223
+ if attn_state is not None:
224
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
225
+ raise ValueError("`attn_state` must be a tuple of two tensors for key/value states")
226
+
227
+ dtype = attn_state[0].dtype
228
+ device = attn_state[0].device
229
+ hidden_size = attn_state[0].size(-1)
230
+
231
+ # Case 1: prefill at the 1st step
232
+ if len(self._cache_position) <= layer_idx:
233
+ self._cache_position.append(
234
+ torch.zeros((cu_seqlens.size(0) - 1,), dtype=torch.int64, device=cu_seqlens.device)
235
+ )
236
+
237
+ kv_seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
238
+ kv_seqlens_cpu = kv_seqlens.cpu().tolist()
239
+ self._cache_position[layer_idx] += kv_seqlens
240
+
241
+ if len(self.states) <= layer_idx:
242
+ key_state, value_state = list(map(lambda x: torch.split(x, kv_seqlens_cpu), attn_state))
243
+ state = dict(
244
+ attn_state=(key_state, value_state),
245
+ cu_seqlens=cu_seqlens,
246
+ max_seqlen=kv_seqlens.max().item()
247
+ )
248
+ self.states.append(state)
249
+
250
+ # Case 2: append current step's kv cache
251
+ else:
252
+ state = self.states[layer_idx]
253
+ if state["attn_state"] is not None:
254
+ key_state, value_state = list(map(lambda x: torch.split(x, kv_seqlens_cpu), attn_state))
255
+ key_cache, value_cache = state['attn_state']
256
+ old_cu_seqlens = state['cu_seqlens']
257
+
258
+ key_cache = tuple(map(lambda x, y: torch.cat([x, y], dim=0), key_cache, key_state))
259
+ value_cache = tuple(map(lambda x, y: torch.cat([x, y], dim=0), value_cache, value_state))
260
+
261
+ new_cu_seqlens = old_cu_seqlens + cu_seqlens
262
+ state.update(
263
+ attn_state=(key_cache, value_cache),
264
+ cu_seqlens=new_cu_seqlens,
265
+ max_seqlen=(new_cu_seqlens[1:] - new_cu_seqlens[:-1]).max().item()
266
+ )
267
+ return state
268
+
269
+ def get_kv_cache(self, state: Dict[str, Any]) -> Tuple[torch.Tensor, torch.Tensor]:
270
+ return tuple(map(lambda x: torch.cat(x, 0), state['attn_state']))
271
+
272
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> torch.Tensor:
273
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
274
+ if len(self.states) <= layer_idx:
275
+ return torch.zeros(self.batch_size, dtype=torch.int64, device=self.device)
276
+ return self._cache_position[layer_idx]
277
+
278
+ def get_cu_seq_length(self, layer_idx: Optional[int] = 0) -> torch.Tensor:
279
+ """Returns the accumulated sequence length of the cached states. A layer index can be optionally passed."""
280
+ if len(self.states) <= layer_idx:
281
+ return torch.zeros(self.batch_size + 1, dtype=torch.int64, device=self.device)
282
+ return self.states[layer_idx]['cu_seqlens']
283
+
284
+ def get_max_length(self) -> Optional[int]:
285
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
286
+ return None
287
+
288
+ def to_legacy_cache(self) -> Tuple:
289
+ return tuple(self.states)
290
+
291
+ def reorder_cache(self, beam_idx: torch.LongTensor):
292
+ """Reorders the cache for beam search, given the selected beam indices."""
293
+ raise NotImplementedError("Varlen Batch Inference does not support beam search at now.")
294
+
295
+ @classmethod
296
+ @torch.compiler.disable
297
+ def from_legacy_cache(
298
+ cls,
299
+ past_key_values: Optional[Tuple] = None,
300
+ cache_position: int = 0,
301
+ batch_size: int = 1,
302
+ device: str | torch.device = None
303
+ ):
304
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
305
+
306
+ cache = cls(cache_position, batch_size=batch_size, device=device)
307
+ if isinstance(past_key_values, list):
308
+ for layer_idx in range(len(past_key_values)):
309
+ cache.states.append(past_key_values[layer_idx])
310
+ return cache
311
+
312
+ @torch.no_grad()
313
+ def linear_init(
314
+ linear: nn.Linear,
315
+ distribution: Optional[str]='normal',
316
+ zero_bias: Optional[bool]=False,
317
+ gain: Optional[float]=1.0
318
+ ) ->None:
319
+ if distribution == 'normal':
320
+ nn.init.xavier_normal_(linear.weight, gain=gain)
321
+ elif distribution == 'uniform':
322
+ nn.init.xavier_uniform_(linear.weight, gain=gain)
323
+ if linear.bias is not None:
324
+ if zero_bias:
325
+ nn.init.zeros_(linear.bias)
326
+ else:
327
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(linear.weight)
328
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
329
+ nn.init.uniform_(linear.bias, -bound, bound)
330
+
331
+
332
+ @torch.no_grad()
333
+ def embedding_init(embedding: nn.Embedding) ->None:
334
+ fan_out = embedding.weight.size(1)
335
+ std = 1.0 * math.sqrt(1.0 / float(fan_out))
336
+ nn.init.normal_(embedding.weight, 0., std)
337
+ if embedding.padding_idx is not None:
338
+ embedding.weight[embedding.padding_idx].fill_(0)
339
+
340
+ def sparse_to_dense(src: torch.Tensor, length: torch.Tensor) ->torch.Tensor:
341
+ maxLength = length.max().item()
342
+
343
+ length = length.cpu().numpy()
344
+ broadcastIdx = np.arange(length[0], dtype=np.int64)
345
+ for i in range(1, length.shape[0]): broadcastIdx = np.concatenate([broadcastIdx, np.arange(length[i], dtype=np.int64) + maxLength * i], axis=0)
346
+ broadcastIdx = torch.tensor(broadcastIdx, dtype=torch.int64, device=src.device)
347
+
348
+ tgt = torch.zeros((length.shape[0] * maxLength, src.size(-1)), dtype=src.dtype, device=src.device)
349
+ tgt[broadcastIdx] = src
350
+ tgt = tgt.reshape(length.shape[0], maxLength, -1).contiguous()
351
+ return tgt
352
+
353
+ #########################################################
354
+ # --- model ---
355
+ #########################################################
356
+ class FullAttention(nn.Module):
357
+ def __init__(
358
+ self,
359
+ hidden_size: int,
360
+ num_heads: int,
361
+ num_kv_heads: int,
362
+ rotary_base: int,
363
+ dropout: float,
364
+ layer_idx: int,
365
+ **kwargs
366
+ ):
367
+ super(FullAttention, self).__init__()
368
+
369
+ self.hidden_size = hidden_size
370
+ self.num_q_heads = num_heads
371
+ self.num_kv_heads = num_kv_heads
372
+ self.head_size = hidden_size // num_heads
373
+ self.dropout = dropout
374
+ self.layer_idx = layer_idx
375
+
376
+ self.qkv = nn.Linear(hidden_size, hidden_size + 2 * num_kv_heads * self.head_size, bias=False)
377
+ self.out = nn.Linear(hidden_size, hidden_size, bias=False)
378
+ self.rotary = RotaryEmbedding(dim=self.head_size, base=rotary_base)
379
+
380
+ self._init_weights()
381
+
382
+ def _init_weights(self):
383
+ for k, v in self.named_modules():
384
+ if isinstance(v, nn.Linear): linear_init(v, zero_bias=True)
385
+
386
+ def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor=None, max_seqlen: int=None, causal: bool=True, past_key_values: Cache | VarlenCache=None):
387
+ """
388
+ Training with varlen:
389
+
390
+ x -> size(B*L, D)
391
+ cu_seqlens -> size(B+1)
392
+
393
+ Generating with padding:
394
+
395
+ x -> size(B, L, D)
396
+ cu_seqlens -> None
397
+ """
398
+
399
+ if cu_seqlens is None:
400
+ qkv: torch.Tensor = self.qkv(x)
401
+ qkv = rearrange(qkv, "B L (H D) -> B L H D", H=(self.num_q_heads + 2 * self.num_kv_heads), D=self.head_size)
402
+ q, kv = torch.split(qkv, [self.num_q_heads, 2 * self.num_kv_heads], dim=-2)
403
+ kv = rearrange(kv, "B L (C H) D -> B L C H D", C=2, H=self.num_kv_heads)
404
+
405
+ if past_key_values is not None:
406
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
407
+ _max_seqlen = q.size(1) + seqlen_offset
408
+ q, kv = self.rotary(q, kv, seqlen_offset=seqlen_offset, max_seqlen=_max_seqlen, num_heads_q=self.num_q_heads)
409
+
410
+ k, v = kv.unbind(dim=2)
411
+ k, v = past_key_values.update(
412
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
413
+ layer_idx=self.layer_idx,
414
+ offset=q.size(1),
415
+ cache_kwargs=dict()
416
+ )["attn_state"]
417
+ k, v = rearrange(k, "... (H D) -> ... H D", H=self.num_kv_heads, D=self.head_size), rearrange(v, "... (H D) -> ... H D", H=self.num_kv_heads, D=self.head_size)
418
+
419
+ kv = torch.cat([k.unsqueeze(2), v.unsqueeze(2)], dim=2)
420
+ else:
421
+ q, kv = self.rotary(q, kv)
422
+
423
+ out = flash_attn_kvpacked_func(q, kv, dropout_p=self.dropout if self.training else 0, causal=causal)
424
+ out = self.out(rearrange(out, "B L H D -> B L (H D)"))
425
+ else:
426
+ qkv: torch.Tensor = self.qkv(x)
427
+ qkv = rearrange(qkv, "L (H D) -> L H D", H=(self.num_q_heads + 2 * self.num_kv_heads), D=self.head_size)
428
+ q, k, v = torch.split(qkv, [self.num_q_heads, self.num_kv_heads, self.num_kv_heads], dim=-2)
429
+
430
+ if past_key_values is not None:
431
+ assert isinstance(past_key_values, VarlenCache)
432
+
433
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
434
+ _seqlen = cu_seqlens[1:] - cu_seqlens[:-1]
435
+ _max_seqlen = (seqlen_offset + _seqlen).max().item()
436
+
437
+ self.rotary._update_cos_sin_cache(seqlen=_max_seqlen, device=q.device, dtype=q.dtype)
438
+ q, k = apply_rotary_emb(q, self.rotary._cos_cached, self.rotary._sin_cached, seqlen_offsets=seqlen_offset, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen),\
439
+ apply_rotary_emb(k, self.rotary._cos_cached, self.rotary._sin_cached, seqlen_offsets=seqlen_offset, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
440
+
441
+ new_cache = past_key_values.update(
442
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
443
+ cu_seqlens=cu_seqlens,
444
+ layer_idx=self.layer_idx,
445
+ cache_kwargs=dict()
446
+ )
447
+ k, v = past_key_values.get_kv_cache(new_cache)
448
+ k, v = rearrange(k, "... (H D) -> ... H D", H=self.num_kv_heads, D=self.head_size), rearrange(v, "... (H D) -> ... H D", H=self.num_kv_heads, D=self.head_size)
449
+
450
+ kv_cu_seqlens, kv_max_seqlen = new_cache['cu_seqlens'], new_cache['max_seqlen']
451
+
452
+ out = flash_attn_varlen_func(q, k, v, cu_seqlens, kv_cu_seqlens, max_seqlen, kv_max_seqlen, dropout_p=self.dropout if self.training else 0, causal=causal)
453
+ else:
454
+ self.rotary._update_cos_sin_cache(seqlen=max_seqlen, device=q.device, dtype=q.dtype)
455
+ q, k = apply_rotary_emb(q, self.rotary._cos_cached, self.rotary._sin_cached, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen),\
456
+ apply_rotary_emb(k, self.rotary._cos_cached, self.rotary._sin_cached, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
457
+
458
+ out = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=self.dropout if self.training else 0, causal=causal)
459
+ out = self.out(rearrange(out, "L H D -> L (H D)"))
460
+
461
+ return out, None, past_key_values
462
+
463
+ class TransformerBlock(nn.Module):
464
+ def __init__(self, config: Sophie0Config, layer_idx: int):
465
+ super().__init__()
466
+
467
+ self.config = config
468
+ self.layer_idx = layer_idx
469
+
470
+ self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.eps)
471
+ self.attn = FullAttention(
472
+ hidden_size=config.hidden_size,
473
+ num_heads=config.num_heads,
474
+ num_kv_heads=config.num_kv_heads,
475
+ rotary_base=config.rope_base,
476
+ dropout=config.dropout,
477
+ layer_idx=self.layer_idx
478
+ )
479
+ self.ffn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.eps)
480
+ self.ffn = GatedMlp(
481
+ in_features=config.hidden_size,
482
+ hidden_features=config.intermediate_size,
483
+ activation=F.silu,
484
+ bias1=False,
485
+ bias2=False,
486
+ multiple_of=1
487
+ )
488
+
489
+ self._init_weights()
490
+
491
+ def _init_weights(self):
492
+ for k, v in self.ffn.named_modules():
493
+ if isinstance(v, nn.Linear): linear_init(v, zero_bias=True)
494
+
495
+ def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor=None, max_seqlen: int=None, causal: bool=True, past_key_values: Cache=None):
496
+ out, _, past_key_values = self.attn(self.attn_norm(x), cu_seqlens, max_seqlen, causal, past_key_values)
497
+ x = x + out
498
+ x = x + self.ffn(self.ffn_norm(x))
499
+
500
+ return (x, _, past_key_values)
501
+
502
+
503
+ class Sophie0PretraindModel(PreTrainedModel):
504
+ config_class = Sophie0Config
505
+ supports_gradient_checkpointing = True
506
+ _supports_cache_class = True
507
+ _no_split_modules = ["TransformerBlock"]
508
+
509
+ def __init__(self, *inputs, **kwargs):
510
+ super().__init__(*inputs, **kwargs)
511
+
512
+ def _init_weights(self, module: nn.Module):
513
+ if isinstance(module, nn.Embedding):
514
+ embedding_init(module)
515
+ elif isinstance(module, nn.Linear):
516
+ linear_init(module, zero_bias=True)
517
+
518
+ class Sophie0Model(Sophie0PretraindModel):
519
+ def __init__(self, config: Sophie0Config, **kwargs):
520
+ super().__init__(config, **kwargs)
521
+
522
+ self.padding_idx = config.pad_token_id
523
+ self.vocab_size = config.vocab_size
524
+
525
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
526
+ self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
527
+ self.norm = RMSNorm(config.hidden_size, eps=config.eps)
528
+
529
+ self.post_init()
530
+
531
+ def get_input_embeddings(self):
532
+ return self.embeddings
533
+
534
+ def set_input_embeddings(self, value):
535
+ self.embeddings = value
536
+
537
+ def forward(
538
+ self,
539
+ input_ids: Optional[torch.LongTensor] = None,
540
+ cu_seqlens: Optional[torch.LongTensor] = None,
541
+ max_seqlen: Optional[int] = None,
542
+ attention_mask: Optional[torch.Tensor] = None,
543
+ inputs_embeds: Optional[torch.FloatTensor] = None,
544
+ past_key_values: Optional[Union[Cache, VarlenCache, List[torch.FloatTensor]]] = None,
545
+ use_cache: Optional[bool] = None,
546
+ output_attentions: Optional[bool] = None,
547
+ output_hidden_states: Optional[bool] = None,
548
+ return_dict: Optional[bool] = True,
549
+ **kwargs: Unpack[Dict]
550
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
551
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False)
552
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
553
+ return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", False)
554
+
555
+ if input_ids is not None and inputs_embeds is not None:
556
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
557
+ if input_ids is None and inputs_embeds is None:
558
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
559
+
560
+ if inputs_embeds is None:
561
+ inputs_embeds = self.embeddings(input_ids)
562
+ hidden_states = inputs_embeds
563
+
564
+ if cu_seqlens is not None:
565
+ if use_cache and not isinstance(past_key_values, VarlenCache): past_key_values = VarlenCache.from_legacy_cache(past_key_values, batch_size=cu_seqlens.size(0)-1, device=cu_seqlens.device)
566
+ else:
567
+ if use_cache and not isinstance(past_key_values, Cache): past_key_values = Cache.from_legacy_cache(past_key_values)
568
+
569
+ if kwargs.get("use_gradient_checkpoint", False) is True and self.supports_gradient_checkpointing and self.training: self.gradient_checkpointing = True
570
+ else: self.gradient_checkpointing = False
571
+
572
+ all_hidden_states = () if output_hidden_states else None
573
+ for layer in self.layers:
574
+ if output_hidden_states: all_hidden_states += (hidden_states,)
575
+
576
+ if self.gradient_checkpointing and self.training:
577
+ hidden_states, _, past_key_values = checkpoint.checkpoint(
578
+ layer.__call__,
579
+ hidden_states,
580
+ cu_seqlens,
581
+ max_seqlen,
582
+ True,
583
+ past_key_values,
584
+ use_reentrant=False
585
+ )
586
+ else:
587
+ hidden_states, _, past_key_values = layer(hidden_states, cu_seqlens, max_seqlen, True, past_key_values)
588
+
589
+ hidden_states = self.norm(hidden_states)
590
+
591
+ if not return_dict:
592
+ return tuple(v for v in [hidden_states, all_hidden_states, past_key_values] if v is not None)
593
+
594
+ return BaseModelOutputWithPast(
595
+ last_hidden_state=hidden_states,
596
+ past_key_values=past_key_values,
597
+ hidden_states=all_hidden_states,
598
+ attentions=None
599
+ )
600
+
601
+ class Sophie0ForCausalLM(Sophie0PretraindModel, GenerationMixin):
602
+ _tied_weights_keys = ["lm_head.weight"]
603
+
604
+ def __init__(self, config: Sophie0Config):
605
+ super().__init__(config)
606
+
607
+ self.model = Sophie0Model(config)
608
+ self.vocab_size = config.vocab_size
609
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
610
+ self.criterion = None
611
+
612
+ self.post_init()
613
+
614
+ def get_input_embeddings(self):
615
+ return self.model.embeddings
616
+
617
+ def set_input_embeddings(self, value):
618
+ self.model.embeddings = value
619
+
620
+ def get_output_embeddings(self):
621
+ return self.lm_head
622
+
623
+ def set_output_embeddings(self, new_embeddings):
624
+ self.lm_head = new_embeddings
625
+
626
+ def set_decoder(self, decoder):
627
+ self.model = decoder
628
+
629
+ def get_decoder(self):
630
+ return self.model
631
+
632
+ def generate(self, *args, **kwargs):
633
+ try:
634
+ return super().generate(*args, **kwargs)
635
+ except AttributeError as exception:
636
+ if 'past_key_values' in str(exception):
637
+ raise AttributeError(
638
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
639
+ f"which is not supported for {self.__class__.__name__}. "
640
+ f"Try another generation strategy instead. "
641
+ f"For the available generation strategies, check this doc: "
642
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
643
+ )
644
+ else:
645
+ raise exception
646
+
647
+ def prepare_inputs_for_generation(
648
+ self,
649
+ input_ids: torch.LongTensor,
650
+ past_key_values: Optional[Cache] = None,
651
+ attention_mask: Optional[torch.Tensor] = None,
652
+ inputs_embeds: Optional[torch.Tensor] = None,
653
+ cache_position: Optional[int] = None,
654
+ use_cache: Optional[bool] = True,
655
+ logits_to_keep = None,
656
+ cu_seqlens: Optional[torch.LongTensor] = None,
657
+ max_seqlen: Optional[int] = None,
658
+ use_varlen_inference: Optional[bool]=False,
659
+ **kwargs
660
+ ):
661
+ if inputs_embeds is not None and len(past_key_values) == 0:
662
+ model_inputs = {'inputs_embeds': inputs_embeds}
663
+ else:
664
+ if past_key_values is not None and len(past_key_values) > 0:
665
+ input_ids = input_ids[:, -1:]
666
+ if isinstance(past_key_values, VarlenCache):
667
+ input_ids = input_ids.squeeze(-1)
668
+ cu_seqlens = torch.arange(past_key_values.batch_size + 1, dtype=torch.int32, device=input_ids.device)
669
+ max_seqlen = 1
670
+ else:
671
+ if use_varlen_inference:
672
+ input_ids, _, cu_seqlens, max_seqlen, _ = unpad_input(input_ids.unsqueeze(-1), attention_mask)
673
+ input_ids = input_ids.squeeze(-1)
674
+
675
+ model_inputs = {'input_ids': input_ids.contiguous()}
676
+
677
+ if logits_to_keep is not None:
678
+ model_inputs['logits_to_keep'] = logits_to_keep
679
+
680
+ model_inputs.update({
681
+ 'past_key_values': past_key_values,
682
+ 'use_cache': use_cache,
683
+ 'cu_seqlens': cu_seqlens,
684
+ 'max_seqlen': max_seqlen
685
+ })
686
+ return model_inputs
687
+
688
+ def forward(
689
+ self,
690
+ input_ids: Optional[torch.LongTensor] = None,
691
+ cu_seqlens: Optional[torch.LongTensor] = None,
692
+ max_seqlen: Optional[int] = None,
693
+ use_varlen_inference: Optional[bool]=False,
694
+ attention_mask: Optional[torch.Tensor] = None,
695
+ inputs_embeds: Optional[torch.FloatTensor] = None,
696
+ past_key_values: Optional[Union[Cache, VarlenCache, List[torch.FloatTensor]]] = None,
697
+ labels: Optional[torch.LongTensor] = None,
698
+ labels_mask: Optional[torch.Tensor] = None,
699
+ use_cache: Optional[bool] = None,
700
+ output_attentions: Optional[bool] = None,
701
+ output_hidden_states: Optional[bool] = None,
702
+ return_dict: Optional[bool] = True,
703
+ **kwargs: Unpack[Dict]
704
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
705
+ output_attentions = output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False)
706
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False)
707
+ return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", False)
708
+
709
+ outputs: BaseModelOutputWithPast = self.model(
710
+ input_ids=input_ids,
711
+ cu_seqlens=cu_seqlens,
712
+ max_seqlen=max_seqlen,
713
+ attention_mask=attention_mask,
714
+ inputs_embeds=inputs_embeds,
715
+ past_key_values=past_key_values,
716
+ use_cache=use_cache,
717
+ output_attentions=output_attentions,
718
+ output_hidden_states=output_hidden_states,
719
+ return_dict=return_dict,
720
+ **kwargs
721
+ )
722
+
723
+ hidden_states = outputs.last_hidden_state
724
+ logits = self.lm_head(hidden_states)
725
+ past_key_values = outputs.past_key_values
726
+
727
+ loss = None
728
+ if labels is not None:
729
+ self.criterion = CrossEntropyLoss(ignore_index=self.config.pad_token_id, reduction="mean" if labels_mask is None else "none")
730
+ if logits.dim() == 2: # varlen
731
+ assert labels.dim() == 1
732
+ loss = self.criterion(logits, labels)
733
+ if labels_mask is not None:
734
+ loss = loss * labels_mask
735
+ loss = loss.sum() / labels_mask.sum()
736
+ else:
737
+ loss = loss.mean()
738
+ else:
739
+ assert labels.dim() == 2
740
+ if self.config.right_shift:
741
+ labels = labels[:, 1:]
742
+ logits = logits[:, :-1].contiguous()
743
+
744
+ loss = self.criterion(logits.flatten(0, 1), labels.flatten(0, 1))
745
+ if labels_mask is not None:
746
+ loss = loss * labels_mask.flatten(0, 1)
747
+ loss = loss.sum() / labels_mask.sum()
748
+ else:
749
+ loss = loss.mean()
750
+
751
+ else:
752
+ if isinstance(past_key_values, VarlenCache):
753
+ kv_cu_seqlens = past_key_values.get_cu_seq_length()
754
+ if logits.size(0) > past_key_values.batch_size: logits = logits.index_select(0, kv_cu_seqlens[1:] - 1)
755
+ logits = logits.unsqueeze(1)
756
+
757
+ if not return_dict:
758
+ output = (logits,) + outputs[1:]
759
+ return (loss,) + output if loss is not None else output
760
+
761
+ return CausalLMOutputWithPast(
762
+ loss=loss,
763
+ logits=logits,
764
+ past_key_values=past_key_values,
765
+ hidden_states=outputs.hidden_states,
766
+ attentions=outputs.attentions
767
+ )