maomaocun commited on
Commit
347e409
·
1 Parent(s): 0bf3d37
README copy.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: transformers
4
+ pipeline_tag: text-generation
5
+ ---
6
+ # MDM-1.7B
7
+
8
+ We introduce MDM-1.7B, a diffusion language model with an 1.7B scale, trained entirely from scratch with open sourece 1.1T tokens.
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/cpfs02/shared/llmit6/liudawei/xpuyu_work_dirs/internlm2-1_8b-myds-llada-v3/hf-310000-anneal4k",
3
+ "architectures": [
4
+ "InternLM2ForCausalLM"
5
+ ],
6
+ "attn_implementation": "sdpa",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_internlm2.InternLM2Config",
9
+ "AutoModel": "modeling_internlm2.InternLM2ForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_internlm2.InternLM2ForCausalLM"
11
+ },
12
+ "bias": false,
13
+ "bos_token_id": 1,
14
+ "eos_token_id": 2,
15
+ "fuse_cross_entropy": true,
16
+ "hidden_act": "silu",
17
+ "hidden_size": 2048,
18
+ "initializer_range": 0.006,
19
+ "intermediate_size": 8192,
20
+ "is_causal": false,
21
+ "max_position_embeddings": 32768,
22
+ "model_type": "internlm2",
23
+ "num_attention_heads": 16,
24
+ "num_hidden_layers": 24,
25
+ "num_key_value_heads": 8,
26
+ "pad_token_id": 2,
27
+ "pretraining_tp": 1,
28
+ "rms_norm_eps": 1e-05,
29
+ "rope_scaling": null,
30
+ "rope_theta": 10000.0,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "bfloat16",
33
+ "transformers_version": "4.46.0",
34
+ "use_cache": false,
35
+ "vocab_size": 128512
36
+ }
configuration_internlm2.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """ InternLM2 model configuration"""
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25
+
26
+
27
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
28
+ class InternLM2Config(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
31
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
32
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32000):
40
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
41
+ `inputs_ids` passed when calling [`InternLM2Model`]
42
+ hidden_size (`int`, *optional*, defaults to 4096):
43
+ Dimension of the hidden representations.
44
+ intermediate_size (`int`, *optional*, defaults to 11008):
45
+ Dimension of the MLP representations.
46
+ num_hidden_layers (`int`, *optional*, defaults to 32):
47
+ Number of hidden layers in the Transformer decoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 32):
49
+ Number of attention heads for each attention layer in the Transformer decoder.
50
+ num_key_value_heads (`int`, *optional*):
51
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55
+ by meanpooling all the original heads within that group. For more details checkout [this
56
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57
+ `num_attention_heads`.
58
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
+ The non-linear activation function (function or string) in the decoder.
60
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
61
+ The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens.
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ pad_token_id (`int`, *optional*):
70
+ Padding token id.
71
+ bos_token_id (`int`, *optional*, defaults to 1):
72
+ Beginning of stream token id.
73
+ eos_token_id (`int`, *optional*, defaults to 2):
74
+ End of stream token id.
75
+ pretraining_tp (`int`, *optional*, defaults to 1):
76
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
77
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism)
78
+ to understand more about it. This value is necessary to ensure exact reproducibility
79
+ of the pretraining results. Please refer to [this
80
+ issue](https://github.com/pytorch/pytorch/issues/76232).
81
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82
+ Whether to tie weight embeddings
83
+ rope_theta (`float`, *optional*, defaults to 10000.0):
84
+ The base period of the RoPE embeddings.
85
+ rope_scaling (`Dict`, *optional*):
86
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
87
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
88
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
89
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
90
+ these scaling strategies behave:
91
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
92
+ experimental feature, subject to breaking API changes in future versions.
93
+ """
94
+ _auto_class = "AutoConfig"
95
+ model_type = "internlm2"
96
+ keys_to_ignore_at_inference = ["past_key_values"]
97
+
98
+ def __init__( # pylint: disable=W0102
99
+ self,
100
+ vocab_size=103168,
101
+ hidden_size=4096,
102
+ intermediate_size=11008,
103
+ num_hidden_layers=32,
104
+ num_attention_heads=32,
105
+ num_key_value_heads=None,
106
+ hidden_act="silu",
107
+ max_position_embeddings=2048,
108
+ initializer_range=0.02,
109
+ rms_norm_eps=1e-6,
110
+ use_cache=True,
111
+ pad_token_id=0,
112
+ bos_token_id=1,
113
+ eos_token_id=2,
114
+ pretraining_tp=1,
115
+ tie_word_embeddings=False,
116
+ bias=True,
117
+ rope_theta=10000,
118
+ rope_scaling=None,
119
+ attn_implementation=None,
120
+ **kwargs,
121
+ ):
122
+ self.vocab_size = vocab_size
123
+ self.max_position_embeddings = max_position_embeddings
124
+ self.hidden_size = hidden_size
125
+ self.intermediate_size = intermediate_size
126
+ self.num_hidden_layers = num_hidden_layers
127
+ self.num_attention_heads = num_attention_heads
128
+ self.bias = bias
129
+
130
+ if num_key_value_heads is None:
131
+ num_key_value_heads = num_attention_heads
132
+ self.num_key_value_heads = num_key_value_heads
133
+
134
+ self.hidden_act = hidden_act
135
+ self.initializer_range = initializer_range
136
+ self.rms_norm_eps = rms_norm_eps
137
+ self.pretraining_tp = pretraining_tp
138
+ self.use_cache = use_cache
139
+ self.rope_theta = rope_theta
140
+ self.rope_scaling = rope_scaling
141
+ self._rope_scaling_validation()
142
+ self.attn_implementation = attn_implementation
143
+ if self.attn_implementation is None:
144
+ self.attn_implementation = "sdpa"
145
+
146
+ super().__init__(
147
+ pad_token_id=pad_token_id,
148
+ bos_token_id=bos_token_id,
149
+ eos_token_id=eos_token_id,
150
+ tie_word_embeddings=tie_word_embeddings,
151
+ **kwargs,
152
+ )
153
+
154
+ def _rope_scaling_validation(self):
155
+ """
156
+ Validate the `rope_scaling` configuration.
157
+ """
158
+ if self.rope_scaling is None:
159
+ return
160
+
161
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
162
+ raise ValueError(
163
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
164
+ f"got {self.rope_scaling}"
165
+ )
166
+ rope_scaling_type = self.rope_scaling.get("type", None)
167
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
168
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
169
+ raise ValueError(
170
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
171
+ )
172
+ if (
173
+ rope_scaling_factor is None
174
+ or not isinstance(rope_scaling_factor, (float, int))
175
+ or rope_scaling_factor < 1.0
176
+ ):
177
+ raise ValueError(
178
+ f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
179
+ f"of type {type(rope_scaling_factor)}"
180
+ )
generate.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from transformers import AutoTokenizer, AutoModel
5
+
6
+ def add_gumbel_noise(logits, temperature):
7
+ if temperature == 0:
8
+ return logits
9
+ logits = logits.to(torch.float64)
10
+ noise = torch.rand_like(logits, dtype=torch.float64)
11
+ gumbel_noise = (- torch.log(noise)) ** temperature
12
+ return logits.exp() / gumbel_noise
13
+
14
+ def get_num_transfer_tokens(mask_index, steps):
15
+ mask_num = mask_index.sum(dim=1, keepdim=True)
16
+ base = mask_num // steps
17
+ remainder = mask_num % steps
18
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
19
+ for i in range(mask_num.size(0)):
20
+ num_transfer_tokens[i, :remainder[i]] += 1
21
+ return num_transfer_tokens
22
+
23
+ @torch.no_grad()
24
+ def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
25
+ cfg_scale=0., remasking='low_confidence', mask_id=128108):
26
+ x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
27
+ x[:, :prompt.shape[1]] = prompt.clone()
28
+ prompt_index = (x != mask_id)
29
+ assert gen_length % block_length == 0
30
+ num_blocks = gen_length // block_length
31
+ assert steps % num_blocks == 0
32
+ steps = steps // num_blocks
33
+ for num_block in range(num_blocks):
34
+ block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
35
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
36
+ for i in range(steps):
37
+ mask_index = (x == mask_id)
38
+ if cfg_scale > 0.:
39
+ un_x = x.clone()
40
+ un_x[prompt_index] = mask_id
41
+ x_ = torch.cat([x, un_x], dim=0)
42
+ logits = model(x_).logits
43
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
44
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
45
+ else:
46
+ logits = model(x).logits
47
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
48
+ x0 = torch.argmax(logits_with_noise, dim=-1)
49
+ if remasking == 'low_confidence':
50
+ p = F.softmax(logits, dim=-1)
51
+ x0_p = torch.squeeze(
52
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
53
+ elif remasking == 'random':
54
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
55
+ else:
56
+ raise NotImplementedError(remasking)
57
+ x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
58
+ x0 = torch.where(mask_index, x0, x)
59
+ confidence = torch.where(mask_index, x0_p, -np.inf)
60
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
61
+ for j in range(confidence.shape[0]):
62
+ k = int(num_transfer_tokens[j, i].item())
63
+ _, select_index = torch.topk(confidence[j], k=k)
64
+ transfer_index[j, select_index] = True
65
+ x[transfer_index] = x0[transfer_index]
66
+ return x[:, prompt.shape[1]:]
67
+
68
+ if __name__ == "__main__":
69
+ model_path = "/cpfs02/shared/llmit6/liudawei/xpuyu_work_dirs/internlm2-1_8b-myds-llada-sft-v3/pretrain-310000-yhc-padto2power/20250430200836/release"
70
+ device = torch.device("cuda")
71
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
72
+ model = AutoModel.from_pretrained(
73
+ model_path,
74
+ torch_dtype=torch.bfloat16,
75
+ device_map=None,
76
+ trust_remote_code=True
77
+ ).to(device)
78
+ input_text = "Question: Jen and Tyler are gymnasts practicing flips. Jen is practicing the triple-flip while Tyler is practicing the double-flip. Jen did sixteen triple-flips during practice. Tyler flipped in the air half the number of times Jen did. How many double-flips did Tyler do?\nAnswer: Jen did 16 triple-flips, so she did 16 * 3 = <<16*3=48>>48 flips.\nTyler did half the number of flips, so he did 48 / 2 = <<48/2=24>>24 flips.\nA double flip has two flips, so Tyler did 24 / 2 = <<24/2=12>>12 double-flips.\n#### 12\n\nQuestion: Four people in a law firm are planning a party. Mary will buy a platter of pasta for $20 and a loaf of bread for $2. Elle and Andrea will split the cost for buying 4 cans of soda which cost $1.50 each, and chicken wings for $10. Joe will buy a cake that costs $5. How much more will Mary spend than the rest of the firm put together?\nAnswer: Mary will spend $20 + $2 = $<<20+2=22>>22.\nElle and Andrea will spend $1.5 x 4 = $<<1.5*4=6>>6 for the soda.\nElle and Andrea will spend $6 + $10 = $<<6+10=16>>16 for the soda and chicken wings.\nElle, Andrea, and Joe together will spend $16 + $5 = $<<16+5=21>>21.\nSo, Mary will spend $22 - $21 = $<<22-21=1>>1 more than all of them combined.\n#### 1\n\nQuestion: A charcoal grill burns fifteen coals to ash every twenty minutes of grilling. The grill ran for long enough to burn three bags of coals. Each bag of coal contains 60 coals. How long did the grill run?\nAnswer: The grill burned 3 * 60 = <<3*60=180>>180 coals.\nIt takes 20 minutes to burn 15 coals, so the grill ran for 180 / 15 * 20 = <<180/15*20=240>>240 minutes.\n#### 240\n\nQuestion: A bear is preparing to hibernate for the winter and needs to gain 1000 pounds. At the end of summer, the bear feasts on berries and small woodland animals. During autumn, it devours acorns and salmon. It gained a fifth of the weight it needed from berries during summer, and during autumn, it gained twice that amount from acorns. Salmon made up half of the remaining weight it had needed to gain. How many pounds did it gain eating small animals?\nAnswer: The bear gained 1 / 5 * 1000 = <<1/5*1000=200>>200 pounds from berries.\nIt gained 2 * 200 = <<2*200=400>>400 pounds from acorns.\nIt still needed 1000 - 200 - 400 = <<1000-200-400=400>>400 pounds.\nThus, it gained 400 / 2 = <<400/2=200>>200 pounds from salmon.\nTherefore, the bear gained 400 - 200 = <<400-200=200>>200 pounds from small animals.\n#### 200\n\nQuestion: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nAnswer:"
79
+ prompt = tokenizer.apply_chat_template(
80
+ [{"role": "user", "content": input_text}],
81
+ add_generation_prompt=True,
82
+ return_tensors="pt"
83
+ ).to(device)
84
+ print(f"输入文本: {input_text}")
85
+ result_ids = generate(model, prompt)
86
+ result_text = tokenizer.decode(result_ids[0], skip_special_tokens=True)
87
+ print(f"最终输出: {result_text}")
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": [
5
+ 2,
6
+ 128131
7
+ ],
8
+ "pad_token_id": 2,
9
+ "transformers_version": "4.46.0",
10
+ "use_cache": false
11
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f720f858050cd3042b4be86fdf2c5eb82ebc17333e11701e4e2913a6feddd11a
3
+ size 4072889152
modeling_internlm2.py ADDED
@@ -0,0 +1,2127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch InternLM2 model."""
17
+
18
+ """2024/12/13 只修改了 rmsnorm"""
19
+ # from fla.modules.activations import swiglu_linear
20
+ from .configuration_internlm2 import InternLM2Config
21
+ # from fla.modules import (
22
+ # FusedCrossEntropyLoss, RMSNorm, RotaryEmbedding,
23
+ # FusedLinearDiffusionCrossEntropyLoss)
24
+ from transformers.utils import (
25
+ add_start_docstrings,
26
+ add_start_docstrings_to_model_forward,
27
+ is_flash_attn_greater_or_equal_2_10,
28
+ logging,
29
+ replace_return_docstrings,
30
+ )
31
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPast,
35
+ CausalLMOutputWithPast,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutputWithPast,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
41
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
42
+ from transformers.activations import ACT2FN
43
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
+ from torch import nn
45
+ from einops import rearrange
46
+ import torch.utils.checkpoint
47
+ import torch.nn.functional as F
48
+ import torch
49
+ from typing import List, Optional, Tuple, Union
50
+ import threading
51
+ import queue
52
+ import math
53
+
54
+
55
+ try:
56
+ from transformers.generation.streamers import BaseStreamer
57
+ except Exception:
58
+ BaseStreamer = None
59
+
60
+
61
+ try:
62
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
63
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
64
+ except:
65
+ pass
66
+
67
+ try:
68
+ support_bf16_triu = torch.__version__ >= "2.1.0"
69
+ except Exception:
70
+ support_bf16_triu = False
71
+
72
+ logger = logging.get_logger(__name__)
73
+
74
+ _CONFIG_FOR_DOC = "InternLM2Config"
75
+
76
+ '''更改的部分全部用单引号注释
77
+
78
+ 更改了 loss 计算方式: 应该只需要更改 InternLM2ForCausalLM.forward()
79
+ 更改了 rmsnorm
80
+
81
+ TODO: 区分 attention_mask 与 causal_mask, 目前混用
82
+
83
+ 前者是用于 padding inputs, 忽略 pad token, AR/Diffu 均有可能需要使用
84
+ 后者用于下三角 causal mask matrix, 在 diffusion 过程需要禁用
85
+ 目前更改了 InternLM2FlashAttention2 的代码,将 attn_mask 视作 padding_mask, 将 padding 改为变长逻辑
86
+ 在 config.json 中设置 is_causal, 用来禁用 causal_mask
87
+ '''
88
+
89
+
90
+ def _get_unpad_data(attention_mask):
91
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
92
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
93
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
94
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0,
95
+ dtype=torch.int32), (1, 0)) # pylint: disable=E1102
96
+ return (
97
+ indices,
98
+ cu_seqlens,
99
+ max_seqlen_in_batch,
100
+ )
101
+
102
+
103
+ '''
104
+ class RMSNorm(nn.Module):
105
+ """InternLM2RMSNorm is equivalent to T5LayerNorm."""
106
+
107
+ def __init__(self, hidden_size, eps=1e-6):
108
+ super().__init__()
109
+ self.weight = nn.Parameter(torch.ones(hidden_size))
110
+ self.variance_epsilon = eps
111
+
112
+ def forward(self, hidden_states):
113
+ input_dtype = hidden_states.dtype
114
+ hidden_states = hidden_states.to(torch.float32)
115
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
116
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
117
+ return self.weight * hidden_states.to(input_dtype)
118
+
119
+
120
+ ALL_LAYERNORM_LAYERS.append(InternLM2RMSNorm)
121
+ '''
122
+ # ALL_LAYERNORM_LAYERS.append(RMSNorm)
123
+
124
+ class RMSNorm(nn.Module):
125
+ """InternLM2RMSNorm is equivalent to T5LayerNorm."""
126
+
127
+ def __init__(self, hidden_size, eps=1e-6):
128
+ super().__init__()
129
+ self.weight = nn.Parameter(torch.ones(hidden_size))
130
+ self.variance_epsilon = eps
131
+
132
+ def forward(self, hidden_states):
133
+ input_dtype = hidden_states.dtype
134
+ hidden_states = hidden_states.to(torch.float32)
135
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
136
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
137
+ return self.weight * hidden_states.to(input_dtype)
138
+
139
+
140
+ # class RMSNorm(nn.Module):
141
+ # """
142
+ # RMS layer norm, a simplified :class:`LayerNorm` implementation
143
+ # """
144
+
145
+ # def __init__(
146
+ # self,
147
+ # config: ModelConfig,
148
+ # size: Optional[int] = None,
149
+ # elementwise_affine: Optional[bool] = None,
150
+ # eps: float = 1e-5,
151
+ # ):
152
+ # # super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
153
+ # super().__init__()
154
+ # self.config = config
155
+ # self.eps = eps
156
+ # self.normalized_shape = (size or config.d_model,)
157
+ # if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
158
+ # self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
159
+ # use_bias = self.config.bias_for_layer_norm
160
+ # if use_bias is None:
161
+ # use_bias = self.config.include_bias
162
+ # if use_bias:
163
+ # self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
164
+ # else:
165
+ # self.register_parameter("bias", None)
166
+ # else:
167
+ # self.register_parameter("bias", None)
168
+ # self.register_parameter("weight", None)
169
+
170
+ # def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ # with torch.autocast(enabled=False, device_type=x.device.type):
172
+ # og_dtype = x.dtype
173
+ # x = x.to(torch.float32)
174
+ # variance = x.pow(2).mean(-1, keepdim=True)
175
+ # x = x * torch.rsqrt(variance + self.eps)
176
+ # x = x.to(og_dtype)
177
+
178
+ # if self.weight is not None:
179
+ # if self.bias is not None:
180
+ # return self.weight * x + self.bias
181
+ # else:
182
+ # return self.weight * x
183
+ # else:
184
+ # return x
185
+
186
+
187
+ class InternLM2RotaryEmbedding(nn.Module):
188
+ """Rotary Position Embedding for the InternLM2 model. Credits to the Reddit user /u/lucidrains."""
189
+
190
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
191
+ super().__init__()
192
+ self.scaling_factor = scaling_factor
193
+ self.dim = dim
194
+ self.max_position_embeddings = max_position_embeddings
195
+ self.base = base
196
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
197
+ 2, dtype=torch.int64).float().to(device) / self.dim))
198
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
199
+ # For BC we register cos and sin cached
200
+ self.max_seq_len_cached = max_position_embeddings
201
+
202
+ @torch.no_grad()
203
+ def forward(self, x, position_ids):
204
+ # x: [bs, num_attention_heads, seq_len, head_size]
205
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(
206
+ position_ids.shape[0], -1, 1)
207
+ position_ids_expanded = position_ids[:, None, :].float()
208
+ # Force float32 since bfloat16 loses precision on long contexts
209
+ # See https://github.com/huggingface/transformers/pull/29285
210
+ device_type = x.device.type
211
+ device_type = device_type if isinstance(
212
+ device_type, str) and device_type != "mps" else "cpu"
213
+ with torch.autocast(device_type=device_type, enabled=False):
214
+ freqs = (inv_freq_expanded.float() @
215
+ position_ids_expanded.float()).transpose(1, 2)
216
+ emb = torch.cat((freqs, freqs), dim=-1)
217
+ cos = emb.cos()
218
+ sin = emb.sin()
219
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
220
+
221
+
222
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
223
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
224
+
225
+ def forward(self, x, position_ids):
226
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
227
+ position_ids = position_ids.float() / self.scaling_factor
228
+ cos, sin = super().forward(x, position_ids)
229
+ return cos, sin
230
+
231
+
232
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
233
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
234
+ Credits to the Reddit users /u/bloc97 and /u/emozilla"""
235
+
236
+ def forward(self, x, position_ids):
237
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
238
+ seq_len = torch.max(position_ids) + 1
239
+ if seq_len > self.max_position_embeddings:
240
+ base = self.base * (
241
+ (self.scaling_factor * seq_len /
242
+ self.max_position_embeddings) - (self.scaling_factor - 1)
243
+ ) ** (self.dim / (self.dim - 2))
244
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2,
245
+ dtype=torch.int64).float().to(x.device) / self.dim))
246
+ # TODO joao: this may break with compilation
247
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
248
+
249
+ cos, sin = super().forward(x, position_ids)
250
+ return cos, sin
251
+
252
+
253
+ def rotate_half(x):
254
+ """Rotates half the hidden dims of the input."""
255
+ x1 = x[..., : x.shape[-1] // 2]
256
+ x2 = x[..., x.shape[-1] // 2:]
257
+ return torch.cat((-x2, x1), dim=-1)
258
+
259
+
260
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): # pylint: disable=unused-argument
261
+ """Applies Rotary Position Embedding to the query and key tensors.
262
+
263
+ Args:
264
+ q (`torch.Tensor`): The query tensor.
265
+ k (`torch.Tensor`): The key tensor.
266
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
267
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
268
+ position_ids (`torch.Tensor`, *optional*):
269
+ Deprecated and unused.
270
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
271
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
272
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
273
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
274
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
275
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
276
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
277
+ Returns:
278
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
279
+ """
280
+ cos = cos.unsqueeze(unsqueeze_dim)
281
+ sin = sin.unsqueeze(unsqueeze_dim)
282
+ q_embed = (q * cos) + (rotate_half(q) * sin)
283
+ k_embed = (k * cos) + (rotate_half(k) * sin)
284
+ return q_embed, k_embed
285
+
286
+
287
+ class InternLM2MLP(nn.Module):
288
+ """MLP for InternLM2 model."""
289
+
290
+ def __init__(self, config):
291
+ super().__init__()
292
+ self.config = config
293
+ self.hidden_size = config.hidden_size
294
+ self.intermediate_size = config.intermediate_size
295
+ self.w1 = nn.Linear(
296
+ self.hidden_size, self.intermediate_size, bias=False)
297
+ self.w3 = nn.Linear(
298
+ self.hidden_size, self.intermediate_size, bias=False)
299
+ self.w2 = nn.Linear(self.intermediate_size,
300
+ self.hidden_size, bias=False)
301
+ self.act_fn = ACT2FN[config.hidden_act]
302
+
303
+ def forward(self, x):
304
+ a = self.w1(x)
305
+ b = self.w3(x)
306
+ act = self.act_fn(a) * b
307
+ output = self.w2(act)
308
+ return output
309
+
310
+
311
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
312
+ """
313
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
314
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
315
+ """
316
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
317
+ if n_rep == 1:
318
+ return hidden_states
319
+ hidden_states = hidden_states[:, :, None, :, :].expand(
320
+ batch, num_key_value_heads, n_rep, slen, head_dim)
321
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
322
+
323
+
324
+ class InternLM2Attention(nn.Module):
325
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
326
+
327
+ def __init__(self, config: InternLM2Config, layer_idx: Optional[int] = None):
328
+ super().__init__()
329
+ self.config = config
330
+ self.layer_idx = layer_idx
331
+ if layer_idx is None:
332
+ logger.warning_once(
333
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
334
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
335
+ "when creating this class."
336
+ )
337
+
338
+ self.hidden_size = config.hidden_size
339
+ self.num_heads = config.num_attention_heads
340
+ self.head_dim = self.hidden_size // self.num_heads
341
+ self.num_key_value_heads = config.num_key_value_heads
342
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
343
+ self.max_position_embeddings = config.max_position_embeddings
344
+ self.rope_theta = config.rope_theta
345
+ self.is_causal = config.is_causal
346
+
347
+ if (self.head_dim * self.num_heads) != self.hidden_size:
348
+ raise ValueError(
349
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
350
+ f" and `num_heads`: {self.num_heads})."
351
+ )
352
+
353
+ self.wqkv = nn.Linear(
354
+ self.hidden_size,
355
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
356
+ bias=config.bias,
357
+ )
358
+ self.wo = nn.Linear(self.num_heads * self.head_dim,
359
+ self.hidden_size, bias=config.bias)
360
+
361
+ self._init_rope()
362
+
363
+ def _init_rope(self):
364
+
365
+ # self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
366
+ '''更改 rotary
367
+ self.rotary_emb 可以被弃用了,直接用 self.rotary 即可
368
+ 注意他们的返回值和参数不一样
369
+ '''
370
+ if self.config.rope_scaling is None:
371
+ self.rotary_emb = InternLM2RotaryEmbedding(
372
+ self.head_dim,
373
+ max_position_embeddings=self.max_position_embeddings,
374
+ base=self.rope_theta,
375
+ )
376
+ else:
377
+ scaling_type = self.config.rope_scaling["type"]
378
+ scaling_factor = self.config.rope_scaling["factor"]
379
+ if scaling_type == "linear":
380
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
381
+ self.head_dim,
382
+ max_position_embeddings=self.max_position_embeddings,
383
+ scaling_factor=scaling_factor,
384
+ base=self.rope_theta,
385
+ )
386
+ elif scaling_type == "dynamic":
387
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
388
+ self.head_dim,
389
+ max_position_embeddings=self.max_position_embeddings,
390
+ scaling_factor=scaling_factor,
391
+ base=self.rope_theta,
392
+ )
393
+ else:
394
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
395
+
396
+ def forward(
397
+ self,
398
+ hidden_states: torch.Tensor,
399
+ attention_mask: Optional[torch.Tensor] = None,
400
+ position_ids: Optional[torch.LongTensor] = None,
401
+ past_key_value: Optional[Cache] = None,
402
+ output_attentions: bool = False,
403
+ use_cache: bool = False, # pylint: disable=unused-argument
404
+ cache_position: Optional[torch.LongTensor] = None,
405
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
406
+ raise RuntimeError("暂不支持 eager attention, 请切换 flash attention")
407
+ bsz, q_len, _ = hidden_states.size()
408
+
409
+ # print("这里没有使用 flashattn , 用了一般的forward")
410
+ if self.config.pretraining_tp > 1:
411
+ # split qkv_states by tp size
412
+ key_value_slicing = (self.num_key_value_heads *
413
+ self.head_dim) // self.config.pretraining_tp
414
+ qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0)
415
+ qkv_states = torch.cat(
416
+ [F.linear(hidden_states, qkv_slice) for qkv_slice in qkv_slices], dim=-1 # pylint: disable=E1102
417
+ )
418
+ else:
419
+ qkv_states = self.wqkv(hidden_states)
420
+
421
+ qkv_states = rearrange(
422
+ qkv_states,
423
+ "b q (h gs d) -> b q h gs d",
424
+ gs=2 + self.num_key_value_groups,
425
+ d=self.head_dim,
426
+ )
427
+
428
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
429
+ query_states = rearrange(
430
+ query_states, "b q h gs d -> b q (h gs) d").transpose(1, 2)
431
+ key_states = qkv_states[..., -2, :].transpose(1, 2)
432
+ value_states = qkv_states[..., -1, :].transpose(1, 2)
433
+ '''更改 rotary'''
434
+ cos, sin = self.rotary_emb(value_states, position_ids)
435
+ query_states, key_states = apply_rotary_pos_emb(
436
+ query_states, key_states, cos, sin, position_ids)
437
+ # 从 seqlen_offset, max_seqlen = 0, q_len 开始
438
+ # 到计算 query_states, key_states 结束,都是更改后的代码,新的计算方式
439
+ '''
440
+ seqlen_offset, max_seqlen = 0, q_len
441
+ if past_key_value is not None:
442
+ seqlen_offset = past_key_value.get_seq_length(self.layer_idx)
443
+ max_seqlen = query_states.shape[1] + seqlen_offset
444
+
445
+ if attention_mask is not None:
446
+ # to deliminate the offsets of padding tokens
447
+ seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0)
448
+ max_seqlen = query_states.shape[1] + max(seqlen_offset)
449
+
450
+ if self.max_position_embeddings is not None:
451
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
452
+ query_states, key_states = self.rotary(query_states, key_states, seqlen_offset, max_seqlen)
453
+ '''
454
+ if past_key_value is not None:
455
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
456
+ '''更改 rotary'''
457
+ cache_kwargs = {"sin": sin, "cos": cos,
458
+ "cache_position": cache_position}
459
+ '''
460
+ cache_kwargs = {"cache_position": cache_position}
461
+ '''
462
+ key_states, value_states = past_key_value.update(
463
+ key_states, value_states, self.layer_idx, cache_kwargs)
464
+
465
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
466
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
467
+
468
+ attn_weights = torch.matmul(
469
+ query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
470
+
471
+ if attention_mask is not None: # no matter the length, we just slice it
472
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
473
+ attn_weights = attn_weights + causal_mask
474
+
475
+ # upcast attention to fp32
476
+ attn_weights = nn.functional.softmax(
477
+ attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
478
+ attn_output = torch.matmul(attn_weights, value_states)
479
+
480
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
481
+ raise ValueError(
482
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
483
+ f" {attn_output.size()}"
484
+ )
485
+
486
+ attn_output = attn_output.transpose(1, 2).contiguous()
487
+
488
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
489
+
490
+ if self.config.pretraining_tp > 1:
491
+ attn_output = attn_output.split(
492
+ self.hidden_size // self.config.pretraining_tp, dim=2)
493
+ o_proj_slices = self.wo.weight.split(
494
+ self.hidden_size // self.config.pretraining_tp, dim=1)
495
+ attn_output = sum(
496
+ [
497
+ F.linear(attn_output[i], o_proj_slices[i]
498
+ ) # pylint: disable=E1102
499
+ for i in range(self.config.pretraining_tp)
500
+ ]
501
+ )
502
+ else:
503
+ attn_output = self.wo(attn_output)
504
+
505
+ if not output_attentions:
506
+ attn_weights = None
507
+
508
+ return attn_output, attn_weights, past_key_value
509
+
510
+
511
+ class InternLM2FlashAttention2(InternLM2Attention):
512
+ """
513
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
514
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
515
+ flash attention and deal with padding tokens in case the input contains any of them.
516
+ """
517
+
518
+ def __init__(self, *args, **kwargs):
519
+ super().__init__(*args, **kwargs)
520
+
521
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
522
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement,
523
+ # that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
524
+ # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
525
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1)
526
+ # produces a wrong mask (top-left).
527
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states: torch.Tensor,
532
+ attention_mask: Optional[torch.LongTensor] = None,
533
+ position_ids: Optional[torch.LongTensor] = None,
534
+ past_key_value: Optional[Cache] = None,
535
+ output_attentions: bool = False,
536
+ use_cache: bool = False,
537
+ cache_position: Optional[torch.LongTensor] = None,
538
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
539
+ if isinstance(past_key_value, StaticCache):
540
+ raise ValueError(
541
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
542
+ "make sure to use `sdpa` in the mean time, and open an issue at "
543
+ "https://github.com/huggingface/transformers"
544
+ )
545
+
546
+ output_attentions = False
547
+
548
+ bsz, q_len, _ = hidden_states.size()
549
+
550
+ qkv_states = self.wqkv(hidden_states)
551
+
552
+ qkv_states = rearrange(
553
+ qkv_states,
554
+ "b q (h gs d) -> b q h gs d",
555
+ gs=2 + self.num_key_value_groups,
556
+ d=self.head_dim,
557
+ )
558
+
559
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
560
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
561
+ key_states = qkv_states[..., -2, :]
562
+ value_states = qkv_states[..., -1, :]
563
+
564
+ '''更改 rotary'''
565
+ query_states = query_states.transpose(1, 2)
566
+ key_states = key_states.transpose(1, 2)
567
+ value_states = value_states.transpose(1, 2)
568
+
569
+ cos, sin = self.rotary_emb(value_states, position_ids)
570
+ query_states, key_states = apply_rotary_pos_emb(
571
+ query_states, key_states, cos, sin, position_ids)
572
+ # 从 seqlen_offset, max_seqlen = 0, q_len 开始
573
+ # 到计算 query_states, key_states 结束,都是更改后的代码,新的计算方式
574
+ '''
575
+ seqlen_offset, max_seqlen = 0, q_len
576
+ if past_key_value is not None:
577
+ seqlen_offset = past_key_value.get_seq_length(self.layer_idx)
578
+ max_seqlen = query_states.shape[1] + seqlen_offset
579
+
580
+ if attention_mask is not None:
581
+ # to deliminate the offsets of padding tokens
582
+ seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0)
583
+ max_seqlen = query_states.shape[1] + max(seqlen_offset)
584
+
585
+ if self.max_position_embeddings is not None:
586
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
587
+ query_states, key_states = self.rotary(query_states, key_states, seqlen_offset, max_seqlen)
588
+ '''
589
+ if past_key_value is not None:
590
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
591
+ '''更改 rotary'''
592
+ cache_kwargs = {"sin": sin, "cos": cos,
593
+ "cache_position": cache_position}
594
+ '''
595
+ cache_kwargs = {"cache_position": cache_position}
596
+ '''
597
+ key_states, value_states = past_key_value.update(
598
+ key_states, value_states, self.layer_idx, cache_kwargs)
599
+
600
+ '''更改rotary'''
601
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
602
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
603
+ # to be able to avoid many of these transpose/reshape/view.
604
+ query_states = query_states.transpose(1, 2)
605
+ key_states = key_states.transpose(1, 2)
606
+ value_states = value_states.transpose(1, 2)
607
+
608
+ # dropout_rate = self.attention_dropout if self.training else 0.0
609
+ dropout_rate = 0.0
610
+
611
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
612
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
613
+ # cast them back in the correct dtype just to be sure everything works as expected.
614
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
615
+ # in fp32. (InternLM2RMSNorm handles it correctly)
616
+
617
+ input_dtype = query_states.dtype
618
+ if input_dtype == torch.float32:
619
+ if torch.is_autocast_enabled():
620
+ target_dtype = torch.get_autocast_gpu_dtype()
621
+ # Handle the case where the model is quantized
622
+ elif hasattr(self.config, "_pre_quantization_dtype"):
623
+ target_dtype = self.config._pre_quantization_dtype
624
+ else:
625
+ target_dtype = self.wqkv.weight.dtype
626
+
627
+ logger.warning_once(
628
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
629
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
630
+ f" {target_dtype}."
631
+ )
632
+
633
+ query_states = query_states.to(target_dtype)
634
+ key_states = key_states.to(target_dtype)
635
+ value_states = value_states.to(target_dtype)
636
+
637
+ attn_output = self._flash_attention_forward(
638
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
639
+ )
640
+
641
+ attn_output = attn_output.reshape(
642
+ bsz, q_len, self.hidden_size).contiguous()
643
+ attn_output = self.wo(attn_output)
644
+
645
+ if not output_attentions:
646
+ attn_weights = None
647
+
648
+ return attn_output, attn_weights, past_key_value # pylint: disable=E0606
649
+
650
+ def _flash_attention_forward(
651
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
652
+ ):
653
+ """
654
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
655
+ first unpad the input, then computes the attention scores and pad the final attention scores.
656
+
657
+ Args:
658
+ query_states (`torch.Tensor`):
659
+ Input query states to be passed to Flash Attention API
660
+ key_states (`torch.Tensor`):
661
+ Input key states to be passed to Flash Attention API
662
+ value_states (`torch.Tensor`):
663
+ Input value states to be passed to Flash Attention API
664
+ attention_mask (`torch.Tensor`):
665
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
666
+ position of padding tokens and 1 for the position of non-padding tokens.
667
+ dropout (`float`):
668
+ Attention dropout
669
+ softmax_scale (`float`, *optional*):
670
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
671
+ """
672
+ if not self._flash_attn_uses_top_left_mask:
673
+ causal = self.is_causal
674
+ else:
675
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
676
+ # For details, please see the comment in InternLM2FlashAttention2 __init__.
677
+ causal = self.is_causal and query_length != 1
678
+
679
+ # Contains at least one padding token in the sequence
680
+ if attention_mask is not None:
681
+ batch_size = query_states.shape[0]
682
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
683
+ query_states, key_states, value_states, attention_mask, query_length
684
+ )
685
+
686
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
687
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
688
+
689
+ attn_output_unpad = flash_attn_varlen_func( # pylint: disable=E0606
690
+ query_states,
691
+ key_states,
692
+ value_states,
693
+ cu_seqlens_q=cu_seqlens_q,
694
+ cu_seqlens_k=cu_seqlens_k,
695
+ max_seqlen_q=max_seqlen_in_batch_q,
696
+ max_seqlen_k=max_seqlen_in_batch_k,
697
+ dropout_p=dropout,
698
+ softmax_scale=softmax_scale,
699
+ causal=causal,
700
+ )
701
+
702
+ attn_output = pad_input(
703
+ attn_output_unpad, indices_q, batch_size, query_length) # pylint: disable=E0606
704
+ else:
705
+ attn_output = flash_attn_func( # pylint: disable=E0606
706
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
707
+ )
708
+
709
+ return attn_output
710
+
711
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
712
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
713
+ attention_mask)
714
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
715
+
716
+ key_layer = index_first_axis( # pylint: disable=E0606
717
+ key_layer.reshape(batch_size * kv_seq_len,
718
+ num_key_value_heads, head_dim), indices_k
719
+ )
720
+ value_layer = index_first_axis( # pylint: disable=E0606
721
+ value_layer.reshape(batch_size * kv_seq_len,
722
+ num_key_value_heads, head_dim), indices_k
723
+ )
724
+ if query_length == kv_seq_len:
725
+ query_layer = index_first_axis( # pylint: disable=E0606
726
+ query_layer.reshape(batch_size * kv_seq_len,
727
+ self.num_heads, head_dim), indices_k
728
+ )
729
+ cu_seqlens_q = cu_seqlens_k
730
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
731
+ indices_q = indices_k
732
+ elif query_length == 1:
733
+ max_seqlen_in_batch_q = 1
734
+ cu_seqlens_q = torch.arange(
735
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
736
+ ) # There is a memcpy here, that is very bad.
737
+ indices_q = cu_seqlens_q[:-1]
738
+ query_layer = query_layer.squeeze(1)
739
+ else:
740
+ # The -q_len: slice assumes left padding.
741
+ attention_mask = attention_mask[:, -query_length:]
742
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( # pylint: disable=E0606
743
+ query_layer, attention_mask
744
+ )
745
+
746
+ return (
747
+ query_layer,
748
+ key_layer,
749
+ value_layer,
750
+ indices_q,
751
+ (cu_seqlens_q, cu_seqlens_k),
752
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
753
+ )
754
+
755
+
756
+ # Copied from transformers.models.llama.modeling_llama.LllamaSdpaAttention with Llama->InternLM2
757
+ class InternLM2SdpaAttention(InternLM2Attention):
758
+ """
759
+ InternLM2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
760
+ `InternLM2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
761
+ to adapt to SDPA API.
762
+ """
763
+
764
+ # Adapted from InternLM2Attention.forward
765
+ def forward(
766
+ self,
767
+ hidden_states: torch.Tensor,
768
+ attention_mask: Optional[torch.Tensor] = None,
769
+ position_ids: Optional[torch.LongTensor] = None,
770
+ past_key_value: Optional[Cache] = None,
771
+ output_attentions: bool = False,
772
+ use_cache: bool = False,
773
+ cache_position: Optional[torch.LongTensor] = None,
774
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
775
+ if output_attentions:
776
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"`
777
+ # once this is implemented.
778
+ logger.warning_once(
779
+ "InternLM2Model uses InternLM2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` "
780
+ "does not support `output_attentions=True`. Falling back to the manual attention implementation, "
781
+ "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
782
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
783
+ )
784
+ return super().forward(
785
+ hidden_states=hidden_states,
786
+ attention_mask=attention_mask,
787
+ position_ids=position_ids,
788
+ past_key_value=past_key_value,
789
+ output_attentions=output_attentions,
790
+ use_cache=use_cache,
791
+ cache_position=cache_position,
792
+ )
793
+
794
+ bsz, q_len, _ = hidden_states.size()
795
+
796
+ qkv_states = self.wqkv(hidden_states)
797
+
798
+ qkv_states = rearrange(
799
+ qkv_states,
800
+ "b q (h gs d) -> b q h gs d",
801
+ gs=2 + self.num_key_value_groups,
802
+ d=self.head_dim,
803
+ )
804
+
805
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
806
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
807
+ key_states = qkv_states[..., -2, :]
808
+ value_states = qkv_states[..., -1, :]
809
+
810
+ query_states = query_states.transpose(1, 2)
811
+ key_states = key_states.transpose(1, 2)
812
+ value_states = value_states.transpose(1, 2)
813
+
814
+ '''更改 rotary
815
+ cos, sin = self.rotary_emb(value_states, position_ids)
816
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
817
+ # 从 seqlen_offset, max_seqlen = 0, q_len 开始
818
+ # 到计算 query_states, key_states 结束,都是更改后的代码,新的计算方式
819
+ '''
820
+ seqlen_offset, max_seqlen = 0, q_len
821
+ if past_key_value is not None:
822
+ seqlen_offset = past_key_value.get_seq_length(self.layer_idx)
823
+ max_seqlen = query_states.shape[1] + seqlen_offset
824
+
825
+ if attention_mask is not None:
826
+ # to deliminate the offsets of padding tokens
827
+ seqlen_offset = (
828
+ seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0)
829
+ max_seqlen = query_states.shape[1] + max(seqlen_offset)
830
+
831
+ if self.max_position_embeddings is not None:
832
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
833
+ # query_states, key_states = self.rotary_emb(
834
+ # query_states, key_states, seqlen_offset, max_seqlen)
835
+ cos, sin = self.rotary_emb(value_states, position_ids)
836
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
837
+ if past_key_value is not None:
838
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
839
+ '''更改 rotary
840
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
841
+ '''
842
+ cache_kwargs = {"cache_position": cache_position}
843
+
844
+ key_states, value_states = past_key_value.update(
845
+ key_states, value_states, self.layer_idx, cache_kwargs)
846
+
847
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
848
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
849
+
850
+ causal_mask = attention_mask
851
+ # if attention_mask is not None:
852
+ # causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
853
+
854
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with
855
+ # custom attn_mask, Reference: https://github.com/pytorch/pytorch/issues/112577.
856
+ if query_states.device.type == "cuda" and causal_mask is not None:
857
+ query_states = query_states.contiguous()
858
+ key_states = key_states.contiguous()
859
+ value_states = value_states.contiguous()
860
+
861
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of
862
+ # an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph
863
+ # options. An inline conditional prevents dynamic shapes from compiling.
864
+ # is_causal = bool(causal_mask is None and q_len > 1)
865
+
866
+ attn_output = torch.nn.functional.scaled_dot_product_attention( # pylint: disable=E1102
867
+ query_states,
868
+ key_states,
869
+ value_states,
870
+ attn_mask=None,
871
+ dropout_p=0.0,
872
+ is_causal=False,
873
+ )
874
+
875
+ attn_output = attn_output.transpose(1, 2).contiguous()
876
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
877
+
878
+ attn_output = self.wo(attn_output)
879
+
880
+ return attn_output, None, past_key_value
881
+
882
+
883
+ INTERNLM2_ATTENTION_CLASSES = {
884
+ "eager": InternLM2Attention,
885
+ "flash_attention_2": InternLM2FlashAttention2,
886
+ "sdpa": InternLM2SdpaAttention,
887
+ }
888
+
889
+
890
+ # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM2
891
+ class InternLM2DecoderLayer(nn.Module):
892
+ """InternLM2 Decoder Layer. This module is a single layer of the InternLM2 model."""
893
+
894
+ def __init__(self, config: InternLM2Config, layer_idx: int):
895
+ super().__init__()
896
+ self.hidden_size = config.hidden_size
897
+ self.layer_idx = layer_idx
898
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](
899
+ config=config, layer_idx=layer_idx)
900
+
901
+ self.feed_forward = InternLM2MLP(config)
902
+ '''
903
+ # 更改 rmsnorm
904
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
905
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
906
+ '''
907
+ self.attention_norm = RMSNorm(
908
+ config.hidden_size, eps=config.rms_norm_eps)
909
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
910
+
911
+ def forward(
912
+ self,
913
+ hidden_states: torch.Tensor,
914
+ attention_mask: Optional[torch.Tensor] = None,
915
+ position_ids: Optional[torch.LongTensor] = None,
916
+ past_key_value: Optional[Cache] = None,
917
+ output_attentions: Optional[bool] = False,
918
+ use_cache: Optional[bool] = False,
919
+ cache_position: Optional[torch.LongTensor] = None,
920
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
921
+ """
922
+ Args:
923
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
924
+ attention_mask (`torch.FloatTensor`, *optional*):
925
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
926
+ query_sequence_length, key_sequence_length)` if default attention is used.
927
+ output_attentions (`bool`, *optional*):
928
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
929
+ returned tensors for more detail.
930
+ use_cache (`bool`, *optional*):
931
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
932
+ (see `past_key_values`).
933
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
934
+ """
935
+
936
+ residual = hidden_states
937
+
938
+ hidden_states = self.attention_norm(hidden_states)
939
+ # Self Attention
940
+ hidden_states, self_attn_weights, present_key_value = self.attention(
941
+ hidden_states=hidden_states,
942
+ attention_mask=attention_mask,
943
+ position_ids=position_ids,
944
+ past_key_value=past_key_value,
945
+ output_attentions=output_attentions,
946
+ use_cache=use_cache,
947
+ cache_position=cache_position,
948
+ )
949
+ hidden_states = residual + hidden_states
950
+
951
+ # Fully Connected
952
+ residual = hidden_states
953
+ hidden_states = self.ffn_norm(hidden_states)
954
+ hidden_states = self.feed_forward(hidden_states)
955
+ hidden_states = residual + hidden_states
956
+ outputs = (hidden_states,)
957
+
958
+ if output_attentions:
959
+ outputs += (self_attn_weights,)
960
+
961
+ if use_cache:
962
+ outputs += (present_key_value,)
963
+
964
+ return outputs
965
+
966
+
967
+ InternLM2_START_DOCSTRING = r"""
968
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
969
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
970
+ etc.)
971
+
972
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
973
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
974
+ and behavior.
975
+
976
+ Parameters:
977
+ config ([`InternLM2Config`]):
978
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
979
+ load the weights associated with the model, only the configuration. Check out the
980
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
981
+ """
982
+
983
+
984
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
985
+ @add_start_docstrings(
986
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
987
+ InternLM2_START_DOCSTRING,
988
+ )
989
+ class InternLM2PreTrainedModel(PreTrainedModel):
990
+ """
991
+ InternLM2 pretraiend model's base class.
992
+ """
993
+
994
+ config_class = InternLM2Config
995
+ base_model_prefix = "model"
996
+ supports_gradient_checkpointing = True
997
+ _no_split_modules = ["InternLM2DecoderLayer"]
998
+ _skip_keys_device_placement = ["past_key_values"]
999
+ _supports_flash_attn_2 = True
1000
+ _supports_sdpa = True
1001
+ _supports_cache_class = True
1002
+ _supports_quantized_cache = True
1003
+ _supports_static_cache = True
1004
+
1005
+ def _init_weights(
1006
+ self,
1007
+ module,
1008
+ rescale_prenorm_residual: bool = True,
1009
+ num_residuals_per_layer: int = 2,
1010
+ ):
1011
+ # 参考 https://github.com/fla-org/flash-linear-attention/blob/f0e66517b46f062dd2212b68b39dc9fbf3cd52de/fla/models/transformer/modeling_transformer.py#L159
1012
+ # 参考 https://huggingface.co/yulan-team/YuLan-Mini/blob/main/config.json 里的 std, 但可能太小
1013
+ # 参考 deepseek v3: std=6e-3
1014
+ std = self.config.initializer_range
1015
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
1016
+ # Slightly different from the TF version which uses truncated_normal for initialization
1017
+ # cf https://github.com/pytorch/pytorch/pull/5617
1018
+ nn.init.normal_(module.weight, mean=0.0, std=std)
1019
+ if module.bias is not None:
1020
+ nn.init.zeros_(module.bias)
1021
+ elif isinstance(module, nn.Embedding):
1022
+ nn.init.normal_(module.weight, mean=0.0, std=std)
1023
+ if module.padding_idx is not None:
1024
+ nn.init.zeros_(module.weight[module.padding_idx])
1025
+
1026
+ if rescale_prenorm_residual:
1027
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
1028
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
1029
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
1030
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
1031
+ #
1032
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
1033
+ for name, p in module.named_parameters():
1034
+ if name in ["wo.weight", "w2.weight"]:
1035
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
1036
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
1037
+ # We need to reinit p since this code could be called multiple times
1038
+ # Having just p *= scale would repeatedly scale it down
1039
+ with torch.no_grad():
1040
+ # 如果需要 rescale, 直接对 params 赋值没有用,需要对 params.data 进行修改
1041
+ # 这里很奇怪,理论上可以直接对 p 赋值
1042
+ p.data /= math.sqrt(num_residuals_per_layer *
1043
+ self.config.num_hidden_layers)
1044
+
1045
+
1046
+ InternLM2_INPUTS_DOCSTRING = r"""
1047
+ Args:
1048
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1049
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1050
+ it.
1051
+
1052
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1053
+ [`PreTrainedTokenizer.__call__`] for details.
1054
+
1055
+ [What are input IDs?](../glossary#input-ids)
1056
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1057
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1058
+
1059
+ - 1 for tokens that are **not masked**,
1060
+ - 0 for tokens that are **masked**.
1061
+
1062
+ [What are attention masks?](../glossary#attention-mask)
1063
+
1064
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1065
+ [`PreTrainedTokenizer.__call__`] for details.
1066
+
1067
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1068
+ `past_key_values`).
1069
+
1070
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1071
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1072
+ information on the default strategy.
1073
+
1074
+ - 1 indicates the head is **not masked**,
1075
+ - 0 indicates the head is **masked**.
1076
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1077
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1078
+ config.n_positions - 1]`.
1079
+
1080
+ [What are position IDs?](../glossary#position-ids)
1081
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1082
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1083
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1084
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1085
+
1086
+ Two formats are allowed:
1087
+ - a [`~cache_utils.Cache`] instance;
1088
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1089
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1090
+ cache format.
1091
+
1092
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1093
+ legacy cache format will be returned.
1094
+
1095
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1096
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1097
+ of shape `(batch_size, sequence_length)`.
1098
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1099
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1100
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1101
+ model's internal embedding lookup matrix.
1102
+ use_cache (`bool`, *optional*):
1103
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1104
+ `past_key_values`).
1105
+ output_attentions (`bool`, *optional*):
1106
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1107
+ tensors for more detail.
1108
+ output_hidden_states (`bool`, *optional*):
1109
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1110
+ more detail.
1111
+ return_dict (`bool`, *optional*):
1112
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1113
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1114
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1115
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1116
+ the complete sequence length.
1117
+ """
1118
+
1119
+
1120
+ # Modified from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM2
1121
+ @add_start_docstrings(
1122
+ "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
1123
+ InternLM2_START_DOCSTRING,
1124
+ )
1125
+ class InternLM2Model(InternLM2PreTrainedModel):
1126
+ """
1127
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
1128
+
1129
+ Args:
1130
+ config: InternLM2Config
1131
+ """
1132
+
1133
+ _auto_class = "AutoModel"
1134
+
1135
+ def __init__(self, config: InternLM2Config):
1136
+ super().__init__(config)
1137
+ self.padding_idx = config.pad_token_id
1138
+ self.vocab_size = config.vocab_size
1139
+ self.config = config
1140
+
1141
+ self.tok_embeddings = nn.Embedding(
1142
+ config.vocab_size, config.hidden_size, self.padding_idx)
1143
+
1144
+ self.layers = nn.ModuleList(
1145
+ [InternLM2DecoderLayer(config, layer_idx)
1146
+ for layer_idx in range(config.num_hidden_layers)]
1147
+ )
1148
+ '''
1149
+ # 更改 rmsnorm
1150
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1151
+ '''
1152
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1153
+
1154
+ self.gradient_checkpointing = False
1155
+ # Initialize weights and apply final processing
1156
+ self.post_init()
1157
+
1158
+ def get_input_embeddings(self):
1159
+ return self.tok_embeddings
1160
+
1161
+ def set_input_embeddings(self, value):
1162
+ self.tok_embeddings = value
1163
+
1164
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1165
+ def forward(
1166
+ self,
1167
+ input_ids: torch.LongTensor = None,
1168
+ attention_mask: Optional[torch.Tensor] = None,
1169
+ position_ids: Optional[torch.LongTensor] = None,
1170
+ past_key_values: Optional[Union[Cache,
1171
+ List[torch.FloatTensor]]] = None,
1172
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1173
+ use_cache: Optional[bool] = None,
1174
+ output_attentions: Optional[bool] = None,
1175
+ output_hidden_states: Optional[bool] = None,
1176
+ return_dict: Optional[bool] = None,
1177
+ cache_position: Optional[torch.LongTensor] = None,
1178
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1179
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1180
+ output_hidden_states = (
1181
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1182
+ )
1183
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1184
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1185
+
1186
+ if (input_ids is None) ^ (inputs_embeds is not None):
1187
+ raise ValueError(
1188
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1189
+ )
1190
+
1191
+ if self.gradient_checkpointing and self.training and use_cache:
1192
+ logger.warning_once(
1193
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1194
+ )
1195
+ use_cache = False
1196
+
1197
+ if inputs_embeds is None:
1198
+ inputs_embeds = self.tok_embeddings(input_ids)
1199
+
1200
+ return_legacy_cache = False
1201
+ # kept for BC (non `Cache` `past_key_values` inputs)
1202
+ if use_cache and not isinstance(past_key_values, Cache):
1203
+ return_legacy_cache = True
1204
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1205
+
1206
+ if cache_position is None:
1207
+ past_seen_tokens = past_key_values.get_seq_length(
1208
+ ) if past_key_values is not None else 0
1209
+ cache_position = torch.arange(
1210
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1211
+ )
1212
+ if position_ids is None:
1213
+ position_ids = cache_position.unsqueeze(0)
1214
+
1215
+ ''' # internlm要在这里初始化 causal_mask
1216
+ causal_mask = self._update_causal_mask(
1217
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1218
+ )
1219
+ '''
1220
+ if self.config.is_causal:
1221
+ causal_mask = self._update_causal_mask(
1222
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1223
+ )
1224
+ else:
1225
+ causal_mask = None
1226
+
1227
+ # embed positions
1228
+ hidden_states = inputs_embeds
1229
+
1230
+ # decoder layers
1231
+ all_hidden_states = () if output_hidden_states else None
1232
+ all_self_attns = () if output_attentions else None
1233
+ next_decoder_cache = None
1234
+
1235
+ for decoder_layer in self.layers:
1236
+ if output_hidden_states:
1237
+ all_hidden_states += (hidden_states,)
1238
+
1239
+ if self.gradient_checkpointing and self.training:
1240
+ layer_outputs = self._gradient_checkpointing_func(
1241
+ decoder_layer.__call__,
1242
+ hidden_states,
1243
+ attention_mask,
1244
+ position_ids,
1245
+ past_key_values,
1246
+ output_attentions,
1247
+ use_cache,
1248
+ cache_position,
1249
+ )
1250
+ else:
1251
+ layer_outputs = decoder_layer(
1252
+ hidden_states,
1253
+ attention_mask=attention_mask,
1254
+ position_ids=position_ids,
1255
+ past_key_value=past_key_values,
1256
+ output_attentions=output_attentions,
1257
+ use_cache=use_cache,
1258
+ cache_position=cache_position,
1259
+ )
1260
+
1261
+ hidden_states = layer_outputs[0]
1262
+
1263
+ if use_cache:
1264
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1265
+
1266
+ if output_attentions:
1267
+ all_self_attns += (layer_outputs[1],)
1268
+
1269
+ hidden_states = self.norm(hidden_states)
1270
+
1271
+ # add hidden states from the last decoder layer
1272
+ if output_hidden_states:
1273
+ all_hidden_states += (hidden_states,)
1274
+
1275
+ next_cache = next_decoder_cache if use_cache else None
1276
+ if return_legacy_cache:
1277
+ next_cache = next_cache.to_legacy_cache()
1278
+
1279
+ if not return_dict:
1280
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1281
+ return BaseModelOutputWithPast(
1282
+ last_hidden_state=hidden_states,
1283
+ past_key_values=next_cache,
1284
+ hidden_states=all_hidden_states,
1285
+ attentions=all_self_attns,
1286
+ )
1287
+
1288
+ def _update_causal_mask(
1289
+ self,
1290
+ attention_mask: torch.Tensor,
1291
+ input_tensor: torch.Tensor,
1292
+ cache_position: torch.Tensor,
1293
+ past_key_values: Cache,
1294
+ output_attentions: bool,
1295
+ ):
1296
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length
1297
+ # even when the static KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at
1298
+ # each decode steps due to the dynamic shapes. (`recording cudagraph tree for symint key 13`, etc.), which is
1299
+ # VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`.
1300
+ # See more context in https://github.com/huggingface/transformers/pull/29114
1301
+
1302
+ if self.config.attn_implementation == "flash_attention_2":
1303
+ if attention_mask is not None and 0.0 in attention_mask:
1304
+ return attention_mask
1305
+ return None
1306
+
1307
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1308
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1309
+ # to infer the attention mask.
1310
+ past_seen_tokens = past_key_values.get_seq_length(
1311
+ ) if past_key_values is not None else 0
1312
+ using_static_cache = isinstance(past_key_values, StaticCache)
1313
+
1314
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1315
+ if self.config.attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1316
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1317
+ attention_mask,
1318
+ inputs_embeds=input_tensor,
1319
+ past_key_values_length=past_seen_tokens,
1320
+ is_training=self.training,
1321
+ ):
1322
+ return None
1323
+
1324
+ dtype, device = input_tensor.dtype, input_tensor.device
1325
+ min_dtype = torch.finfo(dtype).min
1326
+ sequence_length = input_tensor.shape[1]
1327
+ if using_static_cache:
1328
+ target_length = past_key_values.get_max_length()
1329
+ else:
1330
+ target_length = (
1331
+ attention_mask.shape[-1]
1332
+ if isinstance(attention_mask, torch.Tensor)
1333
+ else past_seen_tokens + sequence_length + 1
1334
+ )
1335
+
1336
+ if attention_mask is not None and attention_mask.dim() == 4:
1337
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
1338
+ if attention_mask.max() != 0:
1339
+ raise ValueError(
1340
+ "Custom 4D attention mask should be passed in inverted form with max==0`")
1341
+ causal_mask = attention_mask
1342
+ else:
1343
+ causal_mask = torch.full(
1344
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1345
+ if sequence_length != 1:
1346
+ if support_bf16_triu or dtype == torch.float32:
1347
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1348
+ else:
1349
+ triu_mask = torch.triu(torch.ones(
1350
+ causal_mask.size(), device=device), diagonal=1).bool()
1351
+ causal_mask.masked_fill_(~triu_mask, 0)
1352
+ causal_mask *= torch.arange(target_length,
1353
+ device=device) > cache_position.reshape(-1, 1)
1354
+ causal_mask = causal_mask[None, None, :, :].expand(
1355
+ input_tensor.shape[0], 1, -1, -1)
1356
+ if attention_mask is not None:
1357
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1358
+ mask_length = attention_mask.shape[-1]
1359
+ padding_mask = causal_mask[:, :, :,
1360
+ :mask_length] + attention_mask[:, None, None, :]
1361
+ padding_mask = padding_mask == 0
1362
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1363
+ padding_mask, min_dtype
1364
+ )
1365
+ if (
1366
+ self.config.attn_implementation == "sdpa"
1367
+ and attention_mask is not None
1368
+ and attention_mask.device.type == "cuda"
1369
+ and not output_attentions
1370
+ ):
1371
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1372
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1373
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1374
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1375
+ causal_mask, min_dtype) # pylint: disable=E1120
1376
+
1377
+ return causal_mask
1378
+
1379
+
1380
+ # Modified from transformers.models.llama.modeling_llama.LlamaForCausalLM
1381
+ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1382
+ """Causal language model (CLM) for InternLM2."""
1383
+
1384
+ _auto_class = "AutoModelForCausalLM"
1385
+ _tied_weights_keys = ["output.weight"]
1386
+
1387
+ def __init__(self, config):
1388
+ super().__init__(config)
1389
+ self.model = InternLM2Model(config)
1390
+ self.vocab_size = config.vocab_size
1391
+ self.output = nn.Linear(
1392
+ config.hidden_size, config.vocab_size, bias=False)
1393
+
1394
+ # Initialize weights and apply final processing
1395
+ self.post_init()
1396
+
1397
+ def get_input_embeddings(self):
1398
+ return self.model.tok_embeddings
1399
+
1400
+ def set_input_embeddings(self, value):
1401
+ self.model.tok_embeddings = value
1402
+
1403
+ def get_output_embeddings(self):
1404
+ return self.output
1405
+
1406
+ def set_output_embeddings(self, new_embeddings):
1407
+ self.output = new_embeddings
1408
+
1409
+ def set_decoder(self, decoder):
1410
+ self.model = decoder
1411
+
1412
+ def get_decoder(self):
1413
+ return self.model
1414
+
1415
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1416
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1417
+ def forward(
1418
+ self,
1419
+ input_ids: torch.LongTensor = None,
1420
+ attention_mask: Optional[torch.Tensor] = None,
1421
+ position_ids: Optional[torch.LongTensor] = None,
1422
+ past_key_values: Optional[Union[Cache,
1423
+ List[torch.FloatTensor]]] = None,
1424
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1425
+ labels: Optional[torch.LongTensor] = None,
1426
+ p_mask: Optional[torch.Tensor] = None,
1427
+ use_cache: Optional[bool] = None,
1428
+ output_attentions: Optional[bool] = None,
1429
+ output_hidden_states: Optional[bool] = None,
1430
+ return_dict: Optional[bool] = None,
1431
+ cache_position: Optional[torch.LongTensor] = None
1432
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1433
+ r"""
1434
+ Args:
1435
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1436
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1437
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1438
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1439
+
1440
+ Returns:
1441
+
1442
+ Example:
1443
+
1444
+ ```python
1445
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1446
+
1447
+ >>> model = InternLM2ForCausalLM.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
1448
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-InternLM2/InternLM2-2-7b-hf")
1449
+
1450
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1451
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1452
+
1453
+ >>> # Generate
1454
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1455
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1456
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1457
+ ```"""
1458
+
1459
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1460
+ output_hidden_states = (
1461
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1462
+ )
1463
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1464
+
1465
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1466
+ outputs = self.model(
1467
+ input_ids=input_ids,
1468
+ attention_mask=attention_mask,
1469
+ position_ids=position_ids,
1470
+ past_key_values=past_key_values,
1471
+ inputs_embeds=inputs_embeds,
1472
+ use_cache=use_cache,
1473
+ output_attentions=output_attentions,
1474
+ output_hidden_states=output_hidden_states,
1475
+ return_dict=return_dict,
1476
+ cache_position=cache_position,
1477
+ )
1478
+
1479
+ hidden_states = outputs[0]
1480
+
1481
+ '''更改 loss 计算方式'''
1482
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
1483
+
1484
+ '''
1485
+ fuse_linear_and_cross_entropy为False时才计算 logits,
1486
+ 也就是 config.fuse_cross_entropy 设置为 false, 或者模型不是 model.train() 的时候计算 logits
1487
+ 计算逻辑(else以后的部分)没有更改
1488
+
1489
+ ############################
1490
+ hidden_states: (bs, seq_len, d_model)
1491
+ logits(hidden_states经过了最后的线性层): (bs, seq_len, vocab_size)
1492
+ labels: (bs, seq_len)
1493
+ ############################
1494
+ '''
1495
+
1496
+ if fuse_linear_and_cross_entropy:
1497
+ logits = None
1498
+ else: # 需要改回来!
1499
+ # print("这里计算了 logits")
1500
+ '''这里仍然保留浦语官方的计算方式'''
1501
+ if self.config.pretraining_tp > 1:
1502
+ output_slices = self.output.weight.split(
1503
+ self.vocab_size // self.config.pretraining_tp, dim=0)
1504
+ logits = [
1505
+ F.linear(hidden_states,
1506
+ output_slices[i]) # pylint: disable=not-callable
1507
+ for i in range(self.config.pretraining_tp)
1508
+ ]
1509
+ logits = torch.cat(logits, dim=-1)
1510
+ else:
1511
+ logits = self.output(hidden_states)
1512
+ logits = logits.float()
1513
+
1514
+ loss = loss_tuple = None
1515
+ if labels is not None:
1516
+ # if self.config.fuse_cross_entropy:
1517
+ # if fuse_linear_and_cross_entropy:
1518
+ # # 这里必须要使用 sum 而不能用 mean, 因为 fla 的算子中:
1519
+ # # For 'mean' reduction, gradients are normalized by number of *non-ignored* elements
1520
+ # # 此时传出来的 unreduced_loss 也是不符合要求的, unreduced_loss.sum() == mean_loss,
1521
+ # # mean_loss = sum_loss / num_non_ignored_tokens, instead of all tokens
1522
+ # # unreduced_loss: w/ scaled
1523
+ # # 对于 Diffusion 过程则需要除以总的样本数量,因此会有偏差
1524
+ # # 因此这里要使用 reduction=sum, 相对应的 unreduced_loss 是真正的每个token位置实际的 loss, w/o scaled
1525
+ # loss_fct = FusedLinearDiffusionCrossEntropyLoss(
1526
+ # reduction='sum')
1527
+ # else:
1528
+ # loss_fct = FusedCrossEntropyLoss(
1529
+ # reduction='sum', inplace_backward=True)
1530
+ # else:
1531
+ loss_fct = CrossEntropyLoss() # nn.CE
1532
+
1533
+ # you don't have to shift labels
1534
+ # labels = labels.to(hidden_states.device)
1535
+ # labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
1536
+ if fuse_linear_and_cross_entropy:
1537
+ loss_tuple = loss_fct(
1538
+ x=hidden_states.view(-1, self.config.hidden_size),
1539
+ target=labels.view(-1),
1540
+ weight=self.output.weight,
1541
+ bias=self.output.bias,
1542
+ p_mask=p_mask,
1543
+ )
1544
+ loss = loss_tuple
1545
+ else:
1546
+ loss = loss_fct(
1547
+ logits.view(-1, self.config.vocab_size), labels.view(-1))
1548
+ loss_tuple = loss
1549
+
1550
+ # 下面是官方代码,没有更改
1551
+ if not return_dict:
1552
+ output = (logits,) + outputs[1:]
1553
+ return (loss,) + output if loss is not None else output
1554
+
1555
+ return CausalLMOutputWithPast(
1556
+ loss=loss_tuple,
1557
+ logits=logits,
1558
+ past_key_values=outputs.past_key_values,
1559
+ hidden_states=outputs.hidden_states,
1560
+ attentions=outputs.attentions,
1561
+ )
1562
+
1563
+ def prepare_inputs_for_generation(
1564
+ self,
1565
+ input_ids,
1566
+ past_key_values=None,
1567
+ attention_mask=None,
1568
+ inputs_embeds=None,
1569
+ cache_position=None,
1570
+ use_cache=True,
1571
+ **kwargs,
1572
+ ):
1573
+ past_length = 0
1574
+ if past_key_values is not None:
1575
+ if isinstance(past_key_values, Cache):
1576
+ past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length(
1577
+ )
1578
+ max_cache_length = (
1579
+ torch.tensor(past_key_values.get_max_length(),
1580
+ device=input_ids.device)
1581
+ if past_key_values.get_max_length() is not None
1582
+ else None
1583
+ )
1584
+ cache_length = past_length if max_cache_length is None else torch.min(
1585
+ max_cache_length, past_length)
1586
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
1587
+ else:
1588
+ cache_length = past_length = past_key_values[0][0].shape[2]
1589
+ max_cache_length = None
1590
+
1591
+ # Keep only the unprocessed tokens:
1592
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1593
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
1594
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1595
+ input_ids = input_ids[:, -
1596
+ (attention_mask.shape[1] - past_length):]
1597
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1598
+ # input_ids based on the past_length.
1599
+ elif past_length < input_ids.shape[1]:
1600
+ input_ids = input_ids[:, past_length:]
1601
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1602
+
1603
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1604
+ if (
1605
+ max_cache_length is not None
1606
+ and attention_mask is not None
1607
+ and cache_length + input_ids.shape[1] > max_cache_length
1608
+ ):
1609
+ attention_mask = attention_mask[:, -
1610
+ max_cache_length:] # pylint: disable=E1130
1611
+
1612
+ position_ids = kwargs.get("position_ids", None)
1613
+ if attention_mask is not None and position_ids is None:
1614
+ # create position_ids on the fly for batch generation
1615
+ position_ids = attention_mask.long().cumsum(-1) - 1
1616
+ position_ids.masked_fill_(attention_mask == 0, 1)
1617
+ if past_key_values:
1618
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1619
+
1620
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1621
+ if inputs_embeds is not None and past_key_values is None:
1622
+ model_inputs = {"inputs_embeds": inputs_embeds}
1623
+ else:
1624
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1625
+ # recompiles graphs as the stride of the inputs is a guard.
1626
+ # Ref: https://github.com/huggingface/transformers/pull/29114
1627
+ # TODO: use `next_tokens` directly instead.
1628
+ model_inputs = {"input_ids": input_ids.contiguous()}
1629
+
1630
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
1631
+ if cache_position is None:
1632
+ cache_position = torch.arange(
1633
+ past_length, past_length + input_length, device=input_ids.device)
1634
+ elif use_cache:
1635
+ cache_position = cache_position[-input_length:]
1636
+
1637
+ model_inputs.update(
1638
+ {
1639
+ "position_ids": position_ids,
1640
+ "cache_position": cache_position,
1641
+ "past_key_values": past_key_values,
1642
+ "use_cache": use_cache,
1643
+ "attention_mask": attention_mask,
1644
+ }
1645
+ )
1646
+ return model_inputs
1647
+
1648
+ @staticmethod
1649
+ def _reorder_cache(past_key_values, beam_idx):
1650
+ reordered_past = ()
1651
+ for layer_past in past_key_values:
1652
+ reordered_past += (
1653
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device))
1654
+ for past_state in layer_past),
1655
+ )
1656
+ return reordered_past
1657
+
1658
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, meta_instruction=""):
1659
+ if history is None:
1660
+ history = []
1661
+ if tokenizer.add_bos_token:
1662
+ prompt = ""
1663
+ else:
1664
+ prompt = tokenizer.bos_token
1665
+ if meta_instruction:
1666
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1667
+ for record in history:
1668
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1669
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1670
+ return tokenizer([prompt], return_tensors="pt")
1671
+
1672
+ @torch.no_grad()
1673
+ def chat(
1674
+ self,
1675
+ tokenizer,
1676
+ query: str,
1677
+ history: Optional[List[Tuple[str, str]]] = None,
1678
+ streamer: Optional[BaseStreamer] = None,
1679
+ max_new_tokens: int = 1024,
1680
+ do_sample: bool = True,
1681
+ temperature: float = 0.8,
1682
+ top_p: float = 0.8,
1683
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1684
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
1685
+ "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1686
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such "
1687
+ "as English and 中文.",
1688
+ **kwargs,
1689
+ ):
1690
+ if history is None:
1691
+ history = []
1692
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1693
+ inputs = {k: v.to(self.device)
1694
+ for k, v in inputs.items() if torch.is_tensor(v)}
1695
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1696
+ eos_token_id = [tokenizer.eos_token_id,
1697
+ tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
1698
+ outputs = self.generate(
1699
+ **inputs,
1700
+ streamer=streamer,
1701
+ max_new_tokens=max_new_tokens,
1702
+ do_sample=do_sample,
1703
+ temperature=temperature,
1704
+ top_p=top_p,
1705
+ eos_token_id=eos_token_id,
1706
+ **kwargs,
1707
+ )
1708
+ outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
1709
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
1710
+ response = response.split("<|im_end|>")[0]
1711
+ history = history + [(query, response)]
1712
+ return response, history
1713
+
1714
+ @torch.no_grad()
1715
+ def stream_chat(
1716
+ self,
1717
+ tokenizer,
1718
+ query: str,
1719
+ history: List[Tuple[str, str]] = None,
1720
+ max_new_tokens: int = 1024,
1721
+ do_sample: bool = True,
1722
+ temperature: float = 0.8,
1723
+ top_p: float = 0.8,
1724
+ **kwargs,
1725
+ ):
1726
+ if history is None:
1727
+ history = []
1728
+ """
1729
+ Return a generator in format: (response, history)
1730
+ Eg.
1731
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1732
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1733
+ """
1734
+ if BaseStreamer is None:
1735
+ raise ModuleNotFoundError(
1736
+ "The version of `transformers` is too low. Please make sure "
1737
+ "that you have installed `transformers>=4.28.0`."
1738
+ )
1739
+
1740
+ response_queue = queue.Queue(maxsize=20)
1741
+
1742
+ class ChatStreamer(BaseStreamer):
1743
+ """
1744
+ Streamer used in generate to print words one by one.
1745
+ """
1746
+
1747
+ def __init__(self, tokenizer) -> None:
1748
+ super().__init__()
1749
+ self.tokenizer = tokenizer
1750
+ self.queue = response_queue
1751
+ self.query = query
1752
+ self.history = history
1753
+ self.response = ""
1754
+ self.cache = []
1755
+ self.received_inputs = False
1756
+ self.queue.put((self.response, history +
1757
+ [(self.query, self.response)]))
1758
+
1759
+ def put(self, value):
1760
+ if len(value.shape) > 1 and value.shape[0] > 1:
1761
+ raise ValueError("ChatStreamer only supports batch size 1")
1762
+ elif len(value.shape) > 1:
1763
+ value = value[0]
1764
+
1765
+ if not self.received_inputs:
1766
+ # The first received value is input_ids, ignore here
1767
+ self.received_inputs = True
1768
+ return
1769
+
1770
+ self.cache.extend(value.tolist())
1771
+ token = self.tokenizer.decode(
1772
+ self.cache, skip_special_tokens=True)
1773
+ if token.strip() != "<|im_end|>":
1774
+ self.response = self.response + token
1775
+ history = self.history + [(self.query, self.response)]
1776
+ self.queue.put((self.response, history))
1777
+ self.cache = []
1778
+ else:
1779
+ self.end()
1780
+
1781
+ def end(self):
1782
+ self.queue.put(None)
1783
+
1784
+ def stream_producer():
1785
+ return self.chat(
1786
+ tokenizer=tokenizer,
1787
+ query=query,
1788
+ streamer=ChatStreamer(tokenizer=tokenizer),
1789
+ history=history,
1790
+ max_new_tokens=max_new_tokens,
1791
+ do_sample=do_sample,
1792
+ temperature=temperature,
1793
+ top_p=top_p,
1794
+ **kwargs,
1795
+ )
1796
+
1797
+ def consumer():
1798
+ producer = threading.Thread(target=stream_producer)
1799
+ producer.start()
1800
+ while True:
1801
+ res = response_queue.get()
1802
+ if res is None:
1803
+ return
1804
+ yield res
1805
+
1806
+ return consumer()
1807
+
1808
+
1809
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1810
+ @add_start_docstrings(
1811
+ """
1812
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1813
+
1814
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1815
+ (e.g. GPT-2) do.
1816
+
1817
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1818
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1819
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1820
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1821
+ each row of the batch).
1822
+ """,
1823
+ InternLM2_START_DOCSTRING,
1824
+ )
1825
+ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1826
+ """Sequence Classification Head for InternLM2 Model."""
1827
+
1828
+ def __init__(self, config):
1829
+ super().__init__(config)
1830
+ self.num_labels = config.num_labels
1831
+ self.model = InternLM2Model(config)
1832
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1833
+
1834
+ # Initialize weights and apply final processing
1835
+ self.post_init()
1836
+
1837
+ def get_input_embeddings(self):
1838
+ return self.model.tok_embeddings
1839
+
1840
+ def set_input_embeddings(self, value):
1841
+ self.model.tok_embeddings = value
1842
+
1843
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1844
+ def forward(
1845
+ self,
1846
+ input_ids: torch.LongTensor = None,
1847
+ attention_mask: Optional[torch.Tensor] = None,
1848
+ position_ids: Optional[torch.LongTensor] = None,
1849
+ past_key_values: Optional[Union[Cache,
1850
+ List[torch.FloatTensor]]] = None,
1851
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1852
+ labels: Optional[torch.LongTensor] = None,
1853
+ use_cache: Optional[bool] = None,
1854
+ output_attentions: Optional[bool] = None,
1855
+ output_hidden_states: Optional[bool] = None,
1856
+ return_dict: Optional[bool] = None,
1857
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1858
+ r"""
1859
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1860
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1861
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1862
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1863
+ """
1864
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1865
+
1866
+ transformer_outputs = self.model(
1867
+ input_ids,
1868
+ attention_mask=attention_mask,
1869
+ position_ids=position_ids,
1870
+ past_key_values=past_key_values,
1871
+ inputs_embeds=inputs_embeds,
1872
+ use_cache=use_cache,
1873
+ output_attentions=output_attentions,
1874
+ output_hidden_states=output_hidden_states,
1875
+ return_dict=return_dict,
1876
+ )
1877
+ hidden_states = transformer_outputs[0]
1878
+ logits = self.score(hidden_states)
1879
+
1880
+ if input_ids is not None:
1881
+ batch_size = input_ids.shape[0]
1882
+ else:
1883
+ batch_size = inputs_embeds.shape[0]
1884
+
1885
+ if self.config.pad_token_id is None and batch_size != 1:
1886
+ raise ValueError(
1887
+ "Cannot handle batch sizes > 1 if no padding token is defined.")
1888
+ if self.config.pad_token_id is None:
1889
+ sequence_lengths = -1
1890
+ else:
1891
+ if input_ids is not None:
1892
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1893
+ sequence_lengths = torch.eq(
1894
+ input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1895
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1896
+ sequence_lengths = sequence_lengths.to(logits.device)
1897
+ else:
1898
+ sequence_lengths = -1
1899
+
1900
+ pooled_logits = logits[torch.arange(
1901
+ batch_size, device=logits.device), sequence_lengths]
1902
+
1903
+ loss = None
1904
+ if labels is not None:
1905
+ labels = labels.to(logits.device)
1906
+ if self.config.problem_type is None:
1907
+ if self.num_labels == 1:
1908
+ self.config.problem_type = "regression"
1909
+ elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
1910
+ self.config.problem_type = "single_label_classification"
1911
+ else:
1912
+ self.config.problem_type = "multi_label_classification"
1913
+
1914
+ if self.config.problem_type == "regression":
1915
+ loss_fct = MSELoss()
1916
+ if self.num_labels == 1:
1917
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1918
+ else:
1919
+ loss = loss_fct(pooled_logits, labels)
1920
+ elif self.config.problem_type == "single_label_classification":
1921
+ loss_fct = CrossEntropyLoss()
1922
+ loss = loss_fct(
1923
+ pooled_logits.view(-1, self.num_labels), labels.view(-1))
1924
+ elif self.config.problem_type == "multi_label_classification":
1925
+ loss_fct = BCEWithLogitsLoss()
1926
+ loss = loss_fct(pooled_logits, labels)
1927
+ if not return_dict:
1928
+ output = (pooled_logits,) + transformer_outputs[1:]
1929
+ return ((loss,) + output) if loss is not None else output
1930
+
1931
+ return SequenceClassifierOutputWithPast(
1932
+ loss=loss,
1933
+ logits=pooled_logits,
1934
+ past_key_values=transformer_outputs.past_key_values,
1935
+ hidden_states=transformer_outputs.hidden_states,
1936
+ attentions=transformer_outputs.attentions,
1937
+ )
1938
+
1939
+
1940
+ # Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->InternLM2
1941
+ @add_start_docstrings(
1942
+ """
1943
+ The InternLM2 Model transformer with a span classification head on top for extractive question-answering tasks like
1944
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1945
+ """,
1946
+ InternLM2_START_DOCSTRING,
1947
+ )
1948
+ class InternLM2ForQuestionAnswering(InternLM2PreTrainedModel):
1949
+ """Question Answering model for InternLM2."""
1950
+
1951
+ base_model_prefix = "transformer"
1952
+
1953
+ def __init__(self, config):
1954
+ super().__init__(config)
1955
+ self.transformer = InternLM2Model(config)
1956
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1957
+
1958
+ # Initialize weights and apply final processing
1959
+ self.post_init()
1960
+
1961
+ def get_input_embeddings(self):
1962
+ return self.transformer.tok_embeddings
1963
+
1964
+ def set_input_embeddings(self, value):
1965
+ self.transformer.tok_embeddings = value
1966
+
1967
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1968
+ def forward(
1969
+ self,
1970
+ input_ids: Optional[torch.LongTensor] = None,
1971
+ attention_mask: Optional[torch.FloatTensor] = None,
1972
+ position_ids: Optional[torch.LongTensor] = None,
1973
+ past_key_values: Optional[Union[Cache,
1974
+ List[torch.FloatTensor]]] = None,
1975
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1976
+ start_positions: Optional[torch.LongTensor] = None,
1977
+ end_positions: Optional[torch.LongTensor] = None,
1978
+ output_attentions: Optional[bool] = None,
1979
+ output_hidden_states: Optional[bool] = None,
1980
+ return_dict: Optional[bool] = None,
1981
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1982
+ r"""
1983
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1984
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1985
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1986
+ are not taken into account for computing the loss.
1987
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1988
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1989
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1990
+ are not taken into account for computing the loss.
1991
+ """
1992
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1993
+
1994
+ outputs = self.transformer(
1995
+ input_ids,
1996
+ attention_mask=attention_mask,
1997
+ position_ids=position_ids,
1998
+ past_key_values=past_key_values,
1999
+ inputs_embeds=inputs_embeds,
2000
+ output_attentions=output_attentions,
2001
+ output_hidden_states=output_hidden_states,
2002
+ return_dict=return_dict,
2003
+ )
2004
+
2005
+ sequence_output = outputs[0]
2006
+
2007
+ logits = self.qa_outputs(sequence_output)
2008
+ start_logits, end_logits = logits.split(1, dim=-1)
2009
+ start_logits = start_logits.squeeze(-1).contiguous()
2010
+ end_logits = end_logits.squeeze(-1).contiguous()
2011
+
2012
+ total_loss = None
2013
+ if start_positions is not None and end_positions is not None:
2014
+ # If we are on multi-GPU, split add a dimension
2015
+ if len(start_positions.size()) > 1:
2016
+ start_positions = start_positions.squeeze(
2017
+ -1).to(start_logits.device)
2018
+ if len(end_positions.size()) > 1:
2019
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
2020
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
2021
+ ignored_index = start_logits.size(1)
2022
+ start_positions = start_positions.clamp(0, ignored_index)
2023
+ end_positions = end_positions.clamp(0, ignored_index)
2024
+
2025
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
2026
+ start_loss = loss_fct(start_logits, start_positions)
2027
+ end_loss = loss_fct(end_logits, end_positions)
2028
+ total_loss = (start_loss + end_loss) / 2
2029
+
2030
+ if not return_dict:
2031
+ output = (start_logits, end_logits) + outputs[2:]
2032
+ return ((total_loss,) + output) if total_loss is not None else output
2033
+
2034
+ return QuestionAnsweringModelOutput(
2035
+ loss=total_loss,
2036
+ start_logits=start_logits,
2037
+ end_logits=end_logits,
2038
+ hidden_states=outputs.hidden_states,
2039
+ attentions=outputs.attentions,
2040
+ )
2041
+
2042
+
2043
+ # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->InternLM2
2044
+ @add_start_docstrings(
2045
+ """
2046
+ The InternLM2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
2047
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
2048
+ """,
2049
+ InternLM2_START_DOCSTRING,
2050
+ )
2051
+ class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
2052
+ """Token classification model for InternLM2."""
2053
+
2054
+ def __init__(self, config):
2055
+ super().__init__(config)
2056
+ self.num_labels = config.num_labels
2057
+ self.model = InternLM2Model(config)
2058
+ if getattr(config, "classifier_dropout", None) is not None:
2059
+ classifier_dropout = config.classifier_dropout
2060
+ elif getattr(config, "hidden_dropout", None) is not None:
2061
+ classifier_dropout = config.hidden_dropout
2062
+ else:
2063
+ classifier_dropout = 0.1
2064
+ self.dropout = nn.Dropout(classifier_dropout)
2065
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
2066
+
2067
+ # Initialize weights and apply final processing
2068
+ self.post_init()
2069
+
2070
+ def get_input_embeddings(self):
2071
+ return self.model.tok_embeddings
2072
+
2073
+ def set_input_embeddings(self, value):
2074
+ self.model.tok_embeddings = value
2075
+
2076
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
2077
+ def forward(
2078
+ self,
2079
+ input_ids: torch.LongTensor = None,
2080
+ attention_mask: Optional[torch.Tensor] = None,
2081
+ position_ids: Optional[torch.LongTensor] = None,
2082
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
2083
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2084
+ labels: Optional[torch.LongTensor] = None,
2085
+ use_cache: Optional[bool] = None,
2086
+ output_attentions: Optional[bool] = None,
2087
+ output_hidden_states: Optional[bool] = None,
2088
+ return_dict: Optional[bool] = None,
2089
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
2090
+ r"""
2091
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
2092
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
2093
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
2094
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
2095
+ """
2096
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2097
+
2098
+ outputs = self.model(
2099
+ input_ids,
2100
+ attention_mask=attention_mask,
2101
+ position_ids=position_ids,
2102
+ past_key_values=past_key_values,
2103
+ inputs_embeds=inputs_embeds,
2104
+ use_cache=use_cache,
2105
+ output_attentions=output_attentions,
2106
+ output_hidden_states=output_hidden_states,
2107
+ return_dict=return_dict,
2108
+ )
2109
+ sequence_output = outputs[0]
2110
+ sequence_output = self.dropout(sequence_output)
2111
+ logits = self.score(sequence_output)
2112
+
2113
+ loss = None
2114
+ if labels is not None:
2115
+ loss_fct = CrossEntropyLoss()
2116
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
2117
+
2118
+ if not return_dict:
2119
+ output = (logits,) + outputs[2:]
2120
+ return ((loss,) + output) if loss is not None else output
2121
+
2122
+ return TokenClassifierOutput(
2123
+ loss=loss,
2124
+ logits=logits,
2125
+ hidden_states=outputs.hidden_states,
2126
+ attentions=outputs.attentions,
2127
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|action_start|>",
6
+ "<|action_end|>",
7
+ "<|interpreter|>",
8
+ "<|plugin|>",
9
+ "<restate>",
10
+ "</restate>",
11
+ "<planning>",
12
+ "</planning>",
13
+ "<recollect>",
14
+ "</recollect>",
15
+ "<execution>",
16
+ "</execution>",
17
+ "<review>",
18
+ "</review>",
19
+ "<summarize>",
20
+ "</summarize>",
21
+ "<retry>",
22
+ "</retry>",
23
+ "<conclude>",
24
+ "</conclude>",
25
+ "<MASK>",
26
+ "<think>",
27
+ "</think>"
28
+ ],
29
+ "bos_token": {
30
+ "content": "<s>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ },
36
+ "eos_token": {
37
+ "content": "</s>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false
42
+ },
43
+ "pad_token": {
44
+ "content": "</s>",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false
49
+ },
50
+ "unk_token": {
51
+ "content": "<unk>",
52
+ "lstrip": false,
53
+ "normalized": false,
54
+ "rstrip": false,
55
+ "single_word": false
56
+ }
57
+ }
tokenization_internlm3.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
4
+
5
+ import sentencepiece as spm
6
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
7
+ from transformers.utils import logging
8
+
9
+ if TYPE_CHECKING:
10
+ from transformers.tokenization_utils_base import TextInput
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
15
+
16
+ SPIECE_UNDERLINE = "▁"
17
+
18
+
19
+ class InternLM3Tokenizer(PreTrainedTokenizer):
20
+ """
21
+ Construct a InternLM3 tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
22
+ no padding token in the original model.
23
+
24
+ Args:
25
+ vocab_file (`str`):
26
+ Path to the vocabulary file.
27
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
28
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
29
+ token instead.
30
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
31
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
32
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
33
+ The end of sequence token.
34
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*):
35
+ A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
36
+ attention mechanisms or loss computation.
37
+ sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
38
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
39
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
40
+ to set:
41
+
42
+ - `enable_sampling`: Enable subword regularization.
43
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
44
+
45
+ - `nbest_size = {0,1}`: No sampling is performed.
46
+ - `nbest_size > 1`: samples from the nbest_size results.
47
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
48
+ using forward-filtering-and-backward-sampling algorithm.
49
+
50
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
51
+ BPE-dropout.
52
+
53
+ add_bos_token (`bool`, *optional*, defaults to `True`):
54
+ Whether or not to add an `bos_token` at the start of sequences.
55
+ add_eos_token (`bool`, *optional*, defaults to `False`):
56
+ Whether or not to add an `eos_token` at the end of sequences.
57
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
58
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
59
+ extra spaces.
60
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
61
+ Whether or not the default system prompt for InternLM3 should be used.
62
+ spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
63
+ Whether or not to add spaces between special tokens.
64
+ spaces_for_interleaved_special_tokens (`bool`, *optional*, defaults to `False`):
65
+ Whether or not to add spaces between special tokens that are interleaved with normal tokens.
66
+ add_prefix_space (`bool`, *optional*, defaults to `True`):
67
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
68
+ other word. Again, this should be set with `from_slow=True` to make sure it's taken into account.
69
+ """
70
+
71
+ vocab_files_names = VOCAB_FILES_NAMES
72
+ model_input_names = ["input_ids", "attention_mask"]
73
+
74
+ def __init__(
75
+ self,
76
+ vocab_file,
77
+ unk_token="<unk>",
78
+ bos_token="<s>",
79
+ eos_token="</s>",
80
+ pad_token=None,
81
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
82
+ add_bos_token=True,
83
+ add_eos_token=False,
84
+ clean_up_tokenization_spaces=False,
85
+ use_default_system_prompt=False,
86
+ spaces_between_special_tokens=False,
87
+ spaces_for_interleaved_special_tokens=False,
88
+ add_prefix_space=True,
89
+ **kwargs,
90
+ ):
91
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
92
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
93
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
94
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
95
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
96
+
97
+ self.vocab_file = vocab_file
98
+ self.add_bos_token = add_bos_token
99
+ self.add_eos_token = add_eos_token
100
+ self.use_default_system_prompt = use_default_system_prompt
101
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
102
+ self.sp_model.Load(vocab_file)
103
+ self.add_prefix_space = add_prefix_space
104
+ self.spaces_for_interleaved_special_tokens = spaces_for_interleaved_special_tokens
105
+
106
+ vocab_size = self.sp_model.get_piece_size()
107
+ self.decoder = {i: self.sp_model.id_to_piece(i) for i in range(vocab_size)}
108
+
109
+ super().__init__(
110
+ bos_token=bos_token,
111
+ eos_token=eos_token,
112
+ unk_token=unk_token,
113
+ pad_token=pad_token,
114
+ add_bos_token=add_bos_token,
115
+ add_eos_token=add_eos_token,
116
+ sp_model_kwargs=sp_model_kwargs,
117
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
118
+ use_default_system_prompt=use_default_system_prompt,
119
+ spaces_between_special_tokens=spaces_between_special_tokens,
120
+ add_prefix_space=add_prefix_space,
121
+ **kwargs,
122
+ )
123
+
124
+ def __getstate__(self):
125
+ state = self.__dict__.copy()
126
+ state["sp_model"] = None
127
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
128
+ return state
129
+
130
+ def __setstate__(self, d):
131
+ self.__dict__.update(d)
132
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
133
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
134
+
135
+ @property
136
+ def vocab_size(self):
137
+ """Returns vocab size"""
138
+ return self.sp_model.get_piece_size()
139
+
140
+ def get_vocab(self):
141
+ """Returns vocab as a dict"""
142
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
143
+ vocab.update(self.added_tokens_encoder)
144
+ return vocab
145
+
146
+ def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
147
+ """
148
+ Args:
149
+ text: TextInput
150
+ Simply calls PreTrainedTokenizer's method
151
+ """
152
+ return super().tokenize(text, **kwargs)
153
+
154
+ def _tokenize(self, text, **kwargs):
155
+ """
156
+ Args:
157
+ text: TextInput
158
+ Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
159
+ """
160
+ return self.sp_model.encode(text, out_type=str)
161
+
162
+ def _convert_token_to_id(self, token):
163
+ """Converts a token (str) in an id using the vocab."""
164
+ return self.sp_model.piece_to_id(token)
165
+
166
+ def _convert_id_to_token(self, index):
167
+ """Converts an index (integer) in a token (str) using the vocab."""
168
+ return self.decoder.get(index, "")
169
+
170
+ def convert_tokens_to_string(self, tokens):
171
+ """Converts a sequence of tokens (string) in a single string."""
172
+ # since we manually add the prefix space, we have to remove it when decoding
173
+ if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space:
174
+ tokens[0] = tokens[0][1:]
175
+
176
+ current_sub_tokens = []
177
+ out_string = ""
178
+ prev_is_special = False
179
+ for i, token in enumerate(tokens):
180
+ # make sure that special tokens are not decoded using sentencepiece model
181
+ if token in self.all_special_tokens:
182
+ if not prev_is_special and i != 0 and self.spaces_for_interleaved_special_tokens:
183
+ out_string += " "
184
+ out_string += self.sp_model.decode(current_sub_tokens) + token
185
+ prev_is_special = True
186
+ current_sub_tokens = []
187
+ else:
188
+ if (
189
+ prev_is_special
190
+ and i == 1
191
+ and self.add_prefix_space
192
+ and not token.startswith(SPIECE_UNDERLINE)
193
+ and self.spaces_for_interleaved_special_tokens
194
+ ):
195
+ out_string += " "
196
+ current_sub_tokens.append(token)
197
+ prev_is_special = False
198
+ out_string += self.sp_model.decode(current_sub_tokens)
199
+ return out_string
200
+
201
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
202
+ """
203
+ Save the vocabulary and special tokens file to a directory.
204
+
205
+ Args:
206
+ save_directory (`str`):
207
+ The directory in which to save the vocabulary.
208
+
209
+ Returns:
210
+ `Tuple(str)`: Paths to the files saved.
211
+ """
212
+ if not os.path.isdir(save_directory):
213
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
214
+ return
215
+ out_vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"])
216
+
217
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
218
+ copyfile(self.vocab_file, out_vocab_file)
219
+ elif not os.path.isfile(self.vocab_file):
220
+ with open(out_vocab_file, "wb") as fi:
221
+ content_spiece_model = self.sp_model.serialized_model_proto()
222
+ fi.write(content_spiece_model)
223
+
224
+ return (out_vocab_file,)
225
+
226
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
227
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
228
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
229
+
230
+ output = bos_token_id + token_ids_0 + eos_token_id
231
+
232
+ if token_ids_1 is not None:
233
+ output = output + bos_token_id + token_ids_1 + eos_token_id
234
+
235
+ return output
236
+
237
+ def get_special_tokens_mask(
238
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
239
+ ) -> List[int]:
240
+ """
241
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
242
+ special tokens using the tokenizer `prepare_for_model` method.
243
+
244
+ Args:
245
+ token_ids_0 (`List[int]`):
246
+ List of IDs.
247
+ token_ids_1 (`List[int]`, *optional*):
248
+ Optional second list of IDs for sequence pairs.
249
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
250
+ Whether or not the token list is already formatted with special tokens for the model.
251
+
252
+ Returns:
253
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
254
+ """
255
+ if already_has_special_tokens:
256
+ return super().get_special_tokens_mask(token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True)
257
+
258
+ bos_token_id = [1] if self.add_bos_token else []
259
+ eos_token_id = [1] if self.add_eos_token else []
260
+
261
+ if token_ids_1 is None:
262
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
263
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + bos_token_id + ([0] * len(token_ids_1)) + eos_token_id
264
+
265
+ def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
266
+ """
267
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
268
+ sequence pair mask has the following format:
269
+
270
+ ```
271
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
272
+ | first sequence | second sequence |
273
+ ```
274
+
275
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
276
+
277
+ Args:
278
+ token_ids_0 (`List[int]`):
279
+ List of ids.
280
+ token_ids_1 (`List[int]`, *optional*):
281
+ Optional second list of IDs for sequence pairs.
282
+
283
+ Returns:
284
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
285
+ """
286
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
287
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
288
+
289
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
290
+
291
+ if token_ids_1 is not None:
292
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
293
+
294
+ return output
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc08ac7d6a2c6183ccc63f98a90ec6a04e30249b831d2f30773b8e1f89b32c6b
3
+ size 2474916
tokenizer_config.json ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": true,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "128108": {
31
+ "content": "<MASK>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "128109": {
39
+ "content": "<think>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "128110": {
47
+ "content": "</think>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "128111": {
55
+ "content": "<restate>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "128112": {
63
+ "content": "</restate>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "128113": {
71
+ "content": "<planning>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "128114": {
79
+ "content": "</planning>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "128115": {
87
+ "content": "<recollect>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "128116": {
95
+ "content": "</recollect>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "128117": {
103
+ "content": "<execution>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ },
110
+ "128118": {
111
+ "content": "</execution>",
112
+ "lstrip": false,
113
+ "normalized": false,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": true
117
+ },
118
+ "128119": {
119
+ "content": "<review>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": true
125
+ },
126
+ "128120": {
127
+ "content": "</review>",
128
+ "lstrip": false,
129
+ "normalized": false,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": true
133
+ },
134
+ "128121": {
135
+ "content": "<summarize>",
136
+ "lstrip": false,
137
+ "normalized": false,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": true
141
+ },
142
+ "128122": {
143
+ "content": "</summarize>",
144
+ "lstrip": false,
145
+ "normalized": false,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": true
149
+ },
150
+ "128123": {
151
+ "content": "<retry>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": true
157
+ },
158
+ "128124": {
159
+ "content": "</retry>",
160
+ "lstrip": false,
161
+ "normalized": false,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": true
165
+ },
166
+ "128125": {
167
+ "content": "<conclude>",
168
+ "lstrip": false,
169
+ "normalized": false,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": true
173
+ },
174
+ "128126": {
175
+ "content": "</conclude>",
176
+ "lstrip": false,
177
+ "normalized": false,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": true
181
+ },
182
+ "128127": {
183
+ "content": "<|plugin|>",
184
+ "lstrip": false,
185
+ "normalized": false,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": true
189
+ },
190
+ "128128": {
191
+ "content": "<|interpreter|>",
192
+ "lstrip": false,
193
+ "normalized": false,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": true
197
+ },
198
+ "128129": {
199
+ "content": "<|action_end|>",
200
+ "lstrip": false,
201
+ "normalized": false,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": true
205
+ },
206
+ "128130": {
207
+ "content": "<|action_start|>",
208
+ "lstrip": false,
209
+ "normalized": false,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": true
213
+ },
214
+ "128131": {
215
+ "content": "<|im_end|>",
216
+ "lstrip": false,
217
+ "normalized": false,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": true
221
+ },
222
+ "128132": {
223
+ "content": "<|im_start|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ }
230
+ },
231
+ "additional_special_tokens": [
232
+ "<|im_start|>",
233
+ "<|im_end|>",
234
+ "<|action_start|>",
235
+ "<|action_end|>",
236
+ "<|interpreter|>",
237
+ "<|plugin|>",
238
+ "<restate>",
239
+ "</restate>",
240
+ "<planning>",
241
+ "</planning>",
242
+ "<recollect>",
243
+ "</recollect>",
244
+ "<execution>",
245
+ "</execution>",
246
+ "<review>",
247
+ "</review>",
248
+ "<summarize>",
249
+ "</summarize>",
250
+ "<retry>",
251
+ "</retry>",
252
+ "<conclude>",
253
+ "</conclude>",
254
+ "<MASK>",
255
+ "<think>",
256
+ "</think>"
257
+ ],
258
+ "auto_map": {
259
+ "AutoTokenizer": [
260
+ "tokenization_internlm3.InternLM3Tokenizer",
261
+ null
262
+ ]
263
+ },
264
+ "bos_token": "<s>",
265
+ "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
266
+ "clean_up_tokenization_spaces": false,
267
+ "eos_token": "</s>",
268
+ "extra_special_tokens": {},
269
+ "model_max_length": 1000000000000000019884624838656,
270
+ "pad_token": "</s>",
271
+ "padding_side": "right",
272
+ "sp_model_kwargs": {},
273
+ "spaces_between_special_tokens": false,
274
+ "tokenizer_class": "InternLM3Tokenizer",
275
+ "unk_token": "<unk>",
276
+ "use_default_system_prompt": false
277
+ }