nvedant07 commited on
Commit
ab6170d
·
verified ·
1 Parent(s): fe389eb

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. config.py +51 -10
  2. model.py +110 -231
  3. norm.py +21 -0
  4. splitter.py +4 -10
  5. transformer_backbone.py +1553 -0
config.py CHANGED
@@ -23,6 +23,9 @@ class TransformerHATModelConfig(LlamaConfig):
23
  sliding_window: int | None = None,
24
  vocab_size: int = 0,
25
  hidden_act: str = "silu",
 
 
 
26
  **kwargs,
27
  ):
28
  super().__init__(
@@ -43,6 +46,9 @@ class TransformerHATModelConfig(LlamaConfig):
43
  )
44
 
45
  self.sliding_window = sliding_window
 
 
 
46
 
47
  def to_dict(self):
48
  config_dict = {
@@ -60,6 +66,9 @@ class TransformerHATModelConfig(LlamaConfig):
60
  "use_cache": self.use_cache,
61
  "sliding_window": self.sliding_window,
62
  "transformers_version": self.transformers_version,
 
 
 
63
  }
64
  return config_dict
65
 
@@ -74,6 +83,8 @@ class CrossAttentionConfig:
74
  num_attention_heads: int,
75
  attention_num_kv_heads: int,
76
  word_window_size: int,
 
 
77
  ):
78
  self.hidden_size = hidden_size
79
  self.hidden_size_q = hidden_size_q
@@ -81,6 +92,8 @@ class CrossAttentionConfig:
81
  self.num_attention_heads = num_attention_heads
82
  self.attention_num_kv_heads = attention_num_kv_heads
83
  self.word_window_size = word_window_size
 
 
84
 
85
  def to_dict(self):
86
  return {
@@ -90,6 +103,8 @@ class CrossAttentionConfig:
90
  "num_attention_heads": self.num_attention_heads,
91
  "attention_num_kv_heads": self.attention_num_kv_heads,
92
  "word_window_size": self.word_window_size,
 
 
93
  }
94
 
95
 
@@ -158,17 +173,19 @@ class EncoderHATModelConfig(TransformerHATModelConfig):
158
 
159
  @dataclass
160
  class HATArchitectureConfig(PretrainedConfig):
161
- model_type: str
162
 
163
  def __init__(
164
  self,
165
- special_token_dict : dict | None = None,
166
  encoder_config: EncoderHATModelConfig | None = None,
167
  backbone_config: TransformerHATModelConfig | None = None,
168
  decoder_config: DecoderHATModelConfig | None = None,
169
  model_type: str = "hierarchical_autoregressive_transformer",
170
  eos_token_id: int = 192,
171
  max_word_size: int = 100,
 
 
172
  **kwargs,
173
  ):
174
  super().__init__(**kwargs)
@@ -181,6 +198,12 @@ class HATArchitectureConfig(PretrainedConfig):
181
  self.special_token_dict = special_token_dict
182
  self.transformers_version = "4.46.3"
183
 
 
 
 
 
 
 
184
  @classmethod
185
  def from_dict(cls, config_dict, **kwargs):
186
  """
@@ -202,14 +225,25 @@ class HATArchitectureConfig(PretrainedConfig):
202
  decoder_config = DecoderHATModelConfig.from_dict(decoder_dict) if decoder_dict else None
203
  special_token_dict = config_dict.pop("special_token_dict", {"<|eot_id|>": 192})
204
  max_word_size = config_dict.pop("max_word_size", 100)
205
- return cls(
206
- encoder_config=encoder_config,
207
- backbone_config=backbone_config,
208
- decoder_config=decoder_config,
209
- special_token_dict=special_token_dict,
210
- max_word_size=max_word_size,
211
- **config_dict,
212
- ), {}
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  def to_dict(self):
215
  config_dict = {}
@@ -223,6 +257,13 @@ class HATArchitectureConfig(PretrainedConfig):
223
  config_dict["transformers_version"] = self.transformers_version
224
  config_dict["auto_map"] = {"AutoConfig": "config.HATArchitectureConfig", "AutoModelForCausalLM": "model.HATForCausalLM"}
225
  config_dict["special_token_dict"] = self.special_token_dict
 
 
 
 
 
 
 
226
  return config_dict
227
 
228
 
 
23
  sliding_window: int | None = None,
24
  vocab_size: int = 0,
25
  hidden_act: str = "silu",
26
+ key_query_norm: bool = False,
27
+ key_query_norm_per_head: bool = False,
28
+ is_neox_style: bool = True,
29
  **kwargs,
30
  ):
31
  super().__init__(
 
46
  )
47
 
48
  self.sliding_window = sliding_window
49
+ self.key_query_norm = key_query_norm
50
+ self.key_query_norm_per_head = key_query_norm_per_head
51
+ self.is_neox_style = is_neox_style
52
 
53
  def to_dict(self):
54
  config_dict = {
 
66
  "use_cache": self.use_cache,
67
  "sliding_window": self.sliding_window,
68
  "transformers_version": self.transformers_version,
69
+ "key_query_norm": self.key_query_norm,
70
+ "key_query_norm_per_head": self.key_query_norm_per_head,
71
+ "is_neox_style": self.is_neox_style,
72
  }
73
  return config_dict
74
 
 
83
  num_attention_heads: int,
84
  attention_num_kv_heads: int,
85
  word_window_size: int,
86
+ key_query_norm: bool,
87
+ key_query_norm_per_head: bool,
88
  ):
89
  self.hidden_size = hidden_size
90
  self.hidden_size_q = hidden_size_q
 
92
  self.num_attention_heads = num_attention_heads
93
  self.attention_num_kv_heads = attention_num_kv_heads
94
  self.word_window_size = word_window_size
95
+ self.key_query_norm = key_query_norm
96
+ self.key_query_norm_per_head = key_query_norm_per_head
97
 
98
  def to_dict(self):
99
  return {
 
103
  "num_attention_heads": self.num_attention_heads,
104
  "attention_num_kv_heads": self.attention_num_kv_heads,
105
  "word_window_size": self.word_window_size,
106
+ "key_query_norm": self.key_query_norm,
107
+ "key_query_norm_per_head": self.key_query_norm_per_head,
108
  }
109
 
110
 
 
173
 
174
  @dataclass
175
  class HATArchitectureConfig(PretrainedConfig):
176
+ model_type: str = "hierarchical_autoregressive_transformer"
177
 
178
  def __init__(
179
  self,
180
+ special_token_dict: dict | None = None,
181
  encoder_config: EncoderHATModelConfig | None = None,
182
  backbone_config: TransformerHATModelConfig | None = None,
183
  decoder_config: DecoderHATModelConfig | None = None,
184
  model_type: str = "hierarchical_autoregressive_transformer",
185
  eos_token_id: int = 192,
186
  max_word_size: int = 100,
187
+ sliding_window: int = 768,
188
+ max_position_embeddings: int = 262144,
189
  **kwargs,
190
  ):
191
  super().__init__(**kwargs)
 
198
  self.special_token_dict = special_token_dict
199
  self.transformers_version = "4.46.3"
200
 
201
+ # set these for out of the box vllm inference
202
+ self.architectures = ["HATDecoderForCausalLM"]
203
+ self.sliding_window = sliding_window
204
+ self.max_position_embeddings = max_position_embeddings
205
+ self.torch_dtype = "bfloat16"
206
+
207
  @classmethod
208
  def from_dict(cls, config_dict, **kwargs):
209
  """
 
225
  decoder_config = DecoderHATModelConfig.from_dict(decoder_dict) if decoder_dict else None
226
  special_token_dict = config_dict.pop("special_token_dict", {"<|eot_id|>": 192})
227
  max_word_size = config_dict.pop("max_word_size", 100)
228
+ return_unused_kwargs = config_dict.pop("return_unused_kwargs", False)
229
+ if return_unused_kwargs:
230
+ return cls(
231
+ encoder_config=encoder_config,
232
+ backbone_config=backbone_config,
233
+ decoder_config=decoder_config,
234
+ special_token_dict=special_token_dict,
235
+ max_word_size=max_word_size,
236
+ **config_dict,
237
+ ), {}
238
+ else:
239
+ return cls(
240
+ encoder_config=encoder_config,
241
+ backbone_config=backbone_config,
242
+ decoder_config=decoder_config,
243
+ special_token_dict=special_token_dict,
244
+ max_word_size=max_word_size,
245
+ **config_dict,
246
+ )
247
 
248
  def to_dict(self):
249
  config_dict = {}
 
257
  config_dict["transformers_version"] = self.transformers_version
258
  config_dict["auto_map"] = {"AutoConfig": "config.HATArchitectureConfig", "AutoModelForCausalLM": "model.HATForCausalLM"}
259
  config_dict["special_token_dict"] = self.special_token_dict
260
+
261
+ # print these out to the config for vllm
262
+ config_dict["max_word_size"] = self.max_word_size
263
+ config_dict["sliding_window"] = self.sliding_window
264
+ config_dict["max_position_embeddings"] = self.max_position_embeddings
265
+ config_dict["torch_dtype"] = self.torch_dtype
266
+ config_dict["architectures"] = self.architectures
267
  return config_dict
268
 
269
 
model.py CHANGED
@@ -10,10 +10,6 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_func
10
  from transformers import PreTrainedModel
11
  from transformers.cache_utils import Cache, DynamicCache
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
- from transformers.models.llama.modeling_llama import (
14
- LlamaDecoderLayer,
15
- LlamaRotaryEmbedding,
16
- )
17
  from transformers.utils import ModelOutput
18
 
19
  from .config import (
@@ -24,6 +20,11 @@ from .config import (
24
  TransformerHATModelConfig,
25
  )
26
  from .splitter import HATSplitter
 
 
 
 
 
27
 
28
  try:
29
  transformers_version = version("transformers")
@@ -37,10 +38,9 @@ def sample_argmax(logits: torch.Tensor) -> torch.Tensor:
37
  return torch.argmax(logits, dim=-1)[:, -1]
38
 
39
 
40
- LLAMA_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n
41
- You are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n
42
- {input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""
43
-
44
 
45
 
46
  class HATCache(Cache):
@@ -173,26 +173,6 @@ class HATDecoderConnector(nn.Module):
173
  return activations
174
 
175
 
176
- class RMSNorm(nn.Module):
177
- def __init__(self, dimensions: int, eps: float, device: torch.device, dtype: torch.dtype = torch.bfloat16, norm_in_fp32: bool = False):
178
- super().__init__()
179
- self.eps = eps
180
- self.weight = torch.nn.Parameter(torch.ones(dimensions, dtype=dtype).to(device))
181
- self.norm_in_fp32 = norm_in_fp32
182
-
183
- def forward(self, x: torch.Tensor) -> torch.Tensor:
184
- original_dtype = x.dtype
185
- if self.norm_in_fp32:
186
- x = x.float()
187
-
188
- out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
189
-
190
- if out.dtype != original_dtype:
191
- out = out.to(original_dtype)
192
-
193
- return out * self.weight
194
-
195
-
196
  class HATDecoderBlock(nn.Module):
197
  def __init__(
198
  self,
@@ -354,6 +334,8 @@ class HATCrossAttention(nn.Module):
354
  self.num_key_value_heads = cross_attention_config.attention_num_kv_heads
355
  self.num_repeat_kv = cross_attention_config.num_attention_heads // cross_attention_config.attention_num_kv_heads
356
  self.head_dim = hidden_size // self.num_heads
 
 
357
 
358
  self.q_proj = nn.Linear(
359
  in_features=hidden_size_q,
@@ -376,6 +358,30 @@ class HATCrossAttention(nn.Module):
376
  bias=False,
377
  )
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  self.o_proj = nn.Linear(in_features=hidden_size, out_features=hidden_size_q, dtype=dtype, bias=False)
380
 
381
  rope_theta = config.rope_theta
@@ -402,6 +408,34 @@ class HATCrossAttention(nn.Module):
402
  key_states = self.k_proj(kv_activations)
403
  value_states = self.v_proj(kv_activations)
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  # TODO get rid of the double rearrange, this is just for compatibility with scaling
406
  query_states = rearrange(query_states, "bsz seq_len (h d) -> bsz h seq_len d", h=self.num_heads)
407
  key_states = rearrange(
@@ -489,7 +523,6 @@ class HATEncoderConnector(nn.Module):
489
  device=self.latent_query.device,
490
  dtype=torch.int32,
491
  )
492
-
493
  word_embeddings = self.cross_attention_encoder_connector.forward(
494
  q_activations=latent_query_repeated,
495
  kv_activations=hidden_states,
@@ -609,7 +642,7 @@ class HATForCausalLM(PreTrainedModel):
609
  backbone_past_key_values = past_key_values.get_backbone_cache() if past_key_values is not None else None
610
  decoder_past_key_values = past_key_values.get_decoder_cache() if past_key_values is not None else None
611
 
612
- encoder_output: BaseModelOutputWithPast = self.encoder.forward(
613
  input_ids=input_ids,
614
  cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
615
  byte_position_ids=byte_position_ids,
@@ -619,13 +652,13 @@ class HATForCausalLM(PreTrainedModel):
619
  )
620
  byte_level_activations = encoder_output.hidden_states
621
 
622
- encoder_connector_output = self.encoder_connector.forward(
623
  byte_level_activations,
624
  cumulative_seq_lengths_per_word,
625
  word_position_ids,
626
  byte_position_ids,
627
  )
628
- backbone_output: CausalLMOutputWithPast = self.backbone.forward(
629
  hidden_states=encoder_connector_output,
630
  position_ids=word_position_ids,
631
  past_key_values=backbone_past_key_values,
@@ -660,7 +693,7 @@ class HATForCausalLM(PreTrainedModel):
660
  def _append_byte(self, words: list[list[int]], token: int) -> list[list[int]]:
661
  extended_last_word = words.pop() + [token]
662
  try:
663
- text = self.splitter.decode(extended_last_word, errors="strict", skip_special_tokens=False)
664
  list_of_bytes = self.splitter.encode(text)
665
  words.extend([list(word_in_bytes) for word_in_bytes in list_of_bytes])
666
  except UnicodeDecodeError:
@@ -669,70 +702,20 @@ class HATForCausalLM(PreTrainedModel):
669
  words.append(extended_last_word)
670
  return words
671
 
672
- def _split_encoder_activations(
673
- self,
674
- byte_encoder_activations: torch.Tensor,
675
- words: list[list[int]],
676
- previous_encoder_activations: torch.Tensor | None = None,
677
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
678
- """Split encoder activations between first word and next word.
679
-
680
- Args:
681
- byte_encoder_activations: Tensor of shape [batch_size, seq_len, hidden_size] containing all encoder activations which were computed in the current iteration
682
- words: List of word byte sequences which were completed in previous iteration and current iteration
683
- previous_encoder_activations: Optional tensor of shape [batch_size, prev_seq_len, hidden_size] containing precomputed activations from the previous iteration
684
-
685
- Returns:
686
- tuple containing:
687
- - first_word_encoder_activations: Tensor of shape [batch_size, first_word_len, hidden_size]
688
- - next_word_encoder_activations: Tensor of shape [batch_size, remaining_len, hidden_size]
689
- """
690
-
691
- assert sum(len(word) for word in words) - 1 == byte_encoder_activations.shape[1] + (previous_encoder_activations.shape[1] if previous_encoder_activations is not None else 0), "Length of (words - 1) must match the sum of byte_encoder_activations and previous_encoder_activations dimensions"
692
-
693
- next_word_encoder_activations = None
694
- if previous_encoder_activations is not None:
695
- # We have already precomputed first word's encoder activations partially in the previous iteration
696
- new_bytes_of_first_words = len(words[0]) - previous_encoder_activations.shape[1]
697
- # Concatenate the precomputed activations with the new activations that still belong to the first word
698
- first_word_encoder_activations = torch.cat([previous_encoder_activations, byte_encoder_activations[:, :new_bytes_of_first_words]], dim=1)
699
- if len(words[1]) > 1:
700
- # The remaining activations that belong to the next word
701
- next_word_encoder_activations = byte_encoder_activations[:, new_bytes_of_first_words:]
702
- else:
703
- next_word_encoder_activations = None
704
- else:
705
- # We have not precomputed any activations for the first word previously
706
- first_word_encoder_activations = byte_encoder_activations[:, : len(words[0])]
707
-
708
- if len(words[1]) > 1:
709
- next_word_encoder_activations = byte_encoder_activations[:, len(words[0]) :]
710
- else:
711
- next_word_encoder_activations = None
712
-
713
- return first_word_encoder_activations, next_word_encoder_activations
714
-
715
  def _complete_word(
716
  self,
717
  input_ids: torch.Tensor,
718
  byte_position_ids: torch.Tensor,
719
- predictive_word_embeddings: torch.Tensor,
720
  word_position_id: torch.Tensor,
721
  encoder_cache: DynamicCache,
722
  decoder_cache: DynamicCache,
723
  sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
724
- previous_encoder_activations: torch.Tensor | None = None,
725
  ):
726
  """Generate byte tokens until we hit the first byte of a new word."""
727
- words: list[list[int]] = [input_ids.squeeze(0).tolist()]
728
- byte_encoder_activations: list[torch.Tensor] = []
729
- completion_logits: list[torch.Tensor] = []
730
-
731
- if previous_encoder_activations is not None:
732
- # we need to pass all inputs in order to get the correct encoding/decoding by the splitter
733
- # but only the last byte is used for the generation
734
- # since the cache is already populated with the first word's activations
735
- input_ids = input_ids[:, -1:]
736
 
737
  while True:
738
  encoder_output = self.encoder.forward(
@@ -744,7 +727,7 @@ class HATForCausalLM(PreTrainedModel):
744
  )
745
  byte_encoder_activations.append(encoder_output.hidden_states)
746
  decoder_output = self.decoder.forward(
747
- predictive_word_embeddings,
748
  encoder_output.hidden_states,
749
  byte_position_ids=None,
750
  word_position_ids=word_position_id,
@@ -757,112 +740,22 @@ class HATForCausalLM(PreTrainedModel):
757
  next_byte = int(sample_fn(logits).item())
758
  words = self._append_byte(words, next_byte)
759
  if len(words) > 1 or next_byte == self.eos_token_id:
760
- byte_encoder_activations = torch.cat(byte_encoder_activations, dim=1)
761
- first_word_encoder_activations, next_word_encoder_activations = self._split_encoder_activations(
762
- byte_encoder_activations,
763
- words,
764
- previous_encoder_activations,
765
- )
766
  break
767
  input_ids = torch.tensor([[next_byte]], dtype=input_ids.dtype, device=input_ids.device)
768
 
 
769
  num_kv = encoder_cache.get_seq_length()
770
-
771
- completion = sum(words, [])[-len(completion_logits) :]
772
- if next_word_encoder_activations is not None:
773
- start_idx = num_kv - first_word_encoder_activations.shape[1] - next_word_encoder_activations.shape[1]
774
- end_idx = num_kv - next_word_encoder_activations.shape[1]
775
- # We do not want to return the logits for the second word went into the mulitbyte starting character case
776
- # When that happens we remove the logits and post-hoc fix the decoder cache and compute new logits
777
- # This is breaking causality but we want to imitate uncached generation/training behavior
778
- completion_logits = completion_logits[:-next_word_encoder_activations.shape[1]]
779
- else:
780
- start_idx = num_kv - first_word_encoder_activations.shape[1]
781
- end_idx = num_kv
782
-
783
- byte_position_ids = torch.arange(start_idx, end_idx, device=input_ids.device, dtype=torch.long).unsqueeze(0)
784
  completed_word_embedding = self.encoder_connector.forward(
785
- first_word_encoder_activations,
786
- cumulative_seq_lengths_per_word=torch.tensor([0, first_word_encoder_activations.size(1)], dtype=torch.int32, device=input_ids.device),
787
  word_position_ids=word_position_id,
788
  byte_position_ids=byte_position_ids,
789
  )
790
 
791
- bytes_of_next_word = words[1]
792
-
793
- return (
794
- completion,
795
- completed_word_embedding,
796
- bytes_of_next_word,
797
- byte_position_ids[:, -1].item() + 1,
798
- completion_logits,
799
- next_word_encoder_activations,
800
- )
801
-
802
- def _populate_cache(
803
- self,
804
- input_ids: torch.Tensor,
805
- cumulative_seq_lengths_per_word: torch.Tensor,
806
- byte_position_ids: torch.Tensor,
807
- word_position_ids: torch.Tensor,
808
- ):
809
- last_word_start = cumulative_seq_lengths_per_word[-2]
810
- last_word_end = cumulative_seq_lengths_per_word[-1]
811
-
812
- # Populate cache with everything except last word
813
- initial_forward_output = self.forward(
814
- input_ids=input_ids[:, :last_word_start],
815
- cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word[:-1],
816
- byte_position_ids=byte_position_ids[:, :last_word_start],
817
- word_position_ids=word_position_ids[:, :-1],
818
- past_key_values=None,
819
- use_cache=True,
820
- )
821
- return initial_forward_output, last_word_start, last_word_end
822
-
823
- def _initialize_generation_state(
824
- self,
825
- input_ids: torch.Tensor,
826
- max_new_tokens: int,
827
- cumulative_seq_lengths_per_word: torch.Tensor,
828
- byte_position_ids: torch.Tensor | None = None,
829
- word_position_ids: torch.Tensor | None = None,
830
- ):
831
- max_total_bytes = max_new_tokens + input_ids.shape[1]
832
- if byte_position_ids is None:
833
- byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0)
834
-
835
- if word_position_ids is None:
836
- word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
837
-
838
- initial_forward_output, last_word_start, last_word_end = self._populate_cache(
839
- input_ids=input_ids,
840
- cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
841
- byte_position_ids=byte_position_ids,
842
- word_position_ids=word_position_ids,
843
- )
844
-
845
- completion_bytes: list[int] = []
846
- completion_logits: list[torch.Tensor] = []
847
- # Slice input_ids and byte_position_ids to only contain the last word for the generation loop
848
- current_input_ids = input_ids[:, last_word_start:last_word_end]
849
- next_byte_id = last_word_end.item() # Ensure this is an int
850
- current_byte_position_ids = byte_position_ids[:, last_word_start:last_word_end]
851
- current_word_position_id = word_position_ids[:, -1].unsqueeze(-1)
852
- backbone_last_hidden_state = initial_forward_output.hidden_states[:, -1:, :]
853
- next_word_encoder_activations = None
854
- return (
855
- initial_forward_output,
856
- completion_bytes,
857
- completion_logits,
858
- current_input_ids,
859
- next_byte_id,
860
- current_byte_position_ids,
861
- current_word_position_id,
862
- backbone_last_hidden_state,
863
- next_word_encoder_activations,
864
- max_total_bytes,
865
- )
866
 
867
  def generate(
868
  self,
@@ -898,20 +791,6 @@ class HATForCausalLM(PreTrainedModel):
898
  completion_logits=completion_logits,
899
  )
900
 
901
- def _fix_decoder_cache(self, predictive_word_embeddings: torch.Tensor, encoder_activions: torch.Tensor, decoder_cache: DynamicCache, word_position_id: torch.Tensor):
902
- decoder_cache.crop(decoder_cache.get_seq_length() - encoder_activions.shape[1])
903
- real_decoder_logits = self.decoder.forward(
904
- predictive_word_embeddings,
905
- encoder_activions,
906
- byte_position_ids=None,
907
- word_position_ids=word_position_id,
908
- past_key_values=decoder_cache,
909
- ).last_hidden_state
910
-
911
- decoder_output = self.layer_norm(real_decoder_logits)
912
- logits = self.lm_head(decoder_output)
913
- return logits
914
-
915
  @torch.no_grad()
916
  def _generate_cached(
917
  self,
@@ -923,35 +802,43 @@ class HATForCausalLM(PreTrainedModel):
923
  sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
924
  stop_sequences: Sequence[str] | None = None,
925
  ):
926
- (
927
- initial_forward_output,
928
- completion_bytes, # empty list
929
- completion_logits, # empty list
930
- input_ids, # This is now the sliced input_ids for the last word
931
- next_byte_id,
932
- byte_position_ids, # This is now the sliced byte_position_ids for the last word
933
- word_position_id,
934
- backbone_last_hidden_state,
935
- next_word_encoder_activations, # None for the first iteration
936
- max_total_bytes,
937
- ) = self._initialize_generation_state(
938
- input_ids=input_ids,
939
- max_new_tokens=max_new_tokens,
940
- cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
941
- byte_position_ids=byte_position_ids,
942
- word_position_ids=word_position_ids,
 
 
943
  )
944
 
 
 
 
 
 
 
 
945
  while next_byte_id < max_total_bytes:
946
- completion, completed_word_embedding, bytes_of_next_word, next_byte_id, next_completion_logits, next_word_encoder_activations = self._complete_word(
947
  input_ids=input_ids,
948
  byte_position_ids=byte_position_ids,
949
- predictive_word_embeddings=backbone_last_hidden_state,
950
  word_position_id=word_position_id,
951
  encoder_cache=initial_forward_output.past_key_values.get_encoder_cache(),
952
  decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
953
  sample_fn=sample_fn,
954
- previous_encoder_activations=next_word_encoder_activations,
955
  )
956
  completion_logits.extend(next_completion_logits)
957
  completion_bytes.extend(completion)
@@ -976,19 +863,11 @@ class HATForCausalLM(PreTrainedModel):
976
  )
977
  backbone_last_hidden_state = backbone_output.hidden_states[:, -1, :].unsqueeze(1)
978
 
979
- word_position_id = word_position_id + 1
980
- if len(bytes_of_next_word) > 1:
981
- real_decoder_logits = self._fix_decoder_cache(
982
- predictive_word_embeddings=backbone_last_hidden_state,
983
- encoder_activions=next_word_encoder_activations,
984
- decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
985
- word_position_id=word_position_id,
986
- )
987
- completion_logits.extend(real_decoder_logits)
988
-
989
- input_ids = torch.tensor([bytes_of_next_word], dtype=input_ids.dtype, device=input_ids.device)
990
  byte_position_ids = torch.tensor([[next_byte_id]], dtype=input_ids.dtype, device=input_ids.device)
 
991
 
 
992
  completion_bytes = completion_bytes[:max_new_tokens]
993
  completion_logits = torch.cat(completion_logits[:max_new_tokens], dim=0)
994
  completion_text = self.splitter.decode(completion_bytes)
@@ -1003,7 +882,7 @@ class HATForCausalLM(PreTrainedModel):
1003
  cumulative_seq_lengths_per_word: torch.Tensor,
1004
  byte_position_ids: torch.Tensor | None = None,
1005
  word_position_ids: torch.Tensor | None = None,
1006
- sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
1007
  stop_sequences: Sequence[str] | None = None,
1008
  ):
1009
  if byte_position_ids is None:
 
10
  from transformers import PreTrainedModel
11
  from transformers.cache_utils import Cache, DynamicCache
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
 
 
 
13
  from transformers.utils import ModelOutput
14
 
15
  from .config import (
 
20
  TransformerHATModelConfig,
21
  )
22
  from .splitter import HATSplitter
23
+ from .norm import RMSNorm
24
+ from .transformer_backbone import (
25
+ LlamaDecoderLayer,
26
+ LlamaRotaryEmbedding,
27
+ )
28
 
29
  try:
30
  transformers_version = version("transformers")
 
38
  return torch.argmax(logits, dim=-1)[:, -1]
39
 
40
 
41
+ LLAMA_TEMPLATE = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
42
+ You are a helpful assistant. You give engaging, well-structured answers to user inquiries.<|eot_id|><|start_header_id|>user<|end_header_id|>
43
+ {input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
 
44
 
45
 
46
  class HATCache(Cache):
 
173
  return activations
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  class HATDecoderBlock(nn.Module):
177
  def __init__(
178
  self,
 
334
  self.num_key_value_heads = cross_attention_config.attention_num_kv_heads
335
  self.num_repeat_kv = cross_attention_config.num_attention_heads // cross_attention_config.attention_num_kv_heads
336
  self.head_dim = hidden_size // self.num_heads
337
+ self.key_query_norm = cross_attention_config.key_query_norm
338
+ self.key_query_norm_per_head = cross_attention_config.key_query_norm_per_head
339
 
340
  self.q_proj = nn.Linear(
341
  in_features=hidden_size_q,
 
358
  bias=False,
359
  )
360
 
361
+ if self.key_query_norm:
362
+ if self.key_query_norm_per_head:
363
+ # Both query and key have head dim equal to self.hidden_size_per_attention_head
364
+ query_norm_dimensions = self.head_dim
365
+ key_norm_dimensions = self.head_dim
366
+ else:
367
+ # Query dimensions across head is equal to hidden_size but key dimensions are divided
368
+ # by self.num_repeat_kv
369
+ query_norm_dimensions = self.hidden_size
370
+ key_norm_dimensions = self.hidden_size // self.num_repeat_kv
371
+
372
+ self.norm_query = RMSNorm(
373
+ dimensions=query_norm_dimensions,
374
+ eps=config.rms_norm_eps,
375
+ device=self.q_proj.weight.device,
376
+ dtype=dtype,
377
+ )
378
+ self.norm_key = RMSNorm(
379
+ dimensions=key_norm_dimensions,
380
+ eps=config.rms_norm_eps,
381
+ device=self.q_proj.weight.device,
382
+ dtype=dtype,
383
+ )
384
+
385
  self.o_proj = nn.Linear(in_features=hidden_size, out_features=hidden_size_q, dtype=dtype, bias=False)
386
 
387
  rope_theta = config.rope_theta
 
408
  key_states = self.k_proj(kv_activations)
409
  value_states = self.v_proj(kv_activations)
410
 
411
+ if self.key_query_norm:
412
+ assert self.norm_query is not None
413
+ assert self.norm_key is not None
414
+ # query_states and key_states are bsz seq_len (h d)
415
+ if self.key_query_norm_per_head:
416
+ # for per head qk norm we need head dim to be the last dim
417
+ query_states = rearrange(
418
+ query_states,
419
+ "bsz seq_len (h d) -> bsz seq_len h d",
420
+ h=self.num_heads,
421
+ )
422
+ key_states = rearrange(
423
+ key_states,
424
+ "bsz seq_len (h d) -> bsz seq_len h d",
425
+ h=self.num_key_value_heads,
426
+ )
427
+ query_states = self.norm_query(query_states)
428
+ key_states = self.norm_key(key_states)
429
+ if self.key_query_norm_per_head:
430
+ query_states = rearrange(
431
+ query_states,
432
+ "bsz seq_len h d -> bsz seq_len (h d)",
433
+ )
434
+ key_states = rearrange(
435
+ key_states,
436
+ "bsz seq_len h d -> bsz seq_len (h d)",
437
+ )
438
+
439
  # TODO get rid of the double rearrange, this is just for compatibility with scaling
440
  query_states = rearrange(query_states, "bsz seq_len (h d) -> bsz h seq_len d", h=self.num_heads)
441
  key_states = rearrange(
 
523
  device=self.latent_query.device,
524
  dtype=torch.int32,
525
  )
 
526
  word_embeddings = self.cross_attention_encoder_connector.forward(
527
  q_activations=latent_query_repeated,
528
  kv_activations=hidden_states,
 
642
  backbone_past_key_values = past_key_values.get_backbone_cache() if past_key_values is not None else None
643
  decoder_past_key_values = past_key_values.get_decoder_cache() if past_key_values is not None else None
644
 
645
+ encoder_output: BaseModelOutputWithPast = self.encoder(
646
  input_ids=input_ids,
647
  cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word,
648
  byte_position_ids=byte_position_ids,
 
652
  )
653
  byte_level_activations = encoder_output.hidden_states
654
 
655
+ encoder_connector_output = self.encoder_connector(
656
  byte_level_activations,
657
  cumulative_seq_lengths_per_word,
658
  word_position_ids,
659
  byte_position_ids,
660
  )
661
+ backbone_output: CausalLMOutputWithPast = self.backbone(
662
  hidden_states=encoder_connector_output,
663
  position_ids=word_position_ids,
664
  past_key_values=backbone_past_key_values,
 
693
  def _append_byte(self, words: list[list[int]], token: int) -> list[list[int]]:
694
  extended_last_word = words.pop() + [token]
695
  try:
696
+ text = self.splitter.decode(extended_last_word, errors='strict', skip_special_tokens=False)
697
  list_of_bytes = self.splitter.encode(text)
698
  words.extend([list(word_in_bytes) for word_in_bytes in list_of_bytes])
699
  except UnicodeDecodeError:
 
702
  words.append(extended_last_word)
703
  return words
704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  def _complete_word(
706
  self,
707
  input_ids: torch.Tensor,
708
  byte_position_ids: torch.Tensor,
709
+ backbone_word_prediction: torch.Tensor,
710
  word_position_id: torch.Tensor,
711
  encoder_cache: DynamicCache,
712
  decoder_cache: DynamicCache,
713
  sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
 
714
  ):
715
  """Generate byte tokens until we hit the first byte of a new word."""
716
+ words = [input_ids.squeeze(0).tolist()]
717
+ byte_encoder_activations = []
718
+ completion_logits = []
 
 
 
 
 
 
719
 
720
  while True:
721
  encoder_output = self.encoder.forward(
 
727
  )
728
  byte_encoder_activations.append(encoder_output.hidden_states)
729
  decoder_output = self.decoder.forward(
730
+ backbone_word_prediction,
731
  encoder_output.hidden_states,
732
  byte_position_ids=None,
733
  word_position_ids=word_position_id,
 
740
  next_byte = int(sample_fn(logits).item())
741
  words = self._append_byte(words, next_byte)
742
  if len(words) > 1 or next_byte == self.eos_token_id:
 
 
 
 
 
 
743
  break
744
  input_ids = torch.tensor([[next_byte]], dtype=input_ids.dtype, device=input_ids.device)
745
 
746
+ byte_encoder_activations = torch.cat(byte_encoder_activations, dim=1)
747
  num_kv = encoder_cache.get_seq_length()
748
+ byte_position_ids = torch.arange(num_kv + 1 - byte_encoder_activations.shape[1], num_kv + 1, device=input_ids.device, dtype=torch.long).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
749
  completed_word_embedding = self.encoder_connector.forward(
750
+ byte_encoder_activations,
751
+ cumulative_seq_lengths_per_word=torch.tensor([0, byte_encoder_activations.size(1)], dtype=torch.int32, device=input_ids.device),
752
  word_position_ids=word_position_id,
753
  byte_position_ids=byte_position_ids,
754
  )
755
 
756
+ completion = sum(words, [])[-len(completion_logits) :]
757
+ first_byte_of_next_word = words[1]
758
+ return completion, completed_word_embedding, first_byte_of_next_word, byte_position_ids[:, -1].item() + 1, completion_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
759
 
760
  def generate(
761
  self,
 
791
  completion_logits=completion_logits,
792
  )
793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
  @torch.no_grad()
795
  def _generate_cached(
796
  self,
 
802
  sample_fn: Callable[[torch.Tensor], torch.Tensor] = sample_argmax,
803
  stop_sequences: Sequence[str] | None = None,
804
  ):
805
+ max_total_bytes = max_new_tokens + input_ids.shape[1]
806
+ if byte_position_ids is None:
807
+ byte_position_ids = torch.arange(0, cumulative_seq_lengths_per_word[-1].item(), device=input_ids.device, dtype=torch.int32).unsqueeze(0)
808
+
809
+ if word_position_ids is None:
810
+ word_position_ids = torch.arange(0, cumulative_seq_lengths_per_word.shape[0] - 1, device=input_ids.device, dtype=torch.int32).unsqueeze(0)
811
+
812
+ last_word_start, last_word_end = (
813
+ cumulative_seq_lengths_per_word[-2],
814
+ cumulative_seq_lengths_per_word[-1],
815
+ )
816
+ # Populate cache with everything except last word
817
+ initial_forward_output = self.forward(
818
+ input_ids=input_ids[:, :last_word_start],
819
+ cumulative_seq_lengths_per_word=cumulative_seq_lengths_per_word[:-1],
820
+ byte_position_ids=byte_position_ids[:, :last_word_start],
821
+ word_position_ids=word_position_ids[:, :-1],
822
+ past_key_values=None,
823
+ use_cache=True,
824
  )
825
 
826
+ completion_bytes = []
827
+ completion_logits = []
828
+ input_ids = input_ids[:, last_word_start:last_word_end]
829
+ next_byte_id = last_word_end
830
+ byte_position_ids = byte_position_ids[:, last_word_start:last_word_end]
831
+ word_position_id = word_position_ids[:, -1].unsqueeze(-1)
832
+ backbone_last_hidden_state = initial_forward_output.hidden_states[:, -1:, :]
833
  while next_byte_id < max_total_bytes:
834
+ completion, completed_word_embedding, first_byte_of_next_word, next_byte_id, next_completion_logits = self._complete_word(
835
  input_ids=input_ids,
836
  byte_position_ids=byte_position_ids,
837
+ backbone_word_prediction=backbone_last_hidden_state,
838
  word_position_id=word_position_id,
839
  encoder_cache=initial_forward_output.past_key_values.get_encoder_cache(),
840
  decoder_cache=initial_forward_output.past_key_values.get_decoder_cache(),
841
  sample_fn=sample_fn,
 
842
  )
843
  completion_logits.extend(next_completion_logits)
844
  completion_bytes.extend(completion)
 
863
  )
864
  backbone_last_hidden_state = backbone_output.hidden_states[:, -1, :].unsqueeze(1)
865
 
866
+ input_ids = torch.tensor([first_byte_of_next_word], dtype=input_ids.dtype, device=input_ids.device)
 
 
 
 
 
 
 
 
 
 
867
  byte_position_ids = torch.tensor([[next_byte_id]], dtype=input_ids.dtype, device=input_ids.device)
868
+ word_position_id = word_position_id + 1
869
 
870
+ completion_bytes.extend(first_byte_of_next_word)
871
  completion_bytes = completion_bytes[:max_new_tokens]
872
  completion_logits = torch.cat(completion_logits[:max_new_tokens], dim=0)
873
  completion_text = self.splitter.decode(completion_bytes)
 
882
  cumulative_seq_lengths_per_word: torch.Tensor,
883
  byte_position_ids: torch.Tensor | None = None,
884
  word_position_ids: torch.Tensor | None = None,
885
+ sample_fn=sample_argmax,
886
  stop_sequences: Sequence[str] | None = None,
887
  ):
888
  if byte_position_ids is None:
norm.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class RMSNorm(nn.Module):
5
+ def __init__(self, dimensions: int, eps: float, device: torch.device, dtype: torch.dtype = torch.bfloat16, norm_in_fp32: bool = False):
6
+ super().__init__()
7
+ self.eps = eps
8
+ self.weight = torch.nn.Parameter(torch.ones(dimensions, dtype=dtype).to(device))
9
+ self.norm_in_fp32 = norm_in_fp32
10
+
11
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
+ original_dtype = x.dtype
13
+ if self.norm_in_fp32:
14
+ x = x.float()
15
+
16
+ out = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
17
+
18
+ if out.dtype != original_dtype:
19
+ out = out.to(original_dtype)
20
+
21
+ return out * self.weight
splitter.py CHANGED
@@ -1,4 +1,5 @@
1
  import re
 
2
  from hat_splitter import HATSplitter as RustHATSplitter
3
 
4
 
@@ -7,15 +8,8 @@ class HATSplitter:
7
  self.hat_splitter = RustHATSplitter()
8
  self.max_word_size = max_word_size
9
  self.special_token_dict = special_token_dict
10
- self.special_token_replace: dict[int, list[int]] = {
11
- token: list(text.encode("utf-8")) for text, token in self.special_token_dict.items()
12
- }
13
- self.special_token_pattern = (
14
- re.compile(rf"({'|'.join(map(re.escape, special_token_dict.keys()))})")
15
- if special_token_dict
16
- else re.compile(r"(?!)")
17
- )
18
-
19
 
20
  def encode(self, text: str) -> list[list[int]]:
21
  chunks = []
@@ -26,7 +20,7 @@ class HATSplitter:
26
  else:
27
  chunks.extend(list(chunk) for chunk in self.hat_splitter.split_with_limit(str_chunk, self.max_word_size))
28
  return chunks
29
-
30
  def decode(self, token_ids: list[int], errors: str = "replace", skip_special_tokens: bool = False) -> str:
31
  assert isinstance(token_ids, list), "token_ids must be a list"
32
  assert all(isinstance(token_id, int) for token_id in token_ids), "token_ids must be a list of integers"
 
1
  import re
2
+
3
  from hat_splitter import HATSplitter as RustHATSplitter
4
 
5
 
 
8
  self.hat_splitter = RustHATSplitter()
9
  self.max_word_size = max_word_size
10
  self.special_token_dict = special_token_dict
11
+ self.special_token_replace: dict[int, list[int]] = {token: list(text.encode("utf-8")) for text, token in self.special_token_dict.items()}
12
+ self.special_token_pattern = re.compile(rf"({'|'.join(map(re.escape, special_token_dict.keys()))})") if special_token_dict else re.compile(r"(?!)")
 
 
 
 
 
 
 
13
 
14
  def encode(self, text: str) -> list[list[int]]:
15
  chunks = []
 
20
  else:
21
  chunks.extend(list(chunk) for chunk in self.hat_splitter.split_with_limit(str_chunk, self.max_word_size))
22
  return chunks
23
+
24
  def decode(self, token_ids: list[int], errors: str = "replace", skip_special_tokens: bool = False) -> str:
25
  assert isinstance(token_ids, list), "token_ids must be a list"
26
  assert all(isinstance(token_id, int) for token_id in token_ids), "token_ids must be a list of integers"
transformer_backbone.py ADDED
@@ -0,0 +1,1553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hacked in QK norm in LlamaDecoderLayer in transformers; keeping the version from
2
+ # transformers==4.46.3 since this keeps rotary within the decoder layer
3
+ # Source: https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/llama/modeling_llama.py#L400
4
+
5
+ # coding=utf-8
6
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
7
+ #
8
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
9
+ # and OPT implementations in this library. It has been modified from its
10
+ # original forms to accommodate minor architectural differences compared
11
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+ import math
25
+ from typing import List, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
+ from transformers.generation import GenerationMixin
35
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ CausalLMOutputWithPast,
40
+ QuestionAnsweringModelOutput,
41
+ SequenceClassifierOutputWithPast,
42
+ TokenClassifierOutput,
43
+ )
44
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
45
+ from transformers.modeling_utils import PreTrainedModel
46
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
47
+ from transformers.utils import (
48
+ add_code_sample_docstrings,
49
+ add_start_docstrings,
50
+ add_start_docstrings_to_model_forward,
51
+ is_flash_attn_greater_or_equal_2_10,
52
+ logging,
53
+ replace_return_docstrings,
54
+ )
55
+ from transformers.models.llama.configuration_llama import LlamaConfig
56
+
57
+ from .norm import RMSNorm
58
+
59
+ logger = logging.get_logger(__name__)
60
+
61
+ _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
62
+ _CONFIG_FOR_DOC = "LlamaConfig"
63
+
64
+
65
+ class LlamaRMSNorm(nn.Module):
66
+ def __init__(self, hidden_size, eps=1e-6):
67
+ """
68
+ LlamaRMSNorm is equivalent to T5LayerNorm
69
+ """
70
+ super().__init__()
71
+ self.weight = nn.Parameter(torch.ones(hidden_size))
72
+ self.variance_epsilon = eps
73
+
74
+ def forward(self, hidden_states):
75
+ input_dtype = hidden_states.dtype
76
+ hidden_states = hidden_states.to(torch.float32)
77
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
78
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
79
+ return self.weight * hidden_states.to(input_dtype)
80
+
81
+ def extra_repr(self):
82
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
83
+
84
+
85
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
86
+
87
+
88
+ class LlamaRotaryEmbedding(nn.Module):
89
+ def __init__(
90
+ self,
91
+ dim=None,
92
+ max_position_embeddings=2048,
93
+ base=10000,
94
+ device=None,
95
+ scaling_factor=1.0,
96
+ rope_type="default",
97
+ config: Optional[LlamaConfig] = None,
98
+ ):
99
+ super().__init__()
100
+ # TODO (joao): remove the `if` below, only used for BC
101
+ self.rope_kwargs = {}
102
+ if config is None:
103
+ logger.warning_once(
104
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
105
+ "`config` argument. All other arguments will be removed in v4.46"
106
+ )
107
+ self.rope_kwargs = {
108
+ "rope_type": rope_type,
109
+ "factor": scaling_factor,
110
+ "dim": dim,
111
+ "base": base,
112
+ "max_position_embeddings": max_position_embeddings,
113
+ }
114
+ self.rope_type = rope_type
115
+ self.max_seq_len_cached = max_position_embeddings
116
+ self.original_max_seq_len = max_position_embeddings
117
+ else:
118
+ # BC: "rope_type" was originally "type"
119
+ if config.rope_scaling is not None:
120
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
121
+ else:
122
+ self.rope_type = "default"
123
+ self.max_seq_len_cached = config.max_position_embeddings
124
+ self.original_max_seq_len = config.max_position_embeddings
125
+
126
+ self.config = config
127
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
128
+
129
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
130
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
131
+ self.original_inv_freq = self.inv_freq
132
+
133
+ def _dynamic_frequency_update(self, position_ids, device):
134
+ """
135
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
136
+ 1 - growing beyond the cached sequence length (allow scaling)
137
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
138
+ """
139
+ seq_len = torch.max(position_ids) + 1
140
+ if seq_len > self.max_seq_len_cached: # growth
141
+ inv_freq, self.attention_scaling = self.rope_init_fn(
142
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
143
+ )
144
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
145
+ self.max_seq_len_cached = seq_len
146
+
147
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
148
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
149
+ self.max_seq_len_cached = self.original_max_seq_len
150
+
151
+ @torch.no_grad()
152
+ def forward(self, x, position_ids):
153
+ if "dynamic" in self.rope_type:
154
+ self._dynamic_frequency_update(position_ids, device=x.device)
155
+
156
+ # Core RoPE block
157
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
158
+ position_ids_expanded = position_ids[:, None, :].float()
159
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
160
+ device_type = x.device.type
161
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
162
+ with torch.autocast(device_type=device_type, enabled=False):
163
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
164
+ emb = torch.cat((freqs, freqs), dim=-1)
165
+ cos = emb.cos()
166
+ sin = emb.sin()
167
+
168
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
169
+ cos = cos * self.attention_scaling
170
+ sin = sin * self.attention_scaling
171
+
172
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
173
+
174
+
175
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
176
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
177
+
178
+ def __init__(self, *args, **kwargs):
179
+ logger.warning_once(
180
+ "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
181
+ "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
182
+ )
183
+ kwargs["rope_type"] = "linear"
184
+ super().__init__(*args, **kwargs)
185
+
186
+
187
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
188
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
189
+
190
+ def __init__(self, *args, **kwargs):
191
+ logger.warning_once(
192
+ "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
193
+ "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
194
+ "__init__)."
195
+ )
196
+ kwargs["rope_type"] = "dynamic"
197
+ super().__init__(*args, **kwargs)
198
+
199
+
200
+ def rotate_half(x):
201
+ """Rotates half the hidden dims of the input."""
202
+ x1 = x[..., : x.shape[-1] // 2]
203
+ x2 = x[..., x.shape[-1] // 2 :]
204
+ return torch.cat((-x2, x1), dim=-1)
205
+
206
+
207
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
208
+ """Applies Rotary Position Embedding to the query and key tensors.
209
+
210
+ Args:
211
+ q (`torch.Tensor`): The query tensor.
212
+ k (`torch.Tensor`): The key tensor.
213
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
214
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
215
+ position_ids (`torch.Tensor`, *optional*):
216
+ Deprecated and unused.
217
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
218
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
219
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
220
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
221
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
222
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
223
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
224
+ Returns:
225
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
226
+ """
227
+ cos = cos.unsqueeze(unsqueeze_dim)
228
+ sin = sin.unsqueeze(unsqueeze_dim)
229
+ q_embed = (q * cos) + (rotate_half(q) * sin)
230
+ k_embed = (k * cos) + (rotate_half(k) * sin)
231
+ return q_embed, k_embed
232
+
233
+
234
+ class LlamaMLP(nn.Module):
235
+ def __init__(self, config):
236
+ super().__init__()
237
+ self.config = config
238
+ self.hidden_size = config.hidden_size
239
+ self.intermediate_size = config.intermediate_size
240
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
241
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
242
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
243
+ self.act_fn = ACT2FN[config.hidden_act]
244
+
245
+ def forward(self, x):
246
+ if self.config.pretraining_tp > 1:
247
+ slice = self.intermediate_size // self.config.pretraining_tp
248
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
249
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
250
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
251
+
252
+ gate_proj = torch.cat(
253
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
254
+ )
255
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
256
+
257
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
258
+ down_proj = [
259
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
260
+ ]
261
+ down_proj = sum(down_proj)
262
+ else:
263
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
264
+
265
+ return down_proj
266
+
267
+
268
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
269
+ """
270
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
271
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
272
+ """
273
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
274
+ if n_rep == 1:
275
+ return hidden_states
276
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
277
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
278
+
279
+
280
+ class LlamaAttention(nn.Module):
281
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
282
+
283
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
284
+ super().__init__()
285
+ self.config = config
286
+ self.layer_idx = layer_idx
287
+ if layer_idx is None:
288
+ logger.warning_once(
289
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
290
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
291
+ "when creating this class."
292
+ )
293
+
294
+ self.attention_dropout = config.attention_dropout
295
+ self.hidden_size = config.hidden_size
296
+ self.num_heads = config.num_attention_heads
297
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
298
+ self.num_key_value_heads = config.num_key_value_heads
299
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
300
+ self.max_position_embeddings = config.max_position_embeddings
301
+ self.rope_theta = config.rope_theta
302
+ self.is_causal = True
303
+ self.key_query_norm = config.key_query_norm
304
+ self.key_query_norm_per_head = config.key_query_norm_per_head
305
+
306
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
307
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
308
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
309
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
310
+
311
+ if self.key_query_norm:
312
+ if self.key_query_norm_per_head:
313
+ # Both query and key have head dim equal to self.hidden_size_per_attention_head
314
+ query_norm_dimensions = self.head_dim
315
+ key_norm_dimensions = self.head_dim
316
+ else:
317
+ # Query dimensions across head is equal to hidden_size but key dimensions are divided
318
+ # by self.num_repeat_kv
319
+ query_norm_dimensions = self.hidden_size
320
+ key_norm_dimensions = self.hidden_size // self.num_repeat_kv
321
+
322
+ # For numerical compatibility we use RMSNorm as it was used during training
323
+ self.norm_query = RMSNorm(
324
+ dimensions=query_norm_dimensions,
325
+ eps=config.rms_norm_eps,
326
+ device=self.q_proj.weight.device,
327
+ dtype=self.q_proj.weight.dtype,
328
+ )
329
+ self.norm_key = RMSNorm(
330
+ dimensions=key_norm_dimensions,
331
+ eps=config.rms_norm_eps,
332
+ device=self.q_proj.weight.device,
333
+ dtype=self.q_proj.weight.dtype,
334
+ )
335
+
336
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
337
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor,
342
+ attention_mask: Optional[torch.Tensor] = None,
343
+ position_ids: Optional[torch.LongTensor] = None,
344
+ past_key_value: Optional[Cache] = None,
345
+ output_attentions: bool = False,
346
+ use_cache: bool = False,
347
+ cache_position: Optional[torch.LongTensor] = None,
348
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
349
+ **kwargs,
350
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
351
+ bsz, q_len, _ = hidden_states.size()
352
+
353
+ if self.key_query_norm:
354
+ raise ValueError("QK norm not supported for eager attention, use flash_attention_2!")
355
+
356
+ if self.config.pretraining_tp > 1:
357
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
358
+ query_slices = self.q_proj.weight.split(
359
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
360
+ )
361
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
362
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
363
+
364
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
365
+ query_states = torch.cat(query_states, dim=-1)
366
+
367
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
368
+ key_states = torch.cat(key_states, dim=-1)
369
+
370
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
371
+ value_states = torch.cat(value_states, dim=-1)
372
+
373
+ else:
374
+ query_states = self.q_proj(hidden_states)
375
+ key_states = self.k_proj(hidden_states)
376
+ value_states = self.v_proj(hidden_states)
377
+
378
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
379
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
380
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
381
+
382
+ if position_embeddings is None:
383
+ logger.warning_once(
384
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
385
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
386
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
387
+ "removed and `position_embeddings` will be mandatory."
388
+ )
389
+ cos, sin = self.rotary_emb(value_states, position_ids)
390
+ else:
391
+ cos, sin = position_embeddings
392
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
393
+
394
+ if past_key_value is not None:
395
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
396
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
397
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
398
+
399
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
400
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
401
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
402
+
403
+ if attention_mask is not None: # no matter the length, we just slice it
404
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
405
+ attn_weights = attn_weights + causal_mask
406
+
407
+ # upcast attention to fp32
408
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
409
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
410
+ attn_output = torch.matmul(attn_weights, value_states)
411
+
412
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
413
+ raise ValueError(
414
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
415
+ f" {attn_output.size()}"
416
+ )
417
+
418
+ attn_output = attn_output.transpose(1, 2).contiguous()
419
+
420
+ attn_output = attn_output.reshape(bsz, q_len, -1)
421
+
422
+ if self.config.pretraining_tp > 1:
423
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
424
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
425
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
426
+ else:
427
+ attn_output = self.o_proj(attn_output)
428
+
429
+ if not output_attentions:
430
+ attn_weights = None
431
+
432
+ return attn_output, attn_weights, past_key_value
433
+
434
+
435
+ class LlamaFlashAttention2(LlamaAttention):
436
+ """
437
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
438
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
439
+ flash attention and deal with padding tokens in case the input contains any of them.
440
+ """
441
+
442
+ def __init__(self, *args, **kwargs):
443
+ super().__init__(*args, **kwargs)
444
+
445
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
446
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
447
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
448
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
449
+
450
+ def forward(
451
+ self,
452
+ hidden_states: torch.Tensor,
453
+ attention_mask: Optional[torch.LongTensor] = None,
454
+ position_ids: Optional[torch.LongTensor] = None,
455
+ past_key_value: Optional[Cache] = None,
456
+ output_attentions: bool = False,
457
+ use_cache: bool = False,
458
+ cache_position: Optional[torch.LongTensor] = None,
459
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
460
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
461
+ if isinstance(past_key_value, StaticCache):
462
+ raise ValueError(
463
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
464
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
465
+ )
466
+
467
+ output_attentions = False
468
+
469
+ bsz, q_len, _ = hidden_states.size()
470
+
471
+ query_states = self.q_proj(hidden_states)
472
+ key_states = self.k_proj(hidden_states)
473
+ value_states = self.v_proj(hidden_states)
474
+
475
+ if self.key_query_norm:
476
+ if not self.key_query_norm_per_head:
477
+ # norm the full hidden fdim
478
+ query_states = self.norm_query(query_states)
479
+ key_states = self.norm_key(key_states)
480
+
481
+
482
+ # Flash attention requires the input to have the shape
483
+ # batch_size x seq_length x head_dim x hidden_dim
484
+ # therefore we just need to keep the original shape
485
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
486
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
487
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
488
+
489
+ if self.key_query_norm:
490
+ if self.key_query_norm_per_head:
491
+ # norm each head (with shared weights)
492
+ query_states = self.norm_query(query_states)
493
+ key_states = self.norm_key(key_states)
494
+
495
+ if position_embeddings is None:
496
+ logger.warning_once(
497
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
498
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
499
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
500
+ "removed and `position_embeddings` will be mandatory."
501
+ )
502
+ cos, sin = self.rotary_emb(value_states, position_ids)
503
+ else:
504
+ cos, sin = position_embeddings
505
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
506
+
507
+ if past_key_value is not None:
508
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
509
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
510
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
511
+
512
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
513
+ # to be able to avoid many of these transpose/reshape/view.
514
+ query_states = query_states.transpose(1, 2)
515
+ key_states = key_states.transpose(1, 2)
516
+ value_states = value_states.transpose(1, 2)
517
+
518
+ dropout_rate = self.attention_dropout if self.training else 0.0
519
+
520
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
521
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
522
+ # cast them back in the correct dtype just to be sure everything works as expected.
523
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
524
+ # in fp32. (LlamaRMSNorm handles it correctly)
525
+
526
+ input_dtype = query_states.dtype
527
+ if input_dtype == torch.float32:
528
+ if torch.is_autocast_enabled():
529
+ target_dtype = torch.get_autocast_gpu_dtype()
530
+ # Handle the case where the model is quantized
531
+ elif hasattr(self.config, "_pre_quantization_dtype"):
532
+ target_dtype = self.config._pre_quantization_dtype
533
+ else:
534
+ target_dtype = self.q_proj.weight.dtype
535
+
536
+ logger.warning_once(
537
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
538
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
539
+ f" {target_dtype}."
540
+ )
541
+
542
+ query_states = query_states.to(target_dtype)
543
+ key_states = key_states.to(target_dtype)
544
+ value_states = value_states.to(target_dtype)
545
+
546
+ attn_output = _flash_attention_forward(
547
+ query_states,
548
+ key_states,
549
+ value_states,
550
+ attention_mask,
551
+ q_len,
552
+ position_ids=position_ids,
553
+ dropout=dropout_rate,
554
+ sliding_window=getattr(self, "sliding_window", None),
555
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
556
+ is_causal=self.is_causal,
557
+ )
558
+
559
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
560
+ attn_output = self.o_proj(attn_output)
561
+
562
+ if not output_attentions:
563
+ attn_weights = None
564
+
565
+ return attn_output, attn_weights, past_key_value
566
+
567
+
568
+ class LlamaSdpaAttention(LlamaAttention):
569
+ """
570
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
571
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
572
+ SDPA API.
573
+ """
574
+
575
+ # Adapted from LlamaAttention.forward
576
+ def forward(
577
+ self,
578
+ hidden_states: torch.Tensor,
579
+ attention_mask: Optional[torch.Tensor] = None,
580
+ position_ids: Optional[torch.LongTensor] = None,
581
+ past_key_value: Optional[Cache] = None,
582
+ output_attentions: bool = False,
583
+ use_cache: bool = False,
584
+ cache_position: Optional[torch.LongTensor] = None,
585
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
586
+ **kwargs,
587
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
588
+ if output_attentions:
589
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
590
+ logger.warning_once(
591
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
592
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
593
+ )
594
+ return super().forward(
595
+ hidden_states=hidden_states,
596
+ attention_mask=attention_mask,
597
+ position_ids=position_ids,
598
+ past_key_value=past_key_value,
599
+ output_attentions=output_attentions,
600
+ use_cache=use_cache,
601
+ cache_position=cache_position,
602
+ position_embeddings=position_embeddings,
603
+ )
604
+
605
+ bsz, q_len, _ = hidden_states.size()
606
+
607
+ query_states = self.q_proj(hidden_states)
608
+ key_states = self.k_proj(hidden_states)
609
+ value_states = self.v_proj(hidden_states)
610
+
611
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
612
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
613
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
614
+
615
+ if position_embeddings is None:
616
+ logger.warning_once(
617
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
618
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
619
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
620
+ "removed and `position_embeddings` will be mandatory."
621
+ )
622
+ cos, sin = self.rotary_emb(value_states, position_ids)
623
+ else:
624
+ cos, sin = position_embeddings
625
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
626
+
627
+ if past_key_value is not None:
628
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
629
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
630
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
631
+
632
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
633
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
634
+
635
+ causal_mask = attention_mask
636
+ if attention_mask is not None:
637
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
638
+
639
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
640
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
641
+ if query_states.device.type == "cuda" and causal_mask is not None:
642
+ query_states = query_states.contiguous()
643
+ key_states = key_states.contiguous()
644
+ value_states = value_states.contiguous()
645
+
646
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
647
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
648
+ is_causal = True if causal_mask is None and q_len > 1 else False
649
+
650
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
651
+ query_states,
652
+ key_states,
653
+ value_states,
654
+ attn_mask=causal_mask,
655
+ dropout_p=self.attention_dropout if self.training else 0.0,
656
+ is_causal=is_causal,
657
+ )
658
+
659
+ attn_output = attn_output.transpose(1, 2).contiguous()
660
+ attn_output = attn_output.view(bsz, q_len, -1)
661
+
662
+ attn_output = self.o_proj(attn_output)
663
+
664
+ return attn_output, None, past_key_value
665
+
666
+
667
+ LLAMA_ATTENTION_CLASSES = {
668
+ "eager": LlamaAttention,
669
+ "flash_attention_2": LlamaFlashAttention2,
670
+ "sdpa": LlamaSdpaAttention,
671
+ }
672
+
673
+
674
+ class LlamaDecoderLayer(nn.Module):
675
+ def __init__(self, config: LlamaConfig, layer_idx: int):
676
+ super().__init__()
677
+ self.hidden_size = config.hidden_size
678
+
679
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
680
+
681
+ self.mlp = LlamaMLP(config)
682
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
683
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
684
+
685
+ def forward(
686
+ self,
687
+ hidden_states: torch.Tensor,
688
+ attention_mask: Optional[torch.Tensor] = None,
689
+ position_ids: Optional[torch.LongTensor] = None,
690
+ past_key_value: Optional[Cache] = None,
691
+ output_attentions: Optional[bool] = False,
692
+ use_cache: Optional[bool] = False,
693
+ cache_position: Optional[torch.LongTensor] = None,
694
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
695
+ **kwargs,
696
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
697
+ """
698
+ Args:
699
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
700
+ attention_mask (`torch.FloatTensor`, *optional*):
701
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
702
+ query_sequence_length, key_sequence_length)` if default attention is used.
703
+ output_attentions (`bool`, *optional*):
704
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
705
+ returned tensors for more detail.
706
+ use_cache (`bool`, *optional*):
707
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
708
+ (see `past_key_values`).
709
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
710
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
711
+ Indices depicting the position of the input sequence tokens in the sequence
712
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
713
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
714
+ with `head_dim` being the embedding dimension of each attention head.
715
+ kwargs (`dict`, *optional*):
716
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
717
+ into the model
718
+ """
719
+ residual = hidden_states
720
+
721
+ hidden_states = self.input_layernorm(hidden_states)
722
+
723
+ # Self Attention
724
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
725
+ hidden_states=hidden_states,
726
+ attention_mask=attention_mask,
727
+ position_ids=position_ids,
728
+ past_key_value=past_key_value,
729
+ output_attentions=output_attentions,
730
+ use_cache=use_cache,
731
+ cache_position=cache_position,
732
+ position_embeddings=position_embeddings,
733
+ **kwargs,
734
+ )
735
+ hidden_states = residual + hidden_states
736
+
737
+ # Fully Connected
738
+ residual = hidden_states
739
+ hidden_states = self.post_attention_layernorm(hidden_states)
740
+ hidden_states = self.mlp(hidden_states)
741
+ hidden_states = residual + hidden_states
742
+
743
+ outputs = (hidden_states,)
744
+
745
+ if output_attentions:
746
+ outputs += (self_attn_weights,)
747
+
748
+ if use_cache:
749
+ outputs += (present_key_value,)
750
+
751
+ return outputs
752
+
753
+
754
+ LLAMA_START_DOCSTRING = r"""
755
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
756
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
757
+ etc.)
758
+
759
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
760
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
761
+ and behavior.
762
+
763
+ Parameters:
764
+ config ([`LlamaConfig`]):
765
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
766
+ load the weights associated with the model, only the configuration. Check out the
767
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
768
+ """
769
+
770
+
771
+ @add_start_docstrings(
772
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
773
+ LLAMA_START_DOCSTRING,
774
+ )
775
+ class LlamaPreTrainedModel(PreTrainedModel):
776
+ config_class = LlamaConfig
777
+ base_model_prefix = "model"
778
+ supports_gradient_checkpointing = True
779
+ _no_split_modules = ["LlamaDecoderLayer"]
780
+ _skip_keys_device_placement = ["past_key_values"]
781
+ _supports_flash_attn_2 = True
782
+ _supports_sdpa = True
783
+ _supports_cache_class = True
784
+ _supports_quantized_cache = True
785
+ _supports_static_cache = True
786
+
787
+ def _init_weights(self, module):
788
+ std = self.config.initializer_range
789
+ if isinstance(module, nn.Linear):
790
+ module.weight.data.normal_(mean=0.0, std=std)
791
+ if module.bias is not None:
792
+ module.bias.data.zero_()
793
+ elif isinstance(module, nn.Embedding):
794
+ module.weight.data.normal_(mean=0.0, std=std)
795
+ if module.padding_idx is not None:
796
+ module.weight.data[module.padding_idx].zero_()
797
+
798
+
799
+ LLAMA_INPUTS_DOCSTRING = r"""
800
+ Args:
801
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
802
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
803
+ it.
804
+
805
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
806
+ [`PreTrainedTokenizer.__call__`] for details.
807
+
808
+ [What are input IDs?](../glossary#input-ids)
809
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
810
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
811
+
812
+ - 1 for tokens that are **not masked**,
813
+ - 0 for tokens that are **masked**.
814
+
815
+ [What are attention masks?](../glossary#attention-mask)
816
+
817
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
818
+ [`PreTrainedTokenizer.__call__`] for details.
819
+
820
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
821
+ `past_key_values`).
822
+
823
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
824
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
825
+ information on the default strategy.
826
+
827
+ - 1 indicates the head is **not masked**,
828
+ - 0 indicates the head is **masked**.
829
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
830
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
831
+ config.n_positions - 1]`.
832
+
833
+ [What are position IDs?](../glossary#position-ids)
834
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
835
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
836
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
837
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
838
+
839
+ Two formats are allowed:
840
+ - a [`~cache_utils.Cache`] instance, see our
841
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
842
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
843
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
844
+ cache format.
845
+
846
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
847
+ legacy cache format will be returned.
848
+
849
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
850
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
851
+ of shape `(batch_size, sequence_length)`.
852
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
853
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
854
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
855
+ model's internal embedding lookup matrix.
856
+ use_cache (`bool`, *optional*):
857
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
858
+ `past_key_values`).
859
+ output_attentions (`bool`, *optional*):
860
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
861
+ tensors for more detail.
862
+ output_hidden_states (`bool`, *optional*):
863
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
864
+ more detail.
865
+ return_dict (`bool`, *optional*):
866
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
867
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
868
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
869
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
870
+ the complete sequence length.
871
+ """
872
+
873
+
874
+ @add_start_docstrings(
875
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
876
+ LLAMA_START_DOCSTRING,
877
+ )
878
+ class LlamaModel(LlamaPreTrainedModel):
879
+ """
880
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
881
+
882
+ Args:
883
+ config: LlamaConfig
884
+ """
885
+
886
+ def __init__(self, config: LlamaConfig):
887
+ super().__init__(config)
888
+ self.padding_idx = config.pad_token_id
889
+ self.vocab_size = config.vocab_size
890
+
891
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
892
+ self.layers = nn.ModuleList(
893
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
894
+ )
895
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
896
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
897
+ self.gradient_checkpointing = False
898
+
899
+ # Initialize weights and apply final processing
900
+ self.post_init()
901
+
902
+ def get_input_embeddings(self):
903
+ return self.embed_tokens
904
+
905
+ def set_input_embeddings(self, value):
906
+ self.embed_tokens = value
907
+
908
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
909
+ def forward(
910
+ self,
911
+ input_ids: torch.LongTensor = None,
912
+ attention_mask: Optional[torch.Tensor] = None,
913
+ position_ids: Optional[torch.LongTensor] = None,
914
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
915
+ inputs_embeds: Optional[torch.FloatTensor] = None,
916
+ use_cache: Optional[bool] = None,
917
+ output_attentions: Optional[bool] = None,
918
+ output_hidden_states: Optional[bool] = None,
919
+ return_dict: Optional[bool] = None,
920
+ cache_position: Optional[torch.LongTensor] = None,
921
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
922
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
923
+ output_hidden_states = (
924
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
925
+ )
926
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
927
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
928
+
929
+ if (input_ids is None) ^ (inputs_embeds is not None):
930
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
931
+
932
+ if self.gradient_checkpointing and self.training and use_cache:
933
+ logger.warning_once(
934
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
935
+ )
936
+ use_cache = False
937
+
938
+ if inputs_embeds is None:
939
+ inputs_embeds = self.embed_tokens(input_ids)
940
+
941
+ # kept for BC (non `Cache` `past_key_values` inputs)
942
+ return_legacy_cache = False
943
+ if use_cache and not isinstance(past_key_values, Cache):
944
+ return_legacy_cache = True
945
+ if past_key_values is None:
946
+ past_key_values = DynamicCache()
947
+ else:
948
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
949
+ logger.warning_once(
950
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
951
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
952
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
953
+ )
954
+
955
+ if cache_position is None:
956
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
957
+ cache_position = torch.arange(
958
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
959
+ )
960
+ if position_ids is None:
961
+ position_ids = cache_position.unsqueeze(0)
962
+
963
+ causal_mask = self._update_causal_mask(
964
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
965
+ )
966
+ hidden_states = inputs_embeds
967
+
968
+ # create position embeddings to be shared across the decoder layers
969
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
970
+
971
+ # decoder layers
972
+ all_hidden_states = () if output_hidden_states else None
973
+ all_self_attns = () if output_attentions else None
974
+ next_decoder_cache = None
975
+
976
+ for decoder_layer in self.layers:
977
+ if output_hidden_states:
978
+ all_hidden_states += (hidden_states,)
979
+
980
+ if self.gradient_checkpointing and self.training:
981
+ layer_outputs = self._gradient_checkpointing_func(
982
+ decoder_layer.__call__,
983
+ hidden_states,
984
+ causal_mask,
985
+ position_ids,
986
+ past_key_values,
987
+ output_attentions,
988
+ use_cache,
989
+ cache_position,
990
+ position_embeddings,
991
+ )
992
+ else:
993
+ layer_outputs = decoder_layer(
994
+ hidden_states,
995
+ attention_mask=causal_mask,
996
+ position_ids=position_ids,
997
+ past_key_value=past_key_values,
998
+ output_attentions=output_attentions,
999
+ use_cache=use_cache,
1000
+ cache_position=cache_position,
1001
+ position_embeddings=position_embeddings,
1002
+ )
1003
+
1004
+ hidden_states = layer_outputs[0]
1005
+
1006
+ if use_cache:
1007
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1008
+
1009
+ if output_attentions:
1010
+ all_self_attns += (layer_outputs[1],)
1011
+
1012
+ hidden_states = self.norm(hidden_states)
1013
+
1014
+ # add hidden states from the last decoder layer
1015
+ if output_hidden_states:
1016
+ all_hidden_states += (hidden_states,)
1017
+
1018
+ next_cache = next_decoder_cache if use_cache else None
1019
+ if return_legacy_cache:
1020
+ next_cache = next_cache.to_legacy_cache()
1021
+
1022
+ if not return_dict:
1023
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1024
+ return BaseModelOutputWithPast(
1025
+ last_hidden_state=hidden_states,
1026
+ past_key_values=next_cache,
1027
+ hidden_states=all_hidden_states,
1028
+ attentions=all_self_attns,
1029
+ )
1030
+
1031
+ def _update_causal_mask(
1032
+ self,
1033
+ attention_mask: torch.Tensor,
1034
+ input_tensor: torch.Tensor,
1035
+ cache_position: torch.Tensor,
1036
+ past_key_values: Cache,
1037
+ output_attentions: bool,
1038
+ ):
1039
+ if self.config._attn_implementation == "flash_attention_2":
1040
+ if attention_mask is not None and 0.0 in attention_mask:
1041
+ return attention_mask
1042
+ return None
1043
+
1044
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1045
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1046
+ # to infer the attention mask.
1047
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1048
+ using_static_cache = isinstance(past_key_values, StaticCache)
1049
+
1050
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1051
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1052
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1053
+ attention_mask,
1054
+ inputs_embeds=input_tensor,
1055
+ past_key_values_length=past_seen_tokens,
1056
+ is_training=self.training,
1057
+ ):
1058
+ return None
1059
+
1060
+ dtype, device = input_tensor.dtype, input_tensor.device
1061
+ sequence_length = input_tensor.shape[1]
1062
+ if using_static_cache:
1063
+ target_length = past_key_values.get_max_cache_shape()
1064
+ else:
1065
+ target_length = (
1066
+ attention_mask.shape[-1]
1067
+ if isinstance(attention_mask, torch.Tensor)
1068
+ else past_seen_tokens + sequence_length + 1
1069
+ )
1070
+
1071
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1072
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1073
+ attention_mask,
1074
+ sequence_length=sequence_length,
1075
+ target_length=target_length,
1076
+ dtype=dtype,
1077
+ device=device,
1078
+ cache_position=cache_position,
1079
+ batch_size=input_tensor.shape[0],
1080
+ )
1081
+
1082
+ if (
1083
+ self.config._attn_implementation == "sdpa"
1084
+ and attention_mask is not None
1085
+ and attention_mask.device.type == "cuda"
1086
+ and not output_attentions
1087
+ ):
1088
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1089
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1090
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1091
+ min_dtype = torch.finfo(dtype).min
1092
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1093
+
1094
+ return causal_mask
1095
+
1096
+ @staticmethod
1097
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1098
+ attention_mask: torch.Tensor,
1099
+ sequence_length: int,
1100
+ target_length: int,
1101
+ dtype: torch.dtype,
1102
+ device: torch.device,
1103
+ cache_position: torch.Tensor,
1104
+ batch_size: int,
1105
+ **kwargs,
1106
+ ):
1107
+ """
1108
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1109
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1110
+
1111
+ Args:
1112
+ attention_mask (`torch.Tensor`):
1113
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1114
+ `(batch_size, 1, query_length, key_value_length)`.
1115
+ sequence_length (`int`):
1116
+ The sequence length being processed.
1117
+ target_length (`int`):
1118
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1119
+ to account for the 0 padding, the part of the cache that is not filled yet.
1120
+ dtype (`torch.dtype`):
1121
+ The dtype to use for the 4D attention mask.
1122
+ device (`torch.device`):
1123
+ The device to plcae the 4D attention mask on.
1124
+ cache_position (`torch.Tensor`):
1125
+ Indices depicting the position of the input sequence tokens in the sequence.
1126
+ batch_size (`torch.Tensor`):
1127
+ Batch size.
1128
+ """
1129
+ if attention_mask is not None and attention_mask.dim() == 4:
1130
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1131
+ causal_mask = attention_mask
1132
+ else:
1133
+ min_dtype = torch.finfo(dtype).min
1134
+ causal_mask = torch.full(
1135
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1136
+ )
1137
+ if sequence_length != 1:
1138
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1139
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1140
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1141
+ if attention_mask is not None:
1142
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1143
+ mask_length = attention_mask.shape[-1]
1144
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1145
+ padding_mask = padding_mask == 0
1146
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1147
+ padding_mask, min_dtype
1148
+ )
1149
+
1150
+ return causal_mask
1151
+
1152
+
1153
+ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
1154
+ _tied_weights_keys = ["lm_head.weight"]
1155
+
1156
+ def __init__(self, config):
1157
+ super().__init__(config)
1158
+ self.model = LlamaModel(config)
1159
+ self.vocab_size = config.vocab_size
1160
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1161
+
1162
+ # Initialize weights and apply final processing
1163
+ self.post_init()
1164
+
1165
+ def get_input_embeddings(self):
1166
+ return self.model.embed_tokens
1167
+
1168
+ def set_input_embeddings(self, value):
1169
+ self.model.embed_tokens = value
1170
+
1171
+ def get_output_embeddings(self):
1172
+ return self.lm_head
1173
+
1174
+ def set_output_embeddings(self, new_embeddings):
1175
+ self.lm_head = new_embeddings
1176
+
1177
+ def set_decoder(self, decoder):
1178
+ self.model = decoder
1179
+
1180
+ def get_decoder(self):
1181
+ return self.model
1182
+
1183
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1184
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1185
+ def forward(
1186
+ self,
1187
+ input_ids: torch.LongTensor = None,
1188
+ attention_mask: Optional[torch.Tensor] = None,
1189
+ position_ids: Optional[torch.LongTensor] = None,
1190
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1191
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1192
+ labels: Optional[torch.LongTensor] = None,
1193
+ use_cache: Optional[bool] = None,
1194
+ output_attentions: Optional[bool] = None,
1195
+ output_hidden_states: Optional[bool] = None,
1196
+ return_dict: Optional[bool] = None,
1197
+ cache_position: Optional[torch.LongTensor] = None,
1198
+ num_logits_to_keep: int = 0,
1199
+ **loss_kwargs,
1200
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1201
+ r"""
1202
+ Args:
1203
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1204
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1205
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1206
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1207
+
1208
+ num_logits_to_keep (`int`, *optional*):
1209
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1210
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1211
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1212
+
1213
+ Returns:
1214
+
1215
+ Example:
1216
+
1217
+ ```python
1218
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1219
+
1220
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1221
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1222
+
1223
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1224
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1225
+
1226
+ >>> # Generate
1227
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1228
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1229
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1230
+ ```"""
1231
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1232
+ output_hidden_states = (
1233
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1234
+ )
1235
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1236
+
1237
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1238
+ outputs = self.model(
1239
+ input_ids=input_ids,
1240
+ attention_mask=attention_mask,
1241
+ position_ids=position_ids,
1242
+ past_key_values=past_key_values,
1243
+ inputs_embeds=inputs_embeds,
1244
+ use_cache=use_cache,
1245
+ output_attentions=output_attentions,
1246
+ output_hidden_states=output_hidden_states,
1247
+ return_dict=return_dict,
1248
+ cache_position=cache_position,
1249
+ )
1250
+
1251
+ hidden_states = outputs[0]
1252
+ if self.config.pretraining_tp > 1:
1253
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1254
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1255
+ logits = torch.cat(logits, dim=-1)
1256
+ else:
1257
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1258
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1259
+
1260
+ loss = None
1261
+ if labels is not None:
1262
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
1263
+
1264
+ if not return_dict:
1265
+ output = (logits,) + outputs[1:]
1266
+ return (loss,) + output if loss is not None else output
1267
+
1268
+ return CausalLMOutputWithPast(
1269
+ loss=loss,
1270
+ logits=logits,
1271
+ past_key_values=outputs.past_key_values,
1272
+ hidden_states=outputs.hidden_states,
1273
+ attentions=outputs.attentions,
1274
+ )
1275
+
1276
+
1277
+ @add_start_docstrings(
1278
+ """
1279
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1280
+
1281
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1282
+ (e.g. GPT-2) do.
1283
+
1284
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1285
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1286
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1287
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1288
+ each row of the batch).
1289
+ """,
1290
+ LLAMA_START_DOCSTRING,
1291
+ )
1292
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1293
+ def __init__(self, config):
1294
+ super().__init__(config)
1295
+ self.num_labels = config.num_labels
1296
+ self.model = LlamaModel(config)
1297
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1298
+
1299
+ # Initialize weights and apply final processing
1300
+ self.post_init()
1301
+
1302
+ def get_input_embeddings(self):
1303
+ return self.model.embed_tokens
1304
+
1305
+ def set_input_embeddings(self, value):
1306
+ self.model.embed_tokens = value
1307
+
1308
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1309
+ def forward(
1310
+ self,
1311
+ input_ids: Optional[torch.LongTensor] = None,
1312
+ attention_mask: Optional[torch.Tensor] = None,
1313
+ position_ids: Optional[torch.LongTensor] = None,
1314
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1315
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1316
+ labels: Optional[torch.LongTensor] = None,
1317
+ use_cache: Optional[bool] = None,
1318
+ output_attentions: Optional[bool] = None,
1319
+ output_hidden_states: Optional[bool] = None,
1320
+ return_dict: Optional[bool] = None,
1321
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1322
+ r"""
1323
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1324
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1325
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1326
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1327
+ """
1328
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1329
+
1330
+ transformer_outputs = self.model(
1331
+ input_ids,
1332
+ attention_mask=attention_mask,
1333
+ position_ids=position_ids,
1334
+ past_key_values=past_key_values,
1335
+ inputs_embeds=inputs_embeds,
1336
+ use_cache=use_cache,
1337
+ output_attentions=output_attentions,
1338
+ output_hidden_states=output_hidden_states,
1339
+ return_dict=return_dict,
1340
+ )
1341
+ hidden_states = transformer_outputs[0]
1342
+ logits = self.score(hidden_states)
1343
+
1344
+ if input_ids is not None:
1345
+ batch_size = input_ids.shape[0]
1346
+ else:
1347
+ batch_size = inputs_embeds.shape[0]
1348
+
1349
+ if self.config.pad_token_id is None and batch_size != 1:
1350
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1351
+ if self.config.pad_token_id is None:
1352
+ sequence_lengths = -1
1353
+ else:
1354
+ if input_ids is not None:
1355
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1356
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1357
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1358
+ sequence_lengths = sequence_lengths.to(logits.device)
1359
+ else:
1360
+ sequence_lengths = -1
1361
+
1362
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1363
+
1364
+ loss = None
1365
+ if labels is not None:
1366
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1367
+
1368
+ if not return_dict:
1369
+ output = (pooled_logits,) + transformer_outputs[1:]
1370
+ return ((loss,) + output) if loss is not None else output
1371
+
1372
+ return SequenceClassifierOutputWithPast(
1373
+ loss=loss,
1374
+ logits=pooled_logits,
1375
+ past_key_values=transformer_outputs.past_key_values,
1376
+ hidden_states=transformer_outputs.hidden_states,
1377
+ attentions=transformer_outputs.attentions,
1378
+ )
1379
+
1380
+
1381
+ @add_start_docstrings(
1382
+ """
1383
+ The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
1384
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1385
+ """,
1386
+ LLAMA_START_DOCSTRING,
1387
+ )
1388
+ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1389
+ base_model_prefix = "transformer"
1390
+
1391
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
1392
+ def __init__(self, config):
1393
+ super().__init__(config)
1394
+ self.transformer = LlamaModel(config)
1395
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1396
+
1397
+ # Initialize weights and apply final processing
1398
+ self.post_init()
1399
+
1400
+ def get_input_embeddings(self):
1401
+ return self.transformer.embed_tokens
1402
+
1403
+ def set_input_embeddings(self, value):
1404
+ self.transformer.embed_tokens = value
1405
+
1406
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1407
+ def forward(
1408
+ self,
1409
+ input_ids: Optional[torch.LongTensor] = None,
1410
+ attention_mask: Optional[torch.FloatTensor] = None,
1411
+ position_ids: Optional[torch.LongTensor] = None,
1412
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1413
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1414
+ start_positions: Optional[torch.LongTensor] = None,
1415
+ end_positions: Optional[torch.LongTensor] = None,
1416
+ output_attentions: Optional[bool] = None,
1417
+ output_hidden_states: Optional[bool] = None,
1418
+ return_dict: Optional[bool] = None,
1419
+ **kwargs,
1420
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1421
+ r"""
1422
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1423
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1424
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1425
+ are not taken into account for computing the loss.
1426
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1427
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1428
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1429
+ are not taken into account for computing the loss.
1430
+ """
1431
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1432
+
1433
+ outputs = self.transformer(
1434
+ input_ids,
1435
+ attention_mask=attention_mask,
1436
+ position_ids=position_ids,
1437
+ past_key_values=past_key_values,
1438
+ inputs_embeds=inputs_embeds,
1439
+ output_attentions=output_attentions,
1440
+ output_hidden_states=output_hidden_states,
1441
+ return_dict=return_dict,
1442
+ )
1443
+
1444
+ sequence_output = outputs[0]
1445
+
1446
+ logits = self.qa_outputs(sequence_output)
1447
+ start_logits, end_logits = logits.split(1, dim=-1)
1448
+ start_logits = start_logits.squeeze(-1).contiguous()
1449
+ end_logits = end_logits.squeeze(-1).contiguous()
1450
+
1451
+ loss = None
1452
+ if start_positions is not None and end_positions is not None:
1453
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1454
+
1455
+ if not return_dict:
1456
+ output = (start_logits, end_logits) + outputs[2:]
1457
+ return ((loss,) + output) if loss is not None else output
1458
+
1459
+ return QuestionAnsweringModelOutput(
1460
+ loss=loss,
1461
+ start_logits=start_logits,
1462
+ end_logits=end_logits,
1463
+ hidden_states=outputs.hidden_states,
1464
+ attentions=outputs.attentions,
1465
+ )
1466
+
1467
+
1468
+ @add_start_docstrings(
1469
+ """
1470
+ The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1471
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1472
+ """,
1473
+ LLAMA_START_DOCSTRING,
1474
+ )
1475
+ class LlamaForTokenClassification(LlamaPreTrainedModel):
1476
+ def __init__(self, config):
1477
+ super().__init__(config)
1478
+ self.num_labels = config.num_labels
1479
+ self.model = LlamaModel(config)
1480
+ if getattr(config, "classifier_dropout", None) is not None:
1481
+ classifier_dropout = config.classifier_dropout
1482
+ elif getattr(config, "hidden_dropout", None) is not None:
1483
+ classifier_dropout = config.hidden_dropout
1484
+ else:
1485
+ classifier_dropout = 0.1
1486
+ self.dropout = nn.Dropout(classifier_dropout)
1487
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1488
+
1489
+ # Initialize weights and apply final processing
1490
+ self.post_init()
1491
+
1492
+ def get_input_embeddings(self):
1493
+ return self.model.embed_tokens
1494
+
1495
+ def set_input_embeddings(self, value):
1496
+ self.model.embed_tokens = value
1497
+
1498
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1499
+ @add_code_sample_docstrings(
1500
+ checkpoint=_CHECKPOINT_FOR_DOC,
1501
+ output_type=TokenClassifierOutput,
1502
+ config_class=_CONFIG_FOR_DOC,
1503
+ )
1504
+ def forward(
1505
+ self,
1506
+ input_ids: Optional[torch.LongTensor] = None,
1507
+ attention_mask: Optional[torch.Tensor] = None,
1508
+ position_ids: Optional[torch.LongTensor] = None,
1509
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1510
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1511
+ labels: Optional[torch.LongTensor] = None,
1512
+ use_cache: Optional[bool] = None,
1513
+ output_attentions: Optional[bool] = None,
1514
+ output_hidden_states: Optional[bool] = None,
1515
+ return_dict: Optional[bool] = None,
1516
+ ) -> Union[Tuple, TokenClassifierOutput]:
1517
+ r"""
1518
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1519
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1520
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1521
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1522
+ """
1523
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1524
+
1525
+ outputs = self.model(
1526
+ input_ids,
1527
+ attention_mask=attention_mask,
1528
+ position_ids=position_ids,
1529
+ past_key_values=past_key_values,
1530
+ inputs_embeds=inputs_embeds,
1531
+ use_cache=use_cache,
1532
+ output_attentions=output_attentions,
1533
+ output_hidden_states=output_hidden_states,
1534
+ return_dict=return_dict,
1535
+ )
1536
+ sequence_output = outputs[0]
1537
+ sequence_output = self.dropout(sequence_output)
1538
+ logits = self.score(sequence_output)
1539
+
1540
+ loss = None
1541
+ if labels is not None:
1542
+ loss = self.loss_function(logits, labels, self.config)
1543
+
1544
+ if not return_dict:
1545
+ output = (logits,) + outputs[2:]
1546
+ return ((loss,) + output) if loss is not None else output
1547
+
1548
+ return TokenClassifierOutput(
1549
+ loss=loss,
1550
+ logits=logits,
1551
+ hidden_states=outputs.hidden_states,
1552
+ attentions=outputs.attentions,
1553
+ )