alibayram commited on
Commit
0ae8769
·
verified ·
1 Parent(s): 2228f0a

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_deepseek.py +105 -214
config.json CHANGED
@@ -34,7 +34,7 @@
34
  "tie_word_embeddings": false,
35
  "torch_dtype": "float32",
36
  "auto_map": {
37
- "AutoConfig": "configuration_deepseek.DeepSeekConfig",
38
  "AutoModel": "modeling_deepseek.DeepSeekModel",
39
  "AutoModelForCausalLM": "modeling_deepseek.DeepSeekForCausalLM"
40
  }
 
34
  "tie_word_embeddings": false,
35
  "torch_dtype": "float32",
36
  "auto_map": {
37
+ "AutoConfig": "modeling_deepseek.DeepSeekConfig",
38
  "AutoModel": "modeling_deepseek.DeepSeekModel",
39
  "AutoModelForCausalLM": "modeling_deepseek.DeepSeekForCausalLM"
40
  }
modeling_deepseek.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- PyTorch DeepSeek model.
3
  """
4
 
5
  import math
@@ -8,10 +8,10 @@ from typing import List, Optional, Tuple, Union
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- from configuration_deepseek import DeepSeekConfig
12
  from torch.nn import CrossEntropyLoss
13
  from transformers.activations import ACT2FN
14
  from transformers.cache_utils import Cache, DynamicCache
 
15
  from transformers.modeling_attn_mask_utils import (
16
  AttentionMaskConverter, _prepare_4d_attention_mask,
17
  _prepare_4d_causal_attention_mask)
@@ -31,6 +31,94 @@ if is_flash_attn_2_available():
31
 
32
  logger = logging.get_logger(__name__)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  _CONFIG_FOR_DOC = "DeepSeekConfig"
35
 
36
 
@@ -97,129 +185,6 @@ class DeepSeekRMSNorm(nn.Module):
97
  return self.weight * hidden_states.to(input_dtype)
98
 
99
 
100
- class DeepSeekMLA(nn.Module):
101
- """Multi-head Latent Attention (MLA) module."""
102
-
103
- def __init__(self, config: DeepSeekConfig, layer_idx: Optional[int] = None):
104
- super().__init__()
105
- self.config = config
106
- self.layer_idx = layer_idx
107
-
108
- self.hidden_size = config.hidden_size
109
- self.num_heads = config.num_attention_heads
110
- self.head_dim = self.hidden_size // self.num_heads
111
- self.num_key_value_heads = config.num_attention_heads
112
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
113
- self.max_position_embeddings = config.max_position_embeddings
114
- self.rope_theta = config.rope_theta
115
- self.is_causal = True
116
-
117
- # MLA specific parameters
118
- self.q_lora_rank = config.q_lora_rank
119
- self.kv_lora_rank = config.kv_lora_rank
120
- self.qk_nope_head_dim = config.qk_nope_head_dim
121
- self.qk_rope_head_dim = config.qk_rope_head_dim
122
- self.v_head_dim = config.v_head_dim
123
- self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
124
-
125
- if self.q_lora_rank == 0:
126
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
127
- else:
128
- self.q_a_proj = nn.Linear(self.hidden_size, self.q_lora_rank, bias=False)
129
- self.q_a_layernorm = DeepSeekRMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
130
- self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
131
-
132
- self.kv_a_proj_with_mqa = nn.Linear(
133
- self.hidden_size,
134
- self.kv_lora_rank + self.qk_rope_head_dim,
135
- bias=False
136
- )
137
- self.kv_a_layernorm = DeepSeekRMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
138
- self.kv_b_proj = nn.Linear(
139
- self.kv_lora_rank,
140
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
141
- bias=False
142
- )
143
- self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=False)
144
-
145
- # Scaling
146
- self.scaling = self.qk_head_dim ** -0.5
147
- if config.max_position_embeddings > config.original_seq_len:
148
- mscale = 0.1 * config.mscale * math.log(config.rope_factor) + 1.0
149
- self.scaling = self.scaling * mscale * mscale
150
-
151
- def forward(
152
- self,
153
- hidden_states: torch.Tensor,
154
- attention_mask: Optional[torch.Tensor] = None,
155
- position_ids: Optional[torch.LongTensor] = None,
156
- past_key_value: Optional[Cache] = None,
157
- output_attentions: bool = False,
158
- use_cache: bool = False,
159
- cache_position: Optional[torch.LongTensor] = None,
160
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
161
-
162
- bsz, q_len, _ = hidden_states.size()
163
-
164
- # Query projection
165
- if self.q_lora_rank == 0:
166
- query_states = self.q_proj(hidden_states)
167
- else:
168
- query_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
169
-
170
- query_states = query_states.view(bsz, q_len, self.num_heads, self.qk_head_dim).transpose(1, 2)
171
-
172
- # Split query into no-position-encoding and position-encoding parts
173
- q_nope, q_pe = query_states.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
174
-
175
- # Key-Value projection
176
- kv_input = self.kv_a_proj_with_mqa(hidden_states)
177
- compressed_kv, k_pe = kv_input.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
178
-
179
- # Apply RoPE to position-encoding parts
180
- if position_ids is not None:
181
- cos, sin = self.rotary_emb(hidden_states, position_ids)
182
- q_pe = apply_rotary_pos_emb(q_pe, cos, sin)
183
- k_pe = apply_rotary_pos_emb(k_pe.unsqueeze(2), cos, sin).squeeze(2)
184
-
185
- # Compute key and value from compressed representation
186
- kv_b_weight = self.kv_b_proj.weight.view(
187
- self.num_heads, self.qk_nope_head_dim + self.v_head_dim, self.kv_lora_rank
188
- )
189
-
190
- # Project compressed KV to get keys and values
191
- compressed_kv = self.kv_a_layernorm(compressed_kv)
192
- key_states = torch.einsum('bld,hnd->bhln', compressed_kv, kv_b_weight[:, :self.qk_nope_head_dim, :])
193
- value_states = torch.einsum('bld,hnd->bhln', compressed_kv, kv_b_weight[:, -self.v_head_dim:, :])
194
-
195
- # Attention computation
196
- attn_weights = torch.matmul(q_nope, key_states.transpose(-2, -1)) * self.scaling
197
-
198
- # Add positional attention
199
- if k_pe is not None:
200
- pos_attn = torch.matmul(q_pe, k_pe.unsqueeze(1).transpose(-2, -1)) * self.scaling
201
- attn_weights = attn_weights + pos_attn
202
-
203
- if attention_mask is not None:
204
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
205
- attn_weights = attn_weights + causal_mask
206
-
207
- # Apply softmax
208
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
209
-
210
- # Apply attention to values
211
- attn_output = torch.matmul(attn_weights, value_states)
212
-
213
- attn_output = attn_output.transpose(1, 2).contiguous()
214
- attn_output = attn_output.reshape(bsz, q_len, -1)
215
- attn_output = self.o_proj(attn_output)
216
-
217
- if not output_attentions:
218
- attn_weights = None
219
-
220
- return attn_output, attn_weights, past_key_value
221
-
222
-
223
  class DeepSeekMLP(nn.Module):
224
  """Multi-Layer Perceptron for dense layers."""
225
 
@@ -238,23 +203,6 @@ class DeepSeekMLP(nn.Module):
238
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
239
 
240
 
241
- class DeepSeekExpert(nn.Module):
242
- """Single expert in MoE layer."""
243
-
244
- def __init__(self, config: DeepSeekConfig):
245
- super().__init__()
246
- self.hidden_size = config.hidden_size
247
- self.intermediate_size = config.moe_intermediate_size
248
-
249
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
250
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
251
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
252
- self.act_fn = ACT2FN["silu"]
253
-
254
- def forward(self, x):
255
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
256
-
257
-
258
  DEEPSEEK_START_DOCSTRING = r"""
259
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
260
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -301,83 +249,31 @@ class DeepSeekPreTrainedModel(PreTrainedModel):
301
  DEEPSEEK_INPUTS_DOCSTRING = r"""
302
  Args:
303
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
304
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
305
- it.
306
-
307
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
308
- [`PreTrainedTokenizer.__call__`] for details.
309
-
310
- [What are input IDs?](../glossary#input-ids)
311
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
312
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
313
-
314
- - 1 for tokens that are **not masked**,
315
- - 0 for tokens that are **masked**.
316
-
317
- [What are attention masks?](../glossary#attention-mask)
318
-
319
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
320
- [`PreTrainedTokenizer.__call__`] for details.
321
-
322
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
323
- `past_key_values`).
324
-
325
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
326
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
327
- information on the default strategy.
328
-
329
- - 1 indicates the head is **not masked**,
330
- - 0 indicates the head is **masked**.
331
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
332
- Indices of positions of each input sequence token in the position embeddings. Selected in the range `[0,
333
- config.n_positions - 1]`.
334
-
335
- [What are position IDs?](../glossary#position-ids)
336
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
337
- Pre-computed hidden-states (key and value in the self-attention blocks and in the cross-attention blocks)
338
- that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
339
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
340
-
341
- Two formats are allowed:
342
- - a [`~cache_utils.Cache`] instance;
343
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
344
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
345
- cache format.
346
-
347
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
348
- legacy cache format will be returned.
349
-
350
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
351
- have their past key/value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
352
- of shape `(batch_size, sequence_length)`.
353
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
354
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
355
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
356
- model's internal embedding lookup matrix.
357
  use_cache (`bool`, *optional*):
358
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
359
- `past_key_values`).
360
  output_attentions (`bool`, *optional*):
361
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
362
- tensors for more detail.
363
  output_hidden_states (`bool`, *optional*):
364
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
365
- more detail.
366
  return_dict (`bool`, *optional*):
367
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
368
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
369
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
370
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
371
- the complete sequence length.
372
  """
373
 
374
 
375
  class DeepSeekModel(DeepSeekPreTrainedModel):
376
  """
377
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepSeekDecoderLayer`]
378
-
379
- Args:
380
- config: DeepSeekConfig
381
  """
382
 
383
  def __init__(self, config: DeepSeekConfig):
@@ -386,7 +282,6 @@ class DeepSeekModel(DeepSeekPreTrainedModel):
386
  self.vocab_size = config.vocab_size
387
 
388
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
389
- # Note: We'll implement layers in a separate method due to complexity
390
  self.norm = DeepSeekRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
391
 
392
  self.gradient_checkpointing = False
@@ -549,23 +444,19 @@ class DeepSeekForCausalLM(DeepSeekPreTrainedModel):
549
  def prepare_inputs_for_generation(
550
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
551
  ):
552
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
553
- # Exception 1: when passing input_embeds, input_ids may be missing entries
554
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
555
  if past_key_values is not None:
556
- if inputs_embeds is not None: # Exception 1
557
  input_ids = input_ids[:, -cache_position.shape[0] :]
558
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
559
  input_ids = input_ids[:, cache_position]
560
 
561
  if attention_mask is not None and position_ids is None:
562
- # create position_ids on the fly for batch generation
563
  position_ids = attention_mask.long().cumsum(-1) - 1
564
  position_ids.masked_fill_(attention_mask == 0, 1)
565
  if past_key_values:
566
  position_ids = position_ids[:, -input_ids.shape[1] :]
567
 
568
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
569
  if inputs_embeds is not None and cache_position[0] == 0:
570
  model_inputs = {"inputs_embeds": inputs_embeds}
571
  else:
 
1
  """
2
+ PyTorch DeepSeek model - Standalone version for HuggingFace Hub
3
  """
4
 
5
  import math
 
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
 
11
  from torch.nn import CrossEntropyLoss
12
  from transformers.activations import ACT2FN
13
  from transformers.cache_utils import Cache, DynamicCache
14
+ from transformers.configuration_utils import PretrainedConfig
15
  from transformers.modeling_attn_mask_utils import (
16
  AttentionMaskConverter, _prepare_4d_attention_mask,
17
  _prepare_4d_causal_attention_mask)
 
31
 
32
  logger = logging.get_logger(__name__)
33
 
34
+
35
+ class DeepSeekConfig(PretrainedConfig):
36
+ """
37
+ Configuration class for DeepSeek model.
38
+ """
39
+ model_type = "deepseek"
40
+ keys_to_ignore_at_inference = ["past_key_values"]
41
+
42
+ def __init__(
43
+ self,
44
+ vocab_size=50256,
45
+ hidden_size=1024,
46
+ intermediate_size=4096,
47
+ moe_intermediate_size=704,
48
+ num_hidden_layers=6,
49
+ num_dense_layers=1,
50
+ num_attention_heads=8,
51
+ num_routed_experts=4,
52
+ num_shared_experts=2,
53
+ num_activated_experts=2,
54
+ num_expert_groups=1,
55
+ num_limited_groups=1,
56
+ max_position_embeddings=256,
57
+ max_batch_size=2,
58
+ q_lora_rank=0,
59
+ kv_lora_rank=256,
60
+ qk_nope_head_dim=64,
61
+ qk_rope_head_dim=32,
62
+ v_head_dim=64,
63
+ original_seq_len=512,
64
+ rope_theta=10000.0,
65
+ rope_factor=40,
66
+ beta_fast=32,
67
+ beta_slow=1,
68
+ mscale=1.0,
69
+ initializer_range=0.02,
70
+ rms_norm_eps=1e-3,
71
+ use_cache=True,
72
+ pad_token_id=0,
73
+ bos_token_id=2,
74
+ eos_token_id=3,
75
+ tie_word_embeddings=False,
76
+ output_attentions=False,
77
+ output_hidden_states=False,
78
+ use_return_dict=True,
79
+ **kwargs,
80
+ ):
81
+ self.vocab_size = vocab_size
82
+ self.hidden_size = hidden_size
83
+ self.intermediate_size = intermediate_size
84
+ self.moe_intermediate_size = moe_intermediate_size
85
+ self.num_hidden_layers = num_hidden_layers
86
+ self.num_dense_layers = num_dense_layers
87
+ self.num_attention_heads = num_attention_heads
88
+ self.num_routed_experts = num_routed_experts
89
+ self.num_shared_experts = num_shared_experts
90
+ self.num_activated_experts = num_activated_experts
91
+ self.num_expert_groups = num_expert_groups
92
+ self.num_limited_groups = num_limited_groups
93
+ self.max_position_embeddings = max_position_embeddings
94
+ self.max_batch_size = max_batch_size
95
+ self.q_lora_rank = q_lora_rank
96
+ self.kv_lora_rank = kv_lora_rank
97
+ self.qk_nope_head_dim = qk_nope_head_dim
98
+ self.qk_rope_head_dim = qk_rope_head_dim
99
+ self.v_head_dim = v_head_dim
100
+ self.original_seq_len = original_seq_len
101
+ self.rope_theta = rope_theta
102
+ self.rope_factor = rope_factor
103
+ self.beta_fast = beta_fast
104
+ self.beta_slow = beta_slow
105
+ self.mscale = mscale
106
+ self.initializer_range = initializer_range
107
+ self.rms_norm_eps = rms_norm_eps
108
+ self.use_cache = use_cache
109
+ self.output_attentions = output_attentions
110
+ self.output_hidden_states = output_hidden_states
111
+ self.use_return_dict = use_return_dict
112
+
113
+ super().__init__(
114
+ pad_token_id=pad_token_id,
115
+ bos_token_id=bos_token_id,
116
+ eos_token_id=eos_token_id,
117
+ tie_word_embeddings=tie_word_embeddings,
118
+ **kwargs,
119
+ )
120
+
121
+
122
  _CONFIG_FOR_DOC = "DeepSeekConfig"
123
 
124
 
 
185
  return self.weight * hidden_states.to(input_dtype)
186
 
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  class DeepSeekMLP(nn.Module):
189
  """Multi-Layer Perceptron for dense layers."""
190
 
 
203
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
204
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  DEEPSEEK_START_DOCSTRING = r"""
207
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
208
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
249
  DEEPSEEK_INPUTS_DOCSTRING = r"""
250
  Args:
251
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
252
+ Indices of input sequence tokens in the vocabulary.
 
 
 
 
 
 
253
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
254
+ Mask to avoid performing attention on padding token indices.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
256
+ Indices of positions of each input sequence token in the position embeddings.
 
 
 
257
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
258
+ Pre-computed hidden-states for sequential decoding.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
260
+ Optionally pass an embedded representation instead of input_ids.
 
 
261
  use_cache (`bool`, *optional*):
262
+ If set to `True`, `past_key_values` key value states are returned.
 
263
  output_attentions (`bool`, *optional*):
264
+ Whether or not to return the attentions tensors.
 
265
  output_hidden_states (`bool`, *optional*):
266
+ Whether or not to return the hidden states.
 
267
  return_dict (`bool`, *optional*):
268
+ Whether or not to return a [`~utils.ModelOutput`].
 
 
 
 
269
  """
270
 
271
 
272
  class DeepSeekModel(DeepSeekPreTrainedModel):
273
  """
274
+ Simplified DeepSeek Model for demonstration purposes.
275
+ Note: This is a simplified implementation that preserves the model structure
276
+ but may not have all the advanced MLA and MoE features of the full implementation.
 
277
  """
278
 
279
  def __init__(self, config: DeepSeekConfig):
 
282
  self.vocab_size = config.vocab_size
283
 
284
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
285
  self.norm = DeepSeekRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286
 
287
  self.gradient_checkpointing = False
 
444
  def prepare_inputs_for_generation(
445
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
446
  ):
447
+ # Standard implementation for generation
 
 
448
  if past_key_values is not None:
449
+ if inputs_embeds is not None:
450
  input_ids = input_ids[:, -cache_position.shape[0] :]
451
+ elif input_ids.shape[1] != cache_position.shape[0]:
452
  input_ids = input_ids[:, cache_position]
453
 
454
  if attention_mask is not None and position_ids is None:
 
455
  position_ids = attention_mask.long().cumsum(-1) - 1
456
  position_ids.masked_fill_(attention_mask == 0, 1)
457
  if past_key_values:
458
  position_ids = position_ids[:, -input_ids.shape[1] :]
459
 
 
460
  if inputs_embeds is not None and cache_position[0] == 0:
461
  model_inputs = {"inputs_embeds": inputs_embeds}
462
  else: